In [None]:
#| default_exp gan
from nbdev.showdoc import *
import numpy as np
from tqdm.notebook import trange, tqdm
import matplotlib.pyplot as plt
import torch
import torch_geometric
%load_ext autoreload
%autoreload 2

# Directed Graph Embedding by GAN
> Learning a flow field that, when sampled, matches a class of directed graphs

In [None]:
import torch.nn as nn
import torch.nn.functional as F
from FEARFRED.generator import *
from FEARFRED.discriminator import *
from FEARFRED.graph_builder import *
class FRED_A_GAN(nn.Module):
    def __init__(self,
                 intrinsic_dimension, # intrinsic dimension of data
                 n_nodes, # number of nodes in directed subgraphs
                 n_features, # number of features per node
                ):
        super().__init__()
        self.intrinsic_dimension = intrinsic_dimension
        self.n_nodes = n_nodes
        self.n_features = n_features + 1 # we include an extra vector of 1s when feeding to the discriminator
        # initialize GAN machineru
        self.generator = FlowGenerator(self.intrinsic_dimension,n_features) # generates the actual number of featuers
        self.discriminator = ScatteringDiscriminator(self.n_nodes,self.n_features)
        
    def generate_fake(self):
        # 1. sample from the unit hypercube
        # TODO: Could adjust this to sample from parts of the hypercube
        # to get localized subgraphs
        samples = torch.randn(self.n_nodes,self.intrinsic_dimension)
        # 2. Translate to a sample in the embedding space, and take flows
        points, flows, features = self.generator(samples)
        # 3. construct a directed graph based off of these points and flows, 
        # and create summary node features for it
        A = flashlight_kernel(points,flows,kernel_type='fixed', sigma=0.7)
        # simplify graph with this nonlinearity
        A[A<0.01] = 0
        # TODO: Might have to revise that.
        node_features = torch.ones(self.n_nodes,self.n_features).float()
        node_features[:,1:] = features
        return A, node_features
    
    def train_critic(self, A, features):
        # generate a fake image and run it through the discriminator
        fakeA, fake_features = self.generate_fake()
        # detach gradients when training critic, to prevent unnecessary backprop graph construction
        fakeA = fakeA.detach()
        fake_features = fake_features.detach()
        witness_of_fake = self.discriminator(fakeA,fake_features)
        
        # Test the critic on real data
        node_features = torch.ones(self.n_nodes,self.n_features).float()
        node_features[:,1:] = features
        # Run through discriminator and compute loss
        witness_of_real = self.discriminator(A,node_features)
        # Loss is the difference between the witness function of fake and real
        # The critic wants to maximize this difference
        loss = witness_of_fake - witness_of_real
        return loss
        
    def train_generator(self,A,features):
        # 1. sample from the unit hypercube
        # TODO: Could adjust this to sample from parts of the hypercube
        # to get localized subgraphs
        fakeA, fake_features = self.generate_fake()
        # 4. Run the graph and its features through the discriminator
        witness_of_fake = self.discriminator(fakeA,fake_features)
        # generator wants to minimize the witness function on its data
        loss = - witness_of_fake
        return loss
        

Testing that it runs on fake data

In [None]:
A = torch.rand(10,10)
features = torch.rand(10,2)
fredtest = FRED_A_GAN(
    intrinsic_dimension = 2,
    n_nodes = 10,
    n_features = 2
)

In [None]:
l_g = fredtest.train_generator(A,features)

In [None]:
l_d = fredtest.train_critic(A,features)

# Training

In [None]:
n_epochs = 500
n_critic = 5 # number of times to train the critic for each training iteration of the generator
weight_clipping_value = 0.01
fred = FRED_A_GAN(
    intrinsic_dimension = 2,
    n_nodes = 128,
    n_features = 1
)
opt_gen = torch.optim.Adam(fred.generator.parameters())
opt_discrim = torch.optim.Adam(fred.discriminator.parameters())

In [None]:
from FEARFRED.datasets.manifolds import DirectedCircle
from torch.utils.data import DataLoader
d = DirectedCircle()
dataloader = DataLoader(d, batch_size=1)

In [None]:
for e in trange(n_epochs):
    i = 0
    for A, features in tqdm(dataloader):
        i += 1
        # shape wrangling: presently each batch has but a single matrix and list of features
        A = A[0].float()
        features = features[0][:,None] # reshape to n_nodes x n_features
        features = features.float()
        
        opt_discrim.zero_grad()
        loss = fred.train_critic(A, features)
        loss.backward()
        opt_discrim.step()
        
        for p in fred.discriminator.parameters():
            p.data.clamp_(-weight_clipping_value, weight_clipping_value)

        if i % n_critic == 0:
            # Train generator
            opt_gen.zero_grad()
            loss = fred.train_generator(A, features)
            loss.backward()
            opt_gen.step()

  0%|          | 0/500 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

[E thread_pool.cpp:113] Exception in thread pool task: mutex lock failed: Invalid argument
[E thread_pool.cpp:113] Exception in thread pool task: mutex lock failed: Invalid argument
[E thread_pool.cpp:113] Exception in thread pool task: mutex lock failed: Invalid argument


KeyboardInterrupt: 

In [None]:
U = torch.rand(10,2)
u = torch.rand(10)

In [None]:
features.dtype

In [None]:
A