In [1]:
%load_ext autoreload
%autoreload 2

## `test_supervised_model`: `def test_forward():`

In [7]:
import torch
import malt

net = malt.models.supervised_model.SimpleSupervisedModel(
    representation=malt.models.representation.DGLRepresentation(
        out_features=128
    ),
    regressor=malt.models.regressor.NeuralNetworkRegressor(
        in_features=128, out_features=1
    ),
    likelihood=malt.models.likelihood.HomoschedasticGaussianLikelihood(),
)

point = malt.Molecule(smiles="C").featurize()
distribution = net.condition(point.g)

assert isinstance(distribution, torch.distributions.Distribution)
assert distribution.batch_shape == torch.Size([1, 1])

net.loss(point.g, torch.Tensor([0.0])).backward()

## `test_gp`: `def test_gp_shape():`

In [18]:
import torch
import dgl
import malt

net = malt.models.supervised_model.GaussianProcessSupervisedModel(
    representation=malt.models.representation.DGLRepresentation(
        out_features=128
    ),
    regressor=malt.models.regressor.ExactGaussianProcessRegressor(
        in_features=128,
        out_features=2,
    ),
    likelihood=malt.models.likelihood.HeteroschedasticGaussianLikelihood(),
)

if torch.cuda.is_available():
    net.cuda()

dataset = malt.data.collections.linear_alkanes(10)
dataset_loader = dataset.view(batch_size=len(dataset))
g, y = next(iter(dataset_loader))

loss = net.loss(g, y)

y_hat = net.condition(g)
assert y_hat.mean.shape[0] == 10
assert len(y_hat.mean.shape) == 1

## `test_gp`: `def test_gp_integrate():`

In [3]:
player.model.cuda()

GaussianProcessSupervisedModel(
  (representation): DGLRepresentation(
    (embedding_in): Sequential(
      (0): Linear(in_features=74, out_features=128, bias=True)
      (1): SiLU()
    )
    (gn0): GraphConv(in=128, out=128, normalization=both, activation=None)
    (gn1): GraphConv(in=128, out=128, normalization=both, activation=None)
    (gn2): GraphConv(in=128, out=128, normalization=both, activation=None)
    (embedding_out): Sequential(
      (0): Linear(in_features=128, out_features=128, bias=True)
    )
    (ff): Sequential(
      (0): Linear(in_features=128, out_features=128, bias=True)
    )
    (activation): SiLU()
  )
  (regressor): ExactGaussianProcessRegressor(
    (kernel): RBF()
  )
  (likelihood): HeteroschedasticGaussianLikelihood()
)

In [4]:
import malt
import torch
from malt.agents.player import SequentialModelBasedPlayer

dataset = malt.data.collections.linear_alkanes(10)
model = malt.models.supervised_model.GaussianProcessSupervisedModel(
   representation=malt.models.representation.DGLRepresentation(
       out_features=128
   ),
   regressor=malt.models.regressor.ExactGaussianProcessRegressor(
       in_features=128, out_features=2,
   ),
   likelihood=malt.models.likelihood.HeteroschedasticGaussianLikelihood(),
)
if torch.cuda.is_available():
    model.cuda()

player = SequentialModelBasedPlayer(
   model = model,
   policy=malt.policy.Greedy(),
   trainer=malt.trainer.get_default_trainer(),
   merchant=malt.agents.merchant.DatasetMerchant(dataset),
   assayer=malt.agents.assayer.DatasetAssayer(dataset),
)

while True:
    if player.step() is None:
        break