In [1]:
import os

if not "experiments" in os.getcwd():
    %cd -q ../../project/parallel_synthesis/experiments

In [2]:
from copy import copy

import torch
from protera_stability.config.lazy import LazyCall as L
from protera_stability.config.common.mlp import mlp_esm

In [3]:
from protera_stability.train import get_cfg, setup_diversity, setup_data, DefaultTrainer

## Setup experiement, data and training cfg

In [4]:
exp_params = {
    "diversity_cutoff": 0.866,
    "random_percent": 0.15,
    "sampling_method": "diversity",
    "experiment_name": "example",
}

cfg = get_cfg(args={})
cfg = setup_diversity(cfg, **exp_params)
mlp_esm.n_units = 2048
mlp_esm.act = L(torch.nn.GELU)()
cfg.model = mlp_esm
cfg.model

{'n_in': 1280, 'n_units': 2048, 'n_layers': 3, 'act': {'_target_': <class 'torch.nn.modules.activation.GELU'>}, 'drop_p': 0.7, 'last_drop': False, '_target_': <class 'protera_stability.models.ProteinMLP'>}

In [5]:
cfg = setup_data(cfg)

In [6]:
cfg.keys()

dict_keys(['trainer_params', 'output_dir', 'random_split', 'experiment', 'model', 'dataloader'])

## Run Training

### Add specific callbacks

In [7]:
from pytorch_lightning.callbacks.early_stopping import EarlyStopping

stop_r2_reached = L(EarlyStopping)(
    monitor="valid/r2",
    patience=1,
    check_on_train_epoch_end=False,
    stopping_threshold=0.72,
    mode="max",
)
cfg.trainer_params["callbacks"] = [stop_r2_reached]

### Build Trainer, Model, Optimizer, Scheduler and Lightning Modules

In [8]:
from protera_stability.engine.default import DefaultTrainer

trainer = DefaultTrainer(cfg)
train_dl = trainer.data_module.train_dataloader()
print(f"=== USING {cfg.experiment.sampling_method} as Sampling Method ===")
print(f"=== USING {len(train_dl.sampler)} out of {len(train_dl.dataset)} samples ===")

if cfg.experiment.sampling_method == "diversity":
    print(f"=== SIZE WAS DETERMINED BY {train_dl.sampler.stopped_by} ===")

elif cfg.experiment.sampling_method == "random":
    print(
        f"=== SIZE WAS DETERMINED BY RANDOM PERCENT OF {cfg.experiment.random_percent} ==="
    )

  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
  rank_zero_warn(f"No correct seed found, seed set to {seed}")
Global seed set to 4059305300


=== USING diversity as Sampling Method ===
=== USING 6123 out of 8204 samples ===
=== SIZE WAS DETERMINED BY CUTOFF ===


In [9]:
cfg["experiment"]

{'sampling_method': 'diversity', 'diversity_cutoff': 0.866, 'random_percent': 0.15, 'random_split': 0.8, 'name': 'example_all-data'}

In [10]:
trainer.fit()

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]

  | Name     | Type       | Params
----------------------------------------
0 | model    | ProteinMLP | 4.7 M 
1 | train_r2 | R2Score    | 0     
2 | valid_r2 | R2Score    | 0     
3 | test_r2  | R2Score    | 0     
----------------------------------------
4.7 M     Trainable params
0         Non-trainable params
4.7 M     Total params
18.891    Total estimated model params size (MB)


                                                                      

Global seed set to 4059305300




Global seed set to 4059305300


Epoch 11: 100%|██████████| 28/28 [00:01<00:00, 19.35it/s, loss=0.245, v_num=48, train/r2=0.727, train/loss=0.254, valid/r2=0.725, valid/loss=0.289]


<protera_stability.trainer.default.DefaultTrainer at 0x7fac86fba760>

### Run as fn

In [11]:
# cfg, trainer_dict = do_train(cfg)

In [12]:
# cfg, trainer_dict = do_test(cfg, trainer_dict)