# Protein melting point prediction
For proteins with an existing representation (those listed in [sequences.csv](../data/s_s_avg/sequences.csv)) this can be executed on a GPU with 9GB+ of memory. 
For novel proteins of length < 700 this notebook is tested on a GPU with roughly 40GB of memory (Nvidia A40)

## Imports and CUDA setup

In [None]:
import torch
from torch import nn as nn
import torch.backends.cudnn as cudnn
import sys
import os

sys.path.append(os.path.dirname(os.getcwd()))
from thermostability.hotinfer import HotInferModel
from ipywidgets import widgets

cudnn.benchmark = True
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

if torch.cuda.is_available():
    torch.cuda.empty_cache()

## Load models

In [None]:
thermo_module_esm = torch.load("../data/pretrained/s_s_avg/model.pt").to(device)
model_esm = HotInferModel(
    "s_s_avg",
    thermo_module=thermo_module_esm,
    pad_representations=False,
    model_parallel=False,
)

thermo_module_prott5 = torch.load("../data/pretrained/prott5_avg/model.pt").to(device)
model_prott5 = HotInferModel(
    "prott5_avg",
    thermo_module=thermo_module_prott5,
    pad_representations=False,
    model_parallel=False,
)

## Run prediction

In [None]:
def infer_therostability(sequence):    
    if len(sequence.value) == 0:
        print("Please provide a protein sequence for which to predict the thermostability.")
    else:
        if len(sequence.value) > 700:
            print(
                "Inference on sequences of a length of more than 700 amino acids can be inaccurate, since the model did not train on such sequences. \n Also a CUDA out of memory error might oocur if the sequence is too long for the amount of GPU memory available."
            )
        sequence_str = sequence.value.upper()
        prediction_esm = model_esm([sequence_str]).item()
        prediction_prott5 = model_prott5([sequence_str]).item()
        print(f"""Predicted melting point of {sequence_str}:\n
            Prediction (ESM): {prediction_esm:4f}\n
            Prediction (ProtT5): {prediction_prott5:4f}""")

You can try the inference for an example protein (Q6ZWK4).

Link: https://www.uniprot.org/uniprotkb/Q6ZWK4/entry#sequences

In [None]:
lbl1 = widgets.Label("Sequence to predict")
display(lbl1)
sequence = widgets.Text()
display(sequence)

In [None]:
infer_therostability(sequence)