In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import torch.backends.cudnn as cudnn
from torch.utils.data import random_split
import numpy as np
import matplotlib.pyplot as plt
import time
import os
import copy
from thermostability.thermo_dataset import ThermostabilityDataset
from util.telegram import TelegramBot

cudnn.benchmark = True
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

if torch.cuda.is_available():
    torch.cuda.empty_cache() 
    
cpu = torch.device("cpu")

torch.cuda.list_gpu_processes()

telegramBot = TelegramBot()
telegramBot.enabled = True

In [2]:
trainSet = ThermostabilityDataset("data/train_sequences.fasta", max_seq_len=700, once_occuring_seq_only=True)
valSet = ThermostabilityDataset("data/eval_sequences.fasta", max_seq_len=700, once_occuring_seq_only=True)

dataloaders = {
    "train": torch.utils.data.DataLoader(trainSet, batch_size=2, shuffle=False, num_workers=4),
    "val": torch.utils.data.DataLoader(valSet, batch_size=2, shuffle=False, num_workers=4)
}

dataset_sizes = {"train": len(trainSet),"val": len(valSet)}
print(dataset_sizes)

Loading data from cache file:  data/train_sequences.fasta_v2_cache.p
Loading data from cache file:  data/eval_sequences.fasta_v2_cache.p
{'train': 15685, 'val': 995}


In [3]:
import esm
import pickle
import os
esmfold = esm.pretrained.esmfold_v1()
esmfold.to(device)



ESMFold(
  (esm): ESM2(
    (embed_tokens): Embedding(33, 2560, padding_idx=1)
    (layers): ModuleList(
      (0): TransformerLayer(
        (self_attn): MultiheadAttention(
          (k_proj): Linear(in_features=2560, out_features=2560, bias=True)
          (v_proj): Linear(in_features=2560, out_features=2560, bias=True)
          (q_proj): Linear(in_features=2560, out_features=2560, bias=True)
          (out_proj): Linear(in_features=2560, out_features=2560, bias=True)
          (rot_emb): RotaryEmbedding()
        )
        (self_attn_layer_norm): LayerNorm((2560,), eps=1e-05, elementwise_affine=True)
        (fc1): Linear(in_features=2560, out_features=10240, bias=True)
        (fc2): Linear(in_features=10240, out_features=2560, bias=True)
        (final_layer_norm): LayerNorm((2560,), eps=1e-05, elementwise_affine=True)
      )
      (1): TransformerLayer(
        (self_attn): MultiheadAttention(
          (k_proj): Linear(in_features=2560, out_features=2560, bias=True)
         

In [6]:
import time
import datetime
rootDir = "data/s_s/"


def generateRepresenations(set):
    timeStart = time.time()
    telegramBot.send_telegram(f"Generating s_s representations for set {set}...")
    setDir= os.path.join(rootDir, set)
    os.makedirs(setDir, exist_ok=True)
    response = telegramBot.send_telegram(f"Generating first batch...")
    messageId = response["result"]["message_id"]

    labelsFilePath = os.path.join(setDir, "labels.csv")
    if not os.path.exists(labelsFilePath):
        with open(labelsFilePath,"w") as csv:
            csv.write(f"filename, thermostability, sequence\n") 
    batchesPredicted = 0
    for index, (inputs, labels) in enumerate(dataloaders[set]):
        
        batch_size = len(inputs)
        print(f"At batch {index}/{dataset_sizes[set]}")
        with torch.no_grad():
            alreadyPredicted = True
            for i, seq in enumerate(inputs):
                file = str(index*batch_size+i)+".pt"
                filePath = os.path.join(setDir, file)
                
                if not os.path.exists(filePath):
                    alreadyPredicted = False
                    break
            
            if not alreadyPredicted:
                print(f"Predicting")
                esm_output = esmfold.infer(sequences=inputs)
                s_s = esm_output["s_s"]
                batchesPredicted +=1
                #s_z = esm_output["s_z"]
                with open(labelsFilePath,"a") as csv:
                    for s, data in enumerate(s_s):
                        file = str(index*batch_size+s)+".pt"

                        if not os.path.exists(file):
                            with open(os.path.join(setDir, file), "wb") as f:
                                torch.save(data.cpu(),f)
                            csv.write(f"{file}, {labels[s][0].item()}, {inputs[s]}\n") 
        if index %5 == 0:
            secsSpent = time.time()- timeStart  
            secsToGo = (secsSpent/(batchesPredicted+1))*(dataset_sizes[set]/batch_size-index-1)
            hoursToGo = secsToGo/(60*60)
            now = datetime.datetime.now()
            telegramBot.edit_text_message(messageId, f"Done with {index}/{len(dataloaders[set])} batches for {set} (hours to go: {int(hoursToGo)}) [last update: {now.hour}:{now.minute}]")
        


try:
    for set in ["train", "val"]:
        generateRepresenations(set)
except Exception as e:
    print("Exception raised: ", e)
    telegramBot.send_telegram("Generation of representations failed with error message: "+str(e))

telegramBot.send_telegram(f"Doneinger!")

At batch 0/15685
At batch 1/15685
At batch 2/15685
At batch 3/15685
Predicting
At batch 4/15685
Predicting


KeyboardInterrupt: 