In [2]:
%load_ext autoreload
%autoreload 2

from torch_geometric.datasets import QM9
import torch_geometric.transforms as T
import torch
from torch_geometric.loader import DataLoader
from data_utils import SelectQM9TargetProperties, create_qm9_data_split, SelectQM9NodeFeatures

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

transform = T.Compose([
    SelectQM9TargetProperties(properties=["homo", "lumo"]),
    SelectQM9NodeFeatures(features=["atom_type"]),
    T.ToDevice(device=device)
])

dataset = QM9(root="./data", transform=transform)

train_dataset, val_dataset, test_dataset = create_qm9_data_split(dataset=dataset)

batch_size = 128

dataloaders = {
    "train_single": DataLoader(train_dataset[:1], batch_size=batch_size, shuffle=True),
    "train_tiny": DataLoader(train_dataset[:16], batch_size=batch_size, shuffle=True),
    "train_small": DataLoader(train_dataset[:4096], batch_size=batch_size, shuffle=True),
    "train": DataLoader(train_dataset, batch_size=batch_size, shuffle=True),

    "val_tiny": DataLoader(val_dataset[:4], batch_size=batch_size, shuffle=False),
    "val_small": DataLoader(val_dataset[:512], batch_size=batch_size, shuffle=False),
    "val": DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
}
num_node_feature = dataset.num_node_features
num_targets = dataset.num_classes

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


ImportError: attempted relative import with no known parent package

In [63]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, global_mean_pool

class GCN(torch.nn.Module):
    def __init__(self, num_node_features: int, num_targets: int):
        super().__init__()
        
        conv_features = 32

        self.conv1 = GCNConv(num_node_features, conv_features)
        self.conv2 = GCNConv(conv_features, conv_features)
        self.conv3 = GCNConv(conv_features, conv_features)
        self.fc1 = nn.Linear(conv_features, conv_features)
        self.fc2 = nn.Linear(conv_features, num_targets)

    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch

        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.conv2(x, edge_index)
        x = F.relu(x)
        x = self.conv3(x, edge_index)
        x = global_mean_pool(x, batch)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        return x

In [68]:
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
from torchmetrics import MeanAbsoluteError

model = GCN(num_node_features=num_node_feature, num_targets=num_targets).to(device=device)

learning_rate = 5e-4
epochs = 100

optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
loss_function = nn.MSELoss()
metric_function = MeanAbsoluteError().to(device=device)

writer = SummaryWriter()

train_loader = dataloaders["train"]
val_loader = dataloaders["val"]

val_interval = 2

for epoch in range(epochs):
    # Training
    model.train()
    epoch_train_loss = 0
    for batch in tqdm(train_loader, desc=f"Epoch {epoch + 1} Training"):
        optimizer.zero_grad()
        model_output = model(batch)
        loss = loss_function(model_output, batch.y)
        loss.backward()
        optimizer.step()

        epoch_train_loss += loss.item()

    epoch_train_loss /= len(train_loader)

    # Validation
    if epoch % val_interval == 0:
        model.eval()
        epoch_val_loss = 0
        with torch.no_grad():
            epoch_metric = 0
            for batch in tqdm(val_loader, desc=f"Epoch {epoch + 1} Validation"):
                model_output = model(batch)
                loss = loss_function(model_output, batch.y)
                
                epoch_val_loss += loss.item()
                epoch_metric += metric_function(model_output, batch.y)  
            
            epoch_val_loss /= len(val_loader)
            epoch_metric /= len(val_loader)
        
        writer.add_scalar("Metric", epoch_metric, epoch)
        writer.add_scalars("Loss", {"Validation": epoch_val_loss}, epoch)

    writer.add_scalars("Loss", {"Training": epoch_train_loss}, epoch)

# TODO: validate every n iterations with random validation subset
# TODO: add baseline mean prediction

Epoch 1 Training: 100%|██████████| 818/818 [00:42<00:00, 19.40it/s]
Epoch 1 Validation: 100%|██████████| 103/103 [00:05<00:00, 17.47it/s]
Epoch 2 Training: 100%|██████████| 818/818 [00:30<00:00, 26.62it/s]
Epoch 3 Training: 100%|██████████| 818/818 [00:30<00:00, 26.72it/s]
Epoch 3 Validation: 100%|██████████| 103/103 [00:03<00:00, 29.20it/s]
Epoch 4 Training: 100%|██████████| 818/818 [00:31<00:00, 25.98it/s]
Epoch 5 Training: 100%|██████████| 818/818 [00:30<00:00, 26.75it/s]
Epoch 5 Validation: 100%|██████████| 103/103 [00:03<00:00, 28.74it/s]
Epoch 6 Training: 100%|██████████| 818/818 [00:31<00:00, 25.94it/s]
Epoch 7 Training: 100%|██████████| 818/818 [00:30<00:00, 26.87it/s]
Epoch 7 Validation: 100%|██████████| 103/103 [00:03<00:00, 29.05it/s]
Epoch 8 Training: 100%|██████████| 818/818 [00:30<00:00, 26.58it/s]
Epoch 9 Training: 100%|██████████| 818/818 [00:30<00:00, 26.52it/s]
Epoch 9 Validation: 100%|██████████| 103/103 [00:03<00:00, 29.39it/s]
Epoch 10 Training: 100%|██████████| 81

KeyboardInterrupt: 