In [24]:
from models.gan_model import *
import torch
from torch_geometric.utils import to_dense_adj
from torch_geometric.datasets import TUDataset
from torch.utils.data import TensorDataset, DataLoader

In [25]:
root = './enzymes'
name = 'ENZYMES'

# The ENZYMES dataset
pyg_dataset = TUDataset(root, name)

# You will find that there are 600 graphs in this dataset
print(pyg_dataset)

ENZYMES(600)


In [26]:
def fix_size(graph):
    input_tensor = to_dense_adj(graph.edge_index)
    if input_tensor.shape[1] != 42:
        zeros = torch.zeros(1, input_tensor.shape[1], 42 - input_tensor.shape[1])
        tensor = torch.cat([input_tensor, zeros], dim=2)

        zeros = torch.zeros(1, 42 - input_tensor.shape[1], 42)
        tensor = torch.cat([tensor, zeros], dim=1)
        return tensor
    else:
        return input_tensor

In [27]:
graphs_relevance = [g for g in pyg_dataset if g.num_nodes in [38, 40, 42]]
graphs = torch.zeros(71, 1, 42, 42)
for i in range(len(graphs_relevance)):
    graphs[i] = fix_size(graphs_relevance[i])

In [28]:
batch_size = 8
graphs_dataset = TensorDataset(graphs)
graphs_dataloader = DataLoader(graphs_dataset, batch_size)

In [29]:
lr=0.001
beta1=0.9
beta2=0.999

In [30]:
device = torch.device('cpu')

def real_loss(D_out, smooth=False):
    batch_size = D_out.size(0)
    
    # label smoothing
    if smooth:
        labels = torch.FloatTensor(batch_size).uniform_(0.9, 1).to(device)
    else:
        labels = torch.ones(batch_size)

    labels = labels.to(device)
    criterion = nn.BCELoss()
    # calculate loss
    loss = criterion(D_out.squeeze(), labels)
    return loss

def fake_loss(D_out):
    batch_size = D_out.size(0)
    labels = torch.FloatTensor(batch_size).uniform_(0, 0.1).to(device) # fake labels = 0
    labels = labels.to(device)
    criterion = nn.BCELoss()
    # calculate loss
    loss = criterion(D_out.squeeze(), labels)
    return loss

In [31]:
# gan = GraphGAN(num_vertex=42, batch_size=batch_size)
G = Generator(100, 42)
D = Discriminator(42)
optimizer_g = torch.optim.Adam(G.parameters(), lr=lr, betas=(beta1, beta2))
optimizer_d = torch.optim.Adam(D.parameters(), lr=lr, betas=(beta1, beta2))

In [32]:
def train(dataloader, num_epochs=10, print_every=1):
    for epoch in range(num_epochs):
        g_l = 0
        d_l = 0
        for batch in dataloader:
            data = batch[0]

            optimizer_d.zero_grad()
            real_images = data.to(device)
            D_real = D(real_images)
            d_real_loss = real_loss(D_real)

            z = torch.FloatTensor(batch_size, 100).uniform_(-1, 1).to(device)
            fake_images = G(z)
        
            D_fake = D(fake_images)
            d_fake_loss = fake_loss(D_fake)
        
            d_loss = d_real_loss + d_fake_loss
            d_l += d_loss.item()
            d_loss.backward()
            optimizer_d.step()


            optimizer_g.zero_grad()

            z = torch.FloatTensor(batch_size, 100).uniform_(-1, 1).to(device)
        
            fake_images = G(z)
        
            D_fake = D(fake_images)
            g_loss = real_loss(D_fake)
            g_l += g_loss.item()
        
            g_loss.backward()
            optimizer_g.step()
        if epoch % print_every == 0:
            print("Epoch: " + str(epoch + 1) + "/" + str(num_epochs)
                  + "\td_loss:" + str(round(d_l / len(dataloader), 4))
                  + "\tg_loss:" + str(round(g_l / len(dataloader), 4))
                  )

In [33]:
train(graphs_dataloader)

Epoch: 1/10	d_loss:0.9597	g_loss:2.7789
Epoch: 2/10	d_loss:0.7108	g_loss:2.5206
Epoch: 3/10	d_loss:0.5557	g_loss:2.3791
Epoch: 4/10	d_loss:0.5555	g_loss:2.4459
Epoch: 5/10	d_loss:0.5592	g_loss:2.6503
Epoch: 6/10	d_loss:0.36	g_loss:3.1079
Epoch: 7/10	d_loss:0.2555	g_loss:3.6575
Epoch: 8/10	d_loss:0.2638	g_loss:3.8778
Epoch: 9/10	d_loss:0.3194	g_loss:4.6125
Epoch: 10/10	d_loss:0.4773	g_loss:5.1916
