In [None]:
import torch
from transformers import MT5ForConditionalGeneration, MT5Config, MT5EncoderModel, MT5Tokenizer, Trainer, TrainingArguments

from transformers_custom import MT5ForConditionalGenerationWithLatentSpace
from progeny_tokenizer import TAPETokenizer
import numpy as np
import math
import random
import scipy
import time
import pandas as pd
from torch.utils.data import DataLoader, RandomSampler, Dataset, BatchSampler
import typing
from pathlib import Path
import argparse

from tqdm import tqdm, trange
import shutil

try:
    from torch.utils.tensorboard import SummaryWriter
except ImportError:
    from tensorboardX import SummaryWriter

import os

import matplotlib.pyplot as plt

In [None]:
seed = 30
data_dir = "data/gen_train_data/top_half_ddG"
pretrained_dir = "./congen/v1/clspool_waeDeterencStart84kstep1024dim_cyccon1Start84kstep_lre-04_24ep/"

In [None]:
# src_json = '/export/share/bkrause/progen/progeny/t5_base_uniref_bfd50/config.json'
# shutil.copy(src_json, pretrained_dir)

In [None]:
np.random.seed(seed)
random.seed(seed)
torch.manual_seed(seed)

output_dir = Path("./congen/v1/analysis/clspool_waeDeterencStart84kstep1024dim_cyccon1Start84kstep_lre-04_24ep/")
output_dir.mkdir(parents=True, exist_ok=True)

tokenizer = TAPETokenizer(vocab="progeny")

device = torch.device('cuda:0')

t5config = MT5Config.from_pretrained(pretrained_dir)


In [None]:
latent_space_args = {
    'latent_pooler': 'cls',
    'pool_enc_hidden_states_for_dec': True,
    'latent_space_type': 'wae',
    'mask_non_target_z_vector': False,
    'separate_targetattr_head': False,
    'z_tar_vector_dim': 1,
    'do_mi': False,
    'latent_size': 1024,
    'wae_z_enc_type': 'deterministic',
    'separate_latent_enc': False,
    'separate_latent_dec': False,
}

model = MT5ForConditionalGenerationWithLatentSpace.from_pretrained(pretrained_dir, **latent_space_args)

model.parallelize()

# Evaluate Generator's ddG predictions

In [None]:
train_ratio=0.9
per_device_train_batch_size = 16
per_device_eval_batch_size = 64

In [None]:
class PKLDFDatasetForGen(Dataset):
    """Creates a dataset from an pkl df file.
    Args:
        data_file (typing.Union[str, Path]): Path to pkl df file.
        in_memory (bool, optional): Whether to load the full dataset into memory.
            Default: False.
    """

    def __init__(self,
                data_file: typing.Union[str, Path],
                in_memory: bool = False,
                split: str = 'train',
                train_ratio: float = 1,
                train_data_file: str = '250K_ddG_split/train_ddG.pkl',
                data_subset='full'
                ):

        data_file = Path(data_file)
        if not data_file.exists():
            raise FileNotFoundError(data_file)
        
        df = pd.read_pickle(data_file)
        
        if train_ratio != 1:
            shuffled_df = df.sort_index()
            # shuffled_df = df.sample(frac=1)
            train_num_samples = int(len(shuffled_df) * train_ratio)
            if split == 'train':
                final_df = shuffled_df.iloc[:train_num_samples]
            elif split == 'valid':
                final_df = shuffled_df.iloc[train_num_samples:]
            else:
                final_df = df
        else:
            final_df = df
        
        # split into subset if not full training set
        if data_subset != 'full':
            ddG_sorted_final_df = final_df.sort_values(by='ddG', ascending=True)
            train_subset_num_samples = int( data_subset * len(ddG_sorted_final_df) ) 
            final_df = ddG_sorted_final_df.iloc[:train_subset_num_samples]

        print("split: ", split)
        print("data_file: ", data_file)
        print("len(final_df): ", len(final_df))

        self.df = final_df
        num_examples = len(final_df)
        self._num_examples = num_examples
        
        if in_memory:
            cache = [None] * num_examples
            self._cache = cache
            
        self._in_memory = in_memory

        
    def __len__(self) -> int:
        return self._num_examples

    def __getitem__(self, index: int):
        if not 0 <= index < self._num_examples:
            raise IndexError(index)

        if self._in_memory and self._cache[index] is not None:
            item = self._cache[index]
        else:
            row = self.df.iloc[index]
            item = {}
            item['ddG'] = row['ddG'] #!
            item['input_ids'] = row['MT_seq'] #!
            item['labels'] = row['MT_seq']

            item['id'] = str(index)
            if self._in_memory:
                self._cache[index] = item
            
        return item

def pad_sequences(sequences: typing.Sequence, constant_value=0, dtype=None) -> np.ndarray:
    batch_size = len(sequences)
    shape = [batch_size] + np.max([seq.shape for seq in sequences], 0).tolist()

    if dtype is None:
        dtype = sequences[0].dtype

    if isinstance(sequences[0], np.ndarray):
        array = np.full(shape, constant_value, dtype=dtype)
    elif isinstance(sequences[0], torch.Tensor):
        array = torch.full(shape, constant_value, dtype=dtype)

    for arr, seq in zip(array, sequences):
        arrslice = tuple(slice(dim) for dim in seq.shape)
        arr[arrslice] = seq

    return array

class CustomStabilityDatasetForGenLatentSpace(Dataset):

    def __init__(self,
                data_path: typing.Union[str, Path],
                split: str,
                tokenizer: typing.Union[str, TAPETokenizer] = 'iupac',
                in_memory: bool = False,
                train_ratio: float = 1,
                normalize_targets: bool = False,
                data_subset='full'):

        # if split not in ('train', 'valid', 'test'):
        #     raise ValueError(f"Unrecognized split: {split}. "
        #                     f"Must be one of ['train', 'valid', 'test']")
        if isinstance(tokenizer, str):
            tokenizer = TAPETokenizer(vocab=tokenizer)
        self.tokenizer = tokenizer

        if split == 'valid':
            file_prefix = 'train'
        else:
            file_prefix = split
            
        data_path = Path(data_path)
        data_file = f'{file_prefix}_ddG.pkl' 

        self.data = PKLDFDatasetForGen(data_path / data_file, in_memory, split, train_ratio, data_subset='full')

    def __len__(self) -> int:
        return len(self.data)

    def __getitem__(self, index: int):
        item = self.data[index]
        input_ids = self.tokenizer.encode(item['input_ids']) 
        labels = self.tokenizer.encode(item['labels'])
        ddG = item['ddG']
        return input_ids, labels, ddG

    
    def collate_fn(self, batch: typing.List[typing.Tuple[typing.Any, ...]]) -> typing.Dict[str, torch.Tensor]:
        input_ids, labels, ddG = tuple(zip(*batch))
        input_ids = torch.from_numpy(pad_sequences(input_ids, 0))
        labels = torch.from_numpy(pad_sequences(labels, 0))
        ddG = torch.Tensor(ddG)

        return {'input_ids': input_ids,
                'labels': labels,
                'ddG': ddG}

In [None]:
train_dataset = CustomStabilityDatasetForGenLatentSpace(data_dir, 'train', train_ratio=train_ratio, tokenizer=tokenizer)
eval_dataset = CustomStabilityDatasetForGenLatentSpace(data_dir, 'valid', train_ratio=train_ratio, tokenizer=tokenizer)

# Train data set-up
train_loader = DataLoader(train_dataset, batch_size=per_device_train_batch_size, shuffle=True, 
                        num_workers=0, collate_fn=train_dataset.collate_fn)

train_loader = tqdm(train_loader)

# Eval data set-up
eval_loader = DataLoader(eval_dataset, batch_size=per_device_eval_batch_size, shuffle=False, 
                        num_workers=0, collate_fn=train_dataset.collate_fn)


eval_loader = tqdm(eval_loader)

In [None]:
def spearmanr(target, prediction):
    target_array = np.asarray(target)
    prediction_array = np.asarray(prediction)
    print("target_array.shape: ", target_array.shape)
    print("prediction_array.shape: ", prediction_array.shape)
    return scipy.stats.spearmanr(target_array, prediction_array).correlation

In [None]:
   
def evaluate(model, eval_iterator, do_mi=False, do_ddG_spearmanr=True, latent_space_type='plain', return_pred=False):
    eval_contrastive_loss_total = 0
    eval_lm_loss_total = 0
    if do_mi:
        eval_mi_head_loss_total = 0
    if latent_space_type == 'vae':
        eval_kl_loss_total = 0
    model.eval()
    num_eval_batch = 0
    
    ddG_preds=[]
    ddG_targs = []

    with torch.no_grad():
        for step, batch in enumerate(eval_iterator):
            
            input_ids = batch['input_ids'].to(model.device)
            labels = batch['labels'].to(model.device)
            ddG_targets = batch['ddG'].to(model.device)
            
            if do_mi:
                model_outputs = model(input_ids, labels=labels, contrast_targets=ddG_targets)
                outputs, contrastive_loss, contrastive_value, mi_head_loss = model_outputs[0], model_outputs[1], model_outputs[2], model_outputs[3]
                eval_mi_head_loss_total = eval_mi_head_loss_total + mi_head_loss
            else:
                model_outputs = model(input_ids, labels=labels, contrast_targets=ddG_targets)
                outputs, contrastive_loss, contrastive_value = model_outputs[0], model_outputs[1], model_outputs[2]
            
            if latent_space_type == 'vae':
                kl_loss = model_outputs[-1]

            for pred, target in zip(contrastive_value.squeeze().cpu().numpy(), ddG_targets.cpu().numpy()):
#                 print("target: ", target)
#                 print("pred: ", pred)
                ddG_targs.append(target)
                ddG_preds.append(pred)

            lm_loss = outputs.loss
            
            eval_contrastive_loss_total = eval_contrastive_loss_total + contrastive_loss
            eval_lm_loss_total = eval_lm_loss_total + lm_loss

            if latent_space_type == 'vae':
                eval_kl_loss_total = eval_kl_loss_total + kl_loss
            
            # eval_contrastive_losses.append(contrastive_loss)
            # eval_lm_losses.append(lm_loss)

            num_eval_batch += 1

#             if step == 5:
#                 break

    # eval_contrastive_loss = torch.mean(eval_contrastive_losses)
    # eval_lm_loss = torch.mean(eval_lm_losses)
    eval_lm_loss = eval_lm_loss_total / num_eval_batch
    eval_contrastive_loss = eval_contrastive_loss_total / num_eval_batch
    eval_output = {
                "lm_loss": eval_lm_loss,
                "contrastive_loss": eval_contrastive_loss,
                  }

    if do_mi:
        eval_mi_head_loss_total = eval_mi_head_loss_total / num_eval_batch
        eval_output['mi_head_loss'] = eval_mi_head_loss_total

    if latent_space_type == 'vae':
        eval_kl_loss_total = eval_kl_loss_total / num_eval_batch
        eval_output['kl_loss'] = eval_kl_loss_total

    if do_ddG_spearmanr:
        spearmanr_value = spearmanr(ddG_targs, ddG_preds)
        print("spearmanr_value: ", spearmanr_value)
        eval_output['spearmanr'] = spearmanr_value
    
    if return_pred:
        eval_output['ddG_preds'] = ddG_preds
        eval_output['ddG_targs'] = ddG_targs


    # print("eval_contrastive_loss: ", eval_contrastive_loss)
    # print("eval_lm_loss: ", eval_lm_loss)
    return eval_output

In [None]:
# def evaluate(model, eval_iterator, do_mi=False, do_ddG_spearmanr=True, return_pred=False):
#     eval_contrastive_loss_total = 0
#     eval_lm_loss_total = 0
#     if do_mi:
#         eval_mi_head_loss_total = 0
#     model.eval()
#     num_eval_batch = 0
    
#     ddG_preds=[]
#     ddG_targs = []

#     with torch.no_grad():
#         for step, batch in enumerate(eval_iterator):
            
#             input_ids = batch['input_ids'].to(model.device)
#             labels = batch['labels'].to(model.device)
#             ddG_targets = batch['ddG'].to(model.device)
            
#             if do_mi:
#                 outputs, contrastive_loss, contrastive_value, mi_head_loss = model(input_ids, labels=labels, contrast_targets=ddG_targets)
#                 eval_mi_head_loss_total = eval_mi_head_loss_total + mi_head_loss
#             else:
#                 outputs, contrastive_loss, contrastive_value = model(input_ids, labels=labels, contrast_targets=ddG_targets)
            
#             for pred, target in zip(contrastive_value.squeeze().cpu().numpy(), ddG_targets.cpu().numpy()):
# #                 print("target: ", target)
# #                 print("pred: ", pred)
#                 ddG_targs.append(target)
#                 ddG_preds.append(pred)
                
# #             ddG_targs.append(torch.flatten(ddG_targets).cpu().numpy())
# #             ddG_preds.append(torch.flatten(contrastive_value).squeeze().cpu().numpy())

# #             print("ddG_targets.shape: ", ddG_targets.cpu().numpy().shape)
# #             print("contrastive_value.shape: ", contrastive_value.cpu().numpy().shape)
#             lm_loss = outputs.loss
            
#             eval_contrastive_loss_total = eval_contrastive_loss_total + contrastive_loss
#             eval_lm_loss_total = eval_lm_loss_total + lm_loss
#             # eval_contrastive_losses.append(contrastive_loss)
#             # eval_lm_losses.append(lm_loss)

#             num_eval_batch += 1

# #             if step == 5:
# #                 break

#     # eval_contrastive_loss = torch.mean(eval_contrastive_losses)
#     # eval_lm_loss = torch.mean(eval_lm_losses)
#     eval_lm_loss = eval_lm_loss_total / num_eval_batch
#     eval_contrastive_loss = eval_contrastive_loss_total / num_eval_batch
#     eval_output = {
#                 "lm_loss": eval_lm_loss,
#                 "contrastive_loss": eval_contrastive_loss,
#                   }

#     if do_mi:
#         eval_mi_head_loss_total = eval_mi_head_loss_total / num_eval_batch
#         eval_output['mi_head_loss'] = eval_mi_head_loss_total

#     if do_ddG_spearmanr:
#         spearmanr_value = spearmanr(ddG_targs, ddG_preds)
#         print("spearmanr_value: ", spearmanr_value)
#         eval_output['spearmanr'] = spearmanr_value
        
#     if return_pred:
#         eval_output['ddG_preds'] = ddG_preds
#         eval_output['ddG_targs'] = ddG_targs
        
    
#     # print("eval_contrastive_loss: ", eval_contrastive_loss)
#     # print("eval_lm_loss: ", eval_lm_loss)
#     return eval_output


In [None]:
eval_output = evaluate(model, eval_loader, do_mi=latent_space_args['do_mi'], return_pred=True, latent_space_type=latent_space_args['latent_space_type'])

In [None]:
eval_lm_loss, eval_contrastive_loss, eval_spearmanr_value = eval_output['lm_loss'], eval_output['contrastive_loss'], eval_output['spearmanr']

print("eval_lm_loss: ", eval_lm_loss)
print("eval_contrastive_loss: ", eval_contrastive_loss)
print("eval_spearmanr_value: ", eval_spearmanr_value)

In [None]:
ddG_preds, ddG_targs = eval_output['ddG_preds'], eval_output['ddG_targs']

print("len(ddG_preds): ", len(ddG_preds))
print("len(ddG_targs): ", len(ddG_targs))

In [None]:
print("stats of ddG_preds")
print("min: ", np.min(ddG_preds))
print("mean: ", np.mean(ddG_preds))
print("median: ", np.median(ddG_preds))
print("max: ", np.max(ddG_preds))
print("std: ", np.std(ddG_preds))

In [None]:
plt.figure(figsize=(8,6))
plt.hist(ddG_preds, density=True, label='value_pred', bins=[i for i in range(-20, 10)], alpha=0.4)


plt.hist(ddG_targs, density=True, label='ddG', bins=[i for i in range(-20, 10)], alpha=0.4)
plt.xlabel("ddG", size=14)
plt.ylabel("Density", size=14)
plt.title("Eval set, Controlled generation clspool_waeDeterencStart84kstep1024dim_cyccon1Start84kstep_lre-04_24ep")
plt.legend(loc='upper left')

In [None]:
eval_df = pd.DataFrame()

In [None]:
eval_df['value_pred'] = ddG_preds
eval_df['ddG'] = ddG_targs

In [None]:
eval_df = eval_df.sort_values(by='value_pred', ascending=True)

In [None]:
# topK_list = [10, 100, 1000, 10000]
percentile_list = [95, 90, 85, 80, 75]
topK_list = [len(eval_df)*(100-i)//100 for i in percentile_list]
print(topK_list)

In [None]:
ddG_df = eval_df
all_ddG_list = eval_df['ddG']

In [None]:
for topK in topK_list:
    topK_df = ddG_df[:topK]
    print("top K: ", len(topK_df))
    print("max: ", np.max(topK_df['ddG']))
    print("min: ", np.min(topK_df['ddG']))
    print("mean: ", np.mean(topK_df['ddG']))
    print("median: ", np.median(topK_df['ddG']))
    
    for percentile in percentile_list:
        pct = np.percentile(all_ddG_list, 100-percentile)
        PCI_pct = np.sum(topK_df['ddG'] < pct) / len(topK_df['ddG'])
        print("PCI_{}pct: ".format(percentile), PCI_pct)
    
    PCI_WT = np.sum(topK_df['ddG'] < 0) / len(topK_df['ddG'])
    print("PCI_WT: ", PCI_WT)
    
    print("_"*20)
    
tophalf_df = ddG_df[:len(ddG_df)//2]
print("top half: ", len(tophalf_df))
print("max: ", np.max(tophalf_df['ddG']))
print("min: ", np.min(tophalf_df['ddG']))
print("mean: ", np.mean(tophalf_df['ddG']))
print("median: ", np.median(tophalf_df['ddG']))


# PCI_75pct = np.sum(tophalf_df['ddG'] < train_75pct) / len(tophalf_df['ddG'])
# print("PCI_75pct: ", PCI_75pct)

for percentile in percentile_list:
    pct = np.percentile(all_ddG_list, 100-percentile)
    PCI_pct = np.sum(tophalf_df['ddG'] < pct) / len(tophalf_df['ddG'])
    print("PCI_{}pct: ".format(percentile), PCI_pct)


PCI_WT = np.sum(tophalf_df['ddG'] < 0) / len(tophalf_df['ddG'])
print("PCI_WT: ", PCI_WT)

print("_"*20)


# training data distribution
print("train dataset: ", len(all_ddG_list))
print("max: ", np.max(all_ddG_list))
print("min: ", np.min(all_ddG_list))
print("mean: ", np.mean(all_ddG_list))
print("median: ", np.median(all_ddG_list))


for percentile in percentile_list:
    pct = np.percentile(all_ddG_list, 100-percentile)
    PCI_pct = np.sum(all_ddG_list < pct) / len(all_ddG_list)
    print("PCI_{}pct: ".format(percentile), PCI_pct)


PCI_WT = np.sum(all_ddG_list < 0) / len(all_ddG_list)
print("PCI_WT: ", PCI_WT)

print("_"*20)

# Get value_pred of gen input samples

In [None]:
input_data_path = Path(data_dir)
input_data_file = f'train_ddG.pkl' 
input_data_file = input_data_path / input_data_file
input_data_df = pd.read_pickle(input_data_file)

In [None]:
input_data_df

In [None]:
topk_as_input = 12500

In [None]:
print("ddG stats of input data")
print("min: ", np.min(input_data_df['ddG']))
print("mean: ", np.mean(input_data_df['ddG']))
print("median: ", np.median(input_data_df['ddG']))
print("max: ", np.max(input_data_df['ddG']))

ddG_sorted_input_df = input_data_df.sort_values(by='ddG', ascending=True)

gen_input_df = ddG_sorted_input_df.iloc[:topk_as_input]

In [None]:
gen_input_df

# Get value_pred of train data

In [None]:
train_eval_output = evaluate(model, train_loader, do_mi=latent_space_args['do_mi'], return_pred=True,  latent_space_type=latent_space_args['latent_space_type'])
# eval_output = evaluate(model, eval_loader, do_mi=latent_space_args['do_mi'], return_pred=True, latent_space_type=latent_space_args['latent_space_type'])

In [None]:
train_lm_loss, train_contrastive_loss, train_spearmanr_value = train_eval_output['lm_loss'], train_eval_output['contrastive_loss'], train_eval_output['spearmanr']

print("train_lm_loss: ", train_lm_loss)
print("train_contrastive_loss: ", train_contrastive_loss)
print("train_spearmanr_value: ", train_spearmanr_value)

In [None]:
train_ddG_preds, train_ddG_targs = train_eval_output['ddG_preds'], train_eval_output['ddG_targs']

print("len(train_ddG_preds): ", len(train_ddG_preds))
print("len(train_ddG_targs): ", len(train_ddG_targs))

In [None]:
print("stats of ddG_preds, train set")
print("min: ", np.min(train_ddG_preds))
print("mean: ", np.mean(train_ddG_preds))
print("median: ", np.median(train_ddG_preds))
print("max: ", np.max(train_ddG_preds))
print("std: ", np.std(train_ddG_preds))

In [None]:
plt.figure(figsize=(8,6))
plt.hist(train_ddG_preds, density=True, label='value_pred', bins=[i for i in range(-20, 10)], alpha=0.4)


plt.hist(train_ddG_targs, density=True, label='ddG', bins=[i for i in range(-20, 10)], alpha=0.4)
plt.xlabel("ddG", size=14)
plt.ylabel("Density", size=14)
plt.title("Train set, Controlled generation clspool_waeDeterencStart84kstep1024dim_cyccon1Start84kstep_lre-04_24ep")
plt.legend(loc='upper left')