Manually reproduce what train.py would do

In [1]:
import os
os.chdir('../..')

In [2]:
from argparse import Namespace

args_dict = {
    # General args
    "data_path": f'data/complex/crossdock-no-litpcba/smol',
    "dataset": 'crossdock',
    "is_pseudo_complex": False,
    "trial_run": False,
    "complex_debug": False,
    "model_checkpoint": None,
    "num_workers": None,
    "num_gpus": 1,

    # Model args
    "d_model": 384,
    "n_layers": 12,
    "d_message": 128,
    "d_edge": 128,
    "n_coord_sets": 64,
    "n_attn_heads": 32,
    "d_message_hidden": 128,
    "coord_norm": "length",
    "size_emb": 64,
    "max_atoms": 256,
    "arch": "semla",

    # Protein model args
    "pocket_n_layers": 4,
    "fixed_equi": False,
    "pocket_d_inv": 256,

    # Training args
    "epochs": 10000,
    "val_check_epochs": 10,
    "lr": 0.0003,
    "batch_cost": 500,
    "acc_batches": 1,
    "gradient_clip_val": 1.0,
    "dist_loss_weight": 0.0,
    "type_loss_weight": 0.2,
    "bond_loss_weight": 1.0,
    "charge_loss_weight": 1.0,
    "conf_coord_strategy": "gaussian",
    "categorical_strategy": "uniform-sample",
    "lr_schedule": "constant",
    "warm_up_steps": 10000,
    "bucket_cost_scale": "linear",
    "use_ema": True,
    "self_condition": True,

    # Flow matching and sampling args
    "n_validation_mols": 2000,
    "n_training_mols": float("inf"),
    "num_inference_steps": 100,
    "cat_sampling_noise_level": 1,
    "coord_noise_std_dev": 0.2,
    "type_dist_temp": 1.0,
    "time_alpha": 2.0,
    "time_beta": 1.0,
    "optimal_transport": "equivariant",

    # Autoregressive args
    "t_per_ar_action": 0.2,
    "max_interp_time": 1.0,
    "decomposition_strategy": "atom",
    "ordering_strategy": "connected",
    "max_action_t": 0.8,
    "max_num_cuts": None,
    "min_group_size": 5,

    # Logging args
    "monitor": "val-validity",
    "monitor_mode": "max",
    "use_complex_metrics": False,
}


args = Namespace(**args_dict)

In [3]:
%load_ext autoreload
%autoreload 2
import cgflow.scriptutil as util
from cgflow.buildutil import build_dm, build_model, build_trainer

In [4]:
trainer = build_trainer(args)
print("Arguments:")
print(args)

Using precision '32'


[34m[1mwandb[0m: Currently logged in as: [33mtsa87[0m to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


/home/to.shen/.conda/envs/cgflow/lib/python3.11/site-packages/lightning/fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /home/to.shen/.conda/envs/cgflow/lib/python3.11/site ...
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


Arguments:
Namespace(data_path='data/complex/crossdock-no-litpcba/smol', dataset='crossdock', is_pseudo_complex=False, trial_run=False, complex_debug=False, model_checkpoint=None, num_workers=None, num_gpus=1, d_model=384, n_layers=12, d_message=128, d_edge=128, n_coord_sets=64, n_attn_heads=32, d_message_hidden=128, coord_norm='length', size_emb=64, max_atoms=256, arch='semla', pocket_n_layers=4, fixed_equi=False, pocket_d_inv=256, epochs=10000, val_check_epochs=10, lr=0.0003, batch_cost=500, acc_batches=1, gradient_clip_val=1.0, dist_loss_weight=0.0, type_loss_weight=0.2, bond_loss_weight=1.0, charge_loss_weight=1.0, conf_coord_strategy='gaussian', categorical_strategy='uniform-sample', lr_schedule='constant', warm_up_steps=10000, bucket_cost_scale='linear', use_ema=True, self_condition=True, n_validation_mols=2000, n_training_mols=inf, num_inference_steps=100, cat_sampling_noise_level=1, coord_noise_std_dev=0.2, type_dist_temp=1.0, time_alpha=2.0, time_beta=1.0, optimal_transport='e

In [5]:
print("Building model vocab...")
vocab = util.build_vocab()
print("Vocab complete.")

Building model vocab...
Vocab complete.


In [6]:
print("Loading datamodule...")
dm = build_dm(args, vocab)
print("Datamodule complete.")

Loading datamodule...
Using type GeometricComplexInterpolant for training
Datamodule complete.


In [7]:
print("Building equinv model...")
model = build_model(args, dm, vocab)
print("Model complete.")

Building equinv model...
Total training steps 332040000
Using model class LigandGenerator
Creating RDKit mols from training SMILES...
Initialising novelty metric...
Novelty metric complete.
Using CFM class MolecularCFM
Model complete.


In [8]:
test_dl = dm.train_dataloader()
for batch in test_dl:
    prior, data, interpolated, pockets, _, t = batch
    break

In [9]:
print("Fitting datamodule to model...")
trainer.fit(model, datamodule=dm)
print("Training complete.")

You are using a CUDA device ('NVIDIA L40S') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Fitting datamodule to model...



  | Name              | Type             | Params | Mode 
---------------------------------------------------------------
0 | gen               | LigandGenerator  | 29.8 M | train
1 | ema_gen           | AveragedModel    | 29.8 M | train
2 | stability_metrics | MetricCollection | 0      | train
3 | gen_metrics       | MetricCollection | 0      | train
4 | complex_metrics   | MetricCollection | 0      | train
---------------------------------------------------------------
59.6 M    Trainable params
0         Non-trainable params
59.6 M    Total params
238.322   Total estimated model params size (MB)
1555      Modules in train mode
0         Modules in eval mode


Sanity Checking DataLoader 0:  50%|█████     | 1/2 [00:14<00:14,  0.07it/s]

[23:11:37] 

****
Post-condition Violation
Element '<PAD>' not found
Violation occurred on line 93 in file /project/build/temp.linux-x86_64-cpython-311/rdkit/Code/GraphMol/PeriodicTable.h
Failed Expression: anum > -1
----------
Stacktrace:
----------
****

[23:11:37] 

****
Post-condition Violation
Element '<PAD>' not found
Violation occurred on line 93 in file /project/build/temp.linux-x86_64-cpython-311/rdkit/Code/GraphMol/PeriodicTable.h
Failed Expression: anum > -1
----------
Stacktrace:
----------
****

[23:11:37] 

****
Post-condition Violation
Element '<PAD>' not found
Violation occurred on line 93 in file /project/build/temp.linux-x86_64-cpython-311/rdkit/Code/GraphMol/PeriodicTable.h
Failed Expression: anum > -1
----------
Stacktrace:
----------
****



                                                                           

[23:11:47] 

****
Post-condition Violation
Element '<PAD>' not found
Violation occurred on line 93 in file /project/build/temp.linux-x86_64-cpython-311/rdkit/Code/GraphMol/PeriodicTable.h
Failed Expression: anum > -1
----------
Stacktrace:
----------
****

[23:11:47] 

****
Post-condition Violation
Element '<PAD>' not found
Violation occurred on line 93 in file /project/build/temp.linux-x86_64-cpython-311/rdkit/Code/GraphMol/PeriodicTable.h
Failed Expression: anum > -1
----------
Stacktrace:
----------
****

[23:11:47] 

****
Post-condition Violation
Element '<PAD>' not found
Violation occurred on line 93 in file /project/build/temp.linux-x86_64-cpython-311/rdkit/Code/GraphMol/PeriodicTable.h
Failed Expression: anum > -1
----------
Stacktrace:
----------
****



Epoch 0:   0%|          | 11/33204 [00:05<4:16:10,  2.16it/s, v_num=upnq, train-loss=2.530]

OutOfMemoryError: CUDA out of memory. Tried to allocate 3.98 GiB. GPU 0 has a total capacity of 44.53 GiB of which 3.81 GiB is free. Including non-PyTorch memory, this process has 40.70 GiB memory in use. Of the allocated memory 35.06 GiB is allocated by PyTorch, and 5.15 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)