## Load a checkpoint

In [1]:
import torch
from raygun.pretrained import raygun_2_2mil_800M
# 
raymodel = raygun_2_2mil_800M(return_lightning_module=True) ## must set return_lightning_module to True

In [2]:
from esm.pretrained import esm2_t33_650M_UR50D
model, alph = esm2_t33_650M_UR50D()
model       = model.to(0)

In [3]:
raymodel.to(0)

RaygunLightning(
  (model): Raygun(
    (encoder): RaygunEncoder(
      (encoders): ModuleList(
        (0-11): 12 x Block(
          (encoder): TransformerLayer(
            (self_attn): MultiheadAttention(
              (k_proj): Linear(in_features=1280, out_features=1280, bias=True)
              (v_proj): Linear(in_features=1280, out_features=1280, bias=True)
              (q_proj): Linear(in_features=1280, out_features=1280, bias=True)
              (out_proj): Linear(in_features=1280, out_features=1280, bias=True)
              (rot_emb): RotaryEmbedding()
            )
            (self_attn_layer_norm): ESM1LayerNorm()
            (fc1): Linear(in_features=1280, out_features=2560, bias=True)
            (fc2): Linear(in_features=2560, out_features=1280, bias=True)
            (final_layer_norm): ESM1LayerNorm()
          )
          (convblock): ConvBlock(
            (c1): ConvMasked(
              (conv): Conv1d(1280, 640, kernel_size=(7,), stride=(1,), padding=valid)
       

## Before finetuning

In [4]:
from raygun.modelv2.training import training
from Bio import SeqIO
import random
from io import StringIO
trainfasta = "../data/fastas/human-mouse.sprot.fasta"
recs       = list(SeqIO.parse(trainfasta, "fasta"))
recs       = [r for r in recs if len(r.seq) > 50 and len(r.seq) < 1000]
recs       = random.sample(recs, 100)

In [5]:
from raygun.modelv2.loader import RaygunData
from torch.utils.data import DataLoader
from tqdm import tqdm 
recseq = ""
for r in recs:
    recseq     += f""">{r.id}
{str(r.seq)}
"""
fastafile  = StringIO(recseq)
preddata   = RaygunData(fastafile, 
                        alph, model, device = 0)
## use the collatefn provided in RaygunData
predloader = DataLoader(preddata, shuffle = True, 
                       batch_size = 3, collate_fn=preddata.collatefn) 

In [6]:
true_seqs = []
pred_seqs = []
for tok, emb, mask, bat in tqdm(predloader, desc = "Before finetuning"):
    tok = tok.to(0)
    emb = emb.to(0)
    mask = mask.to(0)
    _, ts = zip(*bat)
    true_seqs += ts
    ## set `return_logits_and_seqs` to true for the model to return `generated-sequences`. 
    ## use `error_c` to determine the amount of noise to be added while generating.
    results = raymodel.model(emb, mask=mask, noise = 0., 
                       return_logits_and_seqs = True)
    pred_seqs += results["generated-sequences"]

Before finetuning: 100%|██████████| 34/34 [00:18<00:00,  1.84it/s]


In [7]:
import numpy as np
def compute_seq_id(s1, s2):
    return np.average([1 if x == y else 0 for x, y in zip(list(s1),
                                             list(s2))])
seqdata     = [(tr, orig, compute_seq_id(tr, orig)) for tr, orig in zip(true_seqs, pred_seqs)]
_,_, seqids = zip(*seqdata)
np.mean(seqids)

0.9148910088045968

## Perform finetuning

In [8]:
from raygun.modelv2.ltraygun import RaygunLightning
from torch.utils.data import DataLoader
from raygun.modelv2.loader import RaygunData
import lightning as L
from lightning.pytorch.callbacks import ModelCheckpoint
from pathlib import Path
import os
def training_(ltmodel, esmmodel, esmalphabet, 
             trainfasta, validfasta, outfld, 
             devices=1, clip=0.001, lr=1e-4, 
             epoch=5, batchsize=2, finetune=True,
             delete_checkpoint_after_loading = True):
    def is_notebook():
        """
        needed for ddp strategy
        """
        try:
            shell = get_ipython().__class__.__name__
            return shell == "ZMQInteractiveShell"
        except NameError:
            return False
    Path(outfld).mkdir(exist_ok=True)
    ltmodel.lr          = lr
    ltmodel.finetune    = finetune
    ## starting epoch
    ltmodel.epoch       = 0
    ltmodel.traininglog = f"{outfld}/traininglog.txt"
    ltmodel.log_wandb   = False
    
    ## train loaders
    traindata = RaygunData(fastafile    = trainfasta,
                           alphabet     = esmalphabet,
                           model        = esmmodel, 
                           device       = 0)
    trainloader = DataLoader(traindata, 
                             shuffle    = False, 
                             batch_size = batchsize,
                             collate_fn = traindata.collatefn)
    ## validation loaders
    validdata = RaygunData(fastafile    = validfasta,
                           alphabet     = esmalphabet,
                           model        = esmmodel,
                           device       = 0)
    validloader = DataLoader(validdata, 
                            shuffle    = False,
                            batch_size = batchsize, 
                            collate_fn = validdata.collatefn)
    
    chk_callback = ModelCheckpoint(
                        monitor           = "val_blosum_ratio",
                        mode              = "max",
                        save_top_k        = 1, 
                        save_weights_only = True, 
                        dirpath           = outfld,
                        filename          = "model-{epoch:02d}-{step:06d}-{val_blosum_ratio:.4f}",
                        save_on_train_epoch_end = True)

    if is_notebook():
        trainer = L.Trainer(accumulate_grad_batches = 2,
                            callbacks = [chk_callback],
                            accelerator             = "gpu", 
                            devices                 = 1, 
                            max_epochs              = epoch, 
                            gradient_clip_val       = clip,
                            gradient_clip_algorithm = "value")
    else:
        trainer = L.Trainer(accumulate_grad_batches = 2,
                            callbacks = [chk_callback],
                            accelerator             = "gpu", 
                            devices                 = devices, 
                            strategy                = "ddp",
                            max_epochs              = epoch, 
                            gradient_clip_val       = clip,
                            gradient_clip_algorithm = "value")
    trainer.fit(ltmodel.to(0), 
                trainloader, 
                validloader)
    chkptloc = [ckpt for ckpt in Path(outfld).iterdir() 
               if ckpt.suffix == ".ckpt"][0]
    
    new_checkpoint = torch.load(chkptloc, weights_only=True)["state_dict"]
    if delete_checkpoint_after_loading:
        os.remove(chkptloc)
    
    return new_checkpoint

In [9]:
new_checkpoint = training_(raymodel, model, alph,
                           StringIO(recseq), StringIO(recseq),
                           "finetuned-output")

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA A100-SXM4-80GB') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
/hpc/group/singhlab/user/kd312/minimamba/envs/molfeat/lib/python3.11/site-packages/lightning/pytorch/callbacks/model_checkpoint.py:654: Checkpoint directory /hpc/group/singhlab/user/kd312/projects/raygunv2/src/raygun-new-publication/raygun/notebooks/finetuned-output exists and is not empty.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
/hpc/group/singhlab/user/kd312/minimamba/envs/molfeat/lib/python3.11/site-packages/lightning/pytorch/core/optimizer.py:258: Found unsupported keys in the lr scheduler dict: {'freq'}. HINT:

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/hpc/group/singhlab/user/kd312/minimamba/envs/molfeat/lib/python3.11/site-packages/lightning/pytorch/utilities/data.py:78: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 2. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=5` reached.


## After finetuning

In [11]:
raymodel.load_state_dict(new_checkpoint)

<All keys matched successfully>

In [12]:
from tqdm import tqdm
from torch.utils.data import DataLoader
from raygun.modelv2.loader import RaygunData

fastafile  = StringIO(recseq)
preddata   = RaygunData(fastafile, 
                        alph, model, device = 0)
## use the collatefn provided in RaygunData
predloader = DataLoader(preddata, shuffle = True, 
                       batch_size = 3, collate_fn=preddata.collatefn) 

In [13]:
raymodel = raymodel.to(0)

In [14]:
true_seqs = []
pred_seqs = []
for tok, emb, mask, bat in tqdm(predloader, desc = "After finetuning"):
    tok = tok.to(0)
    emb = emb.to(0)
    mask = mask.to(0)
    _, ts = zip(*bat)
    true_seqs += ts
    ## set `return_logits_and_seqs` to true for the model to return `generated-sequences`. 
    ## use `error_c` to determine the amount of noise to be added while generating.
    results = raymodel.model(emb, mask=mask, noise = 0.1, 
                       return_logits_and_seqs = True)
    pred_seqs += results["generated-sequences"]

After finetuning: 100%|██████████| 34/34 [00:17<00:00,  1.96it/s]


In [17]:
import numpy as np
seqids = [compute_seq_id(tr, orig) for tr, orig in zip(true_seqs, pred_seqs)]
np.mean(seqids)

0.963639790354421

In [15]:
with open("../example-sh/ALL.fasta", "w") as of:
    of.write(recseq)