In [1]:
from pathlib import Path
data_path = Path("/media/data/pnovelli/md_datasets/DESRES_folding_trajs/DESRES-Trajectory_CLN025-0-protein/CLN025-0-protein")
trajectory_files = list(data_path.glob("*.dcd"))
top = next(data_path.glob("*.pdb"))

In [None]:
from mlcolvar.utils.io import create_dataset_from_trajectories

dataset = create_dataset_from_trajectories(
    trajectories=[
        str(trajectory_files[0])
    ],
    top=[str(top)],
    cutoff=8.0,
    create_labels=True,
    system_selection="all and not type H",
    show_progress=False
    )



[TYR1-N, TYR1-CA, TYR1-CB, TYR1-CG, TYR1-CD1, TYR1-CE1, TYR1-CZ, TYR1-OH, TYR1-CD2, TYR1-CE2, TYR1-C, TYR1-O, TYR2-N, TYR2-CA, TYR2-CB, TYR2-CG, TYR2-CD1, TYR2-CE1, TYR2-CZ, TYR2-OH, TYR2-CD2, TYR2-CE2, TYR2-C, TYR2-O, ASP3-N, ASP3-CA, ASP3-CB, ASP3-CG, ASP3-OD1, ASP3-OD2, ASP3-C, ASP3-O, PRO4-N, PRO4-CD, PRO4-CA, PRO4-CB, PRO4-CG, PRO4-C, PRO4-O, GLU5-N, GLU5-CA, GLU5-CB, GLU5-CG, GLU5-CD, GLU5-OE1, GLU5-OE2, GLU5-C, GLU5-O, THR6-N, THR6-CA, THR6-CB, THR6-OG1, THR6-CG2, THR6-C, THR6-O, GLY7-N, GLY7-CA, GLY7-C, GLY7-O, THR8-N, THR8-CA, THR8-CB, THR8-OG1, THR8-CG2, THR8-C, THR8-O, TRP9-N, TRP9-CA, TRP9-CB, TRP9-CG, TRP9-CD1, TRP9-NE1, TRP9-CE2, TRP9-CD2, TRP9-CE3, TRP9-CZ3, TRP9-CZ2, TRP9-CH2, TRP9-C, TRP9-O, TYR10-C, TYR10-O, TYR10-OXT, TYR10-N, TYR10-CA, TYR10-CB, TYR10-CG, TYR10-CD1, TYR10-CE1, TYR10-CZ, TYR10-OH, TYR10-CD2, TYR10-CE2]


In [11]:
from mlcolvar.utils.timelagged import create_timelagged_dataset

timelagged_dataset = create_timelagged_dataset(dataset, progress_bar=True)

100%|██████████| 9997/9997 [00:01<00:00, 6735.36it/s]


Model init

In [42]:
from linear_operator_learning.nn import SimNorm
from mlcolvar.core.nn.graph.schnet import SchNetModel
import torch

n_features = 512

gnn_model = SchNetModel(n_out=n_features,
    cutoff=timelagged_dataset.metadata['cutoff'],
    atomic_numbers=timelagged_dataset.metadata['z_table'],
    n_bases=6,
    n_layers=2,
    n_filters=32,
    n_hidden_channels=32
)
model = torch.nn.Sequential(
    gnn_model,
    SimNorm()
)

Training

In [43]:
from lightning import Trainer
from mlcolvar.data import DictModule


trainer = Trainer(
    logger=False,
    enable_checkpointing=False,
    accelerator='cuda',
    max_epochs=10,
    enable_model_summary=True
)

datamodule = DictModule(timelagged_dataset, batch_size=8)

Trainer will use only 1 of 8 GPUs because it is running inside an interactive / notebook environment. You may try to set `Trainer(devices=8)` but please note that multi-GPU inside interactive / notebook environments is considered experimental and unstable. Your mileage may vary.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


In [44]:
datamodule.setup()

In [40]:
for batch in datamodule.train_dataloader():
    print(batch)
    break

def _setup_graph_data(train_batch, key : str='data_list'):
    data = train_batch[key]
    data['positions'].requires_grad_(True)
    data['node_attrs'].requires_grad_(True)
    return data

{'data_list': DataBatch(edge_index=[2, 35752], shifts=[35752, 3], unit_shifts=[35752, 3], positions=[744, 3], cell=[24, 3], node_attrs=[744, 3], graph_labels=[8, 1], n_system=[8, 1], n_env=[8, 1], weight=[8], names_idx=[744], batch=[744], ptr=[9]), 'data_list_lag': DataBatch(edge_index=[2, 36702], shifts=[36702, 3], unit_shifts=[36702, 3], positions=[744, 3], cell=[24, 3], node_attrs=[744, 3], graph_labels=[8, 1], n_system=[8, 1], n_env=[8, 1], weight=[8], names_idx=[744], batch=[744], ptr=[9])}


In [46]:
model(_setup_graph_data(batch)).shape

torch.Size([8, 512])

In [32]:
batch

{'data_list': DataBatch(edge_index=[2, 39152], shifts=[39152, 3], unit_shifts=[39152, 3], positions=[744, 3], cell=[24, 3], node_attrs=[744, 3], graph_labels=[8, 1], n_system=[8, 1], n_env=[8, 1], weight=[8], names_idx=[744], batch=[744], ptr=[9]),
 'data_list_lag': DataBatch(edge_index=[2, 39388], shifts=[39388, 3], unit_shifts=[39388, 3], positions=[744, 3], cell=[24, 3], node_attrs=[744, 3], graph_labels=[8, 1], n_system=[8, 1], n_env=[8, 1], weight=[8], names_idx=[744], batch=[744], ptr=[9])}

In [None]:
from src.loss import RegSpectralLoss