# `ProteinWorkshop` Tutorial, Part 1 - Training a New Model
![Models](../docs/source/_static/box_models.png)

## Train a new model using the `ProteinWorkshop`

In [1]:
%load_ext autoreload
%autoreload 2
# %load_ext blackcellmagic

Welcome to the tutorial series for the `ProteinWorkshop`! 

In the `ProteinWorkshop`, we implement numerous [featurisation](https://www.proteins.sh/configs/features) schemes, [datasets](https://www.proteins.sh/configs/dataset) for [self-supervised pre-training](https://proteins.sh/quickstart_component/pretrain.html) and [downstream evaluation](https://proteins.sh/quickstart_component/downstream.html), [pre-training](https://proteins.sh/configs/task) tasks, and [auxiliary tasks](https://proteins.sh/configs/task.html#auxiliary-tasks).

[Processed datasets](https://drive.google.com/drive/folders/18i8rLST6ZICTBu6Q67ClT0KqN9AHeqoW?usp=sharing) and [pre-trained weights](https://drive.google.com/drive/folders/1zK1r8FpmGaqV_QwUJuvDacwSL0RW-Vw9?usp=sharing) are made available. Downloading datasets is not required; upon first run all datasets will be downloaded and processed from their respective source.

The `ProteinWorkshop` encompasses several models as well as pre-trained weights for them so that you can readily use them.

In this tutorial, we show you how you can use what is already available in the protein workshop to train and use models for specific tasks. The `ProteinWorkshop` is structured as a very modular package; we will therefore talk about how to change the different parts of it, like the model, training task, dataset, featurization scheme, etc. in this tutorial. 

Besides using all the different options we provide, you can make use of the modular nature of the `ProteinWorkshop` to add your own models, datasets, featurization schemes, and training tasks. We will show you how to do this in the next tutorials.

To train a new model, you can follow the following 3-step procedure:

1. Choose the parts you want to consider: model, training task, dataset, featurization scheme and auxiliary tasks
2. Validate the designed training config
3. Use the designed config to train a new model

### 1. Choose the parts you want to consider: model, training task, dataset, featurization scheme and auxiliary tasks

You can switch out any of these for another available option by replacing the corresponding argument's value in `overrides`:

`cfg = hydra.compose("template", overrides=["encoder=schnet", "task=inverse_folding", "dataset=afdb_swissprot_v4", "features=ca_base", "+aux_task=none"], return_hydra_config=True)`

In [2]:
# Misc. tools
import os

# Hydra tools
import hydra

from hydra.compose import GlobalHydra
from hydra.core.hydra_config import HydraConfig

from proteinworkshop.constants import HYDRA_CONFIG_PATH
from proteinworkshop.utils.notebook import init_hydra_singleton

version_base = "1.2"  # Note: Need to update whenever Hydra is upgraded
init_hydra_singleton(reload=True, version_base=version_base)

path = HYDRA_CONFIG_PATH
rel_path = os.path.relpath(path, start=".")
# print(rel_path)
GlobalHydra.instance().clear()
hydra.initialize(rel_path, version_base=version_base)

cfg = hydra.compose(
    config_name="train",
    overrides=[
        "encoder=pronet",
        "encoder.level='aminoacid'",
        "encoder.num_blocks=4",
        "encoder.hidden_channels=128",
        "encoder.out_channels=1195",
        "encoder.mid_emb=64",
        "encoder.num_radial=6",
        "encoder.num_spherical=2",
        "encoder.cutoff=10.0",
        "encoder.max_num_neighbors=32",
        "encoder.int_emb_layers=3",
        "encoder.out_layers=2",
        "encoder.num_pos_emb=16",
        "encoder.dropout=0.3",
        "encoder.data_augment_eachlayer=True",
        "encoder.euler_noise=False",
        "encoder.pretraining=False",
        "encoder.node_embedding=False",

        "decoder.graph_label.dummy=True",

        "task=multiclass_graph_classification",
        "dataset=fold_family",
        "dataset.datamodule.batch_size=32",
        "features=ca_base", 
        "+aux_task=none",
        
        "trainer.max_epochs=400",
        "optimiser=adam",
        "optimiser.optimizer.lr=5e-4",
        "callbacks.early_stopping.patience=200",
        "test=True",
        "scheduler=steplr",
        "+ckpt_path=/home/zhang/Projects/3d/ProteinWorkshop/notebooks/outputs/checkpoints/last.ckpt",
        ## for test ONLY
        # "task_name=test",  # here
        # "ckpt_path_test=/home/zhang/Projects/3d/proteinworkshop_checkpoints/outputs_pronet_fold_400epochs/checkpoints/epoch_273.ckpt", # here
        # "optimizer.weight_decay=0.5"
    ],
    return_hydra_config=True,
)

# Note: Customize as needed e.g., when running a sweep
cfg.hydra.job.num = 0
cfg.hydra.job.id = 0
cfg.hydra.hydra_help.hydra_help = False
cfg.hydra.runtime.output_dir = "outputs"
HydraConfig.instance().set_config(cfg)

### 2. Validate the designed training config

This is not strictly necessary, but it is a good idea to validate the config before training. This will check that all the arguments you have provided are valid and that the config is complete.

In [3]:
from proteinworkshop.configs import config

cfg = config.validate_config(cfg)

In [4]:
print(cfg.keys())
for key in cfg.keys():
    print(key)
    print(cfg[key])

dict_keys(['hydra', 'env', 'dataset', 'features', 'encoder', 'decoder', 'transforms', 'callbacks', 'optimiser', 'scheduler', 'trainer', 'extras', 'metrics', 'task', 'logger', 'name', 'seed', 'num_workers', 'task_name', 'ckpt_path_test', 'test', 'aux_task', 'ckpt_path'])
hydra
env
{'paths': {'root_dir': '${oc.env:ROOT_DIR}', 'data': '${oc.env:DATA_PATH}', 'output_dir': '${hydra:runtime.output_dir}', 'work_dir': '${hydra:runtime.cwd}', 'log_dir': '${oc.env:RUNS_PATH}', 'runs': '${oc.env:RUNS_PATH}', 'run_dir': '${env.paths.runs}/${name}/${env.init_time}'}, 'python': {'version': '${python_version:micro}'}, 'init_time': '${now:%y-%m-%d_%H:%M:%S}'}
dataset
{'datamodule': {'_target_': 'proteinworkshop.datasets.fold_classification.FoldClassificationDataModule', 'path': '${env.paths.data}/FoldClassification/', 'split': 'family', 'batch_size': 32, 'pin_memory': True, 'num_workers': 4, 'dataset_fraction': 1.0, 'shuffle_labels': False, 'transforms': '${transforms}', 'overwrite': False, 'in_memory

### 3. Use the designed config to train a new model

Now with the config you have designed, you can train a new model. You can also use the `ProteinWorkshop` to evaluate the model on a downstream task.

In [5]:
# import torch
# import torch.nn as nn
# ckpt_path = '/home/zhang/Projects/3d/proteinworkshop_checkpoints/outputs_pronet_pretraining_best@2/checkpoints/epoch_002.ckpt'
# # Assuming `model` is your model and `encoder_weights` is the state_dict of pretrained weights
# print(torch.load(ckpt_path).keys())
# # Load the pretrained state_dict
# checkpoint = torch.load(ckpt_path)
# pretrained_dict = torch.load(ckpt_path)["state_dict"]

# # Create a new state_dict that excludes the final layer parameters
# filtered_dict = {k: v for k, v in pretrained_dict.items() if not k.startswith('encoder.lin_out.')}
# filtered_dict = {k: v for k, v in filtered_dict.items() if not k.startswith('encoder.lins_out.')}
# print(filtered_dict.keys())
# new_checkpoint = {
#     "epoch": checkpoint["epoch"],
#     "global_step": checkpoint["global_step"],
#     "pytorch-lightning_version": checkpoint["pytorch-lightning_version"],
#     "state_dict": filtered_dict,
#     "loops": checkpoint["loops"],
#     "callbacks": checkpoint["callbacks"],
#     "optimizer_states": checkpoint["optimizer_states"],
#     "lr_schedulers": checkpoint["lr_schedulers"],
#     "hparams_name": checkpoint["hparams_name"],
#     "hyper_parameters": checkpoint["hyper_parameters"]
# }

# torch.save(new_checkpoint, '/home/zhang/Projects/3d/proteinworkshop_checkpoints/outputs_pronet_pretraining_best@2/checkpoints/epoch_002_filtered.ckpt')

In [6]:
cfg.get("ckpt_path")

'/home/zhang/Projects/3d/ProteinWorkshop/notebooks/outputs/checkpoints/last.ckpt'

In [7]:
from proteinworkshop.finetune import train_model

train_model(cfg)

Seed set to 52


Trainer already configured with model summary callbacks: [<class 'lightning.pytorch.callbacks.rich_model_summary.RichModelSummary'>]. Skipping setting a default `ModelSummary` callback.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


100%|██████████| 736/736 [00:00<00:00, 2098.66it/s]


You are using a CUDA device ('NVIDIA GeForce RTX 4090') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision


100%|██████████| 12312/12312 [00:06<00:00, 2008.13it/s]



Checkpoint directory /home/zhang/Projects/3d/ProteinWorkshop/notebooks/outputs/checkpoints exists and is not empty.

Restoring states from the checkpoint path at /home/zhang/Projects/3d/ProteinWorkshop/notebooks/outputs/checkpoints/last.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]



You have overridden `on_after_batch_transfer` in `LightningModule` but have passed in a `LightningDataModule`. It will use the implementation from `LightningModule` instance.



Restored all states from the checkpoint at /home/zhang/Projects/3d/ProteinWorkshop/notebooks/outputs/checkpoints/last.ckpt


Output()

Metric train/loss/total improved by 0.000 >= min_delta = 0.0. New best score: 0.000
Epoch 194, global step 75075: 'val/graph_label/accuracy' was not in top 1


Epoch 195, global step 75460: 'val/graph_label/accuracy' was not in top 1


Epoch 196, global step 75845: 'val/graph_label/accuracy' was not in top 1


Metric train/loss/total improved by 0.000 >= min_delta = 0.0. New best score: 0.000
Epoch 197, global step 76230: 'val/graph_label/accuracy' was not in top 1


Epoch 198, global step 76615: 'val/graph_label/accuracy' was not in top 1


Epoch 199, global step 77000: 'val/graph_label/accuracy' was not in top 1


Epoch 200, global step 77385: 'val/graph_label/accuracy' was not in top 1


Epoch 201, global step 77770: 'val/graph_label/accuracy' was not in top 1


Epoch 202, global step 78155: 'val/graph_label/accuracy' was not in top 1


Epoch 203, global step 78540: 'val/graph_label/accuracy' was not in top 1


Epoch 204, global step 78925: 'val/graph_label/accuracy' was not in top 1


Epoch 205, global step 79310: 'val/graph_label/accuracy' was not in top 1


Epoch 206, global step 79695: 'val/graph_label/accuracy' was not in top 1


Epoch 207, global step 80080: 'val/graph_label/accuracy' was not in top 1


Epoch 208, global step 80465: 'val/graph_label/accuracy' was not in top 1


Epoch 209, global step 80850: 'val/graph_label/accuracy' was not in top 1


Epoch 210, global step 81235: 'val/graph_label/accuracy' was not in top 1


Epoch 211, global step 81620: 'val/graph_label/accuracy' was not in top 1


Epoch 212, global step 82005: 'val/graph_label/accuracy' was not in top 1


Epoch 213, global step 82390: 'val/graph_label/accuracy' was not in top 1


Epoch 214, global step 82775: 'val/graph_label/accuracy' was not in top 1


Epoch 215, global step 83160: 'val/graph_label/accuracy' was not in top 1


Epoch 216, global step 83545: 'val/graph_label/accuracy' was not in top 1


Epoch 217, global step 83930: 'val/graph_label/accuracy' was not in top 1


Epoch 218, global step 84315: 'val/graph_label/accuracy' was not in top 1


Epoch 219, global step 84700: 'val/graph_label/accuracy' was not in top 1


Epoch 220, global step 85085: 'val/graph_label/accuracy' was not in top 1


Epoch 221, global step 85470: 'val/graph_label/accuracy' was not in top 1


Metric val/graph_label/accuracy improved by 0.003 >= min_delta = 0.0. New best score: 0.596
Epoch 222, global step 85855: 'val/graph_label/accuracy' reached 0.59572 (best 0.59572), saving model to '/home/zhang/Projects/3d/ProteinWorkshop/notebooks/outputs/checkpoints/epoch_222.ckpt' as top 1


Epoch 223, global step 86240: 'val/graph_label/accuracy' was not in top 1


Epoch 224, global step 86625: 'val/graph_label/accuracy' was not in top 1


Epoch 225, global step 87010: 'val/graph_label/accuracy' was not in top 1


Epoch 226, global step 87395: 'val/graph_label/accuracy' was not in top 1


Epoch 227, global step 87780: 'val/graph_label/accuracy' was not in top 1


Epoch 228, global step 88165: 'val/graph_label/accuracy' was not in top 1


Epoch 229, global step 88550: 'val/graph_label/accuracy' was not in top 1


Epoch 230, global step 88935: 'val/graph_label/accuracy' was not in top 1


Epoch 231, global step 89320: 'val/graph_label/accuracy' was not in top 1


Epoch 232, global step 89705: 'val/graph_label/accuracy' was not in top 1


Epoch 233, global step 90090: 'val/graph_label/accuracy' was not in top 1


Epoch 234, global step 90475: 'val/graph_label/accuracy' was not in top 1


Epoch 235, global step 90860: 'val/graph_label/accuracy' was not in top 1


Metric train/loss/total improved by 0.000 >= min_delta = 0.0. New best score: 0.000
Epoch 236, global step 91245: 'val/graph_label/accuracy' was not in top 1


Epoch 237, global step 91630: 'val/graph_label/accuracy' was not in top 1


Epoch 238, global step 92015: 'val/graph_label/accuracy' was not in top 1


Epoch 239, global step 92400: 'val/graph_label/accuracy' was not in top 1


Epoch 240, global step 92785: 'val/graph_label/accuracy' was not in top 1


Epoch 241, global step 93170: 'val/graph_label/accuracy' was not in top 1


Epoch 242, global step 93555: 'val/graph_label/accuracy' was not in top 1


Epoch 243, global step 93940: 'val/graph_label/accuracy' was not in top 1


Epoch 244, global step 94325: 'val/graph_label/accuracy' was not in top 1


Epoch 245, global step 94710: 'val/graph_label/accuracy' was not in top 1


Epoch 246, global step 95095: 'val/graph_label/accuracy' was not in top 1


Epoch 247, global step 95480: 'val/graph_label/accuracy' was not in top 1


Epoch 248, global step 95865: 'val/graph_label/accuracy' was not in top 1


Epoch 249, global step 96250: 'val/graph_label/accuracy' was not in top 1


Epoch 250, global step 96635: 'val/graph_label/accuracy' was not in top 1


Epoch 251, global step 97020: 'val/graph_label/accuracy' was not in top 1


Epoch 252, global step 97405: 'val/graph_label/accuracy' was not in top 1


Epoch 253, global step 97790: 'val/graph_label/accuracy' was not in top 1


Epoch 254, global step 98175: 'val/graph_label/accuracy' was not in top 1


Epoch 255, global step 98560: 'val/graph_label/accuracy' was not in top 1


Epoch 256, global step 98945: 'val/graph_label/accuracy' was not in top 1


Epoch 257, global step 99330: 'val/graph_label/accuracy' was not in top 1


Epoch 258, global step 99715: 'val/graph_label/accuracy' was not in top 1


Epoch 259, global step 100100: 'val/graph_label/accuracy' was not in top 1


Epoch 260, global step 100485: 'val/graph_label/accuracy' was not in top 1


Epoch 261, global step 100870: 'val/graph_label/accuracy' was not in top 1


Epoch 262, global step 101255: 'val/graph_label/accuracy' was not in top 1


Epoch 263, global step 101640: 'val/graph_label/accuracy' was not in top 1


Epoch 264, global step 102025: 'val/graph_label/accuracy' was not in top 1


Epoch 265, global step 102410: 'val/graph_label/accuracy' was not in top 1


Epoch 266, global step 102795: 'val/graph_label/accuracy' was not in top 1


Metric val/graph_label/accuracy improved by 0.000 >= min_delta = 0.0. New best score: 0.596
Epoch 267, global step 103180: 'val/graph_label/accuracy' reached 0.59589 (best 0.59589), saving model to '/home/zhang/Projects/3d/ProteinWorkshop/notebooks/outputs/checkpoints/epoch_267.ckpt' as top 1


Epoch 268, global step 103565: 'val/graph_label/accuracy' was not in top 1


Epoch 269, global step 103950: 'val/graph_label/accuracy' was not in top 1


Epoch 270, global step 104335: 'val/graph_label/accuracy' was not in top 1


Epoch 271, global step 104720: 'val/graph_label/accuracy' was not in top 1


Epoch 272, global step 105105: 'val/graph_label/accuracy' was not in top 1


Epoch 273, global step 105490: 'val/graph_label/accuracy' was not in top 1


Epoch 274, global step 105875: 'val/graph_label/accuracy' was not in top 1


Metric val/graph_label/accuracy improved by 0.001 >= min_delta = 0.0. New best score: 0.596
Epoch 275, global step 106260: 'val/graph_label/accuracy' reached 0.59646 (best 0.59646), saving model to '/home/zhang/Projects/3d/ProteinWorkshop/notebooks/outputs/checkpoints/epoch_275.ckpt' as top 1


Epoch 276, global step 106645: 'val/graph_label/accuracy' was not in top 1


Epoch 277, global step 107030: 'val/graph_label/accuracy' was not in top 1


Epoch 278, global step 107415: 'val/graph_label/accuracy' was not in top 1


Epoch 279, global step 107800: 'val/graph_label/accuracy' was not in top 1


Epoch 280, global step 108185: 'val/graph_label/accuracy' was not in top 1


Epoch 281, global step 108570: 'val/graph_label/accuracy' was not in top 1


Epoch 282, global step 108955: 'val/graph_label/accuracy' was not in top 1


Epoch 283, global step 109340: 'val/graph_label/accuracy' was not in top 1


Epoch 284, global step 109725: 'val/graph_label/accuracy' was not in top 1


Epoch 285, global step 110110: 'val/graph_label/accuracy' was not in top 1


Epoch 286, global step 110495: 'val/graph_label/accuracy' was not in top 1


Epoch 287, global step 110880: 'val/graph_label/accuracy' was not in top 1


Epoch 288, global step 111265: 'val/graph_label/accuracy' was not in top 1


Epoch 289, global step 111650: 'val/graph_label/accuracy' was not in top 1


Epoch 290, global step 112035: 'val/graph_label/accuracy' was not in top 1


Epoch 291, global step 112420: 'val/graph_label/accuracy' was not in top 1


Epoch 292, global step 112805: 'val/graph_label/accuracy' was not in top 1


### 4. Wrapping up

Have any additional questions about using the components provided in the `ProteinWorkshop`? [Create a new issue](https://github.com/a-r-j/ProteinWorkshop/issues/new/choose) on our [GitHub repository](https://github.com/a-r-j/ProteinWorkshop). We would be happy to work with you to leverage the full power of the repository!

/home/yang/anaconda3/envs/3d/lib/python3.10/site-packages/lightning/pytorch/utilities/data.py