In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import torch.backends.cudnn as cudnn
from torch.utils.data import random_split
import numpy as np
import matplotlib.pyplot as plt
import time
import os
import copy
from thermostability.thermo_dataset import ThermostabilityDataset
from util.telegram import TelegramBot

cudnn.benchmark = True
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

if torch.cuda.is_available():
    torch.cuda.empty_cache() 
    
cpu = torch.device("cpu")

torch.cuda.list_gpu_processes()

telegramBot = TelegramBot()
telegramBot.enabled = True

In [2]:
trainSet = ThermostabilityDataset("data/train_sequences.fasta", max_seq_len=250)
valSet = ThermostabilityDataset("data/eval_sequences.fasta", max_seq_len=250)

dataloaders = {
    "train": torch.utils.data.DataLoader(trainSet, batch_size=16, shuffle=False, num_workers=4),
    "val": torch.utils.data.DataLoader(valSet, batch_size=16, shuffle=False, num_workers=4)
}

dataset_sizes = {"train": len(trainSet),"val": len(valSet)}
print(dataset_sizes)

Loading data from cache file:  data/train_sequences.fasta_cache.p
Loading data from cache file:  data/eval_sequences.fasta_cache.p
{'train': 34234, 'val': 3356}


In [3]:
import esm
import pickle
import os
esmfold = esm.pretrained.esmfold_v1()
esmfold.to(device)

rootDir = "data/s_s/"

def generateRepresenations(set):
    telegramBot.send_telegram(f"Generating s_s representations for set {set}...")
    setDir= os.path.join(rootDir, set)
    os.makedirs(setDir, exist_ok=True)
    response = telegramBot.send_telegram(f"Generating first batch...")
    messageId = response["result"]["message_id"]
    with open(os.path.join(setDir, "labels.csv"),"w") as csv:
        csv.write(f"filename, thermostability, sequence\n") 
        for index, (inputs, labels) in enumerate(dataloaders[set]):
            batch_size = len(inputs)
            with torch.no_grad():
                esm_output = esmfold.infer(sequences=inputs)
                s_s = esm_output["s_s"]
                #s_z = esm_output["s_z"]
                for s, data in enumerate(s_s):
                    file = str(index*batch_size+s)+".pt"
                    with open(os.path.join(setDir, file), "wb") as f:
                        torch.save(data.cpu(),f)
                        csv.write(f"{file}, {labels[s][0].item()}, {inputs[s]}\n") 
                    break

            if index %10 == 0:
                telegramBot.edit_text_message(messageId, f"Done with {index}/{len(dataloaders[set])} batches for {set}")
            break


try:
    for set in ["train", "val"]:
        generateRepresenations(set)
except Exception as e:
    telegramBot.send_telegram("Generation of representations failed with error message: "+str(e))

telegramBot.send_telegram(f"Doneinger!")

{'ok': True,
 'result': {'message_id': 356,
  'from': {'id': 5956605174,
   'is_bot': True,
   'first_name': 'hotprotbot',
   'username': 'hotprotbot'},
  'chat': {'id': -813132580,
   'title': 'HotProt Training',
   'type': 'group',
   'all_members_are_administrators': True},
  'date': 1673448228,
  'text': 'Doneinger!'}}