## a-AlphaBio homework 
### Mark Thompson
### Started April 29, 2024 

In [None]:
%load_ext autoreload

In [None]:
%autoreload
# import libraries
import numpy as np
import pickle as pk
import pandas as pd
import math
import os
import matplotlib.pyplot as plt
%matplotlib inline


In [None]:
import torch
import torch.nn as nn

# Some plotting functions
#
def plot_preds_hist(preds_file_path):
    preds = pk.load(open(preds_file_path, 'rb'))
    print('len(preds):', len(preds))
    preds = [p[0] for p in preds]
    print('preds[0:10]:', preds[0:10])

    # Histogram of predicted values
    plt.hist(preds, bins=100)
    plt.xlabel('pred Kd (nm)')
    plt.ylabel('count')
    plt.title('Distribution of pred values set')
    plt.show()


def plot_pred_vs_true(preds_file_path, true_file_path, xlim=(0,5), ylim=(0,5)):
    preds = pk.load(open(preds_file_path, 'rb'))
    y = pk.load(open(true_file_path, 'rb'))
    print('len(preds):', len(preds), ', len(y):', len(y))
    preds = [p[0] for p in preds]
    y = [a[0] for a in y]

    # scatter plot of true vs pred
    plt.scatter(y, preds, c ="blue")
    plt.xlabel('experimental Kd (nm)')
    plt.ylabel('predicted Kd (nm)')
    plt.title('true vs predicted Kd on validation set')
    plt.xlim(xlim)
    plt.ylim(ylim)
    plt.show()



---------
## Holdout dataset and predictions

In [None]:
# The holdout data
data_file = './data/alphaseq_data_hold_out.csv'
df = pd.read_csv(data_file)
rows1 = df.shape[0]
print('holdout dataframe has', rows1, 'rows')
print(df.columns.tolist())
print(df['sequence_a'].describe())

In [None]:
# The predictions on the holdout set

# tform_mlp version 1 predictions
data_file = './inference_results/tform_mlp_model/cleaned-4-data/preds_tform_mlp_1715104590.5575511.csv'
df = pd.read_csv(data_file)
rows1 = df.shape[0]
print('holdout predictions has', rows1, 'rows')
print(df.columns.tolist())
print(df.describe())
preds = df['pred_Kd'].values

# tform_mlp_v2 version 2 predictions
data_file = './inference_results/tform_mlp_model_v2/addendum/cleaned-4b-data/preds_tform_mlp_1715280172.2843447.csv'
df = pd.read_csv(data_file)
rows1 = df.shape[0]
print('holdout predictions has', rows1, 'rows')
print(df.columns.tolist())
print(df.describe())
preds_v2 = df['pred_Kd'].values


# Histogram of predicted values
# plt.figure(figsize=(6,6))
# plt.hist(preds, bins=100)
plt.hist(preds, bins=100, alpha=1.0, label='Transformer v1', color='b')
plt.hist(preds_v2, bins=100, alpha=1.0, label='Transformer v2', color='r')
plt.legend(loc='upper left')

plt.xlabel('pred Kd (nm)')
plt.ylabel('count')
plt.title('Distribution of pred Kd values on the holdout set')
plt.xlim((-1,5))
# plt.ylim((0,5))
plt.show()


----
## alphaseq_data_train dataset  (not cleaned)

In [None]:
# The predictions on the holdout set
data_file = './data/alphaseq_data_train.csv'
df = pd.read_csv(data_file)
rows1 = df.shape[0]
print('dataset has', rows1, 'rows')
print(df.columns.tolist())
print(df.describe())

raw_Kds = df['Kd'].values

# Histogram of Kd values
plt.figure(figsize=(3,3))
plt.hist(raw_Kds, bins=100)
plt.xlabel('Experimental Kd (nm)')
plt.ylabel('count')
plt.title('Distribution of Kd values in the raw training data set')
plt.xlim((-1,5))
plt.show()


----
## Kd distribution for clean-4 dataset train only

In [None]:
# The predictions on the holdout set
# data_file = './data/q_cleaned_4_train_set.csv'
data_file = './data/q_cleaned_4b_train_set.csv'
df = pd.read_csv(data_file)
rows1 = df.shape[0]
print('dataset has', rows1, 'rows')
print(df.columns.tolist())
print(df.describe())

clean4_Kds = df['Kd'].values

# Histogram of Kd values
# plt.figure(figsize=(3,3))
# plt.hist(clean4_Kds, bins=100)
# plt.hist(raw_Kds, bins=100)
plt.hist(raw_Kds, bins=100, alpha=1.0, label='raw data', color='b')
plt.hist(clean4_Kds, bins=100, alpha=1.0, label='clean-4 data', color='r')
plt.legend(loc='upper left')
plt.xlabel('Experimental Kd (nm)')
plt.ylabel('count')
plt.title('Distribution of experimental Kd values')
plt.xlim((-1,5))
plt.show()


----
### MLP model  Clean-3b dataset

In [None]:
pred_file_path = ''
plot_preds_hist(pred_file_path)

In [None]:
pred_file_path = './inference_results/mlp_model/cleaned-3/test_no_cls_token/preds_mlp_1714982629.5918856.pkl'
true_file_path = './inference_results/mlp_model/cleaned-3/test_no_cls_token/y_mlp_1714982629.5921109.pkl'
plot_pred_vs_true(pred_file_path, true_file_path, xlim=(0,3.5), ylim=(0,3.5))

In [None]:
# Pearson Correlation Coefficient
#
pred_file_path = '/Users/markthompson/Documents/dev/a-alphaBio-homework/test_results/mlp_model/cleaned-3b-data/preds_mlp_1715105115.4793816.pkl'
true_file_path = '/Users/markthompson/Documents/dev/a-alphaBio-homework/test_results/mlp_model/cleaned-3b-data/y_mlp_1715105115.4793816.pkl'

pred = torch.tensor(pk.load(open(pred_file_path, 'rb'))).squeeze()
true = torch.tensor(pk.load(open(true_file_path, 'rb'))).squeeze()

print(pred.shape)
print(true.shape)

c = torch.stack((pred, true), dim=0)
print(c.shape)

p = torch.corrcoef(c)
print(p)


-------
### Vision Transform Model (VIT)  1-channel Clean-3b Dataset

In [None]:
pred_file_path = ''
plot_preds_hist(pred_file_path)

In [None]:
# Inference on the validation set to compare actual with predicted values
pred_file_path = './inference_results/vit_model/cleaned-3b/BW/test_no_cls_token/preds_vit_1715016846.793841.pkl'
true_file_path = './inference_results/vit_model/cleaned-3b/BW/test_no_cls_token/y_vit_1715016846.7940617.pkl'
plot_pred_vs_true(pred_file_path, true_file_path, xlim=(0,3.5), ylim=(0,3.5))

In [None]:
# Pearson Correlation Coefficient
#
pred_file_path = '/Users/markthompson/Documents/dev/a-alphaBio-homework/test_results/vit_model/cleaned-3b-data/1-channel/preds_vit_1715105283.6443539.pkl'
true_file_path = '/Users/markthompson/Documents/dev/a-alphaBio-homework/test_results/vit_model/cleaned-3b-data/1-channel/y_vit_1715105283.6443539.pkl'

pred = torch.tensor(pk.load(open(pred_file_path, 'rb'))).squeeze()
true = torch.tensor(pk.load(open(true_file_path, 'rb'))).squeeze()

print(pred.shape)
print(true.shape)

c = torch.stack((pred, true), dim=0)
print(c.shape)

p = torch.corrcoef(c)
print(p)


----
### Vision Transform Model (VIT)  3-channel, Clean-3b Dataset

In [None]:
# Inference on the validation set to compare actual with predicted values
pred_file_path = './inference_results/vit_model/cleaned-3b/BGR/test_no_cls_token/preds_vit_1715020986.6452327.pkl'
true_file_path = './inference_results/vit_model/cleaned-3b/BGR/test_no_cls_token/y_vit_1715020986.6455376.pkl'
plot_pred_vs_true(pred_file_path, true_file_path, xlim=(0,3.5), ylim=(0,3.5))

In [None]:
# Pearson Correlation Coefficient
#
pred_file_path = '/Users/markthompson/Documents/dev/a-alphaBio-homework/test_results/vit_model/cleaned-3b-data/3-channel/preds_vit_1715105421.8529866.pkl'
true_file_path = '/Users/markthompson/Documents/dev/a-alphaBio-homework/test_results/vit_model/cleaned-3b-data/3-channel/y_vit_1715105421.8529866.pkl'

pred = torch.tensor(pk.load(open(pred_file_path, 'rb'))).squeeze()
true = torch.tensor(pk.load(open(true_file_path, 'rb'))).squeeze()

print(pred.shape)
print(true.shape)

c = torch.stack((pred, true), dim=0)
print(c.shape)

p = torch.corrcoef(c)
print(p)


----
### TFormMLP Clean-4 Dataset 

In [None]:
pred_file_path = './test_results/tform_mlp_model/cleaned-4-data/preds_tform_mlp_1715105469.0702462.pkl'
true_file_path = './test_results/tform_mlp_model/cleaned-4-data/y_tform_mlp_1715105469.0702462.pkl'
plot_pred_vs_true(pred_file_path, true_file_path, xlim=(0,3.5), ylim=(0,3.5))

In [None]:
# Pearson Correlation Coefficient
#
pred_file_path = '/Users/markthompson/Documents/dev/a-alphaBio-homework/test_results/tform_mlp_model/cleaned-4-data/preds_tform_mlp_1715105469.0702462.pkl'
true_file_path = '/Users/markthompson/Documents/dev/a-alphaBio-homework/test_results/tform_mlp_model/cleaned-4-data/y_tform_mlp_1715105469.0702462.pkl'

pred = torch.tensor(pk.load(open(pred_file_path, 'rb'))).squeeze()
true = torch.tensor(pk.load(open(true_file_path, 'rb'))).squeeze()

print(pred.shape)
print(true.shape)

c = torch.stack((pred, true), dim=0)
print(c.shape)

p = torch.corrcoef(c)
print(p)


----
### TFormMLP Clean-4 Dataset:  pretrained, then fine-tuned

In [None]:
pred_file_path = './test_results/tform_mlp_model/finetune/cleaned-4b-data/preds_tform_mlp_1715572603.424339.pkl'
true_file_path = './test_results/tform_mlp_model/finetune/cleaned-4b-data/y_tform_mlp_1715572603.424339.pkl'
plot_pred_vs_true(pred_file_path, true_file_path, xlim=(-1.0, 0.5), ylim=(-0.2,0))

a = pk.load(open(pred_file_path, 'rb'))
print(a)