In [1]:
import torch
import esm

from lib.data.datasets.GB1 import get_GB1_dataset
from lib.utils.file import save_pt_file

In [3]:
import itertools
from lib.data.datasets.GB1 import prepare_sequences

In [12]:
symbols = 'ARNDCQEGHILKMFPSTWYV'
variants = [''.join(x) for x in itertools.product(symbols, repeat=4)]

In [13]:
sequences = prepare_sequences(variants)

In [2]:
sequences = ["MQYKLILNGKTLKGETTTEAVDAATAEKVFKQYANDNG<mask><mask><mask>EWTYDDATKTFT<mask>TE"]
variants = ["<mask><mask><mask><mask>"]

In [7]:
sequences = ['A', 'R', 'N', 'D', 'C', 'Q', 'E', 'G', 'H', 'I', 'L', 'K', 'M', 'F', 'P', 'S', 'T', 'W', 'Y', 'V']
variants = sequences

In [2]:
# checkpoints are downloaded to "C:\Users\Matouš\.cache\torch\hub\checkpoints\"
#esm2_t6_8M_UR50D #esm2_t33_650M_UR50D #esm1b_t33_650M_UR50S
model, alphabet = esm.pretrained.esm1b_t33_650M_UR50S()

batch_converter = alphabet.get_batch_converter()
model.eval()  # disables dropout for deterministic results

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

In [6]:
tokenized_sequences = torch.load("./../../data/GB1/esm-1b_sequences_complete.pt")
variants = torch.load("./../../data/GB1/esm-1b_variants_complete.pt")

In [3]:
sequences, fitness, variants = get_GB1_dataset(
    return_variants=True,
    shuffle=False,
)

  warn(msg)


In [66]:
data = []
for s in range(len(sequences)):
    data.append((variants[s], sequences[s]))
variants, sequences, tokenized_sequences = batch_converter(data)

#tokenized_sequences = tokenized_sequences.to(device)

In [62]:
batch_converter([(None, "A C"), (None, "A<mask>C"), (None, "ABC")])

([None, None, None],
 ['A C', 'A<mask>C', 'ABC'],
 tensor([[ 0,  5, 23,  2,  1],
         [ 0,  5, 32, 23,  2],
         [ 0,  5, 25, 23,  2]]))

In [12]:
with torch.no_grad():
    output = model(tokenized_sequences[0:100], repr_layers=[33])

In [8]:
# Extract sequence embeddings
SEQUENCE_LENGTH = len(tokenized_sequences[0])
EMBEDDING_LENGTH = 1280

embeddings = torch.empty(
    (len(tokenized_sequences), EMBEDDING_LENGTH), dtype=torch.float32
)
with torch.no_grad():
    for s in range(len(tokenized_sequences)):
        results = model(
            tokenized_sequences[s : s + 1, :], repr_layers=[33], return_contacts=True
        )
        representation = results["representations"][33]
        representation = representation[0,1:SEQUENCE_LENGTH-1].mean(0)
        embeddings[s, :] = representation

embeddings.size()

torch.Size([160000, 1280])

In [9]:
save_pt_file(
    embeddings,
    save_to="./../../data/GB1/esm-1b_embedding_complete.pt",
    absolute_path=False,
)

saving to file "s:\Documents\master\code\llm./../../data/GB1/esm-1b_embedding_complete.pt"


In [14]:
save_pt_file(
    embeddings,
    save_to="./../../data/GB1/esm-1b_embedding_masked.pt",
    absolute_path=False,
)

saving to file "s:\Documents\master\code\llm./../../data/GB1/esm-1b_embedding_masked.pt"


In [10]:
save_pt_file(
    tokenized_sequences,
    save_to="./../../data/GB1/esm-1b_sequences_masked.pt",
    absolute_path=False,
)

saving to file "s:\Documents\master\code\llm./../../data/esm-1b_amino_acids.pt"


In [5]:
import pandas as pd

#fitness_norm = torch.load("./../../data/GB1/progen2_fitness_norm.pt")

data = {
    "Variants": variants,
    "Fitness": fitness,
    "Fitness_norm": fitness_norm,
    "Sequences": [seq for seq in tokenized_sequences.to("cpu")],
    "Embedding": [emb for emb in embeddings.to("cpu")],
}
df = pd.DataFrame(data)

In [6]:
save_pt_file(
    df,
    save_to="./../../data/GB1/esm-1b_dataframe.pt",
    absolute_path=False,
)

saving to file "s:\Documents\master\code\llm./../../data/GB1/esm-1b_dataframe.pt"


Normalized Fitness

In [42]:
fitness = torch.from_numpy(df["Fitness"].values)

In [8]:
min_fitness, _ = torch.min(fitness, dim=0)
fitness_norm = (fitness - min_fitness)
max_fitness, _ = torch.max(fitness, dim=0)
fitness_norm = fitness_norm / max_fitness

In [9]:
save_pt_file(
    fitness,
    save_to="./../../data/GB1/esm-1b_fitness.pt",
    absolute_path=False,
)
save_pt_file(
    fitness_norm,
    save_to="./../../data/GB1/esm-1b_fitness_norm.pt",
    absolute_path=False,
)

saving to file "s:\Documents\master\code\llm./../../data/GB1/esm-1b_fitness.pt"
saving to file "s:\Documents\master\code\llm./../../data/GB1/esm-1b_fitness_norm.pt"


In [26]:
variants = df["Variants"].values

In [12]:
save_pt_file(
    variants,
    save_to="./../../data/GB1/esm-1b_variants_complete.pt",
    absolute_path=False,
)

saving to file "s:\Documents\master\code\llm./../../data/amino_acids.pt"


In [48]:
df["Fitness_norm"] = fitness_norm

Sandbox

In [2]:
embedding = torch.load("./../../data/GB1/esm-1b_embedding.pt")
fitness = torch.load("./../../data/GB1/esm-1b_fitness.pt")
fitness_norm = torch.load("./../../data/GB1/esm-1b_fitness_norm.pt")
sequences = torch.load("./../../data/GB1/esm-1b_sequences.pt")
variants = torch.load("./../../data/GB1/esm-1b_variants.pt")

In [3]:
fitness = torch.load("./../../data/GB1/esm-1b_fitness.pt")
fitness_norm = torch.load("./../../data/GB1/esm-1b_fitness_norm.pt")

In [2]:
import pandas

In [None]:
fitness_df = pandas.DataFrame(fitness.cpu().numpy())
fitness_df.to_csv("./../../data/esm-1b_fitness.csv", index=False)
fitness_norm_df = pandas.DataFrame(fitness_norm.cpu().numpy())
fitness_norm_df.to_csv("./../../data/esm-1b_fitness_norm.csv", index=False)

In [12]:
embedding_df = pandas.DataFrame(embeddings.cpu().numpy())
embedding_df.to_csv("./../../data/esm-1b_embedding_complete.csv", index=False)

In [18]:
sequences_df = pandas.DataFrame(tokenized_sequences.cpu().numpy())
sequences_df.to_csv("./../../data/esm-1b_amino_acids.csv", index=False)

In [15]:
variants_df = pandas.DataFrame(variants)
variants_df.to_csv("./../../data/amino_acids.csv", index=False)

In [31]:
loaded_df = pandas.read_csv("./../../data/GB1/esm-1b_sequences.csv")
loaded_df

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,48,49,50,51,52,53,54,55,56,57
0,0,20,16,19,15,4,12,4,17,6,...,5,11,15,11,18,11,5,11,9,2
1,0,20,16,19,15,4,12,4,17,6,...,5,11,15,11,18,11,10,11,9,2
2,0,20,16,19,15,4,12,4,17,6,...,5,11,15,11,18,11,17,11,9,2
3,0,20,16,19,15,4,12,4,17,6,...,5,11,15,11,18,11,13,11,9,2
4,0,20,16,19,15,4,12,4,17,6,...,5,11,15,11,18,11,23,11,9,2
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
159995,0,20,16,19,15,4,12,4,17,6,...,5,11,15,11,18,11,8,11,9,2
159996,0,20,16,19,15,4,12,4,17,6,...,5,11,15,11,18,11,11,11,9,2
159997,0,20,16,19,15,4,12,4,17,6,...,5,11,15,11,18,11,22,11,9,2
159998,0,20,16,19,15,4,12,4,17,6,...,5,11,15,11,18,11,19,11,9,2


In [3]:
df = torch.load("./../../data/GB1/esm-1b_dataframe.pt")

In [7]:
embedding.size()

torch.Size([149361, 1280])

In [51]:
df.loc[df["Variants"] == "ADGV"]

Unnamed: 0,Variants,Fitness,Fitness_norm,Sequences,Embedding
122551,ADGV,0.06191,0.007066,"[tensor(0), tensor(20), tensor(16), tensor(19)...","[tensor(-0.0198), tensor(0.1364), tensor(-0.01..."
