# Protein melting point prediction

## Imports and CUDA setup

In [None]:
import torch
from torch import nn as nn
import torch.backends.cudnn as cudnn
from thermostability.hotinfer import HotInferModel
from ipywidgets import widgets

cudnn.benchmark = True
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

if torch.cuda.is_available():
    torch.cuda.empty_cache()


## Load model

In [None]:
thermo_module = torch.load('path')
model = HotInferModel("s_s_avg", thermo_module=thermo_module, pad_representations=False, model_parallel=False).to(device)

## Sequence input

In [None]:
lbl1 = widgets.Label('Sequence to predict')
display(lbl1)
sequence = widgets.Text()
display(sequence)

## Run prediction

In [None]:
if len(sequence.value) > 700:
    print('Inference on sequences of a length of more than 700 amino acids can be inaccurate, since the model did not train on such sequences')
sequence_str = sequence.value.upper()
prediction = model([sequence]).item()
print(f"Predicted melting point of {sequence_str}: \n{prediction:4f}")