In [1]:
import numpy as np
DATA_PATH = "/projects/jlab/to.shen/cgflow-dev/experiments/data/complex/plinder_15A"

# Dataset name
DATASET = "plinder"

# Number of molecules to evaluate
NUM_EVAL_MOLS = np.inf

# Number of inference steps
NUM_INFERENCE_STEPS = 100

# Whether the data involves protein-ligand complexes
IS_COMPLEX = DATASET in ["plinder", "crossdock", "zinc15m"] or False

# Create a class to simulate command line arguments
class Args:
    def __init__(self):
        pass

args = Args()

# Set required arguments
args.data_path = DATA_PATH
args.dataset = DATASET
args.n_validation_mols = NUM_EVAL_MOLS
args.num_inference_steps = NUM_INFERENCE_STEPS
args.num_gpus = 1
args.is_pseudo_complex = False
args.batch_cost = 8
args.use_complex_metrics = IS_COMPLEX
args.sampling_strategy = "linear"
args.num_workers = 0

# Model architecture parameters - these should match the trained model
args.d_model = 384
args.n_layers = 12
args.d_message = 64
args.d_edge = 128
args.n_coord_sets = 64
args.n_attn_heads = 32
args.d_message_hidden = 96
args.coord_norm = "length"
args.size_emb = 64
args.max_atoms = 256
args.pocket_n_layers = 4
args.pocket_d_inv = 256
args.fixed_equi = False

# Flow matching parameters
args.categorical_strategy = "auto-regressive"
args.conf_coord_strategy = "gaussian"
args.optimal_transport = None
args.cat_sampling_noise_level = 1
args.coord_noise_std_dev = 0.2
args.type_dist_temp = 1.0
args.time_alpha = 1.0
args.time_beta = 1.0
args.dist_loss_weight = 0.0
args.type_loss_weight = 0.0
args.bond_loss_weight = 0.0
args.charge_loss_weight = 0.0
args.monitor = "val-strain"
args.monitor_mode = "min"
args.val_check_epochs = 1


# Autoregressive parameters (only needed if model was trained with AR)
args.t_per_ar_action = 0.33  # updated
args.max_interp_time = 1.0  # updated
args.decomposition_strategy = "reaction"  # updated
args.ordering_strategy = "connected"  # updated
args.max_action_t = 0.66  # updated
args.max_num_cuts = 2  # updated
args.min_group_size = 5

# Model loading defaults
args.arch = "semla"
args.trial_run = False
args.use_ema = True
args.self_condition = True
args.lr = 0.0003
args.type_loss_weight = 0.0  # updated
args.bond_loss_weight = 0.0  # updated
args.charge_loss_weight = 0.0  # updated
args.dist_loss_weight = 0.0  # updated
args.lr_schedule = "constant"
args.warm_up_steps = 10000
args.bucket_cost_scale = "linear"
args.epochs = 1
args.acc_batches = 1
args.val_check_epochs = 1  # updated
args.gradient_clip_val = 1.0
args.monitor = "val-strain"  # updated
args.monitor_mode = "min"  # updated

args.n_training_mols = np.inf

import numpy as np

# Fixed or default parameters
fixed_equi = False
pocket_d_equi = 1 if fixed_equi else 64
pocket_d_inv = 256
pocket_n_layers = 4

# Model hyperparameters
d_model = 384
n_layers = 12
d_message = 128
d_edge = 128
n_coord_sets = 64
n_attn_heads = 32
d_message_hidden = 128
self_condition = False
# Vocabulary and bond types
n_extra_atom_feats = 1
n_res_types = 21

PLINDER_STD_DEV = 2.2693647416252976
PLINDER_BUCKET_LIMITS = [
    96,
    125,
    149,
    166,
    179,
    189,
    199,
    208,
    216,
    223,
    231,
    239,
    248,
    258,
    269,
    283,
    300,
    324,
    377,
    978
]

In [2]:
import cgflow.scriptutil as util
from cgflow.buildutil import build_dm
vocab = util.build_vocab()
dm = build_dm(
    args,
    vocab,
)

Using type ARGeometricComplexInterpolant for training


In [None]:
import pickle

pickle.dump(dm, open("dm.pkl", "wb"))

In [None]:
dm = pickle.load(open("dm.pkl", "rb"))

In [3]:
# import copy
# from cgflow.models.pocket import _PairwiseMessages, _InvariantEmbedding, SemlaLayer


# class PocketEncoder(torch.nn.Module):

#     def __init__(
#         self,
#         d_equi,
#         d_inv,
#         d_message,
#         n_layers,
#         n_attn_heads,
#         d_message_ff,
#         d_edge,
#         n_atom_names,
#         n_bond_types,
#         n_res_types,
#         n_charge_types=7,
#         emb_size=64,
#         fixed_equi=False,
#         eps=1e-6,
#     ):
#         super().__init__()

#         if fixed_equi and d_equi != 1:
#             raise ValueError(
#                 f"If fixed_equi is True d_equi must be 1, got {d_equi}")

#         self.d_equi = d_equi
#         self.d_inv = d_inv
#         self.d_message = d_message
#         self.n_layers = n_layers
#         self.n_attn_heads = n_attn_heads
#         self.d_message_ff = d_message_ff
#         self.d_edge = d_edge
#         self.emb_size = emb_size
#         self.fixed_equi = fixed_equi
#         self.eps = eps

#         # Embedding and encoding modules
#         self.inv_emb = _InvariantEmbedding(d_inv,
#                                            n_atom_names,
#                                            n_bond_types,
#                                            emb_size,
#                                            n_charge_types=n_charge_types,
#                                            n_res_types=n_res_types)
#         self.bond_emb = _PairwiseMessages(d_equi, d_inv, d_inv, d_message,
#                                           d_edge, d_message_ff, emb_size)

#         if fixed_equi is not None:
#             self.coord_emb = torch.nn.Linear(1, d_equi, bias=False)

#         # Create a stack of encoder layers
#         layer = SemlaLayer(
#             d_equi,
#             d_inv,
#             d_message,
#             n_attn_heads,
#             d_message_ff,
#             d_self_edge_in=d_edge,
#             fixed_equi=fixed_equi,
#             zero_com=False,
#             eps=eps,
#         )

#         layers = self._get_clones(layer, n_layers)
#         self.layers = torch.nn.ModuleList(layers)

#     @property
#     def hparams(self):
#         return {
#             "d_equi": self.d_equi,
#             "d_inv": self.d_inv,
#             "d_message": self.d_message,
#             "n_layers": self.n_layers,
#             "n_attn_heads": self.n_attn_heads,
#             "d_message_ff": self.d_message_ff,
#             "d_edge": self.d_edge,
#             "emb_size": self.emb_size,
#             "fixed_equi": self.fixed_equi,
#             "eps": self.eps,
#         }

#     def forward(self,
#                 coords,
#                 atom_names,
#                 atom_charges,
#                 res_types,
#                 bond_types,
#                 atom_mask=None):
#         """Encode the protein pocket into a learnable representation

#         Args:
#             coords (torch.Tensor): Coordinate tensor, shape [B, N, 3]
#             atom_names (torch.Tensor): Atom name indices, shape [B, N]
#             atom_charges (torch.Tensor): Atom charge indices, shape [B, N]
#             residue_types (torch.Tensor): Residue type indices for each atom, shape [B, N]
#             bond_types (torch.Tensor): Bond type indicies for each pair, shape [B, N, N]
#             atom_mask (torch.Tensor): Mask for atoms, shape [B, N], 1 for real atom, 0 otherwise

#         Returns:
#             tuple[torch.Tensor, torch.Tensor]: Equivariant and invariant features, [B, N, 3, d_equi] and [B, N, d_inv]
#         """

#         atom_mask = torch.ones_like(
#             coords[..., 0]) if atom_mask is None else atom_mask
#         # adj_matrix = smolF.adj_from_node_mask(atom_mask, self_connect=True)
#         adj_matrix = smolF.edges_from_nodes(coords,
#                                             k=None,
#                                             node_mask=atom_mask,
#                                             edge_format="adjacency",
#                                             self_connect=True)
        
#         coords = coords.unsqueeze(-1)
#         equis = coords if self.fixed_equi else self.coord_emb(coords)

#         invs, edges = self.inv_emb(atom_names,
#                                    bond_types,
#                                    atom_mask,
#                                    atom_charges=atom_charges,
#                                    res_types=res_types)
#         edges = self.bond_emb(equis, invs, equis, invs, edges)
#         edges = edges * adj_matrix.unsqueeze(-1)

#         for layer in self.layers:
#             equis, invs, _, _ = layer(equis, invs, edges, adj_matrix,
#                                       atom_mask)

#         return equis, invs

#     def _get_clones(self, module, n):
#         return [copy.deepcopy(module) for _ in range(n)]


In [4]:
from cgflow.models.pocket import LigandGenerator, PocketEncoder

n_bond_types = 5
# Initialize PocketEncoder
pocket_enc = PocketEncoder(
    d_equi=pocket_d_equi,
    d_inv=pocket_d_inv,
    d_message=d_message,
    n_layers=pocket_n_layers,
    n_attn_heads=n_attn_heads,
    d_message_ff=d_message_hidden,
    d_edge=d_edge,
    n_atom_names=vocab.size,
    n_bond_types=n_bond_types,
    n_res_types=n_res_types,
    fixed_equi=fixed_equi
)

# Initialize LigandGenerator
egnn_gen = LigandGenerator(
    d_equi=n_coord_sets,
    d_inv=d_model,
    d_message=d_message,
    n_layers=n_layers,
    n_attn_heads=n_attn_heads,
    d_message_ff=d_message_hidden,
    d_edge=d_edge,
    n_atom_types=vocab.size,
    n_bond_types=n_bond_types,
    n_extra_atom_feats=n_extra_atom_feats,
    self_cond=self_condition,
    pocket_enc=pocket_enc
).cuda()

In [5]:
test_dl = dm.val_dataloader()
for batch in test_dl:
    prior, data, interpolated, masked_data, pockets, pocket_raw, t, rel_times, gen_times = batch
    break

for k, v in data.items():
    print(k, v.shape)

coords torch.Size([8, 28, 3])
atomics torch.Size([8, 28])
bonds torch.Size([8, 28, 28])
charges torch.Size([8, 28])
residues torch.Size([8, 28])
mask torch.Size([8, 28])


[08:08:13] Explicit valence for atom # 0 O, 6, is greater than permitted
[08:08:13] Explicit valence for atom # 0 O, 6, is greater than permitted
[08:08:13] Explicit valence for atom # 0 O, 6, is greater than permitted


In [6]:
import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Move data to device
def to_device(*tensors):
    return [t.to(device) for t in tensors]

coords, atom_types, bond_types, mask = to_device(
    data['coords'], data['atomics'], data['bonds'], data['mask']
)
pocket_coords, pocket_atom_names, pocket_atom_charges, pocket_res_types, pocket_bond_types, pocket_atom_mask = to_device(
    pockets['coords'], pockets['atomics'], pockets['charges'],
    pockets['residues'], pockets['bonds'], pockets['mask']
)
ligand_times = t.view(-1, 1, 1).expand(-1, coords.shape[1], -1).to(device)

# Run model
output = egnn_gen(
    coords, atom_types, bond_types, atom_mask=mask, extra_feats=ligand_times,
    pocket_coords=pocket_coords, pocket_atom_names=pocket_atom_names,
    pocket_atom_charges=pocket_atom_charges, pocket_res_types=pocket_res_types,
    pocket_bond_types=pocket_bond_types, pocket_atom_mask=pocket_atom_mask
)


OutOfMemoryError: CUDA out of memory. Tried to allocate 105.38 GiB. GPU 0 has a total capacity of 44.53 GiB of which 36.40 GiB is free. Including non-PyTorch memory, this process has 8.12 GiB memory in use. Of the allocated memory 7.61 GiB is allocated by PyTorch, and 18.58 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)