# dataset

In [1]:
# Misc. tools
import os

# Hydra tools
import hydra

from hydra.compose import GlobalHydra
from hydra.core.hydra_config import HydraConfig

from proteinworkshop.constants import HYDRA_CONFIG_PATH
from proteinworkshop.utils.notebook import init_hydra_singleton

version_base = "1.2"  # Note: Need to update whenever Hydra is upgraded
init_hydra_singleton(reload=True, version_base=version_base)

path = HYDRA_CONFIG_PATH
rel_path = os.path.relpath(path, start=".")
# print(rel_path)
GlobalHydra.instance().clear()
hydra.initialize(rel_path, version_base=version_base)

cfg = hydra.compose(
    config_name="train",
    overrides=[
        "encoder=pronet",
        "encoder.level='aminoacid'",
        "encoder.num_blocks=4",
        "encoder.hidden_channels=128",
        "encoder.out_channels=384",
        "encoder.mid_emb=64",
        "encoder.num_radial=6",
        "encoder.num_spherical=2",
        "encoder.cutoff=10.0",
        "encoder.max_num_neighbors=32",
        "encoder.int_emb_layers=3",
        "encoder.out_layers=2",
        "encoder.num_pos_emb=16",
        "encoder.dropout=0.3",
        "encoder.data_augment_eachlayer=True",
        "encoder.euler_noise=False",
        "encoder.pretraining=False",
        "encoder.node_embedding=False",

        "decoder.graph_label.dummy=True",

        "task=multiclass_graph_classification",
        "dataset=ec_reaction",
        "dataset.datamodule.batch_size=32",
        "features=ca_base", 
        "+aux_task=none",
        
        "trainer.max_epochs=400",
        "optimiser=adam",
        "optimiser.optimizer.lr=5e-4",
        "callbacks.early_stopping.patience=200",
        "test=False",
        "scheduler=steplr",

        ## for test ONLY
        # "task_name=test",  # here
        # "ckpt_path_test=/home/zhang/Projects/3d/ProteinWorkshop/notebooks/outputs/checkpoints/epoch_275.ckpt", # here
        # "optimizer.weight_decay=0.5"
    ],
    return_hydra_config=True,
)

# Note: Customize as needed e.g., when running a sweep
cfg.hydra.job.num = 0
cfg.hydra.job.id = 0
cfg.hydra.hydra_help.hydra_help = False
cfg.hydra.runtime.output_dir = "outputs"

HydraConfig.instance().set_config(cfg)

In [2]:
from proteinworkshop.configs import config

cfg = config.validate_config(cfg)

In [3]:
print(cfg.keys())
for key in cfg.keys():
    print(key)
    print(cfg[key])

dict_keys(['hydra', 'env', 'dataset', 'features', 'encoder', 'decoder', 'transforms', 'callbacks', 'optimiser', 'scheduler', 'trainer', 'extras', 'metrics', 'task', 'logger', 'name', 'seed', 'num_workers', 'task_name', 'ckpt_path_test', 'test', 'aux_task'])
hydra
env
{'paths': {'root_dir': '${oc.env:ROOT_DIR}', 'data': '${oc.env:DATA_PATH}', 'output_dir': '${hydra:runtime.output_dir}', 'work_dir': '${hydra:runtime.cwd}', 'log_dir': '${oc.env:RUNS_PATH}', 'runs': '${oc.env:RUNS_PATH}', 'run_dir': '${env.paths.runs}/${name}/${env.init_time}'}, 'python': {'version': '${python_version:micro}'}, 'init_time': '${now:%y-%m-%d_%H:%M:%S}'}
dataset
{'datamodule': {'_target_': 'proteinworkshop.datasets.ec_reaction.EnzymeCommissionReactionDataset', 'path': '${env.paths.data}/ECReaction/', 'pdb_dir': '${env.paths.data}/pdb/', 'format': 'pdb', 'batch_size': 32, 'pin_memory': True, 'num_workers': 8, 'dataset_fraction': 1.0, 'shuffle_labels': False, 'transforms': '${transforms}', 'overwrite': False, '

In [4]:
from omegaconf import OmegaConf

In [5]:
from proteinworkshop.datasets.atom3d_datamodule import ATOM3DDataModule

In [6]:
from proteinworkshop.configs import config

cfg = config.validate_config(cfg)
# print("Original config:\n", OmegaConf.to_yaml(cfg))
mutable_cfg = OmegaConf.to_container(cfg.dataset.datamodule, resolve=True)
mutable_cfg = OmegaConf.create(mutable_cfg)
# print("Cloned config:\n", OmegaConf.to_yaml(mutable_cfg))
# Instantiate the datamodule with the mutable configuration
datamodule = hydra.utils.instantiate(mutable_cfg)
datamodule.setup("fit")
dl = datamodule.train_dataloader()

for i in dl:
    print(i)
    break

100%|██████████| 29161/29161 [00:09<00:00, 3108.70it/s]


100%|██████████| 2558/2558 [00:00<00:00, 3199.91it/s]


100%|██████████| 29161/29161 [00:08<00:00, 3243.87it/s]


DataProteinBatch(fill_value=1e-05, atom_list=[37], coords=[10513, 37, 3], residues=[32], id=[32], residue_id=[32], residue_type=[10513], chains=[10513], graph_y=[32], x=[10513], amino_acid_one_hot=[10513, 23], seq_pos=[10513, 1], batch=[10513], ptr=[33])


In [7]:
print(i.x)

tensor[10513] 41Kb [38;2;127;127;127mall_zeros[0m


In [8]:
from proteinworkshop.models.base import BenchMarkModel
import lightning as L

model: L.LightningModule = BenchMarkModel(cfg)

In [11]:
import torch

for i in dl:
    batch = model.featurise(i)
    # print(batch)
    break
batch_data = batch
# z, pos, _ = torch.squeeze(batch_data.x.long()), batch_data.pos, batch_data.batch
# print(z.shape, pos.shape) # torch.Size([5762, 16]) torch.Size([5762, 3]) torch.Size([5762])
# from proteinworkshop.models.graph_encoders.pronet import ProNet
# mutable_cfg = OmegaConf.to_containexr(cfg.encoder, resolve=True)
# mutable_cfg = OmegaConf.create(mutable_cfg)
# print("Cloned config:\n", OmegaConf.to_yaml(mutable_cfg))
# Instantiate the datamodule with the mutable configuration
# encoder = hydra.utils.instantiate(mutable_cfg)
# encoder = ProNet()
# print(model)
# print(encoder)
out = model(batch)
# out = encoder(batch)
print(out)
# print(batch.shape)

{'graph_embedding': tensor[32, 384] n=12288 (48Kb) x∈[-0.218, 0.278] μ=0.001 σ=0.040 grad AddmmBackward0, 'graph_label': tensor[32, 384] n=12288 (48Kb) x∈[-0.218, 0.278] μ=0.001 σ=0.040 grad AddmmBackward0}


ProteinBatch(fill_value=[32], atom_list=[32], coords=[10916, 37, 3], residues=[32], residue_id=[32], chains=[10916], residue_type=[10916], b_factor=[10916], id=[32], x=[10916, 23], seq_pos=[10916, 1], batch=[10916], ptr=[33], pos=[10916, 3], edge_index=[2, 153666], subgraphs=[1076, 149], dist=[1076], subgraph_lengths=[1076])
b ProteinBatch(fill_value=[32], atom_list=[32], coords=[10916, 37, 3], residues=[32], residue_id=[32], chains=[10916], residue_type=[10916], b_factor=[10916], id=[32], x=[10916, 23], seq_pos=[10916, 1], batch=[10916], ptr=[33], pos=[10916, 3], edge_index=[2, 153666], subgraphs=[1076, 149], dist=[1076], subgraph_lengths=[1076])
o {'fused_repr': tensor[1076, 256] n=275456 (1.1Mb) x∈[-0.880, 0.835] μ=0.012 σ=0.141 grad StackBackward0}
h subgraph_distance
O {'fused_repr': tensor[1076, 256] n=275456 (1.1Mb) x∈[-0.880, 0.835] μ=0.012 σ=0.141 grad StackBackward0, 'subgraph_distance': tensor[1076, 1] 4.2Kb x∈[-0.010, 0.043] μ=0.014 σ=0.014 grad AddmmBackward0}
{'fused_repr': tensor[1076, 256] n=275456 (1.1Mb) x∈[-0.880, 0.835] μ=0.012 σ=0.141 grad StackBackward0, 'subgraph_distance': tensor[1076, 1] 4.2Kb x∈[-0.010, 0.043] μ=0.014 σ=0.014 grad AddmmBackward0}

In [12]:
print(batch_data.x)
print(batch_data.edge_index)
print(batch_data.atom_list)
print(batch_data.coords)
print(batch_data.residues)
print(batch_data.id)
print(batch_data.residue_id)
print(batch_data.residue_type)
print(batch_data.chains)
print(batch_data.graph_y)
print(batch_data.amino_acid_one_hot)
print(batch_data.batch)
print(batch_data.ptr)
print(batch_data.pos)
print(batch_data.edge_index)
print(batch_data.edge_type)
print(batch_data.edge_attr)

tensor[8177, 23] n=188071 (0.7Mb) x∈[0., 1.000] μ=0.043 σ=0.204
tensor[2, 130832] i64 n=261664 (2.0Mb) x∈[0, 8176] μ=4.088e+03 σ=2.360e+03
['N', 'CA', 'C', 'O', 'CB', 'OG', 'CG', 'CD1', 'CD2', 'CE1', 'CE2', 'CZ', 'OD1', 'ND2', 'CG1', 'CG2', 'CD', 'CE', 'NZ', 'OD2', 'OE1', 'NE2', 'OE2', 'OH', 'NE', 'NH1', 'NH2', 'OG1', 'SD', 'ND1', 'SG', 'NE1', 'CE3', 'CZ2', 'CZ3', 'CH2', 'OXT']
tensor[8177, 37, 3] n=907647 (3.5Mb) x∈[-97.240, 160.044] μ=3.037 σ=19.055
[['MET', 'PRO', 'VAL', 'ASN', 'LEU', 'LYS', 'GLY', 'ARG', 'SER', 'LEU', 'ASP', 'SER', 'LEU', 'LEU', 'ASN', 'PHE', 'THR', 'THR', 'GLU', 'GLU', 'VAL', 'GLN', 'HIS', 'LEU', 'ILE', 'ASP', 'LEU', 'SER', 'ILE', 'ASP', 'LEU', 'LYS', 'LYS', 'ALA', 'LYS', 'TYR', 'GLN', 'GLY', 'LEU', 'HIS', 'ILE', 'ASN', 'ASN', 'ARG', 'PRO', 'LEU', 'VAL', 'GLY', 'LYS', 'ASN', 'ILE', 'ALA', 'ILE', 'LEU', 'PHE', 'GLN', 'LYS', 'ASP', 'SER', 'THR', 'ARG', 'THR', 'ARG', 'CYS', 'ALA', 'PHE', 'GLU', 'VAL', 'ALA', 'ALA', 'SER', 'ASP', 'LEU', 'GLY', 'ALA', 'GLY', 'VAL', 'TH

In [10]:
# import torch

# # Assuming `dl` is your data loader and `model.featurise(i)` returns a batch with `.x`, `.pos`, and `.batch` attributes
# unique_values = set()

# for i in dl:
#     batch = model.featurise(i)
#     z, pos, _ = torch.squeeze(batch.x.long()), batch.pos, batch.batch
    
#     # Reshape z to a flat tensor to extract unique values
#     z_flat = z.view(-1)
#     unique_values.update(z_flat.unique().tolist())

# print(unique_values)


KeyboardInterrupt: 

afdb_swissprot_v4

ProteinBatch(fill_value=[32], atom_list=[32], coords=[12984, 37, 3], residues=[32], residue_id=[32], chains=[12984], residue_type=[12984], b_factor=[12984], id=[32], x=[12984], seq_pos=[12984, 1], batch=[12984], ptr=[33])



fold_fold

DataProteinBatch(fill_value=1e-05, atom_list=[37], coords=[6159, 37, 3], residues=[32], id=[32], residue_id=[32], residue_type=[6159], chains=[6159], graph_y=[32], x=[6159], amino_acid_one_hot=[6159, 23], seq_pos=[6159, 1], batch=[6159], ptr=[33])

No feature:

DataProteinBatch(fill_value=1e-05, atom_list=[37], coords=[4221, 37, 3], residues=[32], id=[32], residue_id=[32], residue_type=[4221], chains=[4221], graph_y=[32], x=[4221, 16], amino_acid_one_hot=[4221, 23], seq_pos=[4221, 1], batch=[4221], ptr=[33], pos=[4221, 3], edge_index=[2, 67536], edge_type=[1, 67536], num_relation=1, edge_attr=[67536, 1])

ca_base:

DataProteinBatch(fill_value=1e-05, atom_list=[37], coords=[5388, 37, 3], residues=[32], id=[32], residue_id=[32], residue_type=[5388], chains=[5388], graph_y=[32], x=[5388, 23], amino_acid_one_hot=[5388, 23], seq_pos=[5388, 1], batch=[5388], ptr=[33], pos=[5388, 3], edge_index=[2, 86208], edge_type=[1, 86208], num_relation=1, edge_attr=[86208, 1])
torch.Size([5388, 23]) torch.Size([5388, 3])

# train

In [13]:
import copy
import sys
from typing import List, Optional

import graphein
import hydra
import lightning as L
import lovely_tensors as lt
import torch
import torch.nn as nn
import torch_geometric
from graphein.protein.tensor.dataloader import ProteinDataLoader
from lightning.pytorch.callbacks import Callback
from lightning.pytorch.loggers import Logger
from loguru import logger as log
from omegaconf import DictConfig

from proteinworkshop import (
    constants,
    register_custom_omegaconf_resolvers,
    utils,
)
from proteinworkshop.configs import config
from proteinworkshop.models.base import BenchMarkModel

In [14]:
import pytorch_lightning as pl

In [15]:
from proteinworkshop.models.base import BenchMarkModel
from proteinworkshop.models.graph_encoders.pronet import ProNet
model: L.LightningModule = BenchMarkModel(cfg)
datamodule: L.LightningDataModule = hydra.utils.instantiate(
        cfg.dataset.datamodule
    )
callbacks: List[Callback] = utils.callbacks.instantiate_callbacks(
        cfg.get("callbacks")
    )
logger: List[Logger] = utils.loggers.instantiate_loggers(cfg.get("logger"))
trainer: L.Trainer = hydra.utils.instantiate(
        cfg.trainer, callbacks=callbacks, logger=logger
    )
# trainer.connectors.logger_connector.result.ResultCollection.extract_batch_size = 32
with torch.no_grad():
    datamodule.setup(stage="lazy_init")  # type: ignore
    batch = next(iter(datamodule.val_dataloader()))
    log.info(f"Unfeaturized batch: {batch}")
    batch = model.featurise(batch)
    log.info(f"Featurized batch: {batch}")
    log.info(f"Example labels: {model.get_labels(batch)}")
    # Check batch has required attributes
    for attr in model.encoder.required_batch_attributes:  # type: ignore
        if not hasattr(batch, attr):
            raise AttributeError(
                f"Batch {batch} does not have required attribute: {attr} ({model.encoder.required_batch_attributes})"
            )
    out = model(batch)
    log.info(f"Model output: {out}")
    del batch, out


Trainer already configured with model summary callbacks: [<class 'lightning.pytorch.callbacks.rich_model_summary.RichModelSummary'>]. Skipping setting a default `ModelSummary` callback.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


100%|██████████| 2558/2558 [00:00<00:00, 3183.97it/s]


100%|██████████| 2558/2558 [00:00<00:00, 3280.95it/s]


In [4]:
# dl = datamodule.train_dataloader()
# for i in dl:
#     batch = model.featurise(i)
    

In [4]:
if cfg.get("compile"):
    log.info("Compiling model!")
    model = torch_geometric.compile(model, dynamic=True)

: 

In [16]:

trainer.fit(
    model=model, datamodule=datamodule, ckpt_path=cfg.get("ckpt_path")
)

You are using a CUDA device ('NVIDIA GeForce RTX 4090') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision


100%|██████████| 29161/29161 [00:10<00:00, 2869.94it/s]


100%|██████████| 2558/2558 [00:00<00:00, 3095.63it/s]

Checkpoint directory /home/zhang/Projects/3d/ProteinWorkshop/notebooks/outputs/checkpoints exists and is not empty.

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]



You have overridden `on_after_batch_transfer` in `LightningModule` but have passed in a `LightningDataModule`. It will use the implementation from `LightningModule` instance.



Output()

RecursionError: maximum recursion depth exceeded while calling a Python object

: 