## Loading the model

In [1]:
import torch

In [27]:
from raygun.pretrained import raygun_2_2mil_800M
# 
raymodel = raygun_2_2mil_800M()
raymodel = raymodel.to(0)

In [2]:
from esm.pretrained import esm2_t33_650M_UR50D
model, alph = esm2_t33_650M_UR50D()
model       = model.to(0)

## Dataset and dataloader

In [8]:
from tqdm import tqdm
from torch.utils.data import DataLoader
from raygun.modelv2.loader import RaygunData
from Bio import SeqIO
path = "../data/fastas/human-mouse.sprot.fasta"
recs = list(SeqIO.parse(path, "fasta"))
recs = [r for r in recs if len(r.seq) < 1000 and len(r.seq) > 50]

## Selecting seqs with lengths < 1000

In [9]:
import random
from io import StringIO
recstr  = ""
for r in recs:
    recstr += f""">{r.id}
{str(r.seq)}
"""
recstream = StringIO(recstr)

In [10]:
preddata = RaygunData(recstream, alph, model, device = 0)
predloader = DataLoader(preddata, shuffle = True, 
                       batch_size = 1, collate_fn=preddata.collatefn) ## use the collatefn provided in RaygunData
len(preddata)

33048

## Running Raygun

In [11]:
true_seqs = []
pred_seqs = []
for tok, emb, mask, bat in tqdm(predloader, desc = "Running SPROT sequences"):
    """
    tok -> normal ESM-2 tokens. Shape [batch_size, no_sequences]
    emb -> ESM-2 embedding representation. Shape [batch_size, no_sequences, 1280]
    mask-> If batch_size > 1, mask gives the information about the length of the individual 
           sequence in the batch. Shape [batch_size, no_sequences], where `no_sequences`=max(length(seq_i)), i=1 to batch_size
           Note that the sequences are left-padded. 
    batch-> Actual sequence information. Is a tuple [(seq-name, seq), ...]
    """
    tok = tok.to(0)
    emb = emb.to(0)
    mask = mask.to(0)
    _, ts = zip(*bat)
    true_seqs += ts
    ## set `return_logits_and_seqs` to true for the model to return `generated-sequences`. 
    ## use `error_c` to determine the amount of noise to be added while generating.
    results = raymodel(emb, mask=mask, 
                       return_logits_and_seqs = True)
    pred_seqs += results["generated-sequences"]

Running SPROT sequences: 100%|██████████| 33048/33048 [1:50:50<00:00,  4.97it/s]  


## Getting sequence identity

In [15]:
import numpy as np
def compute_seq_id(s1, s2):
    return np.average([1 if x == y else 0 for x, y in zip(list(s1),
                                             list(s2))])

seqdata     = [(tr, orig, compute_seq_id(tr, orig)) for tr, orig in zip(true_seqs, pred_seqs)]
_,_, seqids = zip(*seqdata)
np.median(seqids)

0.9572649572649573

## Compute sequence identity on FPs

In [28]:
fpseqs   = f"../data/fastas/all_gfp_seqs.fasta"
fpdata   = RaygunData(fpseqs, alph, model, device = 0)
fploader = DataLoader(fpdata, shuffle = True, 
                       batch_size = 3, collate_fn=fpdata.collatefn) ## use the collatefn provided in RaygunData
len(fpdata)

933

In [35]:
true_fps = []
pred_fps = []
for tok, emb, mask, bat in tqdm(fploader, desc = "Running FP sequences"):
    tok        = tok.to(0)
    emb        = emb.to(0)
    mask       = mask.to(0)
    _, ts      = zip(*bat)
    true_fps  += ts
    results    = raymodel(emb, mask=mask, noise=0.,
                       return_logits_and_seqs = True)
    pred_fps += results["generated-sequences"]

Running FP sequences: 100%|██████████| 311/311 [01:44<00:00,  2.98it/s]


In [37]:
fpseqids    = [compute_seq_id(tr, orig) for tr, orig in zip(true_fps, pred_fps)]
np.median(fpseqids)

0.9957627118644068