# Imports

In [1]:
import torch.nn as nn
from torch_geometric.nn import RGCNConv
from torch_geometric.datasets import ZINC
import torch
from tqdm import tqdm

# RealNVP

## Fully connected neural network for the base network

In [2]:
class FCNN(nn.Module):
    def __init__(self, in_dim, out_dim, hidden_dim):
        super().__init__()
        self.network = nn.Sequential(
            nn.Linear(in_dim, hidden_dim),
            nn.Tanh(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.Tanh(),
            nn.Linear(hidden_dim, out_dim),
        )
        
    def forward(self, x):
        return self.network(x)

## Flow

In [3]:
class RealNVP(nn.Module):
    def __init__(self, dim, hidden_dim=8, base_network=FCNN):
        super().__init__()
        self.dim = dim
        self.t1 = base_network(dim // 2, dim // 2, hidden_dim)
        self.s1 = base_network(dim // 2, dim // 2, hidden_dim)
        self.t2 = base_network(dim // 2, dim // 2, hidden_dim)
        self.s2 = base_network(dim // 2, dim // 2, hidden_dim)

    def forward(self, x):
        lower, upper = x[:, :self.dim // 2], x[:, self.dim // 2:]
        t1_transformed = self.t1(lower)
        s1_transformed = self.s1(lower)
        upper = t1_transformed + upper * torch.exp(s1_transformed)
        t2_transformed = self.t2(upper)
        s2_transformed = self.s2(upper)
        lower = t2_transformed + lower * torch.exp(s2_transformed)
        z = torch.cat([lower, upper], dim=1)
        log_det = torch.sum(s1_transformed, dim=1) + torch.sum(s2_transformed, dim=1)
        return z, log_det

    def inverse(self, z):
        lower, upper = z[:, :self.dim // 2], z[:, self.dim // 2:]
        t2_transformed = self.t2(upper)
        s2_transformed = self.s2(upper)
        lower = (lower - t2_transformed) * torch.exp(-s2_transformed)
        t1_transformed = self.t1(lower)
        s1_transformed = self.s1(lower)
        upper = (upper - t1_transformed) * torch.exp(-s1_transformed)
        x = torch.cat([lower, upper], dim=1)
        log_det = torch.sum(-s1_transformed, dim=1) + torch.sum(-s2_transformed, dim=1)
        return x, log_det

# GAN

## Generator

In [4]:
class Generator(nn.Module):
    def __init__(
        self,
        input_dim,
        num_nodes,
        num_features, 
        num_edge_types,
    ):
        super().__init__()
        self.input_dim = input_dim
        self.flow_model = RealNVP(input_dim)
        self.num_nodes = num_nodes
        self.num_features = num_features
        self.num_edge_types = num_edge_types
        
        self.adj = nn.Linear(input_dim, num_nodes * num_nodes * num_edge_types)
        self.feat = nn.Linear(input_dim, num_nodes * num_features)
    def forward(self, x):
        z, _ = self.flow_model(x)
        
        adj = self.adj(z).view(x.shape[0], self.num_nodes, self.num_nodes, self.num_edge_types)
        adj = torch.nn.functional.gumbel_softmax(adj)
        
        feat = self.feat(z).view(x.shape[0], self.num_nodes, self.num_features)
        feat = torch.nn.functional.gumbel_softmax(feat)
        
        
        return adj, feat

## Discriminator

In [5]:
class Discriminator(nn.Module):
    def __init__(self, num_features, num_edge_types):
        super().__init__()
        # Use Relational GCN to support multiple edge types - same layer as MolGAN paper
        self.gcn = RGCNConv(num_features, 1, num_edge_types)
        self.activation = torch.nn.Sigmoid()
    def forward(self, data):
        x, edge_index, edge_type = data['x'], data['edge_index'], data['edge_type']
        out = self.gcn(x, edge_index, edge_type)
        return self.activation(out)

# Train

## Load dataset

In [6]:
dataset = ZINC('../dataset')

### For now, filter molecules to have the same number of atoms

In [7]:
num_nodes = {}
for mol in dataset:
    if mol.num_nodes not in num_nodes:
        num_nodes[mol.num_nodes] = 0
    num_nodes[mol.num_nodes] += 1

# Select size which has the most samples in the dataset
most_samples = sorted(num_nodes, key=num_nodes.get, reverse=True)[0]

# Filter
dataset = list(filter(lambda mol: mol.num_nodes == most_samples, dataset))

### Convert atom type to one-hot vector

In [8]:
# Find out how many types of atoms we have in the dataset
num_node_types = 0
for mol in dataset:
    num_node_types = max(int(mol.x.max()), num_node_types)
num_node_types += 1

for mol in dataset:
    new_x = torch.zeros(mol.num_nodes, num_node_types)
    for i, atom_type in enumerate(mol.x):
        new_x[i] = torch.eye(num_node_types)[atom_type]
    mol.x = new_x

## Initialize models

In [9]:
DISTRIBUTION_DIM = 64
MOL_SIZE = dataset[0].num_nodes
NUM_FEATURES = dataset[0].num_node_features
NUM_EDGE_TYPES = 3

In [10]:
# Gaussian distribution with mean=0 and std=1 as prior distribution
prior = torch.distributions.MultivariateNormal(torch.zeros(DISTRIBUTION_DIM), torch.eye(DISTRIBUTION_DIM))

# Generator
generator = Generator(DISTRIBUTION_DIM, MOL_SIZE, NUM_FEATURES, NUM_EDGE_TYPES)

# Discriminator
discriminator = Discriminator(NUM_FEATURES, NUM_EDGE_TYPES)

In [11]:
adj, feat = generator(prior.sample((10,)))

# TODO: Convert adj to edge_index without breaking the computation graph