# Lobster-Gemini tutorial notebook
* Adapted from [here](https://code.roche.com/liny82/gemini_ranking_demo)

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import torch
import pandas as pd
import numpy as np
import os
import math
from math import sqrt
import torch.nn.functional as F
import lightning.pytorch as pl
import matplotlib.pyplot as plt
import seaborn as sns

from lobster.data import GeminiDataFrameLightningDataModule
from lobster.model import GeminiModel

from torchvision.transforms import Resize

device = 'cuda' if torch.cuda.is_available() else 'cpu'
torch.backends.cuda.matmul.allow_tf32 = True
device

In [None]:
plt.rcParams["figure.figsize"] = (4, 3)  # Set default figure size
plt.rcParams["figure.dpi"] = 150  # Set default figure dpi
plt.rcParams["font.size"] = 12  # Set default font size
plt.rcParams["lines.linewidth"] = 1.5  # Set default line width
plt.rcParams["axes.linewidth"] = 1.5  # Set default axes line width
plt.rcParams["axes.grid"] = True  # Show grid by default
plt.rcParams["grid.linestyle"] = "--"  # Set grid line style

sns.set_style("ticks")  # Set default seaborn style
sns.set_palette("colorblind")  # Set default color palette

### Data prep
* Read data with `fv_heavy`, `fv_light`, and `pKD` columns into a dataframe. Do whatever pre-processing and cleaning you'd like to do.

In [None]:
df_COSMO = pd.read_csv('s3://prescient-data-dev/sandbox/liny82/C1/262_cosmo_clean.csv')
print(len(df_COSMO))
df = df_COSMO.dropna(subset=['pKD'])
print(len(df))

FIT_QUALITY = 20. # 20 -> only keep very good quality sensograms. Change it to 1 to be closer to what scientists accept as a binder
filter = (
    (df.Rmax >= 20) 
    # & (df.Chi2 < df.Rmax/FIT_QUALITY) 
    & (1e-11 < df.kinetics_kd) 
    # & (df.kinetics_kd < 1e-4)
)

print(len(df[filter]))
df = df[filter]

### Instantiate a datamodule from the dataframe

In [None]:
dm = GeminiDataFrameLightningDataModule(data=df_COSMO,
                                 batch_size=128)
dm.setup()

### Instantiate a Gemini model with a pre-trained encoder

In [None]:
model = GeminiModel(model_name='esm2_t6_8M_UR50D');

### Train the model

In [None]:
trainer = pl.Trainer(accelerator=device,
                     max_epochs=1,
                    )

In [None]:
trainer.fit(model, datamodule=dm)

### Inference

In [None]:
dm.setup(stage='predict')

In [None]:
preds = trainer.predict(model, datamodule=dm)

In [None]:
model.eval();
predict_dataloader = dm.predict_dataloader()

In [None]:
for idx, batch in enumerate(predict_dataloader):
    with torch.inference_mode():
        preds, targets = model.predict_step(batch, idx)
        break

In [None]:
scale= np.linspace(-3, 3 ,100)

plt.figure(figsize=(3,3))        
sns.scatterplot(x=targets, y=preds, s = 2)
plt.plot(scale, scale, "r")
plt.title("delta pKD")
plt.xlim(-3, 3)
plt.ylim(-3, 3)
plt.xlabel("delta pKD (ground truth)")
plt.ylabel("Gemini output [delta pKD]")
plt.show()