In [1]:
%load_ext autoreload
%autoreload 2

import torch
import esm

from lib.data.datasets.GB1 import get_GB1_dataset
from lib.utils.file import save_pt_file

In [2]:
import itertools
from lib.data.datasets.GB1 import prepare_sequences

In [3]:
symbols = 'ARNDCQEGHILKMFPSTWYV'
variants = [''.join(x) for x in itertools.product(symbols, repeat=4)]

In [16]:
phoq_wt_sequence = "MKKLLRLFFPLSLRVRFLLATAAVVLVLSLAYGMVALIGYSVSFDKTTFRLLRGESNLFYTLAKWENNKLHVELPENIDKQSPTMTLIYDENGQLLWAQRDVPWLMKMIQPDWLKSNGFHEIEADVNDTSLLLSGDHSIQQQLQEVREDDDDAEMTHSVAVNVYPATSRMPKLTIVVVDTIPVELKSSYMVWSWFIYVLSANLLLVIPLLWVAAWWSLRPIEALAKEVRELEEHNRELLNPATTRELTSLVRNLNRLLKSERERYDKYRTTLTDLTHSLKTPLAVLQSTLRSLRSEKMSVSDAEPVMLEQISRISQQIGYYLHRASMRGGTLLSRELHPVAPLLDNLTSALNKVYQRKGVNISLDISPEISFVGEQNDFVEVMGNVLDNACKYCLEFVEISARQTDEHLYIVVEDDGPGIPLSKREVIFDRGQRVDTLRPGQGVGLAVAREITEQYEGKIVAGESMLGGARMEVIFGRQHSAPKDE"
phoq_mutation_positions = [283, 284, 287, 288] # [284, 285, 288, 289] - 1 for indexing

In [4]:
sequences = prepare_sequences(variants, phoq_wt_sequence, phoq_mutation_positions)

In [2]:
sequences = ["MQYKLILNGKTLKGETTTEAVDAATAEKVFKQYANDNG<mask><mask><mask>EWTYDDATKTFT<mask>TE"]
variants = ["<mask><mask><mask><mask>"]

In [7]:
sequences = ['A', 'R', 'N', 'D', 'C', 'Q', 'E', 'G', 'H', 'I', 'L', 'K', 'M', 'F', 'P', 'S', 'T', 'W', 'Y', 'V']
variants = sequences

In [23]:
# checkpoints are downloaded to "C:\Users\Matouš\.cache\torch\hub\checkpoints\"
#esm2_t6_8M_UR50D #esm2_t33_650M_UR50D #esm1b_t33_650M_UR50S
model, alphabet = esm.pretrained.esm1b_t33_650M_UR50S()

batch_converter = alphabet.get_batch_converter()
model.eval()  # disables dropout for deterministic results

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

In [6]:
tokenized_sequences = torch.load("./../../data/GB1/esm-1b_sequences_complete.pt")
variants = torch.load("./../../data/GB1/esm-1b_variants_complete.pt")

In [17]:
# Load PhoQ dataset
sequences, fitness, variants = get_GB1_dataset(
    return_variants=True,
    shuffle=False,
    file_path="../../../../../data/PhoQ/PhoQ.xlsx",
    wt_sequence=phoq_wt_sequence,
    mutation_positions=phoq_mutation_positions,
)

In [21]:
variants

array(['WIPY', 'SSGD', 'FGGK', ..., 'NYWS', 'NYWP', 'NYWQ'], dtype=object)

In [24]:
data = []
for s in range(len(sequences)):
    data.append((variants[s], sequences[s]))
variants, sequences, tokenized_sequences = batch_converter(data)

#tokenized_sequences = tokenized_sequences.to(device)

In [8]:
save_pt_file(
    variants,
    save_to="./../../data/PhoQ/esm-1b_variants.pt",
    absolute_path=False,
)
save_pt_file(
    tokenized_sequences,
    save_to="./../../data/PhoQ/esm-1b_sequences.pt",
    absolute_path=False,
)

saving to file "s:\Documents\master\code\llm./../../data/PhoQ/esm-1b_variants_complete.pt"
saving to file "s:\Documents\master\code\llm./../../data/PhoQ/esm-1b_sequences_complete.pt"


In [62]:
batch_converter([(None, "A C"), (None, "A<mask>C"), (None, "ABC")])

([None, None, None],
 ['A C', 'A<mask>C', 'ABC'],
 tensor([[ 0,  5, 23,  2,  1],
         [ 0,  5, 32, 23,  2],
         [ 0,  5, 25, 23,  2]]))

In [12]:
with torch.no_grad():
    output = model(tokenized_sequences[0:100], repr_layers=[33])

In [None]:
# Extract sequence embeddings
SEQUENCE_LENGTH = len(tokenized_sequences[0])
EMBEDDING_LENGTH = 1280

embeddings = torch.empty(
    (len(tokenized_sequences), EMBEDDING_LENGTH), dtype=torch.float32
)
tokenized_sequences = tokenized_sequences.to(device)
with torch.no_grad():
    for s in range(141662, len(tokenized_sequences)):
        if s % 1000 == 0:
            print(s)
        results = model(
            tokenized_sequences[s : s + 1, :], repr_layers=[33], return_contacts=True
        )
        representation = results["representations"][33]
        representation = representation[0,1:SEQUENCE_LENGTH-1].mean(0)
        embeddings[s, :] = representation

embeddings.size()

In [35]:
save_pt_file(
    embeddings,
    save_to="./../../data/PhoQ/esm-1b_embedding_complete.pt",
    absolute_path=False,
)

saving to file "s:\Documents\master\code\llm./../../data/PhoQ/esm-1b_embedding_complete.pt"


In [None]:
import pandas as pd

#fitness_norm = torch.load("./../../data/GB1/progen2_fitness_norm.pt")

data = {
    "Variants": variants,
    "Fitness": fitness,
    "Fitness_norm": fitness_norm,
    "Sequences": [seq for seq in tokenized_sequences.to("cpu")],
    "Embedding": [emb for emb in embeddings.to("cpu")],
}
df = pd.DataFrame(data)

In [None]:
save_pt_file(
    df,
    save_to="./../../data/PhoQ/esm-1b_dataframe.pt",
    absolute_path=False,
)

saving to file "s:\Documents\master\code\llm./../../data/GB1/esm-1b_dataframe.pt"


Normalized Fitness

In [25]:
import pandas

df = pandas.read_excel("S:\Documents\master\data\PhoQ\PhoQ.xlsx")
fitness = torch.from_numpy(df["Fitness"].values)

In [26]:
min_fitness, _ = torch.min(fitness, dim=0)
fitness_norm = (fitness - min_fitness)
max_fitness, _ = torch.max(fitness, dim=0)
fitness_norm = fitness_norm / max_fitness

In [27]:
save_pt_file(
    fitness,
    save_to="./../../data/PhoQ/esm-1b_fitness.pt",
    absolute_path=False,
)
save_pt_file(
    fitness_norm,
    save_to="./../../data/PhoQ/esm-1b_fitness_norm.pt",
    absolute_path=False,
)

saving to file "s:\Documents\master\code\llm./../../data/PhoQ/esm-1b_fitness.pt"
saving to file "s:\Documents\master\code\llm./../../data/PhoQ/esm-1b_fitness_norm.pt"


Sandbox

In [7]:
dataset = "PhoQ" # GB1, PhoQ
embedding = torch.load("./../../data/" + dataset + "/esm-1b_embedding_complete.pt")
fitness = torch.load("./../../data/" + dataset + "/esm-1b_fitness.pt")
fitness_norm = torch.load("./../../data/" + dataset + "/esm-1b_fitness_norm.pt")
sequences = torch.load("./../../data/" + dataset + "/esm-1b_sequences_complete.pt")
variants = torch.load("./../../data/" + dataset + "/esm-1b_variants_complete.pt")

In [8]:
import pandas

In [9]:
fitness_df = pandas.DataFrame(fitness.cpu().numpy())
fitness_df.to_csv("./../../data/" + dataset + "/esm-1b_fitness.csv", index=False)
fitness_norm_df = pandas.DataFrame(fitness_norm.cpu().numpy())
fitness_norm_df.to_csv("./../../data/" + dataset + "/esm-1b_fitness_norm.csv", index=False)

In [11]:
embedding_df = pandas.DataFrame(embedding.cpu().numpy())
embedding_df.to_csv("./../../data/" + dataset + "/esm-1b_embedding_complete.csv", index=False)

In [12]:
sequences_df = pandas.DataFrame(sequences.cpu().numpy())
sequences_df.to_csv("./../../data/" + dataset + "/esm-1b_sequences_complete.csv", index=False)

In [20]:
variants_df = pandas.DataFrame(variants)
variants_df.to_csv("./../../data/" + dataset + "/esm-1b_variants.csv", index=False)

In [14]:
loaded_df = pandas.read_csv("./../../data/" + dataset + "/esm-1b_sequences_complete.csv")
loaded_df

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,478,479,480,481,482,483,484,485,486,487
0,0,20,15,15,4,4,10,4,18,18,...,10,16,21,8,5,14,15,13,9,2
1,0,20,15,15,4,4,10,4,18,18,...,10,16,21,8,5,14,15,13,9,2
2,0,20,15,15,4,4,10,4,18,18,...,10,16,21,8,5,14,15,13,9,2
3,0,20,15,15,4,4,10,4,18,18,...,10,16,21,8,5,14,15,13,9,2
4,0,20,15,15,4,4,10,4,18,18,...,10,16,21,8,5,14,15,13,9,2
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
159995,0,20,15,15,4,4,10,4,18,18,...,10,16,21,8,5,14,15,13,9,2
159996,0,20,15,15,4,4,10,4,18,18,...,10,16,21,8,5,14,15,13,9,2
159997,0,20,15,15,4,4,10,4,18,18,...,10,16,21,8,5,14,15,13,9,2
159998,0,20,15,15,4,4,10,4,18,18,...,10,16,21,8,5,14,15,13,9,2


In [3]:
df = torch.load("./../../data/GB1/esm-1b_dataframe.pt")

In [7]:
embedding.size()

torch.Size([149361, 1280])

In [51]:
df.loc[df["Variants"] == "ADGV"]

Unnamed: 0,Variants,Fitness,Fitness_norm,Sequences,Embedding
122551,ADGV,0.06191,0.007066,"[tensor(0), tensor(20), tensor(16), tensor(19)...","[tensor(-0.0198), tensor(0.1364), tensor(-0.01..."
