## Loading the model

In [1]:
from raygun.pretrained import raygun_2_2mil_800M

In [2]:
raymodel = raygun_2_2mil_800M()
raymodel = raymodel.to(0)

Using cache found in /hpc/home/kd312/.cache/torch/hub/rohitsinghlab_raygun_main


## Loading the ESM-2 model for generating the initial embedding

In [3]:
from esm.pretrained import esm2_t33_650M_UR50D
model, alph = esm2_t33_650M_UR50D()
model       = model.to(0)

## Loading the dataset/dataloader

### Used the 933 Fluorescent proteins for demonstration

In [4]:
from tqdm import tqdm
from torch.utils.data import DataLoader
from raygun.modelv2.loader import RaygunData

preddata   = RaygunData("../data/fastas/all_gfp_seqs.fasta", 
                        alph, model, device = 0)
predloader = DataLoader(preddata, shuffle = True, 
                       batch_size = 3, collate_fn=preddata.collatefn) ## use the collatefn provided in RaygunData
len(preddata)

933

## Running the Raygun model to reconstruct the ESM-2 embeddings 

In [5]:
true_seqs = []
pred_seqs = []
for tok, emb, mask, bat in tqdm(predloader, desc = "Running FP sequences"):
    """
    tok -> normal ESM-2 tokens. Shape [batch_size, no_sequences]
    emb -> ESM-2 embedding representation. Shape [batch_size, no_sequences, 1280]
    mask-> If batch_size > 1, mask gives the information about the length of the individual 
           sequence in the batch. Shape [batch_size, no_sequences], where `no_sequences`=max(length(seq_i)), i=1 to batch_size
           Note that the sequences are left-padded. 
    batch-> Actual sequence information. Is a tuple [(seq-name, seq), ...]
    """
    tok        = tok.to(0)
    emb        = emb.to(0)
    mask       = mask.to(0)
    _, ts      = zip(*bat)
    true_seqs += ts
    ## set `return_logits_and_seqs` to true for the model to return `generated-sequences`. 
    ## use `error_c` to determine the amount of noise to be added while generating.
    results = raymodel(emb, mask=mask, noise = 0.1, 
                       return_logits_and_seqs = True)
    pred_seqs += results["generated-sequences"]

Running FP sequences: 100%|██████████| 311/311 [01:42<00:00,  3.02it/s]


## Compute the sequence identity of the reconstructed-sequences, given 

In [7]:
import numpy as np
def compute_seq_id(tr, orig):
    sids = [1 if t==o else 0 for t, o in 
            zip(list(tr), list(orig))]
    return np.mean(sids)

seqids = [compute_seq_id(tr, orig) for tr, orig in zip(true_seqs, pred_seqs)]
np.median(seqids)

0.9957627118644068