# Working with `gpytorch`

Imports.

In [1]:
%load_ext autoreload
%autoreload 2

import malt
import torch
import seaborn as sns
from malt.molecule import Molecule

Using backend: pytorch


Create fluorescence dataset. (for this notebook, skipping for now).

In [2]:
# def read_data():
#     from pathlib import Path
#     f = f'{Path.home()}/dev/choderalab/data/data/moonshot_fluorescence_titration_curves.csv'
#     import pandas as pd
#     df = pd.read_csv(f, index_col=0).dropna()
#     df = df.rename({'concentration': 'c', 'inhibition': 'y'}, axis=1)
#     return df

# def parse_graph(smiles):
#     from dgllife.utils import (
#         smiles_to_bigraph, CanonicalAtomFeaturizer, CanonicalBondFeaturizer
#     )

#     return smiles_to_bigraph(
#         smiles = smiles,
#         node_featurizer = CanonicalAtomFeaturizer(),
#         edge_featurizer = CanonicalBondFeaturizer()
#     )

# def make_dataset():
#     from malt import Dataset, AssayedMolecule
#     from tqdm import tqdm
    
#     df = read_data()

#     molecules = []
#     for smiles, mol_metadata in tqdm(df.groupby('SMILES')):
#         molecule = AssayedMolecule(
#             smiles = smiles,
#             # g = parse_graph(smiles),
#             metadata = {'fluorescence': mol_metadata.drop('SMILES', axis=1).to_dict('records')}
#         )
#         molecules.append(molecule)

#     # create dataset
#     data = Dataset(molecules)
#     return data

# data = make_dataset()

In [3]:
from malt.data.collections import esol

data = esol()
g, y = data.batch(by=['g', 'y'])

Processing dgl graphs from scratch...
Processing molecule 1000/1128


Make model.

In [4]:
import torch
import gpytorch

class ExactGPModel(gpytorch.models.ExactGP):
    def __init__(self, train_x, train_y, likelihood):
        super(ExactGPModel, self).__init__(train_x, train_y, likelihood)
        self.mean_module = gpytorch.means.LinearMean(128)
        self.covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel())

    def forward(self, x):
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)
    
class GPyTorchSupervisedModel(gpytorch.models.GP):
    def __init__(self, representation, regressor, likelihood, train_x, train_y):
        """
            Args:
                - gp (gpytorch.models.ExactGP): A GP that expects to operate on features extracted by a GCN.
                - gcn (torch.nn.Module): Some PyTorch module that extracts graph features.
                - train_x (any input to gcn): The training data as expected by the GCN
                - train_y (torch.Tensor): Training labels
        """
        super(GPyTorchSupervisedModel, self).__init__()
        self.representation = representation
        self.regressor = regressor
        self.likelihood = likelihood

        self.train_x = train_x
        self.train_y = train_y
    
    def forward(self, g):
        train_h = self.representation(self.train_x)
        self.regressor.set_train_data(train_h, self.train_y)

        h = self.representation(g)
        # self.regressor.set_train_data(h, self.train_y, strict=False)
        return self.regressor(h)

In [5]:
import malt

device = 'cuda:0'
x_tr, y_tr = g.to(device), y.to(device)

# initialize
representation = malt.models.representation.DGLRepresentation(
        out_features=128,
).cuda()

# instantiate the GPyTorch model
h = representation(x_tr)

# initialize likelihood
likelihood = gpytorch.likelihoods.GaussianLikelihood().cuda()

# instantiate regressor
regressor = ExactGPModel(h, y_tr, likelihood).cuda()

# initialize model
model = GPyTorchSupervisedModel(
    representation, regressor, likelihood,
    train_x=x_tr, train_y=y_tr
)

# sns.displot(model(g.to('cuda:0')).loc.tolist())

Train

In [6]:
training_iter = 1000

# Find optimal model hyperparameters
likelihood.train()
model.train()

# Use the adam optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)  # Includes GaussianLikelihood parameters
mll = gpytorch.mlls.ExactMarginalLogLikelihood(likelihood, model)

for i in range(training_iter):
    # Zero gradients from previous iteration
    optimizer.zero_grad()
    # Calc loss and backprop gradients
    y_target = model(x_tr)
    loss = -mll(y_target, y_tr).sum()
    loss.backward()
    if i == 1:
        break
    # print('Iter %d/%d - Loss: %.3f   rep_param: %.3f   lengthscale: %.3f   noise: %.3f' % (
    #     i + 1, training_iter, loss.item(),
    #     next(model.representation.parameters())[0, 0].item(),
    #     model.regressor.covar_module.base_kernel.lengthscale.item(),
    #     model.regressor.likelihood.noise.item()
    # ))
    optimizer.step()

Debug gradients.

In [7]:
{n: p.grad.abs().mean().item() for n, p in model.named_parameters()}

{'representation.embedding_in.0.weight': 11.14016342163086,
 'representation.embedding_in.0.bias': 151.15235900878906,
 'representation.gn0.weight': 16.048030853271484,
 'representation.gn0.bias': 288.2734375,
 'representation.gn1.weight': 15.439170837402344,
 'representation.gn1.bias': 558.8292236328125,
 'representation.gn2.weight': 15.742813110351562,
 'representation.gn2.bias': 1154.57763671875,
 'representation.embedding_out.0.weight': 27.92206382751465,
 'representation.embedding_out.0.bias': 3839.130859375,
 'representation.ff.0.weight': 288.83917236328125,
 'representation.ff.0.bias': 181.6907958984375,
 'regressor.likelihood.noise_covar.raw_noise': 166.24801635742188,
 'regressor.mean_module.weights': 192.01132202148438,
 'regressor.mean_module.bias': 226.61843872070312,
 'regressor.covar_module.raw_outputscale': 1673.2318115234375,
 'regressor.covar_module.base_kernel.raw_lengthscale': 1317.110595703125}