The goal of this notebook is to show how to transform datapoints and use for inference

In [2]:
import numpy as np

# Fixed or default parameters
fixed_equi = False
pocket_d_equi = 1 if fixed_equi else 64
pocket_d_inv = 256
pocket_n_layers = 4

# Model hyperparameters
d_model = 384
n_layers = 12
d_message = 128
d_edge = 128
n_coord_sets = 64
n_attn_heads = 32
d_message_hidden = 128
self_condition = False
# Vocabulary and bond types
n_extra_atom_feats = 1
n_res_types = 21

PLINDER_STD_DEV = 2.2693647416252976
PLINDER_BUCKET_LIMITS = [
    96,
    125,
    149,
    166,
    179,
    189,
    199,
    208,
    216,
    223,
    231,
    239,
    248,
    258,
    269,
    283,
    300,
    324,
    377,
    978
]

: 

In [None]:
%load_ext autoreload
%autoreload 2

import torch
import cgflow.util.rdkit as smolRD
import cgflow.util.functional as smolF

def mol_transform(molecule, vocab, n_bonds, coord_std, shift_com_std=0.0):
    rotation = tuple(np.random.rand(3) * np.pi * 2)
    molecule = molecule.scale(1.0 / coord_std).rotate(rotation).zero_com()

    if shift_com_std > 0.0:
        molecule = molecule.shift(np.random.normal(0, shift_com_std, 3))

    return molecule._copy_with(
        atomics=smolF.atomics_to_index(molecule.atomics, vocab),
        charges=smolF.charge_to_index(molecule.charges),
    )


def complex_transform(complex, vocab, n_bonds, coord_std, fix_pos=False):
    # Perform the transformation on the ligand
    if not fix_pos:
        rotation = tuple(np.random.rand(3) * np.pi * 2)
        complex = complex.scale(1.0 / coord_std).rotate(rotation).zero_holo_com()

    transformed_ligand = complex.ligand._copy_with(
        atomics=smolF.atomics_to_index(complex.ligand.atomics, vocab),
        charges=smolF.charge_to_index(complex.ligand.charges),
    )
    
    holo_mol = complex.holo.to_geometric_mol()
    transformed_holo_mol = holo_mol._copy_with(
        atomics=smolF.atomics_to_index(holo_mol.atomics, vocab),
        charges=smolF.charge_to_index(holo_mol.charges),
    )
    
    unscaled_holo = complex.holo.copy().scale(coord_std)
    return complex.copy_with(ligand=transformed_ligand, holo_mol=transformed_holo_mol, holo=unscaled_holo)

In [4]:
from pathlib import Path
from functools import partial
import cgflow.scriptutil as util
from cgflow.data.datasets import PocketComplexDataset

vocab = util.build_vocab()
n_bond_types = util.get_n_bond_types("uniform-sample")
data_path = Path("data/complex/crossdock-no-litpcba/smol")
transform = partial(complex_transform, vocab=vocab, n_bonds=n_bond_types, coord_std=PLINDER_STD_DEV)
dataset = PocketComplexDataset.load(data_path / "val.smol", transform=transform)

In [5]:
from cgflow.data.interpolate import GeometricNoiseSampler 
from cgflow.data.datamodules import GeometricInterpolantDM
from cgflow.data.interpolate import GeometricComplexInterpolant

# Noise and sampling configuration
coord_noise = "gaussian"
type_noise = "uniform-sample"  
bond_noise = "uniform-sample"
scale_ot = False
zero_com = True
type_mask_index = None  
bond_mask_index = None


# Initialize GeometricNoiseSampler
prior_sampler = GeometricNoiseSampler(
    vocab_size=vocab.size,
    n_bond_types=n_bond_types,
    coord_noise=coord_noise,
    type_noise=type_noise,
    bond_noise=bond_noise,
    scale_ot=scale_ot,
    zero_com=zero_com,
    type_mask_index=type_mask_index,
    bond_mask_index=bond_mask_index,
)

eval_interpolant = GeometricComplexInterpolant(
    prior_sampler,
    coord_interpolation='linear',
    type_interpolation='no-change',
    bond_interpolation='no-change',
    equivariant_ot=False,
    batch_ot=False
)

dm = GeometricInterpolantDM(
    None,
    None,
    dataset,
    1000,
    test_interpolant=eval_interpolant,
    bucket_limits=PLINDER_BUCKET_LIMITS,
    bucket_cost_scale='linear',
    pad_to_bucket=False
)

In [None]:
from cgflow.models.pocket import LigandGenerator, PocketEncoder

# Initialize PocketEncoder
pocket_enc = PocketEncoder(
    d_equi=pocket_d_equi,
    d_inv=pocket_d_inv,
    d_message=d_message,
    n_layers=pocket_n_layers,
    n_attn_heads=n_attn_heads,
    d_message_ff=d_message_hidden,
    d_edge=d_edge,
    n_atom_names=vocab.size,
    n_bond_types=n_bond_types,
    n_res_types=n_res_types,
    fixed_equi=fixed_equi
)

# Initialize LigandGenerator
egnn_gen = LigandGenerator(
    d_equi=n_coord_sets,
    d_inv=d_model,
    d_message=d_message,
    n_layers=n_layers,
    n_attn_heads=n_attn_heads,
    d_message_ff=d_message_hidden,
    d_edge=d_edge,
    n_atom_types=vocab.size,
    n_bond_types=n_bond_types,
    n_extra_atom_feats=n_extra_atom_feats,
    self_cond=self_condition,
    pocket_enc=pocket_enc
).cuda()

In [7]:
test_dl = dm.test_dataloader()
for batch in test_dl:
    prior, data, interpolated, pockets, _, t = batch
    break

for k, v in data.items():
    print(k, v.shape)

coords torch.Size([10, 31, 3])
atomics torch.Size([10, 31])
bonds torch.Size([10, 31, 31])
charges torch.Size([10, 31])
residues torch.Size([10, 31])
mask torch.Size([10, 31])


In [8]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Move data to device
def to_device(*tensors):
    return [t.to(device) for t in tensors]

coords, atom_types, bond_types, mask = to_device(
    data['coords'], data['atomics'], data['bonds'], data['mask']
)
pocket_coords, pocket_atom_names, pocket_atom_charges, pocket_res_types, pocket_bond_types, pocket_atom_mask = to_device(
    pockets['coords'], pockets['atomics'], pockets['charges'],
    pockets['residues'], pockets['bonds'], pockets['mask']
)
ligand_times = t.view(-1, 1, 1).expand(-1, coords.shape[1], -1).to(device)

# Run model
output = egnn_gen(
    coords, atom_types, bond_types, atom_mask=mask, extra_feats=ligand_times,
    pocket_coords=pocket_coords, pocket_atom_names=pocket_atom_names,
    pocket_atom_charges=pocket_atom_charges, pocket_res_types=pocket_res_types,
    pocket_bond_types=pocket_bond_types, pocket_atom_mask=pocket_atom_mask
)
