# Imports

In [1]:
import torch.nn as nn
from torch_geometric.nn import RGCNConv
from torch_geometric.datasets import ZINC
from torch_geometric.utils import to_dense_adj
import torch
from torch.utils.data import DataLoader, TensorDataset
from tqdm import tqdm

# RealNVP

## Fully connected neural network for the base network

In [2]:
class FCNN(nn.Module):
    def __init__(self, in_dim, out_dim, hidden_dim):
        super().__init__()
        self.network = nn.Sequential(
            nn.Linear(in_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, out_dim),
        )
        
    def forward(self, x):
        return self.network(x)

## Flow

In [3]:
class RealNVP(nn.Module):
    def __init__(self, dim, hidden_dim=8, base_network=FCNN):
        super().__init__()
        self.dim = dim
        self.t1 = base_network(dim // 2, dim // 2, hidden_dim)
        self.s1 = base_network(dim // 2, dim // 2, hidden_dim)
        self.t2 = base_network(dim // 2, dim // 2, hidden_dim)
        self.s2 = base_network(dim // 2, dim // 2, hidden_dim)

    def forward(self, x):
        lower, upper = x[:, :self.dim // 2], x[:, self.dim // 2:]
        t1_transformed = self.t1(lower)
        s1_transformed = self.s1(lower)
        upper = t1_transformed + upper * torch.exp(s1_transformed)
        t2_transformed = self.t2(upper)
        s2_transformed = self.s2(upper)
        lower = t2_transformed + lower * torch.exp(s2_transformed)
        z = torch.cat([lower, upper], dim=1)
        log_det = torch.sum(s1_transformed, dim=1) + torch.sum(s2_transformed, dim=1)
        return z, log_det

    def inverse(self, z):
        lower, upper = z[:, :self.dim // 2], z[:, self.dim // 2:]
        t2_transformed = self.t2(upper)
        s2_transformed = self.s2(upper)
        lower = (lower - t2_transformed) * torch.exp(-s2_transformed)
        t1_transformed = self.t1(lower)
        s1_transformed = self.s1(lower)
        upper = (upper - t1_transformed) * torch.exp(-s1_transformed)
        x = torch.cat([lower, upper], dim=1)
        log_det = torch.sum(-s1_transformed, dim=1) + torch.sum(-s2_transformed, dim=1)
        return x, log_det

# GAN

## Generator

In [101]:
class Generator(nn.Module):
    def __init__(
        self,
        input_dim,
        num_nodes,
        num_features, 
        num_edge_types,
    ):
        super().__init__()
        self.input_dim = input_dim
        self.flow_model = RealNVP(input_dim)
        self.num_nodes = num_nodes
        self.num_features = num_features
        self.num_edge_types = num_edge_types
        
        self.adj = nn.Linear(input_dim, num_nodes * num_nodes * (num_edge_types + 1))
        self.feat = nn.Linear(input_dim, num_nodes * num_features)
    def forward(self, x):
        z, _ = self.flow_model(x)
        # z = x
        
        adj = self.adj(z).view(x.shape[0], self.num_nodes, self.num_nodes, self.num_edge_types + 1)
        adj = torch.nn.functional.softmax(adj, -1)
        # adj[torch.where(adj == adj.max(-1))] == 1
        # adj[torch.where(adj != adj.max(-1))] == 0
        # adj[:, :, :, 0] = adj[:, :, :, 0] - 1
        
        feat = self.feat(z).view(x.shape[0], self.num_nodes, self.num_features)
        feat = torch.nn.functional.softmax(feat, -1)
        
        
        return adj, feat

## Discriminator

### GCN

In [195]:
class GCNConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.W = nn.Parameter(torch.rand(in_channels, out_channels, requires_grad=True))
    def forward(self, A, X):
        A_hat = A + torch.eye(A.size(1))
        D = A_hat.sum(2).diag_embed()
        D = D.inverse().sqrt()
        A_hat = D.bmm(A_hat).bmm(D)
        out = A_hat.bmm(X).matmul(self.W)
        return out

### Model

In [196]:
class Discriminator(nn.Module):
    def __init__(self, num_features, num_edge_types):
        super().__init__()
        # Use Relational GCN to support multiple edge types - same layer as MolGAN paper
        # self.gcn = RGCNConv(num_features, 1, num_edge_types)
        self.gcns = nn.ModuleList([
            GCNConv(num_features, 1) for _ in range(num_edge_types)
        ])
        # self.gcns = nn.ModuleList([
        #     nn.Linear(num_features, 1) for _ in range(num_edge_types)
        # ])
        self.activation = torch.nn.Sigmoid()
    def forward(self, adj, x):
        # x, edge_index, edge_type = data['x'], data['edge_index'], data['edge_type']
        # out = self.gcn(x, edge_index, edge_type)
        # out = self.gcns[0](adj[:, :, :, 1], x)
        out = torch.zeros(adj.size(0), 1)
        for edge_type, gcn in enumerate(self.gcns):
            out += gcn(adj[:, :, :, edge_type], x).mean(1)
            # out += gcn(x).mean(1)
        return self.activation(out)
        # return out.mean(1)

# Train

## Load dataset

In [7]:
dataset = ZINC('../dataset')

### For now, filter molecules to have the same number of atoms

In [8]:
num_nodes = {}
for mol in dataset:
    if mol.num_nodes not in num_nodes:
        num_nodes[mol.num_nodes] = 0
    num_nodes[mol.num_nodes] += 1

# Select size which has the most samples in the dataset
most_samples = sorted(num_nodes, key=num_nodes.get, reverse=True)[0]

# Filter
dataset = list(filter(lambda mol: mol.num_nodes == most_samples, dataset))

### Convert atom type to one-hot vector

In [9]:
# Find out how many types of atoms we have in the dataset
num_node_types = 0
for mol in dataset:
    num_node_types = max(int(mol.x.max()), num_node_types)
num_node_types += 1

for mol in dataset:
    new_x = torch.zeros(mol.num_nodes, num_node_types)
    for i, atom_type in enumerate(mol.x):
        new_x[i] = torch.eye(num_node_types)[atom_type]
    mol.x = new_x

### Convert from torch geometric data to torch tensor

In [10]:
# An adjacency matrix where adj_ij = 1 if i and j don't have a bond
real_connections = torch.stack([to_dense_adj(mol.edge_index)[0] for mol in dataset])

# Adjacency matrix for a type specific bonds (each type of bond has its own adjacency matrix)
real_type_specific_connections = {edge_type: [] for edge_type in range(1, 3 + 1)}
for mol in tqdm(dataset):
    for edge_type in range(1, 3 + 1):
        try:
            real_type_specific_connections[edge_type].append(to_dense_adj(mol.edge_index[:, torch.where(mol.edge_attr == edge_type)[0]], max_num_nodes=mol.num_nodes)[0])
        except:
            real_type_specific_connections[edge_type].append(torch.zeros(mol.num_nodes, mol.num_nodes))
real_type_specific_connections = {
    edge_type: torch.stack(real_type_specific_connections[edge_type])
    for edge_type in range(1, 1 + 3)
}

100%|██████████| 20444/20444 [01:36<00:00, 210.84it/s]


In [11]:
real_feats = torch.stack([mol.x for mol in dataset])
real_adj = torch.stack([real_connections, *[real_type_specific_connections[edge_type] for edge_type in range(1, 3 + 1)]], dim=-1)
real_dataset = TensorDataset(real_adj, real_feats)

## Initialize models

In [197]:
DISTRIBUTION_DIM = 64
MOL_SIZE = dataset[0].num_nodes
NUM_FEATURES = dataset[0].num_node_features
NUM_EDGE_TYPES = 3

In [198]:
# Gaussian distribution with mean=0 and std=1 as prior distribution
prior = torch.distributions.MultivariateNormal(torch.zeros(DISTRIBUTION_DIM), torch.eye(DISTRIBUTION_DIM))

# Generator
generator = Generator(DISTRIBUTION_DIM, MOL_SIZE, NUM_FEATURES, NUM_EDGE_TYPES)

# Discriminator
discriminator = Discriminator(NUM_FEATURES, NUM_EDGE_TYPES)

### An example - generator receives 10 gaussian samples and generates 10 graphs, discriminator decides whether these 10 samples are real or not

In [210]:
adj, feat = generator(prior.sample((10,)))
discriminator(adj, feat)

tensor([[0.8338],
        [0.8289],
        [0.8296],
        [0.8323],
        [0.8387],
        [0.8325],
        [0.8276],
        [0.8383],
        [0.8356],
        [0.8314]], grad_fn=<SigmoidBackward0>)

In [194]:
adj[0, 0, 5, :]

tensor([1.0487e-10, 7.8916e-13, 1.0000e+00, 8.5668e-09],
       grad_fn=<SliceBackward0>)

## Training loop

In [181]:
generator_optimizer = torch.optim.RMSprop(generator.parameters())
discriminator_optimizer = torch.optim.RMSprop(discriminator.parameters())

BATCH_SIZE = 1

real_dataloader = iter(DataLoader(real_dataset, shuffle=True, batch_size=BATCH_SIZE))

for epoch in range(100):
    ############ Generator ############
    generator.train()
    generator_optimizer.zero_grad()
    
    # Generate fake samples
    fake_adj, fake_feat = generator(prior.sample((BATCH_SIZE,)))
    
    # Pass fake samples to discriminator
    out = discriminator(fake_adj, fake_feat)
    
    # Get the loss
    ground_truth = torch.ones(BATCH_SIZE, 1)
    generator_loss = torch.nn.functional.l1_loss(out, ground_truth)

    # Back propagation
    generator_loss.backward()
    generator_optimizer.step()

    
    ############ Discriminator ############
    discriminator.train()
    discriminator_optimizer.zero_grad()
    
    
    # Pass real samples to the discriminator
    real_adj, real_feat = next(real_dataloader)
    out = discriminator(real_adj, real_feat)
    # Get the loss
    ground_truth = torch.ones(BATCH_SIZE, 1)
    discriminator_loss = torch.nn.functional.l1_loss(out, ground_truth)
    
    
    # Pass fake samples to the discriminator
    fake_adj, fake_feat = generator(prior.sample((BATCH_SIZE,)))
    out = discriminator(fake_adj, fake_feat)
    # Get the loss
    ground_truth = torch.zeros(BATCH_SIZE, 1)
    discriminator_loss += torch.nn.functional.l1_loss(out, ground_truth)
    
    # Back propagation
    discriminator_loss.backward()
    discriminator_optimizer.step()
    
    
    print(f'Epoch {epoch} - Generator loss: {float(generator_loss):.3f}\tDiscriminator loss: {float(discriminator_loss):.3f}')

Epoch 0 - Generator loss: 0.479	Discriminator loss: 1.110
Epoch 1 - Generator loss: 0.608	Discriminator loss: 1.018
Epoch 2 - Generator loss: 0.673	Discriminator loss: 0.915
Epoch 3 - Generator loss: 0.626	Discriminator loss: 0.878
Epoch 4 - Generator loss: 0.626	Discriminator loss: 0.767
Epoch 5 - Generator loss: 0.460	Discriminator loss: 0.963
Epoch 6 - Generator loss: 0.676	Discriminator loss: 0.742
Epoch 7 - Generator loss: 0.674	Discriminator loss: 0.992
Epoch 8 - Generator loss: 0.501	Discriminator loss: 0.814
Epoch 9 - Generator loss: 0.311	Discriminator loss: 1.109
Epoch 10 - Generator loss: 0.340	Discriminator loss: 0.783
Epoch 11 - Generator loss: 0.358	Discriminator loss: 1.017
Epoch 12 - Generator loss: 0.369	Discriminator loss: 1.010
Epoch 13 - Generator loss: 0.417	Discriminator loss: 0.760
Epoch 14 - Generator loss: 0.430	Discriminator loss: 0.979
Epoch 15 - Generator loss: 0.435	Discriminator loss: 0.723
Epoch 16 - Generator loss: 0.472	Discriminator loss: 0.904
Epoch 1