In [3]:
from pathlib import Path
data_path = Path("/media/data/pnovelli/md_datasets/DESRES_folding_trajs/DESRES-Trajectory_CLN025-0-protein/CLN025-0-protein")
trajectory_files = [str(traj) for traj in data_path.glob("*.dcd")]
top = next(data_path.glob("*.pdb")).__str__()

Data loading

In [23]:
import mdtraj
from mlcolvar.utils.io import _configures_from_trajectory, _z_table_from_top, _names_from_top
from mlcolvar.data.graph.utils import create_dataset_from_configurations

stride = 100
system_selection="all and not type H"
traj = mdtraj.load(trajectory_files, top=top, stride = stride)

In [7]:
traj.top = mdtraj.core.trajectory.load_topology(top)

In [14]:
configs = _configures_from_trajectory(traj, system_selection=system_selection)

In [24]:
z_table = _z_table_from_top([traj.top])
atom_names = _names_from_top([traj.top])

In [26]:
show_progress = False
dataset = create_dataset_from_configurations(
    configs,
    z_table,
    6.0,
    0.0,
    atom_names,
    True,
    show_progress
)

[TYR1-N, TYR1-H, TYR1-H2, TYR1-H3, TYR1-CA, TYR1-HA, TYR1-CB, TYR1-HB3, TYR1-HB2, TYR1-CG, TYR1-CD1, TYR1-HD1, TYR1-CE1, TYR1-HE1, TYR1-CZ, TYR1-OH, TYR1-HH, TYR1-CD2, TYR1-HD2, TYR1-CE2, TYR1-HE2, TYR1-C, TYR1-O, TYR2-N, TYR2-H, TYR2-CA, TYR2-HA, TYR2-CB, TYR2-HB3, TYR2-HB2, TYR2-CG, TYR2-CD1, TYR2-HD1, TYR2-CE1, TYR2-HE1, TYR2-CZ, TYR2-OH, TYR2-HH, TYR2-CD2, TYR2-HD2, TYR2-CE2, TYR2-HE2, TYR2-C, TYR2-O, ASP3-N, ASP3-H, ASP3-CA, ASP3-HA, ASP3-CB, ASP3-HB3, ASP3-HB2, ASP3-CG, ASP3-OD1, ASP3-OD2, ASP3-C, ASP3-O, PRO4-N, PRO4-CD, PRO4-HD3, PRO4-HD2, PRO4-CA, PRO4-HA, PRO4-CB, PRO4-HB3, PRO4-HB2, PRO4-CG, PRO4-HG3, PRO4-HG2, PRO4-C, PRO4-O, GLU5-N, GLU5-H, GLU5-CA, GLU5-HA, GLU5-CB, GLU5-HB3, GLU5-HB2, GLU5-CG, GLU5-HG3, GLU5-HG2, GLU5-CD, GLU5-OE1, GLU5-OE2, GLU5-C, GLU5-O, THR6-N, THR6-H, THR6-CA, THR6-HA, THR6-CB, THR6-HB, THR6-OG1, THR6-HG1, THR6-CG2, THR6-HG21, THR6-HG22, THR6-HG23, THR6-C, THR6-O, GLY7-N, GLY7-H, GLY7-CA, GLY7-HA3, GLY7-HA2, GLY7-C, GLY7-O, THR8-N, THR8-H, THR8-CA, 

In [27]:
from mlcolvar.utils.timelagged import create_timelagged_dataset

timelagged_dataset = create_timelagged_dataset(dataset, progress_bar=True)

100%|██████████| 5345/5345 [00:00<00:00, 5808.12it/s]


Model init

In [42]:
from linear_operator_learning.nn import SimNorm
from mlcolvar.core.nn.graph.schnet import SchNetModel
import torch

n_features = 512

gnn_model = SchNetModel(n_out=n_features,
    cutoff=timelagged_dataset.metadata['cutoff'],
    atomic_numbers=timelagged_dataset.metadata['z_table'],
    n_bases=6,
    n_layers=2,
    n_filters=32,
    n_hidden_channels=32
)
model = torch.nn.Sequential(
    gnn_model,
    SimNorm()
)

Training

In [43]:
from lightning import Trainer
from mlcolvar.data import DictModule


trainer = Trainer(
    logger=False,
    enable_checkpointing=False,
    accelerator='cuda',
    max_epochs=10,
    enable_model_summary=True
)

datamodule = DictModule(timelagged_dataset, batch_size=8)

Trainer will use only 1 of 8 GPUs because it is running inside an interactive / notebook environment. You may try to set `Trainer(devices=8)` but please note that multi-GPU inside interactive / notebook environments is considered experimental and unstable. Your mileage may vary.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


In [44]:
datamodule.setup()

In [40]:
for batch in datamodule.train_dataloader():
    print(batch)
    break

def _setup_graph_data(train_batch, key : str='data_list'):
    data = train_batch[key]
    data['positions'].requires_grad_(True)
    data['node_attrs'].requires_grad_(True)
    return data

{'data_list': DataBatch(edge_index=[2, 35752], shifts=[35752, 3], unit_shifts=[35752, 3], positions=[744, 3], cell=[24, 3], node_attrs=[744, 3], graph_labels=[8, 1], n_system=[8, 1], n_env=[8, 1], weight=[8], names_idx=[744], batch=[744], ptr=[9]), 'data_list_lag': DataBatch(edge_index=[2, 36702], shifts=[36702, 3], unit_shifts=[36702, 3], positions=[744, 3], cell=[24, 3], node_attrs=[744, 3], graph_labels=[8, 1], n_system=[8, 1], n_env=[8, 1], weight=[8], names_idx=[744], batch=[744], ptr=[9])}


In [46]:
model(_setup_graph_data(batch)).shape

torch.Size([8, 512])

In [32]:
batch

{'data_list': DataBatch(edge_index=[2, 39152], shifts=[39152, 3], unit_shifts=[39152, 3], positions=[744, 3], cell=[24, 3], node_attrs=[744, 3], graph_labels=[8, 1], n_system=[8, 1], n_env=[8, 1], weight=[8], names_idx=[744], batch=[744], ptr=[9]),
 'data_list_lag': DataBatch(edge_index=[2, 39388], shifts=[39388, 3], unit_shifts=[39388, 3], positions=[744, 3], cell=[24, 3], node_attrs=[744, 3], graph_labels=[8, 1], n_system=[8, 1], n_env=[8, 1], weight=[8], names_idx=[744], batch=[744], ptr=[9])}

In [None]:
from src.loss import RegSpectralLoss