## Load the Raygun model and embedding

In [1]:
from raygun.pretrained import raygun_2_2mil_800M
# 
raymodel = raygun_2_2mil_800M()
raymodel = raymodel.to(0)

# esm-2 model
from esm.pretrained import esm2_t33_650M_UR50D
model, alph = esm2_t33_650M_UR50D()
model       = model.to(0)

## Initialize the dataloader

In [2]:
from tqdm import tqdm
from torch.utils.data import DataLoader
from raygun.modelv2.loader import RaygunData

preddata   = RaygunData("../data/fastas/egfp_mcherry.fasta", 
                        alph, model, device = 0)
predloader = DataLoader(preddata, shuffle = False, 
                       batch_size = 2, collate_fn=preddata.collatefn) ## use the collatefn provided in RaygunData
len(preddata)

2

## Suppose the goal is to convert EGFP and MCHERRY, two FPs of lengths > 235 to 225 for both

In [3]:
import torch
targetlength = torch.tensor([225,225], dtype = int).to(0)

In [4]:
toks, embs, mask, bat = next(iter(predloader))
toks = toks.to(0)
embs = embs.to(0)
mask = mask.to(0)

## Run Raygun with appropriate `noise` value

In [5]:
raymodel.eval()
results = raymodel(embs, mask = mask, noise = 0.05, return_logits_and_seqs = True, target_lengths = targetlength)

## get the resulting sequences

In [8]:
[(len(seq), seq) for seq in results["generated-sequences"]]

[(225,
  'MVSKGEEDNMAIIKEFMRFKVHMEGSVNGHEFEIEGEGEGRPYEGTQTAKLKVTKGGPLPFAWDILSPFYGGKYAVHPADPYLLKEFPEGFKWERVMNFEDGGVVTVTQDSSLQDGEFIYKVKLRGTNFPSDGPVMQKKTMGWEASRMYPEGALKREKQRLKKDGHYDAEVKTTYKAKKPVQLPGAYNVNIKLDITSHNEDYTIVEQYERAEGRHSTGGMWELYK'),
 (225,
  'MVSKGEELFTGVVPILVELDGDVNGHKFSVSGEGEGDATYGKLTLKFIWTTGKLPVPWPTLVTTLTYGQCFRRPHHKQHFFKSAPEGYQQRTIFKDDGNYKTRAEVKFEGDTLVNRIELKGIDFKEDGNILGHKLEYHYNSNIIMADKKKGIKIFKRHNIDGVVLDAHYQQNTPIGDGPVLLPDNHYLSTQSALSKDPNEKRDHMVLLEFVTAAGITLGMEEYYK')]