# Train a SchNet on the JClinic dataset

> Train a simple SchNet model to predict the JClinic labels

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
#| hide
from nbdev.showdoc import *

In [None]:
import logging
import os
import random
import time
import math

import numpy as np
import torch
import torch.nn as nn

from prody import parsePDB

from torch_geometric.loader import DataLoader

from jclinic.data import create_raw_dataset, JClinicDataset, make_train_val_split_clustering_by_rmsd
from jclinic.models import SchNet
from jclinic.pairwise_rmsd import make_rmsds_matrix


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

Use the logging module to store output in a log file for easy reference while printing it to the screen.

In [None]:
os.makedirs('./tmp_out/jclinic/log', exist_ok=True)
logger = logging.getLogger('Training a SchNet model with vanilla PyTorch')
logger.propagate = False
logger.setLevel(logging.DEBUG)
console_handler = logging.StreamHandler()
timeticks = time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime())
file_handler = logging.FileHandler(
    os.path.join('./tmp_out/jclinic/log', f'{timeticks}.log'))
logger.addHandler(console_handler)
logger.addHandler(file_handler)

Set the random number seed in all modules to guarantee the same result when running again.

In [None]:
def seed_everything(seed):
    random.seed(seed)
    np.random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

seed_everything(42)

In [None]:
structures_dir = "../data/structures"
fixed_structures_dir = "../data/structures_fixed"
pretrained_esm_model = "esm2_t12_35M_UR50D"
embeddings_dir = f"../data/embeddings_fixed_{pretrained_esm_model}"
labels_path = "../data/labels.txt"

dataset_dir = f"../data/pyg_dataset_{pretrained_esm_model}"

## Fix the PDB files using the scripts in `jclinic.fix_pdb`

In [None]:
!python -m jclinic.fix_pdb $structures_dir $fixed_structures_dir --rename_chains

## Extract the AA sequences and compute their ESM-2 embeddings

In [None]:
!python -m jclinic.sequence_embeddings $fixed_structures_dir $embeddings_dir --pretrained_esm_model=$pretrained_esm_model

## Create an `InMemoryDataset` instance for PyTorch Geometric

The dataset contains the 3D coordinates of all Ca atoms (`pos`), the per-residue ESM-2 embeddings (`esm_embeddings`), and the target labels for prediction (`y`).

In [None]:
create_raw_dataset(
    fixed_structures_dir,
    embeddings_dir,
    labels_path,
    dataset_dir
)
dataset = JClinicDataset(dataset_dir)

esm_embedding_dim = dataset.esm_embeddings.shape[-1]
print(f"Dimension of ESM-2 embeddings = {esm_embedding_dim}")

## Create a "hard" train-validation split to avoid data leakage

Structures in the validation set should not be too close in 3D structure (and sequence) to those in the training set, to avoid inflating results.

In [None]:
parsed_structures_prody = {
    data.name: parsePDB(f"{fixed_structures_dir}/{data.name}.pdb") for data in dataset
}

rmsds_matrix = make_rmsds_matrix(parsed_structures_prody)
rmsds_matrix_finite = rmsds_matrix.copy()
rmsds_matrix_finite[rmsds_matrix_finite.isna()] = 2 * rmsds_matrix_finite.max(axis=None)

In practice, we cluster structures using complete linkage according to their pairwise RMSD (if available), and using a distance cutoff. Then, each cluster is either included entirely in the training set or entirely in the validation set.

In [None]:
clustering_cutoff = 8
train_frac = 0.8

train_idxs, val_idxs = make_train_val_split_clustering_by_rmsd(
    rmsds_matrix_finite, clustering_cutoff, train_frac=train_frac
)

In [None]:
train_dataset = dataset.copy()
train_dataset = train_dataset[train_idxs]

val_dataset = dataset.copy()
val_dataset = val_dataset[val_idxs]

print(f"Training dataset size: {len(train_dataset)}")
print(f"Validation dataset size: {len(val_dataset)}")

## Define the training function

In [None]:
def train(
    train_loader, validation_loader, model, loss_fn, optimizer, epochs=50
):
    size_train = len(train_loader.dataset)
    size_val = len(validation_loader.dataset)

    for epoch in range(epochs):
        # Train
        model.train()
        num_batches = len(train_loader)
        train_loss = 0
        for data in train_loader:
            n_samples = len(data)
            pos = data.pos.to(device)
            esm_embeddings = data.esm_embeddings.to(device)
            batch = data.batch
            y = data.y.to(device)
    
            # Compute prediction error
            pred = model(esm_embeddings, pos, batch)
            # print(f"Train: {pred}, {y}")
            loss = loss_fn(pred, y)
            train_loss += loss.item() * n_samples
    
            # Backpropagate
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        train_loss /= size_train

        # Validate
        model.eval()
        val_loss = 0
        with torch.no_grad():
            for data in validation_loader:
                n_samples = len(data)
                pos = data.pos.to(device)
                esm_embeddings = data.esm_embeddings.to(device)
                batch = data.batch
                y = data.y.to(device)
    
                pred = model(esm_embeddings, pos, batch)
                # print(f"Val: {pred}, {y}")
                val_loss += loss_fn(pred, y).item() * n_samples
        val_loss /= size_val

        logger.info(
            f"Epoch {epoch}: Training error = {math.sqrt(train_loss):.3f}, "
            f"Validation error = {math.sqrt(val_loss):.3f}"
        )

## Configure and train

Given the small dataset size, we set hyperparamters using the following considerations:
1. The number of hidden channels in the `SchNet` should be smaller than the default of 128 to avoid overparametrization and overfitting
2. Ditto as above for the number of filters in the `SchNet`
3. Possibly ditto as above for the number of Gaussians

In [None]:
batch_size = 1

hidden_channels = 1
num_filters = 2
num_interactions = 2
num_gaussians = 50
atom_distance_cutoff = 20
readout = "sum"  # This is crucial

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

model = SchNet(
    esm_embedding_dim=esm_embedding_dim,
    hidden_channels=hidden_channels,
    num_filters=num_filters,
    num_interactions=num_interactions,
    num_gaussians=num_gaussians,
    cutoff=atom_distance_cutoff,
    readout=readout,
).to(device)
total_params = sum(p.numel() for p in model.parameters())
print(f"Number of model parameters: {total_params}")

loss_fn = nn.MSELoss()

optimizer = torch.optim.AdamW(
    model.parameters(), lr=1e-3, weight_decay=1e-4
)

In [None]:
train(train_loader, val_loader, model, loss_fn, optimizer, epochs=1000)