# `ProteinWorkshop` Tutorial, Part 1 - Training a New Model
![Models](../docs/source/_static/box_models.png)

## Train a new model using the `ProteinWorkshop`

In [1]:
%load_ext autoreload
%autoreload 2
# %load_ext blackcellmagic

Welcome to the tutorial series for the `ProteinWorkshop`! 

In the `ProteinWorkshop`, we implement numerous [featurisation](https://www.proteins.sh/configs/features) schemes, [datasets](https://www.proteins.sh/configs/dataset) for [self-supervised pre-training](https://proteins.sh/quickstart_component/pretrain.html) and [downstream evaluation](https://proteins.sh/quickstart_component/downstream.html), [pre-training](https://proteins.sh/configs/task) tasks, and [auxiliary tasks](https://proteins.sh/configs/task.html#auxiliary-tasks).

[Processed datasets](https://drive.google.com/drive/folders/18i8rLST6ZICTBu6Q67ClT0KqN9AHeqoW?usp=sharing) and [pre-trained weights](https://drive.google.com/drive/folders/1zK1r8FpmGaqV_QwUJuvDacwSL0RW-Vw9?usp=sharing) are made available. Downloading datasets is not required; upon first run all datasets will be downloaded and processed from their respective source.

The `ProteinWorkshop` encompasses several models as well as pre-trained weights for them so that you can readily use them.

In this tutorial, we show you how you can use what is already available in the protein workshop to train and use models for specific tasks. The `ProteinWorkshop` is structured as a very modular package; we will therefore talk about how to change the different parts of it, like the model, training task, dataset, featurization scheme, etc. in this tutorial. 

Besides using all the different options we provide, you can make use of the modular nature of the `ProteinWorkshop` to add your own models, datasets, featurization schemes, and training tasks. We will show you how to do this in the next tutorials.

To train a new model, you can follow the following 3-step procedure:

1. Choose the parts you want to consider: model, training task, dataset, featurization scheme and auxiliary tasks
2. Validate the designed training config
3. Use the designed config to train a new model

### 1. Choose the parts you want to consider: model, training task, dataset, featurization scheme and auxiliary tasks

You can switch out any of these for another available option by replacing the corresponding argument's value in `overrides`:

`cfg = hydra.compose("template", overrides=["encoder=schnet", "task=inverse_folding", "dataset=afdb_swissprot_v4", "features=ca_base", "+aux_task=none"], return_hydra_config=True)`

In [2]:
import os

# Set the environment variable to make only GPU 0 visible
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

# Now import the necessary librarie
import torch

# Verify GPUs visible to PyTorch
torch_visible_gpus = torch.cuda.device_count()
print("Visible GPUs for PyTorch:", torch_visible_gpus)

Visible GPUs for PyTorch: 1


In [3]:
# Misc. tools
import os

# Hydra tools
import hydra

from hydra.compose import GlobalHydra
from hydra.core.hydra_config import HydraConfig

from proteinworkshop.constants import HYDRA_CONFIG_PATH
from proteinworkshop.utils.notebook import init_hydra_singleton

version_base = "1.2"  # Note: Need to update whenever Hydra is upgraded
init_hydra_singleton(reload=True, version_base=version_base)

path = HYDRA_CONFIG_PATH
rel_path = os.path.relpath(path, start=".")
# print(rel_path)
GlobalHydra.instance().clear()
hydra.initialize(rel_path, version_base=version_base)

cfg = hydra.compose(
    config_name="train",
    overrides=[
        "encoder=schnet",
        "encoder.hidden_channels=512", # Number of channels in the hidden layers
        "encoder.out_dim=32", # Output dimension of the model
        "encoder.num_layers=6", # Number of filters used in convolutional layers
        "encoder.num_filters=128", # Number of convolutional layers in the model
        "encoder.num_gaussians=50", # Number of Gaussian functions used for radial filters
        "encoder.cutoff=10.0", # Cutoff distance for interactions
        "encoder.max_num_neighbors=32", # Maximum number of neighboring atoms to consider
        "encoder.readout=add", # Global pooling method to be used
        "encoder.dipole=False",
        "encoder.mean=null",
        "encoder.std=null",
        "encoder.atomref=null",
        "encoder.pretraining=False",

        # "decoder.graph_label.dummy=True",

        "task=multiclass_graph_classification",
        "dataset=fold_family",
        "dataset.datamodule.batch_size=32",
        "features=ca_bb", 
        "+aux_task=none",
        
        "trainer.max_epochs=150",
        "optimiser=adam",
        "optimiser.optimizer.lr=3e-4",
        "callbacks.early_stopping.patience=10",
        "test=True",
        "scheduler=plateau", # 5 epochs - 0.6 default
        # "+ckpt_path=/home/zhang/Projects/3d/ProteinWorkshop/notebooks/outputs/checkpoints/last.ckpt", # continue training
        ## for test ONLY
        # "task_name=test",  # here
        # "ckpt_path_test=/home/zhang/Projects/3d/proteinworkshop_checkpoints/outputs_pronet_fold_400epochs/checkpoints/epoch_273.ckpt", # here
        # "optimizer.weight_decay=0.5"
        "seed=52",
    ],
    return_hydra_config=True,
)

# Note: Customize as needed e.g., when running a sweep
cfg.hydra.job.num = 0
cfg.hydra.job.id = 0
cfg.hydra.hydra_help.hydra_help = False
cfg.hydra.runtime.output_dir = "outputs"
HydraConfig.instance().set_config(cfg)

### 2. Validate the designed training config

This is not strictly necessary, but it is a good idea to validate the config before training. This will check that all the arguments you have provided are valid and that the config is complete.

In [4]:
from proteinworkshop.configs import config

cfg = config.validate_config(cfg)

In [5]:
print(cfg.keys())
for key in cfg.keys():
    print(key)
    print(cfg[key])

dict_keys(['hydra', 'env', 'dataset', 'features', 'encoder', 'decoder', 'transforms', 'callbacks', 'optimiser', 'scheduler', 'trainer', 'extras', 'metrics', 'task', 'logger', 'hparams', 'name', 'seed', 'num_workers', 'task_name', 'ckpt_path_test', 'test', 'aux_task'])
hydra
env
{'paths': {'root_dir': '${oc.env:ROOT_DIR}', 'data': '${oc.env:DATA_PATH}', 'output_dir': '${hydra:runtime.output_dir}', 'work_dir': '${hydra:runtime.cwd}', 'log_dir': '${oc.env:RUNS_PATH}', 'runs': '${oc.env:RUNS_PATH}', 'run_dir': '${env.paths.runs}/${name}/${env.init_time}'}, 'python': {'version': '${python_version:micro}'}, 'init_time': '${now:%y-%m-%d_%H:%M:%S}'}
dataset
{'datamodule': {'_target_': 'proteinworkshop.datasets.fold_classification.FoldClassificationDataModule', 'path': '${env.paths.data}/FoldClassification/', 'split': 'family', 'batch_size': 32, 'pin_memory': True, 'num_workers': 4, 'dataset_fraction': 1.0, 'shuffle_labels': False, 'transforms': '${transforms}', 'overwrite': False, 'in_memory':

### 3. Use the designed config to train a new model

Now with the config you have designed, you can train a new model. You can also use the `ProteinWorkshop` to evaluate the model on a downstream task.

In [6]:
# import torch
# import torch.nn as nn
# ckpt_path = '/home/zhang/Projects/3d/proteinworkshop_checkpoints/outputs_pronet_pretraining_best@2/checkpoints/epoch_002.ckpt'
# # Assuming `model` is your model and `encoder_weights` is the state_dict of pretrained weights
# print(torch.load(ckpt_path).keys())
# # Load the pretrained state_dict
# checkpoint = torch.load(ckpt_path)
# pretrained_dict = torch.load(ckpt_path)["state_dict"]

# # Create a new state_dict that excludes the final layer parameters
# filtered_dict = {k: v for k, v in pretrained_dict.items() if not k.startswith('encoder.lin_out.')}
# filtered_dict = {k: v for k, v in filtered_dict.items() if not k.startswith('encoder.lins_out.')}
# print(filtered_dict.keys())
# new_checkpoint = {
#     "epoch": checkpoint["epoch"],
#     "global_step": checkpoint["global_step"],
#     "pytorch-lightning_version": checkpoint["pytorch-lightning_version"],
#     "state_dict": filtered_dict,
#     "loops": checkpoint["loops"],
#     "callbacks": checkpoint["callbacks"],
#     "optimizer_states": checkpoint["optimizer_states"],
#     "lr_schedulers": checkpoint["lr_schedulers"],
#     "hparams_name": checkpoint["hparams_name"],
#     "hyper_parameters": checkpoint["hyper_parameters"]
# }

# torch.save(new_checkpoint, '/home/zhang/Projects/3d/proteinworkshop_checkpoints/outputs_pronet_pretraining_best@2/checkpoints/epoch_002_filtered.ckpt')

In [7]:
cfg.get("ckpt_path")

In [8]:
from proteinworkshop.finetune import train_model

train_model(cfg)

Seed set to 52


Trainer already configured with model summary callbacks: [<class 'lightning.pytorch.callbacks.rich_model_summary.RichModelSummary'>]. Skipping setting a default `ModelSummary` callback.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs



Lazy modules are a new feature under heavy development so changes to the API or functionality can happen at any moment.



100%|██████████| 736/736 [00:00<00:00, 4869.51it/s]




You are using a CUDA device ('NVIDIA GeForce RTX 4090') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision


100%|██████████| 12312/12312 [00:02<00:00, 4109.13it/s]



Checkpoint directory /home/zhang/Projects/3d/ProteinWorkshop/notebooks/outputs/checkpoints exists and is not empty.

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]



You have overridden `on_after_batch_transfer` in `LightningModule` but have passed in a `LightningDataModule`. It will use the implementation from `LightningModule` instance.



Output()

### 4. Wrapping up

Have any additional questions about using the components provided in the `ProteinWorkshop`? [Create a new issue](https://github.com/a-r-j/ProteinWorkshop/issues/new/choose) on our [GitHub repository](https://github.com/a-r-j/ProteinWorkshop). We would be happy to work with you to leverage the full power of the repository!

/home/yang/anaconda3/envs/3d/lib/python3.10/site-packages/lightning/pytorch/utilities/data.py