In [4]:
%load_ext autoreload
%autoreload 2

from torch_geometric.datasets import QM9
import torch_geometric.transforms as T
import torch
from torch_geometric.loader import DataLoader
from data_utils import SelectQM9TargetProperties, create_qm9_data_split, SelectQM9NodeFeatures
from tqdm import tqdm

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

transform = T.Compose([
    SelectQM9TargetProperties(properties=["homo", "lumo"]),
    SelectQM9NodeFeatures(features=["atom_type"]),
    T.ToDevice(device=device)
])

dataset = QM9(root="./data", transform=transform)

train_dataset, val_dataset, test_dataset = create_qm9_data_split(dataset=dataset)

num_node_features = dataset.num_node_features

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [5]:
batch_size = 128

dataloaders = {
    "train_single": DataLoader(train_dataset[:1], batch_size=batch_size, shuffle=True),
    "train_tiny": DataLoader(train_dataset[:16], batch_size=batch_size, shuffle=True),
    "train_small": DataLoader(train_dataset[:4096], batch_size=batch_size, shuffle=True),
    "train": DataLoader(train_dataset, batch_size=batch_size, shuffle=True),

    "val_small": DataLoader(val_dataset[:512], batch_size=batch_size, shuffle=False),
    "val": DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
}

In [10]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, global_mean_pool

class Encoder(torch.nn.Module):
    def __init__(self, num_node_features: int, num_targets: int):
        super().__init__()
        
        conv_features = 32

        self.conv1 = GCNConv(num_node_features, conv_features)
        self.conv2 = GCNConv(conv_features, conv_features)
        self.conv3 = GCNConv(conv_features, conv_features)
        self.fc1 = nn.Linear(conv_features, conv_features)
        self.fc2 = nn.Linear(conv_features, num_targets)

    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch

        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.conv2(x, edge_index)
        x = F.relu(x)
        x = self.conv3(x, edge_index)
        # x = global_mean_pool(x, batch)
        # x = self.fc1(x)
        # x = F.relu(x)
        # x = self.fc2(x)
        return x

In [13]:
from torch_geometric.nn.models import GAE
from data_utils import create_tensorboard_writer

# The largest molecule in the QM9 dataset contains 29 atoms
latent_size = 29
# TODO: use a VAE instead of a AE
gvae_model = GAE(encoder=Encoder(num_node_features=num_node_features, num_targets=latent_size)).to(device)

learning_rate = 2e-2
epochs = 1000

optimizer = torch.optim.Adam(gvae_model.parameters(), lr=learning_rate)

writer = create_tensorboard_writer(experiment_name="graph-vae")

train_loader = dataloaders["train_single"]

for epoch in range(epochs):
    # Training
    gvae_model.train()
    for batch_index, train_batch in enumerate(tqdm(train_loader, desc=f"Epoch {epoch + 1} Training")):
        optimizer.zero_grad()
        z = gvae_model(train_batch)
        train_loss = gvae_model.recon_loss(z=z, pos_edge_index=train_batch.edge_index)
        train_loss.backward()
        optimizer.step()

        iteration = len(train_loader) * epoch + batch_index
        writer.add_scalars("Loss", {"Training": train_loss.item()}, iteration)

Epoch 1 Training: 100%|██████████| 1/1 [00:00<00:00, 57.96it/s]
Epoch 2 Training:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 2 Training: 100%|██████████| 1/1 [00:00<00:00, 69.35it/s]
Epoch 3 Training: 100%|██████████| 1/1 [00:00<00:00, 66.17it/s]
Epoch 4 Training: 100%|██████████| 1/1 [00:00<00:00, 67.79it/s]
Epoch 5 Training: 100%|██████████| 1/1 [00:00<00:00, 64.60it/s]
Epoch 6 Training: 100%|██████████| 1/1 [00:00<00:00, 64.68it/s]
Epoch 7 Training: 100%|██████████| 1/1 [00:00<00:00, 70.68it/s]
Epoch 8 Training: 100%|██████████| 1/1 [00:00<00:00, 76.82it/s]
Epoch 9 Training: 100%|██████████| 1/1 [00:00<00:00, 80.05it/s]
Epoch 10 Training: 100%|██████████| 1/1 [00:00<00:00, 78.23it/s]
Epoch 11 Training: 100%|██████████| 1/1 [00:00<00:00, 90.79it/s]
Epoch 12 Training: 100%|██████████| 1/1 [00:00<00:00, 84.61it/s]
Epoch 13 Training: 100%|██████████| 1/1 [00:00<00:00, 93.22it/s]
Epoch 14 Training: 100%|██████████| 1/1 [00:00<00:00, 81.58it/s]
Epoch 15 Training: 100%|██████████| 1/1 [00:00<00:00, 89.24it/s]
Epoch 16 Training: 100%|██████████| 1/1 [00:00<00:00, 85.58it/s]
Epoch 17 Training: 100%|█████████

In [16]:
for batch_index, train_batch in enumerate(tqdm(train_loader, desc=f"Epoch {epoch + 1} Training")):
    gvae_model.eval()
    z = gvae_model(train_batch)
    print(train_batch)
    print(z)
    print(z.shape)
    break

#gvae_model.decode(torch.randn(size=(32))

Epoch 10000 Training:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 10000 Training:   0%|          | 0/1 [00:00<?, ?it/s]

DataBatch(x=[18, 5], edge_index=[2, 38], edge_attr=[38, 4], y=[1, 2], pos=[18, 3], z=[18], smiles=[1], name=[1], idx=[1], batch=[18], ptr=[2])
tensor([[ 2.3708e-01, -4.8608e-01, -4.0586e-01, -3.9573e-01, -4.4528e-01,
         -1.2985e+00,  2.8601e-01,  3.0678e-01,  4.1070e-01,  1.1347e-01,
          2.1962e-01,  3.2730e-01, -3.8460e-01,  5.0672e-02,  2.1393e-01,
          4.3009e-01,  3.5713e-01,  2.9876e-01, -2.7696e-01,  7.7895e-03,
         -3.8089e-01, -3.2763e-01, -8.2730e-01, -6.6511e-01,  5.4423e-02,
          1.2898e+00, -6.5748e-01,  3.8531e-01,  4.2064e-01, -2.4082e-01,
          2.3006e-01, -2.3438e-01],
        [ 4.2737e-01,  1.8127e+00,  1.3518e-04,  4.2073e+00, -4.5588e-01,
         -5.6491e+00, -1.7241e-01, -5.0779e+00, -3.0670e+00,  2.8359e+00,
          2.3903e+00,  6.5121e-01,  5.4190e+00,  1.2863e+00,  2.1112e-01,
          1.7446e+00,  1.5870e+00, -2.1043e+00, -1.1925e+00, -2.4880e+00,
          5.3405e+00,  4.9642e+00, -2.4916e+00, -1.2628e+00, -8.5201e-02,
       


