# dataset

In [1]:
# Misc. tools
import os

# Hydra tools
import hydra

from hydra.compose import GlobalHydra
from hydra.core.hydra_config import HydraConfig

from proteinworkshop.constants import HYDRA_CONFIG_PATH
from proteinworkshop.utils.notebook import init_hydra_singleton

version_base = "1.2"  # Note: Need to update whenever Hydra is upgraded
init_hydra_singleton(reload=True, version_base=version_base)

path = HYDRA_CONFIG_PATH
rel_path = os.path.relpath(path, start=".")

GlobalHydra.instance().clear()
hydra.initialize(rel_path, version_base=version_base)

cfg = hydra.compose(
    config_name="train",
    overrides=[
        "encoder=pronet",
        "task=subgraph_distance_prediction",   
        # "task=structure_denoising",    
        "dataset=afdb_swissprot_v4",
        # "dataset.datamodule.num_workers=1",
        "features=fe_subgraph",
        "encoder.node_embedding=False",
        "encoder.pretraining=True",
        # "+aux_task=none",
    ],
    return_hydra_config=True,
)

# Note: Customize as needed e.g., when running a sweep
cfg.hydra.job.num = 0
cfg.hydra.job.id = 0
cfg.hydra.hydra_help.hydra_help = False
cfg.hydra.runtime.output_dir = "outputs"

HydraConfig.instance().set_config(cfg)

In [2]:
from proteinworkshop.configs import config

cfg = config.validate_config(cfg)

In [3]:
print(cfg.keys())
for key in cfg.keys():
    print(key)
    print(cfg[key])

dict_keys(['hydra', 'env', 'dataset', 'features', 'encoder', 'decoder', 'transforms', 'callbacks', 'optimiser', 'scheduler', 'trainer', 'extras', 'metrics', 'task', 'logger', 'name', 'seed', 'num_workers', 'task_name', 'test'])
hydra
env
{'paths': {'root_dir': '${oc.env:ROOT_DIR}', 'data': '${oc.env:DATA_PATH}', 'output_dir': '${hydra:runtime.output_dir}', 'work_dir': '${hydra:runtime.cwd}', 'log_dir': '${oc.env:RUNS_PATH}', 'runs': '${oc.env:RUNS_PATH}', 'run_dir': '${env.paths.runs}/${name}/${env.init_time}'}, 'python': {'version': '${python_version:micro}'}, 'init_time': '${now:%y-%m-%d_%H:%M:%S}'}
dataset
{'datamodule': {'_target_': 'graphein.ml.datasets.foldcomp_dataset.FoldCompLightningDataModule', 'data_dir': '${env.paths.data}/afdb_swissprot_v4/', 'database': 'afdb_swissprot_v4', 'batch_size': 32, 'num_workers': 32, 'train_split': 0.8, 'val_split': 0.1, 'test_split': 0.1, 'pin_memory': True, 'use_graphein': True, 'transform': '${transforms}'}, 'dataset_name': 'afdb_swissprot_v4

In [4]:
from omegaconf import OmegaConf

In [5]:
from proteinworkshop.datasets.atom3d_datamodule import ATOM3DDataModule

In [6]:
from proteinworkshop.configs import config

cfg = config.validate_config(cfg)
# print("Original config:\n", OmegaConf.to_yaml(cfg))
mutable_cfg = OmegaConf.to_container(cfg.dataset.datamodule, resolve=True)
mutable_cfg = OmegaConf.create(mutable_cfg)
# print("Cloned config:\n", OmegaConf.to_yaml(mutable_cfg))
# Instantiate the datamodule with the mutable configuration
datamodule = hydra.utils.instantiate(mutable_cfg)
datamodule.setup("fit")
dl = datamodule.train_dataloader()

for i in dl:
    print(i)
    break

100%|██████████| 542378/542378 [00:00<00:00, 4760535.83it/s]
Processing...
Done!
100%|██████████| 542378/542378 [00:00<00:00, 4793033.28it/s]
Processing...
Done!
100%|██████████| 542378/542378 [00:00<00:00, 4802301.44it/s]
Processing...
Done!
100%|██████████| 542378/542378 [00:00<00:00, 4785360.59it/s]
Processing...
Done!

This DataLoader will create 32 worker processes in total. Our suggested max number of worker in current system is 28, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.



ProteinBatch(fill_value=[32], atom_list=[32], coords=[8139, 37, 3], residues=[32], residue_id=[32], chains=[8139], residue_type=[8139], b_factor=[8139], id=[32], x=[8139], seq_pos=[8139, 1], batch=[8139], ptr=[33])


In [7]:
print(i.x)

tensor[8139] 32Kb [38;2;127;127;127mall_zeros[0m


In [8]:
from proteinworkshop.models.base import BenchMarkModel
import lightning as L

model: L.LightningModule = BenchMarkModel(cfg)

In [9]:
import torch

for i in dl:
    batch = model.featurise(i)
    print(batch)
    break
batch_data = batch
# z, pos, _ = torch.squeeze(batch_data.x.long()), batch_data.pos, batch_data.batch
# print(z.shape, pos.shape) # torch.Size([5762, 16]) torch.Size([5762, 3]) torch.Size([5762])
# from proteinworkshop.models.graph_encoders.pronet import ProNet
# mutable_cfg = OmegaConf.to_containexr(cfg.encoder, resolve=True)
# mutable_cfg = OmegaConf.create(mutable_cfg)
# print("Cloned config:\n", OmegaConf.to_yaml(mutable_cfg))
# Instantiate the datamodule with the mutable configuration
# encoder = hydra.utils.instantiate(mutable_cfg)
# encoder = ProNet()
# print(model)
# print(encoder)
out = model(batch)
# out = encoder(batch)
print(out)
print(batch.device)


This DataLoader will create 32 worker processes in total. Our suggested max number of worker in current system is 28, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.



ProteinBatch(fill_value=[32], atom_list=[32], coords=[11302, 37, 3], residues=[32], residue_id=[32], chains=[11302], residue_type=[11302], b_factor=[11302], id=[32], x=[11302, 23], seq_pos=[11302, 1], batch=[11302], ptr=[33], pos=[11302, 3], edge_index=[2, 203101], subgraphs=[1116, 137], subgraph_distances=[1116], subgraph_lengths=[1116])
{'fused_repr': tensor[1116, 256] n=285696 (1.1Mb) x∈[-0.479, 0.407] μ=0.007 σ=0.079 grad StackBackward0, 'subgraph_distances': tensor[1116, 1] 4.4Kb x∈[-0.031, 0.020] μ=-0.005 σ=0.010 grad AddmmBackward0}


ProteinBatch(fill_value=[32], atom_list=[32], coords=[10916, 37, 3], residues=[32], residue_id=[32], chains=[10916], residue_type=[10916], b_factor=[10916], id=[32], x=[10916, 23], seq_pos=[10916, 1], batch=[10916], ptr=[33], pos=[10916, 3], edge_index=[2, 153666], subgraphs=[1076, 149], dist=[1076], subgraph_lengths=[1076])
b ProteinBatch(fill_value=[32], atom_list=[32], coords=[10916, 37, 3], residues=[32], residue_id=[32], chains=[10916], residue_type=[10916], b_factor=[10916], id=[32], x=[10916, 23], seq_pos=[10916, 1], batch=[10916], ptr=[33], pos=[10916, 3], edge_index=[2, 153666], subgraphs=[1076, 149], dist=[1076], subgraph_lengths=[1076])
o {'fused_repr': tensor[1076, 256] n=275456 (1.1Mb) x∈[-0.880, 0.835] μ=0.012 σ=0.141 grad StackBackward0}
h subgraph_distance
O {'fused_repr': tensor[1076, 256] n=275456 (1.1Mb) x∈[-0.880, 0.835] μ=0.012 σ=0.141 grad StackBackward0, 'subgraph_distance': tensor[1076, 1] 4.2Kb x∈[-0.010, 0.043] μ=0.014 σ=0.014 grad AddmmBackward0}
{'fused_repr': tensor[1076, 256] n=275456 (1.1Mb) x∈[-0.880, 0.835] μ=0.012 σ=0.141 grad StackBackward0, 'subgraph_distance': tensor[1076, 1] 4.2Kb x∈[-0.010, 0.043] μ=0.014 σ=0.014 grad AddmmBackward0}

In [13]:
print(batch_data.x)
print(batch_data.edge_index)
print(batch_data.atom_list)
print(batch_data.coords)
print(batch_data.residues)
print(batch_data.id)
print(batch_data.residue_id)
print(batch_data.residue_type)
print(batch_data.chains)
print(batch_data.graph_y)
print(batch_data.amino_acid_one_hot)
print(batch_data.batch)
print(batch_data.ptr)
print(batch_data.pos)
print(batch_data.edge_index)
print(batch_data.edge_type)
print(batch_data.edge_attr)

tensor[5436, 23] n=125028 (0.5Mb) x∈[0., 1.000] μ=0.043 σ=0.204
tensor[2, 86976] i64 n=173952 (1.3Mb) x∈[0, 5435] μ=2.717e+03 σ=1.569e+03
['N', 'CA', 'C', 'O', 'CB', 'OG', 'CG', 'CD1', 'CD2', 'CE1', 'CE2', 'CZ', 'OD1', 'ND2', 'CG1', 'CG2', 'CD', 'CE', 'NZ', 'OD2', 'OE1', 'NE2', 'OE2', 'OH', 'NE', 'NH1', 'NH2', 'OG1', 'SD', 'ND1', 'SG', 'NE1', 'CE3', 'CZ2', 'CZ3', 'CH2', 'OXT']
tensor[5436, 37, 3] n=603396 (2.3Mb) x∈[-79.883, 167.761] μ=5.206 σ=16.581
[['THR', 'LEU', 'THR', 'ILE', 'ASP', 'ASP', 'GLY', 'ASN', 'ILE', 'GLU', 'ILE', 'VAL', 'GLY', 'THR', 'GLY', 'VAL', 'LYS', 'GLY', 'LYS', 'LEU', 'PRO', 'THR', 'VAL', 'TRP', 'LEU', 'GLN', 'TYR', 'GLY', 'GLN', 'VAL', 'ASN', 'LEU', 'LYS', 'ALA', 'SER', 'GLY', 'GLY', 'ASN', 'GLY', 'LYS', 'TYR', 'THR', 'TRP', 'ARG', 'SER', 'ALA', 'ASN', 'PRO', 'ALA', 'ILE', 'ALA', 'SER', 'VAL', 'ASP', 'ALA', 'SER', 'SER', 'GLY', 'GLN', 'VAL', 'THR', 'LEU', 'LYS', 'GLU', 'LYS', 'GLY', 'THR', 'THR', 'THR', 'ILE', 'SER', 'VAL', 'ILE', 'SER', 'SER', 'ASP', 'ASN', 'GLN

In [10]:
import torch

# Assuming `dl` is your data loader and `model.featurise(i)` returns a batch with `.x`, `.pos`, and `.batch` attributes
unique_values = set()

for i in dl:
    batch = model.featurise(i)
    z, pos, _ = torch.squeeze(batch.x.long()), batch.pos, batch.batch
    
    # Reshape z to a flat tensor to extract unique values
    z_flat = z.view(-1)
    unique_values.update(z_flat.unique().tolist())

print(unique_values)


KeyboardInterrupt: 

afdb_swissprot_v4

ProteinBatch(fill_value=[32], atom_list=[32], coords=[12984, 37, 3], residues=[32], residue_id=[32], chains=[12984], residue_type=[12984], b_factor=[12984], id=[32], x=[12984], seq_pos=[12984, 1], batch=[12984], ptr=[33])



fold_fold

DataProteinBatch(fill_value=1e-05, atom_list=[37], coords=[6159, 37, 3], residues=[32], id=[32], residue_id=[32], residue_type=[6159], chains=[6159], graph_y=[32], x=[6159], amino_acid_one_hot=[6159, 23], seq_pos=[6159, 1], batch=[6159], ptr=[33])

No feature:

DataProteinBatch(fill_value=1e-05, atom_list=[37], coords=[4221, 37, 3], residues=[32], id=[32], residue_id=[32], residue_type=[4221], chains=[4221], graph_y=[32], x=[4221, 16], amino_acid_one_hot=[4221, 23], seq_pos=[4221, 1], batch=[4221], ptr=[33], pos=[4221, 3], edge_index=[2, 67536], edge_type=[1, 67536], num_relation=1, edge_attr=[67536, 1])

ca_base:

DataProteinBatch(fill_value=1e-05, atom_list=[37], coords=[5388, 37, 3], residues=[32], id=[32], residue_id=[32], residue_type=[5388], chains=[5388], graph_y=[32], x=[5388, 23], amino_acid_one_hot=[5388, 23], seq_pos=[5388, 1], batch=[5388], ptr=[33], pos=[5388, 3], edge_index=[2, 86208], edge_type=[1, 86208], num_relation=1, edge_attr=[86208, 1])
torch.Size([5388, 23]) torch.Size([5388, 3])

# train

In [2]:
import copy
import sys
from typing import List, Optional

import graphein
import hydra
import lightning as L
import lovely_tensors as lt
import torch
import torch.nn as nn
import torch_geometric
from graphein.protein.tensor.dataloader import ProteinDataLoader
from lightning.pytorch.callbacks import Callback
from lightning.pytorch.loggers import Logger
from loguru import logger as log
from omegaconf import DictConfig

from proteinworkshop import (
    constants,
    register_custom_omegaconf_resolvers,
    utils,
)
from proteinworkshop.configs import config
from proteinworkshop.models.base import BenchMarkModel

In [6]:
import pytorch_lightning as pl

In [3]:
from proteinworkshop.models.base import BenchMarkModel
from proteinworkshop.models.graph_encoders.pronet import ProNet
model: L.LightningModule = BenchMarkModel(cfg)
datamodule: L.LightningDataModule = hydra.utils.instantiate(
        cfg.dataset.datamodule
    )
callbacks: List[Callback] = utils.callbacks.instantiate_callbacks(
        cfg.get("callbacks")
    )
logger: List[Logger] = utils.loggers.instantiate_loggers(cfg.get("logger"))
trainer: L.Trainer = hydra.utils.instantiate(
        cfg.trainer, callbacks=callbacks, logger=logger
    )
# trainer.connectors.logger_connector.result.ResultCollection.extract_batch_size = 32
with torch.no_grad():
    datamodule.setup(stage="lazy_init")  # type: ignore
    batch = next(iter(datamodule.val_dataloader()))
    log.info(f"Unfeaturized batch: {batch}")
    batch = model.featurise(batch)
    log.info(f"Featurized batch: {batch}")
    log.info(f"Example labels: {model.get_labels(batch)}")
    # Check batch has required attributes
    for attr in model.encoder.required_batch_attributes:  # type: ignore
        if not hasattr(batch, attr):
            raise AttributeError(
                f"Batch {batch} does not have required attribute: {attr} ({model.encoder.required_batch_attributes})"
            )
    out = model(batch)
    log.info(f"Model output: {out}")
    del batch, out



Lazy modules are a new feature under heavy development so changes to the API or functionality can happen at any moment.



Trainer already configured with model summary callbacks: [<class 'lightning.pytorch.callbacks.rich_model_summary.RichModelSummary'>]. Skipping setting a default `ModelSummary` callback.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


100%|██████████| 736/736 [00:00<00:00, 4782.98it/s]


In [4]:
# dl = datamodule.train_dataloader()
# for i in dl:
#     batch = model.featurise(i)
    

In [4]:
if cfg.get("compile"):
    log.info("Compiling model!")
    model = torch_geometric.compile(model, dynamic=True)

: 

In [5]:

trainer.fit(
    model=model, datamodule=datamodule, ckpt_path=cfg.get("ckpt_path")
)

You are using a CUDA device ('NVIDIA GeForce RTX 4090') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision


100%|██████████| 12312/12312 [00:02<00:00, 4348.64it/s]



Checkpoint directory /home/zhang/Projects/3d/ProteinWorkshop/notebooks/outputs/checkpoints exists and is not empty.

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]



You have overridden `on_after_batch_transfer` in `LightningModule` but have passed in a `LightningDataModule`. It will use the implementation from `LightningModule` instance.



Output()

Metric val/graph_label/accuracy improved. New best score: 0.037
Metric train/loss/total improved. New best score: 6.056
Epoch 0, global step 385: 'val/graph_label/accuracy' reached 0.03668 (best 0.03668), saving model to '/home/zhang/Projects/3d/ProteinWorkshop/notebooks/outputs/checkpoints/epoch_000-v5.ckpt' as top 1
