In [1]:
import math
import tqdm
import torch
import gpytorch
from matplotlib import pyplot as plt

Make Data

In [4]:
%load_ext autoreload
%autoreload 2

import malt
import torch
import seaborn as sns
from malt.molecule import Molecule

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [15]:
from malt.data.collections import esol

data = esol()
data.shuffle()
data_tr, data_te = data.split([8, 2])

Processing dgl graphs from scratch...
Processing molecule 1000/1128


In [68]:
train_x, train_y = data_tr.batch()
test_x, test_y = data_te.batch()

if torch.cuda.is_available():
    train_x, train_y, test_x, test_y = train_x.to('cuda:0'), train_y.cuda().ravel(), test_x.to('cuda:0'), test_y.cuda().ravel()

In [69]:
# initialize
feature_extractor = malt.models.representation.DGLRepresentation(
        out_features=128,
).cuda()

# feature_extractor = LargeFeatureExtractor()

In [70]:
class GPRTest(gpytorch.models.ExactGP):
    def __init__(self, train_x, train_y, likelihood):
        super(GPRTest, self).__init__(train_x, train_y, likelihood)
        self.mean_module = gpytorch.means.ConstantMean()
        self.covar_module = gpytorch.kernels.ScaleKernel(
            gpytorch.kernels.RBFKernel(ard_num_dims=128)
        )

    def forward(self, projected_x):
        mean_x = self.mean_module(projected_x)
        covar_x = self.covar_module(projected_x)
        return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)

In [71]:
class ModularModel(gpytorch.models.GP):
    def __init__(self, representation, regressor, likelihood, train_y):
        """
            Args:
                - gp (gpytorch.models.ExactGP): A GP that expects to operate on features extracted by a GCN.
                - gcn (torch.nn.Module): Some PyTorch module that extracts graph features.
                - train_x (any input to gcn): The training data as expected by the GCN
                - train_y (torch.Tensor): Training labels
        """
        super(ModularModel, self).__init__()
        self.representation = representation
        self.regressor = regressor
        self.likelihood = likelihood

        self.train_targets = train_y
    
    def forward(self, train_x):
        train_h = self.representation(train_x)
        if self.training:
            self.regressor.set_train_data(train_h, self.train_targets, strict=False)
        return self.regressor(train_h)

In [72]:
class IntegratedModel(gpytorch.models.ExactGP):
    def __init__(self, train_x, train_y, likelihood):
        super(IntegratedModel, self).__init__(train_x, train_y, likelihood)
        self.mean_module = gpytorch.means.ConstantMean()
        self.covar_module = (
            # gpytorch.kernels.GridInterpolationKernel(
            gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel())
            # num_dims=128, grid_size=100
        )
        self.feature_extractor = feature_extractor

        # This module will scale the NN features so that they're nice values
        # self.scale_to_bounds = gpytorch.utils.grid.ScaleToBounds(-1., 1.)

    def forward(self, x):
        # We're first putting our data through a deep net (feature extractor)
        projected_x = self.feature_extractor(x)
        # projected_x = self.scale_to_bounds(projected_x)  # Make the NN values "nice"

        mean_x = self.mean_module(projected_x)
        covar_x = self.covar_module(projected_x)
        return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)

In [73]:
def make_model_and_likelihood(architecture='modular'):
    likelihood = gpytorch.likelihoods.GaussianLikelihood()
    
    if architecture == 'modular':
        gprt = GPRTest(torch.ones(3), train_y, likelihood)
        # the requirement to give train_y is unnecessary with a smarter training loop
        model = ModularModel(feature_extractor, gprt, likelihood, train_y)
    else:
        model = IntegratedModel(train_x, train_y, likelihood)

    if torch.cuda.is_available():
        model = model.cuda()
        likelihood = likelihood.cuda()
    
    return model, likelihood

In [74]:
likelihood = gpytorch.likelihoods.GaussianLikelihood()
gprt = GPRTest(torch.ones(3), train_y, likelihood)

In [80]:
def train_model(model, likelihood, train_x, train_y, smoke_test=False):
    def _train():
        # iterator = tqdm.notebook.tqdm(range(training_iterations))
        for i in range(training_iterations):
            # Zero backprop gradients
            optimizer.zero_grad()
            # Get output from model
            output = model(train_x)
            # Calc loss and backprop derivatives
            loss = -mll(output, train_y).mean()
            loss.backward()
            optimizer.step()
        return model

    smoke_test = False
    training_iterations = 2 if smoke_test else 200

    # Find optimal model hyperparameters
    model.train()
    likelihood.train()

    # Use the adam optimizer
    optimizer = torch.optim.Adam([
        {'params': model.parameters()},
    ], lr=1e-3)

    # "Loss" for GPs - the marginal log likelihood
    mll = gpytorch.mlls.ExactMarginalLogLikelihood(likelihood, model)

    # train
    model = _train()
    
    return model

def test(model):
    from scipy.stats import pearsonr
    model.eval()
    with torch.no_grad(), gpytorch.settings.use_toeplitz(False), gpytorch.settings.fast_pred_var():
        preds = model(test_x)
    mae = torch.mean(torch.abs(preds.mean - test_y)).item()
    corr, _ = pearsonr(preds.mean.tolist(), test_y.tolist())
    return mae, corr

results = []
for architecture in ['modular']:
    for i in tqdm.tqdm(range(4)):
        model, likelihood = make_model_and_likelihood(architecture=architecture)
        model = train_model(model, likelihood, train_x, train_y, smoke_test=False)
        mae, corr = test(model)
        results.append({'architecture': architecture, 'mae': mae, 'corr': corr})

100%|█████████████████████████████████████████████| 4/4 [01:28<00:00, 22.22s/it]


In [81]:
import pandas as pd
pd.DataFrame(results).groupby('architecture').agg(('mean', 'std'))

Unnamed: 0_level_0,mae,mae,corr,corr
Unnamed: 0_level_1,mean,std,mean,std
architecture,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2
modular,0.520256,0.047595,0.937639,0.003729
