Manually reproduce what train.py would do

In [1]:
import os
os.chdir('../../../..')

In [2]:
# turn off rdkit logging
from rdkit import RDLogger

lg = RDLogger.logger()
lg.setLevel(RDLogger.CRITICAL)

In [3]:
from argparse import Namespace

args_dict = {
    # General args
    "data_path": f'experiments/data/complex/plinder_15A',
    "dataset": 'crossdock',
    "is_pseudo_complex": False,
    "trial_run": False,
    "complex_debug": False,
    "model_checkpoint": None,
    "num_workers":  8,
    "num_gpus": 1,

    # Model args
    "d_model": 384,
    "n_layers": 12,
    "d_message": 128,
    "d_edge": 128,
    "n_coord_sets": 64,
    "n_attn_heads": 32,
    "d_message_hidden": 128,
    "coord_norm": "length",
    "size_emb": 64,
    "max_atoms": 256,
    "arch": "semla",

    # Protein model args
    "pocket_n_layers": 4,
    "fixed_equi": False,
    "pocket_d_inv": 256,

    # Training args
    "epochs": 10000,
    "val_check_epochs": 10,
    "lr": 0.0003,
    "batch_cost": 500,
    "acc_batches": 1,
    "gradient_clip_val": 1.0,
    "dist_loss_weight": 0.0,
    "type_loss_weight": 0.2,
    "bond_loss_weight": 1.0,
    "charge_loss_weight": 1.0,
    "conf_coord_strategy": "gaussian",
    "categorical_strategy": "auto-regressive",
    "lr_schedule": "constant",
    "warm_up_steps": 10000,
    "bucket_cost_scale": "linear",
    "use_ema": True,
    "self_condition": True,

    # Flow matching and sampling args
    "n_validation_mols": 2000,
    "n_training_mols": float("inf"),
    "num_inference_steps": 100,
    "cat_sampling_noise_level": 1,
    "coord_noise_std_dev": 0.2,
    "type_dist_temp": 1.0,
    "time_alpha": 2.0,
    "time_beta": 1.0,
    "optimal_transport": "None",

    # Autoregressive args
    "t_per_ar_action": 0.3,
    "max_interp_time": 0.4,
    "decomposition_strategy": "reaction",
    "ordering_strategy": "connected",
    "max_action_t": 0.6,
    "max_num_cuts": 2,
    "min_group_size": 5,

    # Logging args
    "monitor": "val-validity",
    "monitor_mode": "max",
    "use_complex_metrics": False,
}


args = Namespace(**args_dict)

In [4]:
import cgflow.scriptutil as util
from cgflow.buildutil import build_dm, build_model

In [5]:
# trainer = build_trainer(args)
# print("Arguments:")
# print(args)

In [6]:
print("Building model vocab...")
vocab = util.build_vocab()
print("Vocab complete.")

Building model vocab...
Vocab complete.


In [7]:
print("Loading datamodule...")
args.batch_cost = 8
dm = build_dm(args, vocab)
print("Datamodule complete.")

Loading datamodule...
Using type ARGeometricComplexInterpolant for training
Datamodule complete.


In [8]:
for batch in dm.train_dataloader():
    break

In [31]:
print("Building equinv model...")
model = build_model(args, dm, vocab)
print("Model complete.")

Building equinv model...
Total training steps 332040000
Using model class LigandGenerator
Using CFM class ARMolecularCFM
Model complete.


In [17]:
test_dl = dm.train_dataloader()
for batch in test_dl:
    prior, data, interpolated, masked_data, pocket, pocket_raw, times, rel_times, gen_times = batch
    break

In [None]:
print("Fitting datamodule to model...")
trainer.fit(model, datamodule=dm)
print("Training complete.")

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Fitting datamodule to model...



  | Name              | Type             | Params | Mode
--------------------------------------------------------------
0 | gen               | LigandGenerator  | 29.8 M | eval
1 | ema_gen           | AveragedModel    | 29.8 M | eval
2 | stability_metrics | MetricCollection | 0      | eval
3 | gen_metrics       | MetricCollection | 0      | eval
4 | complex_metrics   | MetricCollection | 0      | eval
5 | conf_metrics      | MetricCollection | 0      | eval
--------------------------------------------------------------
59.6 M    Trainable params
0         Non-trainable params
59.6 M    Total params
238.325   Total estimated model params size (MB)
0         Modules in train mode
1558      Modules in eval mode


Epoch 0:   0%|          | 0/33204 [01:47<?, ?it/s]
Sanity Checking DataLoader 0:   0%|          | 0/2 [00:00<?, ?it/s]



  0%|          | 0/1 [00:00<?, ?it/s]

[rank: 0] Received SIGTERM: 15

Detected KeyboardInterrupt, attempting graceful shutdown ...
