In [None]:
import numpy as np

import json

import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F

#import sepia

import jak
from jak.transformer import XFormer
import matplotlib.pyplot as plt

import rdkit
from rdkit import Chem
from rdkit.Chem.Draw import IPythonConsole

from IPython.display import clear_output

In [None]:
measurement_dict = {"pIC50": 0 , "pKi": 1}
enzyme_dict = {"JAK1": 0, "JAK2": 1, "JAK3": 2, "TYK2": 3}

def find_values(smile, df):
    
    df_smile = df.loc[df["SMILES"] == smile]
    
    x = []
    y = []
    print(df_smile.head())
    for ii in range(len(df_smile)):
              
        index = measurement_dict[df_smile["measurement_type"].to_list()[ii]] * len(enzyme_dict.keys()) \
                + enzyme_dict[df_smile["Kinase_name"].to_list()[ii]]
      
        x.append(index)
        y.append(df_smile["measurement_value"].to_numpy()[ii])
    
    x = np.array(x)
    y = np.array(y)
    x_pic50 = x[x < 4.] + 1
    y_pic50 = y[x < 4.]
    
    x_pki = x[x >= 4.] - 3
    y_pki = y[x >= 4.]
    
    return (x_pic50, y_pic50, x_pki, y_pki)
    
    

In [None]:
df = pd.read_csv("../data/train_JAK.csv")
unique_smiles = df["SMILES"].unique()

example_indices = np.random.randint(0,len(unique_smiles), size=(16,))
smiles_examples = unique_smiles[example_indices]

In [None]:
# set up feature extractor model (pre-trained autoregressive transformer)

seq_length = 100
token_dim = 33
encoder_size = 1
decoder_size = 1
smiles_vocab = "#()+-1234567=BCFHINOPS[]cilnors"
parameters_fp = "../parameters/xformer_x003_seed42/tag_xformer_x003_seed42_epoch299.pt"



xformer = XFormer(vocab = smiles_vocab, \
                  token_dim=token_dim, \
                  seq_length=seq_length, \
                  lr=1e-3,\
                  device="cpu",\
                  tag="inference"
                 )

model_state_dict = torch.load(parameters_fp, map_location=xformer.my_device)
xformer.load_state_dict(model_state_dict)

In [None]:
# set up regression model ensemble (multiple MLPs)

kwarg_filepath = "../parameters/ensemble_004_seed13/exp_tag_ensemble_004_seed13.json"
parameters_filepath = "../parameters/ensemble_004_seed13/tag_ensemble_004_seed13_epoch4999.pt"

with open(kwarg_filepath, "r") as f:
    kwargs = json.load(f)
    
in_dim = 3300
out_dim = 8
    
ensemble = jak.mlp.MLPCohort(in_dim, out_dim,\
                             cohort_size=kwargs["cohort_size"], \
                             depth=kwargs["depth"], \
                             lr=kwargs["lr"])

ensemble.load_state_dict(torch.load(parameters_filepath, map_location=ensemble.my_device))
                             

In [None]:
do_estimate = True

min_x = 5.5
max_x = 13.5

while do_estimate:

    msg = "example small molecules in SMILES format:\n\n"
    for ii, example in enumerate(smiles_examples):
        msg += f"  {ii}    {example}\n"

    print(msg)
    smile = input("enter a small molecule in SMILES format, or enter the index of an example above")

    if len(smile) < 3:
        smile = smiles_examples[int(smile)]

    print(f"\n SMILE format molecule: {smile}")


    encoded = xformer.encode(smile).reshape(1, -1)
    ensemble.eval()
    estimates = ensemble.forward_ensemble(encoded).squeeze()

    # inhibition estimates in order pIC50 JAK1, JAK2, JAK3, TYK2; pKI JAK1, JAK2, JAK3, TYK2
    x = estimates.numpy()
    fig, ax = plt.subplots(2,2, figsize=(10,10))

    ax[1,0].boxplot(x[:,:4])
    ax[1,1].boxplot(x[:,4:])

    ax_ticks = [1,2,3,4]

    ax_ticklabels = ["JAK1", "JAK2", "JAK3", "TYK2"]


    ax[1,0].set_xticklabels(ax_ticklabels)
    ax[1,1].set_xticklabels(ax_ticklabels)

    ax[1,0].set_ylabel("pIC50")
    ax[1,1].set_ylabel("pKI")

    fig.suptitle(f"Kinase inhibition for {smile}", fontsize=17)
    ax[1,0].set_title("pIC50 estimates")
    ax[1,1].set_title("pKI estimates")
    ax[1,0].axis([0,5, min_x, max_x])
    ax[1,1].axis([0,5, min_x, max_x])

    x_pic50, y_pic50, x_pki, y_pki = find_values(smile, df)

    ax[1,0].scatter(x_pic50, y_pic50, label="measured value")
    ax[1,1].scatter(x_pki, y_pki, label="measured value")

    ax[1,0].legend()
    ax[1,1].legend()
    
    mol = Chem.MolFromSmiles(smile)
    mol_image = rdkit.Chem.Draw.MolsToImage([mol])
    ax[0,0].imshow(mol_image)
    
    plt.tight_layout()

    plt.show()
    
    do_estimate_string = input("Estimate for another molecule? (y)es/(n)o")
    
    if "y" in do_estimate_string:
        do_estimate = True
    else:
        do_estimate = False
    clear_output(wait=True)