In [1]:
%load_ext autoreload
%autoreload 2

# model based on: https://arxiv.org/abs/1611.07308

from torch_geometric.datasets import QM9
import torch_geometric.transforms as T
import torch
from torch_geometric.loader import DataLoader
from data_utils import SelectQM9TargetProperties, create_qm9_data_split, SelectQM9NodeFeatures
from tqdm import tqdm

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

transform = T.Compose([
    SelectQM9TargetProperties(properties=["homo", "lumo"]),
    SelectQM9NodeFeatures(features=["atom_type"]),
    T.ToDevice(device=device)
])

dataset = QM9(root="./data", transform=transform)

train_dataset, val_dataset, test_dataset = create_qm9_data_split(dataset=dataset)

num_node_features = dataset.num_node_features

In [2]:
batch_size = 128

dataloaders = {
    "train_single": DataLoader(train_dataset[:1], batch_size=batch_size, shuffle=True),
    "train_tiny": DataLoader(train_dataset[:16], batch_size=batch_size, shuffle=True),
    "train_small": DataLoader(train_dataset[:4096], batch_size=batch_size, shuffle=True),
    "train": DataLoader(train_dataset, batch_size=batch_size, shuffle=True),

    "val_small": DataLoader(val_dataset[:512], batch_size=batch_size, shuffle=False),
    "val": DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
}

In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, global_mean_pool

class Encoder(torch.nn.Module):
    def __init__(self, num_node_features: int, num_targets: int):
        super().__init__()
        
        conv_features = 64

        self.conv1 = GCNConv(num_node_features, conv_features)
        self.conv2 = GCNConv(conv_features, conv_features)
        self.conv3 = GCNConv(conv_features, conv_features)
        self.fc1 = nn.Linear(conv_features, conv_features)
        self.fc2 = nn.Linear(conv_features, num_targets)

    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch

        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.conv2(x, edge_index)
        x = F.relu(x)
        x = self.conv3(x, edge_index)
        # x = global_mean_pool(x, batch)
        # x = self.fc1(x)
        # x = F.relu(x)
        # x = self.fc2(x)
        return x

In [4]:
from torch_geometric.nn.models import GAE
from data_utils import create_tensorboard_writer

# The largest molecule in the QM9 dataset contains 29 atoms
latent_size = 29
# TODO: use a VAE instead of a AE
gvae_model = GAE(encoder=Encoder(num_node_features=num_node_features, num_targets=latent_size)).to(device)

learning_rate = 5e-3
epochs = 1000

optimizer = torch.optim.Adam(gvae_model.parameters(), lr=learning_rate)

writer = create_tensorboard_writer(experiment_name="graph-vae")

train_loader = dataloaders["train_single"]

for epoch in range(epochs):
    # Training
    gvae_model.train()
    for batch_index, train_batch in enumerate(tqdm(train_loader, desc=f"Epoch {epoch + 1} Training")):
        optimizer.zero_grad()
        z = gvae_model(train_batch)
        train_loss = gvae_model.recon_loss(z=z, pos_edge_index=train_batch.edge_index)
        train_loss.backward()
        optimizer.step()

        iteration = len(train_loader) * epoch + batch_index
        writer.add_scalars("Loss", {"Training": train_loss.item()}, iteration)

Epoch 1 Training:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 1 Training: 100%|██████████| 1/1 [00:00<00:00,  1.61it/s]
Epoch 2 Training: 100%|██████████| 1/1 [00:00<00:00, 114.98it/s]
Epoch 3 Training: 100%|██████████| 1/1 [00:00<00:00, 123.89it/s]
Epoch 4 Training: 100%|██████████| 1/1 [00:00<00:00, 124.06it/s]
Epoch 5 Training: 100%|██████████| 1/1 [00:00<00:00, 112.15it/s]
Epoch 6 Training: 100%|██████████| 1/1 [00:00<00:00, 106.61it/s]
Epoch 7 Training: 100%|██████████| 1/1 [00:00<00:00, 92.85it/s]
Epoch 8 Training: 100%|██████████| 1/1 [00:00<00:00, 96.91it/s]
Epoch 9 Training: 100%|██████████| 1/1 [00:00<00:00, 101.65it/s]
Epoch 10 Training: 100%|██████████| 1/1 [00:00<00:00, 100.45it/s]
Epoch 11 Training: 100%|██████████| 1/1 [00:00<00:00, 86.13it/s]
Epoch 12 Training: 100%|██████████| 1/1 [00:00<00:00, 94.26it/s]
Epoch 13 Training: 100%|██████████| 1/1 [00:00<00:00, 93.39it/s]
Epoch 14 Training: 100%|██████████| 1/1 [00:00<00:00, 91.25it/s]
Epoch 15 Training: 100%|██████████| 1/1 [00:00<00:00, 97.21it/s]
Epoch 16 Training: 100%|███

In [10]:
from torch_geometric.utils import to_dense_adj, add_self_loops

decoded_adj_mat = torch.Tensor
decoded_adj_list = []

orig_adj_mat: torch.Tensor

for batch_index, train_batch in enumerate(tqdm(train_loader, desc=f"Epoch {epoch + 1} Training")):
    gvae_model.eval()
    z = gvae_model(train_batch)
    edge_index_with_loops = add_self_loops(train_batch.edge_index)
    orig_adj_mat = to_dense_adj(edge_index_with_loops).int()
    print(z.shape)
    model_output = gvae_model.decoder.forward_all(z)

    decoded_adj_mat = torch.where(model_output < 0.5, 0, 1)
    break

print("Original adjacency matrix:")
print(orig_adj_mat)
print()
print("Generated adjacency matrix:")
print(decoded_adj_mat)

Epoch 1000 Training:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 1000 Training:   0%|          | 0/1 [00:00<?, ?it/s]


AttributeError: 'tuple' object has no attribute 'numel'