# Reorganize representations

In [1]:
# filename, thermostability, sequence
import os
seqToFile = {}

def addToEntries(set):
    with open(f"data/s_s/{set}/labels.csv", "r") as f:
        firstLine = True

        for line in f:
            if firstLine:
                firstLine = False
                continue
            filename, thermostability, sequence = line.split(", ")
            sequence = sequence.replace("\n", "")
            seqToFile[sequence] = f"{set}/{filename}"
addToEntries("train")
addToEntries("val")

In [2]:
next(iter(seqToFile.items()))

('MSGEEEKAADFYVRYYVGHKGKFGHEFLEFEFRPNGSLRYANNSNYKNDTMIRKEATVSESVLSELKRIIEDSEIMQEDDDNWPEPDKIGRQELEILYKNEHISFTTGKIGALADVNNSKDPDGLRSFYYLVQDLKCLVFSLIGLHFKIKPI',
 'train/0.pt')

# Merge val and train set

In [None]:
import shutil
with open("data/s_s/sequences.csv", "w") as f:
    f.write("sequence, filename\n")
    for index, (seq, filename) in enumerate(seqToFile.items()):
        sourcePath = os.path.join("data/s_s", filename)
        targetFilename = f"{index}.pt"
        targetPath = f"data/s_s/{targetFilename}"

        shutil.copyfile(sourcePath, targetPath)
        f.write(f"{seq}, {targetFilename}\n")
        print(f"Done with {index}/{len(seqToFile.items())}", end="\r")

# Create train/eval metadata files

In [None]:
esm_eval_ids = set()
with open('data/uniref201803_ur50_valid_headers.txt') as txt_file:
    for line in txt_file:
        id = line.split('_')[1].replace('\n','')
        esm_eval_ids.add(id)

def readFasta(filepath='data/full_dataset_sequences.fasta'):
    first = True
    max =0
    dataset = []
    with open(filepath) as fasta:
        for line in fasta:
            if line[0] == '>':
                if first:
                    first = False
                else:
                    dataset.append(entry)
                entry = {}
                header_tokens = line.split(' ')
                entry['id'] = header_tokens[0].replace('>','').split('_')[0]
                entry['header'] = line.replace('\n', '')
                entry['temp'] = float(header_tokens[1].split('=')[1].replace('\n',''))
                entry['sequence'] = ''
            else:
                entry['sequence'] = entry['sequence'] + line.replace('\n','')
                max = len(entry['sequence']) if len(entry['sequence'])> max else max
    
    return dataset


evalDs = []
trainUnfilteredDs = []
trainIds = set()
allIds = set()
dataset = readFasta()
for entry in dataset:
    seq = entry['sequence']
    id = entry["id"]
    allIds.add(id)
    if id in esm_eval_ids:
        evalDs.append(entry)
    else: 
        trainUnfilteredDs.append(entry)
        trainIds.add(id)

# Filter train ds
clusters = {}
with open("data/meltome_PIDE20_clusters.tsv", "r") as f:
    firstLine = True
    for line in f:
        if firstLine:
          firstLine = False
          continue   
        clusterId, proteinId = line.replace("\n", "").split("\t")
        proteinId = proteinId.split('_')[0]
        
        if proteinId in allIds:
            if clusterId in clusters:
                clusters[clusterId].add(proteinId)
            else: 
                clusters[clusterId] = set([proteinId])


for clusterId, proteinids in clusters.items():
  numTrain = 0
  numEval = 0
  for proteinId in proteinids:
    if proteinId in esm_eval_ids:
      numEval += 1
    else: 
      numTrain += 1
    
    if numEval > 0 and numTrain>0:
        for proteinId in proteinids:
            if proteinId in trainIds:
                trainIds.remove(proteinId)

trainDs = [item for item in trainUnfilteredDs if item["id"] in trainIds]


def storeMetadata(ds,set: str):
    with open(f"data/{set}.csv", "w") as f:
        f.write("sequence, melting point\n")
        for entry in ds:
            f.write(f'{entry["sequence"]}, {entry["temp"]}\n')

storeMetadata(trainDs, "train")
storeMetadata(evalDs, "val")

# Generate missing representations
## Find missing sequences

In [3]:
maxSeqLen = 700

requiredSequences = set()
def addSequencesFromMetadataFile(set):
    with open(f"data/{set}.csv", "r") as f:
        firstLine = True
        for line in f:
            if firstLine:
                firstLine = False
                continue   
            sequence, _ = line.replace("\n", "").split(", ")
            if len(sequence) <=maxSeqLen:
                requiredSequences.add(sequence)


addSequencesFromMetadataFile("val")
addSequencesFromMetadataFile("train")

generatedSequences = set()

with open(f"data/s_s/sequences.csv", "r") as f:
    firstLine = True
    for line in f:
        if firstLine:
            firstLine = False
            continue   
        sequence, _ = line.replace("\n", "").split(", ")
        
        generatedSequences.add(sequence)

print("Required sequence sample", next(iter(requiredSequences)))
print("Generated sequence sample", next(iter(generatedSequences)))
print("Required sequences", len(requiredSequences))
print("Generated sequences", len(generatedSequences))
print("Remaining sequences to be generated: ", len(requiredSequences.difference(generatedSequences)))
print("Generated but not required sequences: ", len(generatedSequences.difference(requiredSequences)))


remainingSequences = requiredSequences.difference(generatedSequences)

Required sequence sample MIKLFSLKQQKKDEESAGGPRAGGGGKKASAAQLRIQKDINELNLPKTCEIVFPDQDDLLNFKLIISPDEGFYKGGKFVFSFKGRVIRTTPPKSSVKRWFITPTLTWRETSV
Generated sequence sample MTPSTPPRSRGTRYLAQPSGNTSSSALMQGQKTPQKPSQNLVPVTPSTTKSFKNAPLLAPPNSNMGMTSPFNGLTSPQRSPFPKSSVKRTLFQFESHDNGTVREEQEPLGRVNRILFPTQQNVDIDAAEEEEEGEVLLPPSRPTSARQLHLSLERDEFDQTHRKKIIKDVPGTPSDKVITFELAKNWNNNSPKNDARSQESEDEEDIIINPVRVGKNPFASDELVTQEIRNERKRAMLRENPDIEDVITYVNKKGEVVEKRRLTDEEKRRFKPKALFQSRDQEH
Required sequences 21713
Generated sequences 16674
Remaining sequences to be generated:  8277
Generated but not required sequences:  3238


In [4]:
max([len(seq) for seq in remainingSequences])

700

## Generate missing sequences
### Imports

In [5]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import torch.backends.cudnn as cudnn
from torch.utils.data import random_split
import numpy as np
import matplotlib.pyplot as plt
import time
import os
import copy
from thermostability.thermo_dataset import ThermostabilityDataset
from util.telegram import TelegramBot
import esm
import pickle
cudnn.benchmark = True
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

if torch.cuda.is_available():
    torch.cuda.empty_cache() 
    
cpu = torch.device("cpu")

torch.cuda.list_gpu_processes()

telegramBot = TelegramBot()
telegramBot.enabled = True

In [6]:
esmfold = esm.pretrained.esmfold_v1()
esmfold.to(device)

ESMFold(
  (esm): ESM2(
    (embed_tokens): Embedding(33, 2560, padding_idx=1)
    (layers): ModuleList(
      (0): TransformerLayer(
        (self_attn): MultiheadAttention(
          (k_proj): Linear(in_features=2560, out_features=2560, bias=True)
          (v_proj): Linear(in_features=2560, out_features=2560, bias=True)
          (q_proj): Linear(in_features=2560, out_features=2560, bias=True)
          (out_proj): Linear(in_features=2560, out_features=2560, bias=True)
          (rot_emb): RotaryEmbedding()
        )
        (self_attn_layer_norm): LayerNorm((2560,), eps=1e-05, elementwise_affine=True)
        (fc1): Linear(in_features=2560, out_features=10240, bias=True)
        (fc2): Linear(in_features=10240, out_features=2560, bias=True)
        (final_layer_norm): LayerNorm((2560,), eps=1e-05, elementwise_affine=True)
      )
      (1): TransformerLayer(
        (self_attn): MultiheadAttention(
          (k_proj): Linear(in_features=2560, out_features=2560, bias=True)
         

### Define sequences DS for convenience


In [7]:
from torch.utils.data import Dataset
class SequencesDataset(Dataset):
    def __init__(self, sequences: "set[str]") -> None:
        super().__init__()
        self.sequences = list(sequences)

    def __len__(self):
        return len(self.sequences)

    def __getitem__(self, index):
        return self.sequences[index]

### Run generation

In [8]:
import time
import datetime
rootDir = "data/s_s/"

mergedDsSeqToFile = {}
labelsFilePath = "data/s_s/sequences.csv"
with open(labelsFilePath, "r") as f:
    firstLine = True
    for line in f:
        if firstLine:
            firstLine = False
            continue
        sequence, filename = line.replace("\n", "").split(", ")
        mergedDsSeqToFile[sequence] = filename

def generateRepresenations(sequences):
    ds = SequencesDataset(sequences)
    loader = torch.utils.data.DataLoader(ds, batch_size=2, shuffle=False, num_workers=0)
    timeStart = time.time()
    telegramBot.send_telegram(f"Generating remaining s_s representations for {len(sequences)} sequences")
    response = telegramBot.send_telegram(f"Generating first batch...")
    messageId = response["result"]["message_id"]

    
    if not os.path.exists(labelsFilePath):
        with open(labelsFilePath,"w") as csv:
            csv.write(f"sequence, filename\n") 

    maxFilePrefix = 0
    for seq, file in mergedDsSeqToFile.items():
        prefix = int(file.split(".")[0])
        if prefix > maxFilePrefix:
            maxFilePrefix = prefix

    print(f"Starting with maxFilePrefix {maxFilePrefix}")
    batchesPredicted = 0

    for index, (inputs) in enumerate(loader):
        
        batch_size = len(inputs)
        numBatches = int(len(sequences) / batch_size)
        print(f"At batch {index}/{numBatches}")
        with torch.no_grad():
            alreadyPredicted = True
            for i, seq in enumerate(inputs):
                if seq not in mergedDsSeqToFile:
                    alreadyPredicted = False
                    break
            
            if not alreadyPredicted:
                print(f"Predicting")
                esm_output = esmfold.infer(sequences=inputs)
                s_s = esm_output["s_s"]
                batchesPredicted +=1
                #s_z = esm_output["s_z"]
                with open(labelsFilePath,"a") as csv:
                    for s, data in enumerate(s_s):
                        maxFilePrefix+=1
                        file = str(maxFilePrefix)+".pt"

                        if not os.path.exists(file):
                            with open(os.path.join("data/s_s", file), "wb") as f:
                                torch.save(data.cpu(),f)
                            csv.write(f"{inputs[s]}, {file}\n") 
        if index %5 == 0:
            secsSpent = time.time()- timeStart  
            secsToGo = (secsSpent/(batchesPredicted+1))*(numBatches-index-1)
            hoursToGo = secsToGo/(60*60)
            now = datetime.datetime.now()
            telegramBot.edit_text_message(messageId, f"Done with {index}/{numBatches} batches (hours to go: {int(hoursToGo)}) [last update: {now.hour}:{now.minute}]")
        


try:
    generateRepresenations(remainingSequences)
except Exception as e:
    print("Exception raised: ", e)
    telegramBot.send_telegram("Generation of representations failed with error message: "+str(e))

telegramBot.send_telegram(f"Doneinger!")

Starting with maxFilePrefix 16673
At batch 0/4138
Predicting
At batch 1/4138
Predicting
At batch 2/4138
Predicting


KeyboardInterrupt: 

In [6]:
import csv
import os
import torch
repr_root = "../data/s_s"

with open(os.path.join(repr_root,"sequences.csv"), "r") as f:
    reader = csv.reader(f, delimiter=',', skipinitialspace=True)
    seqToFile = dict([(seq, os.path.join(repr_root,file_name)) for i, (seq, file_name) in enumerate(reader) if i!=0])

avg_repr_root = "../data/s_s_avg"
if not os.path.exists(avg_repr_root):
    os.mkdir(avg_repr_root)
with open(os.path.join(avg_repr_root,"sequences.csv"), "w") as f:
    f.write("sequence, filename\n")
    for (seq, repr_file_path) in seqToFile.items():
        s_s = torch.load(repr_file_path)
        avg_repr_f_name = os.path.basename(repr_file_path)
        torch.save(torch.mean(s_s, 0), os.path.join(avg_repr_root,avg_repr_f_name))
        f.write(f"{seq}, {avg_repr_f_name}\n")