In [1]:
import numpy as np
import wandb
import torch
from torch_geometric.loader import DataLoader

import XAIChem

In [15]:
wandb.login()

[34m[1mwandb[0m: Currently logged in as: [33mxwieme[0m ([33mmlchem[0m). Use [1m`wandb login --relogin`[0m to force relogin


True

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## Load data

In [3]:
train_data = XAIChem.Dataset(root="../data", name="ESOL", tag="train")
val_data = XAIChem.Dataset(root="../data", name="ESOL", tag="val")
test_data = XAIChem.Dataset(root="../data", name="ESOL", tag="test")

In [4]:
train_loader = DataLoader(train_data, batch_size=256)
val_loader = DataLoader(val_data, batch_size=256)
test_loader = DataLoader(test_data, batch_size=256)

## Model setup

In [5]:
def train(loader, model, criterion, optimzer):
    """
    Perform one epoch of the training step
    """
    model.train()

    # Iterate through the batches
    for i, data in enumerate(loader):
        data.to(device)
        
        out = model(data.x, data.edge_index, data.edge_type, data.batch)
        loss = criterion(out, data.y.view(-1, 1))
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

In [6]:
def evaluate(loader, model, criterion):
    model.eval()

    # Log average loss of batches
    losses = torch.zeros(len(loader))

    # Save all predictions and labels to compute the 
    # Pearson correlation of the whole data set
    predictions = torch.zeros(len(loader.dataset))
    labels = torch.zeros(len(loader.dataset))

    index = 0
    for i, data in enumerate(loader):
        data.to(device)
        
        pred = model(data.x, data.edge_index, data.edge_type, data.batch)
        losses[i] = criterion(pred, data.y.view(-1, 1))

        new_index = index + data.batch_size
        predictions[index:new_index] = pred.view(1, -1)
        labels[index:new_index] = data.y
        index = new_index

    predictions -= torch.mean(predictions)
    labels -= torch.mean(labels)

    pearson_corr = torch.sum(predictions * labels) / (
        torch.sqrt(torch.sum(predictions**2)) *
        torch.sqrt(torch.sum(labels**2))
    )

    return pearson_corr, torch.mean(losses)

In [7]:
model = XAIChem.RGCN(num_node_features=35).to(device)
criterion = torch.nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.003)

## Train model

In [None]:
for epoch in range(500):

    train(train_loader, model, criterion, optimizer)

    train_pearson_corr, train_loss = evaluate(train_loader, model, criterion)
    val_pearson_corr, val_loss = evaluate(val_loader, model, criterion)
    test_pearson_corr, test_loss = evaluate(test_loader, model, criterion)

    print(train_loss)

## Substructure mask explanation