In [None]:
# libraries
import torch
from transformers import XLNetConfig, XLNetForTokenClassification
from utils import GWSDatasetFromPandas 
import itertools
import pandas as pd
from tqdm import tqdm
import matplotlib.pyplot as plt
import numpy as np
# suppress warnings
import warnings
warnings.filterwarnings("ignore")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
stop_codons = ['TAA', 'TAG', 'TGA']

# global variables
id_to_codon = {idx:''.join(el) for idx, el in enumerate(itertools.product(['A', 'T', 'C', 'G'], repeat=3))}
codon_to_id = {v:k for k,v in id_to_codon.items()}

codonid_list = []

for i in range(64):
    codon = id_to_codon[i]
    if codon not in stop_codons:
        codonid_list.append(i)

print('Number of codons:', len(codonid_list))

In [None]:
# model parameters
annot_thresh = 0.3
longZerosThresh_val = 20
percNansThresh_val = 0.05
d_model_val = 512
n_layers_val = 6
n_heads_val = 4
dropout_val = 0.1
lr_val = 1e-4
batch_size_val = 2
loss_fun_name = '4L' # 5L

model_loc = '../../../checkpoints/XLNet-PLabelDH_S2/' 

condition_dict_values = {64: 'CTRL', 65: 'ILE', 66: 'LEU', 67: 'LEU_ILE', 68: 'LEU_ILE_VAL', 69: 'VAL'}
condition_dict = {v: k for k, v in condition_dict_values.items()}

class XLNetDH(XLNetForTokenClassification):
    def __init__(self, config):
        super().__init__(config)
        self.classifier = torch.nn.Linear(d_model_val, 2, bias=True)

config = XLNetConfig(vocab_size=71, pad_token_id=70, d_model = d_model_val, n_layer = n_layers_val, n_head = n_heads_val, d_inner = d_model_val, num_labels = 1, dropout=dropout_val) # 64*6 tokens + 1 for padding
model = XLNetDH(config)

# load model from the saved model
model = model.from_pretrained(model_loc + "/best_model")
model.to(device)
# set model to evaluation mode
model.eval()

In [None]:
# convert pandas dataframes into torch datasets
test_dataset = pd.read_csv('../../../data/orig/test.csv')
train_dataset = pd.read_csv('../../../data/orig/train.csv')
val_dataset = pd.read_csv('../../../data/orig/val.csv')

# merge the datasets
merged_dataset = pd.concat([train_dataset, val_dataset, test_dataset], ignore_index=True)

# create the datasets
merged_dataset = GWSDatasetFromPandas(merged_dataset)

print("samples in merged dataset: ", len(merged_dataset))

In [None]:
num_windows = 1000
model_bs = 256
window_size = 21 

In [None]:
a_pos = int(window_size/2)
conditions_windows = {'CTRL': [], 'ILE': [], 'LEU': [], 'LEU_ILE': [], 'LEU_ILE_VAL': [], 'VAL': []}
conditions_windows_fin = {'CTRL': [], 'ILE': [], 'LEU': [], 'LEU_ILE': [], 'LEU_ILE_VAL': [], 'VAL': []}

# get num_windows random samples per condition from the merged dataset

for i in tqdm(range(len(merged_dataset))): 
    sample_condition = merged_dataset[i][0][0].item()
    g = merged_dataset[i][3]
    t = merged_dataset[i][4]
    y_true_sample = merged_dataset[i][1].numpy()
    x_input_sample = merged_dataset[i][0].numpy()
    if len(y_true_sample) > 500:
        continue
    non_peak = np.nanmean(y_true_sample) - np.nanstd(y_true_sample)
    for j in range(window_size, len(y_true_sample)-window_size):
        sample_window = x_input_sample[j:j+window_size]
        if len(sample_window) == window_size:
            if sample_window[a_pos] < non_peak:
                conditions_windows[condition_dict_values[sample_condition]].append((g, t, x_input_sample, j))

In [None]:
# choose at random num_windows windows from each condition
for condition in conditions_windows:
    np.random.shuffle(conditions_windows[condition])
    conditions_windows_fin[condition] = conditions_windows[condition][:num_windows]

In [None]:
def AsiteDensity(windows, condition, start):
    # prepend condition value to the window
    condition_val = condition_dict[condition]
    windows = np.insert(windows, 0, condition_val, axis=1)
    windows = torch.tensor(windows).to(device)
    # get the model prediction
    with torch.no_grad():
        for i in range(0, windows.shape[0], model_bs):
            pred = model(windows[i:i+model_bs])
            if i == 0:
                pred_out = pred["logits"][:, 1:, :]
            else:
                pred_out = torch.cat((pred_out, pred["logits"][:, 1:, :]), 0)
    # relu on the first dimension
    ctrl = torch.relu(pred_out[:, :, 0])
    dd = pred_out[:, :, 1]
    # get the density at the A-site only
    dd_out = dd[:, start+a_pos]
    ctrl_out = ctrl[:, start+a_pos]

    if condition_val == 64:
        return ctrl_out
    else:
        return dd_out

def getTopXMutants(full_inp, start, condition, X, mutant_pos1=-1, c_pos1=None, mutant_pos2=-1, c_pos2=None):
    # start+10 is the A-site
    window_density = {}
    if mutant_pos1 == -1 and mutant_pos2 == -1:
        inputs_all = []
        substs_all = []
        for k in range(window_size):
            # if k == 10:
            #     continue
            for c in codonid_list:
                input_copy = full_inp.copy()
                input_copy[start+k] = c
                inputs_all.append(input_copy)
                substs_all.append((start+k, c))
        inputs_all = np.array(inputs_all)
        preds = AsiteDensity(inputs_all, condition, start)
        for l in range(len(substs_all)): 
            window_density[(substs_all[l][0], substs_all[l][1])] = preds[l].item()
    elif mutant_pos1 != -1 and mutant_pos2 == -1:
        inputs_all = []
        substs_all = []
        for k in range(window_size):
            if k+start == mutant_pos1:
                continue
            for c in codonid_list:
                input_copy = full_inp.copy()
                input_copy[start+k] = c
                input_copy[mutant_pos1] = c_pos1
                inputs_all.append(input_copy)
                substs_all.append((start+k, c))
        inputs_all = np.array(inputs_all)
        preds = AsiteDensity(inputs_all, condition, start)
        for l in range(len(substs_all)): 
            window_density[(mutant_pos1, c_pos1, substs_all[l][0], substs_all[l][1])] = preds[l].item()
    elif mutant_pos1 != -1 and mutant_pos2 != -1:
        inputs_all = []
        substs_all = []
        for k in range(window_size):
            if k+start == mutant_pos1 or k+start == mutant_pos2:
                continue
            for c in codonid_list:
                input_copy = full_inp.copy()
                input_copy[start+k] = c
                input_copy[mutant_pos1] = c_pos1
                input_copy[mutant_pos2] = c_pos2
                inputs_all.append(input_copy)
                substs_all.append((start+k, c))
        inputs_all = np.array(inputs_all)
        preds = AsiteDensity(inputs_all, condition, start)
        for l in range(len(substs_all)): 
            window_density[(mutant_pos2, c_pos2, mutant_pos1, c_pos1, substs_all[l][0], substs_all[l][1])] = preds[l].item()

    # sort the dictionary by values
    window_density = dict(sorted(window_density.items(), key=lambda item: item[1], reverse=True))

    # choose only top k
    window_density = dict(itertools.islice(window_density.items(), X))

    return window_density

In [None]:
# for each of the conditions and each of the windows, mutate one codon at a time and check the value at the A site, and choose the top 5 mutations that increase the value
mutations_everything = {}
num_mutants = 5
for condition in conditions_windows_fin:
    for sample in tqdm(conditions_windows_fin[condition]):
        sample_mutations = {}
        window = sample[2][sample[3]:sample[3]+window_size]
        original_density = AsiteDensity(sample[2], condition, sample[3]).item()
        mutants_one = getTopXMutants(sample[2], sample[3], condition, num_mutants)
        # print("Gen 1:", mutants_one)
        # add all the mutants to the list
        for x in mutants_one:
            sample_mutations[x] = mutants_one[x]
        for mutant in mutants_one:
            mutant_pos1 = mutant[0]
            c_pos1 = mutant[1]
            mutants_two = getTopXMutants(sample[2], sample[3], condition, num_mutants, mutant_pos1, c_pos1)
            # print("Gen 2:", mutants_two)
            # add all the mutants to the list
            for x in mutants_two:
                sample_mutations[x] = mutants_two[x]
            for mutant2 in mutants_two:
                mutant_pos2 = mutant2[2]
                c_pos2 = mutant2[3]
                mutants_three = getTopXMutants(sample[2], sample[3], condition, num_mutants, mutant_pos1, c_pos1, mutant_pos2, c_pos2)
                # print("Gen 3:", mutants_three)
                # add all the mutants to the list
                for x in mutants_three:
                    sample_mutations[x] = mutants_three[x]

        mutations_everything[(sample[0], sample[1], sample[3], str(window), condition, original_density)] = sample_mutations


In [None]:
# save the mutations to a file
np.savez('../../../data/motifs/motifs_' + str(window_size) + '_' + str(num_windows) + '.npz', mutations_everything=mutations_everything)