In [None]:
import numpy as np 
from tqdm import tqdm 
import pandas as pd
import itertools
from xlnet_plabel_utils import GWSDatasetFromPandas 
from scipy.stats import chi2_contingency
import matplotlib.pyplot as plt
import seaborn as sns
import scienceplots    
from permetrics.regression import RegressionMetric
from itables import show, init_notebook_mode
import marsilea as ma
import matplotlib

In [None]:
# global variables
id_to_codon = {idx:''.join(el) for idx, el in enumerate(itertools.product(['A', 'T', 'C', 'G'], repeat=3))}
codon_to_id = {v:k for k,v in id_to_codon.items()}

stop_codons = ['TAA', 'TAG', 'TGA']

codonid_list = []

for i in range(64):
    codon = id_to_codon[i]
    if codon not in stop_codons:
        codonid_list.append(i)

print('Number of codons:', len(codonid_list))

condition_dict_values = {64: 'CTRL', 65: 'ILE', 66: 'LEU', 67: 'LEU_ILE', 68: 'LEU_ILE_VAL', 69: 'VAL'}
condition_dict = {v: k for k, v in condition_dict_values.items()}

In [None]:
mutations_everything = np.load('bms/motifswAF_addStall_1000.npz', allow_pickle=True)

In [None]:
mutations_everything = mutations_everything['mutations_everything'].item()

In [None]:
keys = list(mutations_everything.keys())
print(keys)

In [None]:
for key in keys:
    assert len(mutations_everything[key]) == 155

In [None]:
motif_str = []
motif_len = []
condition = []
perc_increase = []
orig_density_list = []
new_density_list = []

In [None]:
for k in tqdm(keys):
    # print(k)
    start = k[2]
    orig_density = k[5]
    # print(mutations_everything[k])
    for mo in mutations_everything[k]:
        # print(mo, len(mo))
        condition.append(k[4])
        new_density = mutations_everything[k][mo]
        orig_density_list.append(orig_density)
        new_density_list.append(new_density)
        # print(orig_density, new_density, (new_density-orig_density)/orig_density)
        try:
            perc_increase.append(np.abs((new_density-orig_density)/orig_density))
        except:
            perc_increase.append(0)
        # print(orig_density, new_density)
        motif_len.append(int(len(mo)/2))
        motif_sample_dict = {}
        for i in range(0, len(mo), 2):
            # print(start-mo[i], mo[i+1])
            motif_sample_dict[mo[i] - (start+10)] = mo[i+1]
        # sort the dictionary in ascending order of the keys
        motif_sample_dict = dict(sorted(motif_sample_dict.items()))
        motif_str_sample = ''
        # print(motif_sample_dict)
        for k1, v1 in motif_sample_dict.items():
            motif_str_sample += str(k1) + '_' + str(v1) + '_'
        # print(motif_str_sample)
        motif_str.append(motif_str_sample)

In [None]:
# make a dataframe
df = pd.DataFrame({'motif': motif_str, 'motif_len': motif_len, 'perc_increase': perc_increase, 'condition': condition, 'orig_density': orig_density_list, 'new_density': new_density_list})

In [None]:
pos = []
for i in range(len(motif_str)):
    x = motif_str[i].split('_')[:-1]
    x = [int(el) for el in x]
    for j in range(0, len(x), 2):
        pos.append(x[j])

In [None]:
np.unique(pos)

In [None]:
# condition wise get the unique motifs and get their percentage counts 
condition_wise_dfs = {}

for c in df['condition'].unique():
    df_c = df[df['condition'] == c]

    # retain only one row per unique motif
    df_c = df_c['motif'].value_counts(normalize=True).reset_index()
    df_c.columns = ['motif', 'perc']

    c_list_fin = [[] for _ in range(21)]

    for m in tqdm(df_c['motif']):
        m_s = m.split('_')[:-1]
        for i in range(0, len(m_s), 2):
            # print(id_to_codon[int(m_s[i+1])], int(m_s[i])+10)
            c_list_fin[int(m_s[i])+10].append(id_to_codon[int(m_s[i+1])])

        # other codons that are not in the motif, put - 
        pos_motif = [int(x) for x in m_s[::2]]
        for i in range(-10, 11):
            if i not in pos_motif:
                c_list_fin[i+10].append('-')

    # add the codons in the motif to the df
    for i in range(-10, 11):
        df_c['codon_'+str(i)] = c_list_fin[i+10]

    # rename the codon columns
    df_c.columns = ['motif', 'perc_counts'] + ['-10', '-9', '-8', '-7', '-6', '-5', '-4', '-3', 'E', 'P', 'A', '1', '2', '3', '4', '5', '6', '7', '8', '9', '10']

    condition_wise_dfs[c] = df_c

In [None]:
GC_FPATH = '../data/genetic_code.csv'

In [None]:
colors_full = plt.cm.tab20c(np.linspace(0, 1, 64))
colors_aa = np.array(colors_full).repeat(2,0)
np.random.seed(42)
np.random.shuffle(colors_aa)
max_motif_val = 28383

with plt.style.context(['science','nature','grid','bright','no-latex']):
    # for c in df['condition'].unique():
    for c in ['CTRL', 'ILE', 'LEU', 'VAL', 'LEU_ILE', 'LEU_ILE_VAL']:
        AA = ['Val', 'Ile', 'Leu', 'Lys', 'Asn', 'Thr', 'Arg', 'Ser', 'Met', 'Gln', 'His', 'Pro', 'Glu', 'Asp', 'Ala', 'Gly', 'Tyr', 'Cys', 'Trp', 'Phe', 'Stp']
        AA_1 = ['V', 'I', 'L', 'K', 'N', 'T', 'R', 'S', 'M', 'Q', 'H', 'P', 'E', 'D', 'A', 'G', 'Y', 'C', 'W', 'F', 'S']
        aa_3_1 = {AA[i]: AA_1[i] for i in range(len(AA))}

        # for c in df['condition'].unique():
        df_c = condition_wise_dfs[c]
        df_c = df_c.drop(columns=['motif', 'perc_counts'])
        df_c_mat = df_c.to_numpy()

        # convert this matrix into a percentage occurrence matrix without the '-' codon
        df_c_mat_perc = np.zeros((64, 21))
        for i in range(21):
            # get the percentage counts of the 64 codons at this position
            codon_counts = df_c_mat[:, i]
            num_non_dash = np.sum(codon_counts != '-')
            # get counts for each codon
            for j in range(64):
                df_c_mat_perc[j, i] = (np.sum(codon_counts == id_to_codon[j])/num_non_dash)*100

        # replace nans with 0
        df_c_mat_perc = np.nan_to_num(df_c_mat_perc)

        stack_data = pd.DataFrame(df_c_mat_perc, index=[id_to_codon[i] for i in range(64)])

        # sort the data by rows 
        stack_data = stack_data.sort_values(by=stack_data.columns.tolist(), ascending=False)
        # make a heatmap
        # remove rows that have a sum less than the mean of the row sums
        stack_data_t = stack_data[stack_data.sum(axis=1) > stack_data.sum(axis=1).mean()]

        # normalize every column
        stack_data_thresh = stack_data_t/stack_data_t.sum(axis=0)
        h = ma.Heatmap(stack_data_thresh.T, linewidth=0.5, width=5, height=5, cmap='Blues', label='Frequency', vmin=0, vmax=1)
        # add lables to x and y axis with chunk
        colors_set = [colors_full[codon_to_id[i]] for i in list(stack_data_thresh.index)]
        # only keep the amino acids for the codons
        genetic_code = pd.read_csv(GC_FPATH, index_col=0).set_index('Codon')
        # add one letter amino acid code to the genetic code
        genetic_code['AminoAcid_1'] = [aa_3_1[i] for i in genetic_code['AminoAcid']]
        genetic_code = genetic_code.loc[stack_data_thresh.index]
        # h.add_top(ma.plotter.Chunk(list(stack_data_thresh.index), colors_set, fontsize=15), pad=0.025)
        # group the amino acids
        h.group_cols(group=genetic_code.AminoAcid_1, order=list(set(genetic_code.AminoAcid_1)), spacing=0.002)
        colors_aa_c = []
        for i in list(set(genetic_code.AminoAcid_1)): # index of the amino acid in AA
            colors_aa_c.append(colors_aa[AA_1.index(i)])
        h.add_top(ma.plotter.Chunk(list(set(genetic_code.AminoAcid_1)), colors_aa_c, fontsize=15), pad=0.025)
        h.add_bottom(ma.plotter.Labels(list(stack_data_thresh.index), fontsize=10, rotation=45), name='Codon')

        # add right barplot with the number of motifs 
        num_motifs_list = stack_data_t.sum(axis=0).values/100
        for i in range(21):
            codon_counts = df_c_mat[:, i]
            num_non_dash = np.sum(codon_counts != '-')
            num_motifs_list[i] = num_non_dash * num_motifs_list[i]

        # log the number of motifs
        num_motifs_list = np.log(num_motifs_list+1)

        # make a colormesh plot and add it to the right of the heatmap
        cm = ma.plotter.ColorMesh(num_motifs_list.reshape(1, -1), cmap='Reds', vmin=0, vmax=np.log(max_motif_val), label='Num. Motifs (log)', label_props={'color': 'white', 'fontsize': 0})
        h.add_right(cm, pad=0.1, size=0.2)

        # h.add_right(ma.plotter.Numbers(num_motifs_list, label='Num. Motifs', show_value = True), pad=0.1)

        c_text = c 
        if c == 'LEU_ILE':
            c_text = 'LEU + ILE'
        if c == 'LEU_ILE_VAL':
            c_text = 'LEU + ILE + VAL'
        h.add_title('Codon frequencies in motifs for ' + c_text, fontsize=20, pad=0.1)

        pos_labels_list = ['-10', '-9', '-8', '-7', '-6', '-5', '-4', '-3', 'E', 'P', 'A', '1', '2', '3', '4', '5', '6', '7', '8', '9', '10']

        h.add_left(ma.plotter.Labels(list(pos_labels_list), fontsize=15))
        h.add_legends(pad=0.025)
        h.render()

        plt.savefig('plots/motifswAF_addStall_1000_HeatMap_' + c + '.pdf', dpi=400, transparent=True)
        plt.show()