In [None]:
import h5py 
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import scienceplots
import matplotlib
import itertools
from adjustText import adjust_text
from scipy.stats import pearsonr
from pyhere import here
plt.style.use('nature')

from tqdm import tqdm

In [None]:
def globStalling(axs, mode='peaks'):
    '''
    mode is either 'peaks' or 'full'
    '''
    window_size = 20
    f_test_path = here('data', 'results', 'interpretability', '241001_RDHPLG_test_int.h5')
    f_train_path = here('data', 'results', 'interpretability', '241022_RDHPLG_train_int.h5')
    f_val_path = here('data', 'results', 'interpretability', '241029_RDHPLG_val_int.h5')
    f_test = h5py.File(f_test_path, 'r')
    f_train = h5py.File(f_train_path, 'r')
    f_val = h5py.File(f_val_path, 'r')

    condition_values = {'CTRL': 64, 'ILE': 65, 'LEU': 66, 'LEU_ILE': 67, 'LEU_ILE_VAL': 68, 'VAL': 69}
    condition_values_inverse = {64: 'CTRL', 65: 'ILE', 66: 'LEU', 67: 'LEU_ILE', 68: 'LEU_ILE_VAL', 69: 'VAL'}
    # global variables
    id_to_codon = {idx:''.join(el) for idx, el in enumerate(itertools.product(['A', 'T', 'C', 'G'], repeat=3))}
    codon_to_id = {v:k for k,v in id_to_codon.items()}
    
    condition_codon_stall = {'CTRL': {codon: [] for codon in id_to_codon.values()}, 'ILE': {codon: [] for codon in id_to_codon.values()}, 'LEU': {codon: [] for codon in id_to_codon.values()}, 'LEU_ILE': {codon: [] for codon in id_to_codon.values()}, 'LEU_ILE_VAL': {codon: [] for codon in id_to_codon.values()}, 'VAL': {codon: [] for codon in id_to_codon.values()}}

    stop_codons = ['TAA', 'TAG', 'TGA']
    for condition in condition_values.keys():
        for codon in stop_codons:
            condition_codon_stall[condition].pop(codon)
            
    for f in [f_val, f_test, f_train]:
        num_samples = len(f["condition"])

        for i in tqdm(range(num_samples)):
            sample_condition = f['condition'][i].decode('utf-8')
            if sample_condition == 'CTRL':
                y_true_full_sample = f['y_true_full'][i]
            else:
                y_true_full_sample = f['y_true_dd'][i]
            x_input_sample = f['x_input'][i][1:]
            y_true_full_sample_norm = y_true_full_sample / np.nanmax(y_true_full_sample)
            for j in range(len(y_true_full_sample_norm)):
                if np.isnan(y_true_full_sample[j]) == False and id_to_codon[int(x_input_sample[j])] in condition_codon_stall[sample_condition] and y_true_full_sample[j] != 0.0:
                    condition_codon_stall[sample_condition][id_to_codon[int(x_input_sample[j])]].append(y_true_full_sample_norm[j])

    condition_codon_stall_mean = {condition: {codon: np.mean(condition_codon_stall[condition][codon]) for codon in condition_codon_stall[condition]} for condition in condition_values.keys()}
    # sort the dictionary by the mean stall value in descending order
    condition_codon_stall_mean_sorted = {condition: {k: v for k, v in sorted(condition_codon_stall_mean[condition].items(), key=lambda item: item[1], reverse=True)} for condition in condition_values.keys()}

    condition_codon_attr_peaks = {'CTRL': {codon: [] for codon in id_to_codon.values()}, 'ILE': {codon: [] for codon in id_to_codon.values()}, 'LEU': {codon: [] for codon in id_to_codon.values()}, 'LEU_ILE': {codon: [] for codon in id_to_codon.values()}, 'LEU_ILE_VAL': {codon: [] for codon in id_to_codon.values()}, 'VAL': {codon: [] for codon in id_to_codon.values()}}
    condition_codon_attr_full = {'CTRL': {codon: [] for codon in id_to_codon.values()}, 'ILE': {codon: [] for codon in id_to_codon.values()}, 'LEU': {codon: [] for codon in id_to_codon.values()}, 'LEU_ILE': {codon: [] for codon in id_to_codon.values()}, 'LEU_ILE_VAL': {codon: [] for codon in id_to_codon.values()}, 'VAL': {codon: [] for codon in id_to_codon.values()}}

    stop_codons = ['TAA', 'TAG', 'TGA']
    for condition in condition_values.keys():
        for codon in stop_codons:
            condition_codon_attr_peaks[condition].pop(codon)
            condition_codon_attr_full[condition].pop(codon)
            
    for f in [f_val, f_test, f_train]:
        num_samples = len(f["condition"])
        for i in tqdm(range(num_samples)):
            sample_cond = f['condition'][i].decode('utf-8')
            x_input_sample = f['x_input'][i][1:]
            num_codons = len(x_input_sample)
            y_true_full_sample = f['y_true_full'][i]
            if sample_cond == 'CTRL':
                lig_attr_ctrl_sample = f['lig_ctrl'][i].reshape(num_codons, num_codons)
            else:
                lig_attr_ctrl_sample = f['lig_dd'][i].reshape(num_codons, num_codons)
            
            # find the indices of the codons with the top 10 highest values
            top10_indices = np.argsort(-y_true_full_sample)[:10]
            # set j to be starting points, and k to be the end points
            for j in range(len(y_true_full_sample)):
                # a_site = top10_indices[j]
                a_site = j
                start = a_site - window_size
                end = a_site + window_size + 1
                
                lig_attr_ctrl_sample_window = lig_attr_ctrl_sample[a_site][start:end]
                if len(lig_attr_ctrl_sample_window) == (window_size*2) + 1:
                    lig_attr_ctrl_sample_window_norm = lig_attr_ctrl_sample_window / np.max(np.abs(lig_attr_ctrl_sample_window))
                    x_input_sample_window = x_input_sample[start:end]
                    for l in range(len(x_input_sample_window)):
                        if id_to_codon[int(x_input_sample_window[l])] in condition_codon_attr_full[sample_cond]:
                            condition_codon_attr_full[sample_cond][id_to_codon[int(x_input_sample_window[l])]].append(lig_attr_ctrl_sample_window_norm[l])

            for j in range(len(top10_indices)):
                a_site = top10_indices[j]
                start = a_site - window_size
                end = a_site + window_size + 1
                
                lig_attr_ctrl_sample_window = lig_attr_ctrl_sample[a_site][start:end]
                if len(lig_attr_ctrl_sample_window) == (window_size*2) + 1:
                    lig_attr_ctrl_sample_window_norm = lig_attr_ctrl_sample_window / np.max(np.abs(lig_attr_ctrl_sample_window))
                    x_input_sample_window = x_input_sample[start:end]
                    for l in range(len(x_input_sample_window)):
                        if id_to_codon[int(x_input_sample_window[l])] in condition_codon_attr_peaks[sample_cond]:
                            condition_codon_attr_peaks[sample_cond][id_to_codon[int(x_input_sample_window[l])]].append(lig_attr_ctrl_sample_window_norm[l])
    
    condition_codon_attr_full_mean = {condition: {codon: np.mean(condition_codon_attr_full[condition][codon]) for codon in condition_codon_attr_full[condition]} for condition in condition_values.keys()}
    # sort the dictionary by the mean stall value in descending order
    condition_codon_attr_full_mean_sorted = {condition: {k: v for k, v in sorted(condition_codon_attr_full_mean[condition].items(), key=lambda item: item[1], reverse=True)} for condition in condition_values.keys()}

    condition_codon_attr_peaks_mean = {condition: {codon: np.mean(condition_codon_attr_peaks[condition][codon]) for codon in condition_codon_attr_peaks[condition]} for condition in condition_values.keys()}
    # sort the dictionary by the mean stall value in descending order
    condition_codon_attr_peaks_mean_sorted = {condition: {k: v for k, v in sorted(condition_codon_attr_peaks_mean[condition].items(), key=lambda item: item[1], reverse=True)} for condition in condition_values.keys()}
    
    f_test.close()
    f_train.close()
    f_val.close()
    
    genetic_code_path = here('data', 'genetic_code.csv')
    genetic_code = pd.read_csv(genetic_code_path)

    colors_depr = {'CTRL': '#6EC207', 'VAL': '#FF204E', 'LEU': '#3498db', 'ILE': '#9b59b6'}

    # get the codons for each deprivation condition
    deprivation_conditions = ['Ile', 'Leu', 'Val', 'CTRL']

    depr_codons = {}

    for condition in deprivation_conditions:
        amino_acids = condition.split('_')
        codons_dep_cond = []
        for amino_acid in amino_acids:
            df_aa = genetic_code[genetic_code['AminoAcid'] == amino_acid]
            codons_dep_cond += df_aa['Codon'].tolist()

        depr_codons[condition.upper()] = codons_dep_cond
    depr_codons['LEU_ILE'] = []
    depr_codons['LEU_ILE_VAL'] = []
    # make CTRL the first key in depr_codons
    depr_codons = {k: depr_codons[k] for k in ['CTRL', 'ILE', 'LEU', 'VAL', 'LEU_ILE', 'LEU_ILE_VAL']}

    codons_to_depr = {codon: [depr for depr, codons in depr_codons.items() if codon in codons] for codon in id_to_codon.values()}
    ctrl_tagged_codons = ['GAC', 'GAA', 'GAT', 'GAG', 'GGA']
    
    # condition_codon_stall_mean_sorted = np.load('/nfs_home/nallapar/final/riboclette/riboclette/models/xlnet/interpretability/data/condition_codon_stall_mean_sorted.npz', allow_pickle=True)['arr_0'].item()
    # condition_codon_attr_full_mean_sorted = np.load('/nfs_home/nallapar/final/riboclette/riboclette/models/xlnet/interpretability/data/condition_codon_attr_full_mean_sorted.npz', allow_pickle=True)['arr_0'].item()
    # condition_codon_attr_peaks_mean_sorted = np.load('/nfs_home/nallapar/final/riboclette/riboclette/models/xlnet/interpretability/data/condition_codon_attr_peaks_mean_sorted.npz', allow_pickle=True)['arr_0'].item()
    # codons_to_depr = np.load('/nfs_home/nallapar/final/riboclette/riboclette/models/xlnet/interpretability/data/codons_to_depr.npz', allow_pickle=True)['arr_0'].item()
    
    # plot codon_stall and codon_attr values for each codon and see if they correlate
    # for each codon, plot the stall value on the x-axis and the attribute value on the y-axis
    if mode == 'peaks':
        with plt.style.context(['science','nature','grid','bright','no-latex']):
            for i, condition in enumerate(depr_codons.keys()):
                texts = []
                for codon in condition_codon_stall_mean_sorted[condition]:
                    x = condition_codon_stall_mean_sorted[condition][codon]
                    y = condition_codon_attr_peaks_mean_sorted[condition][codon]
                    # get the deprivation condition
                    if len(codons_to_depr[codon]) != 0:
                        axs[i].scatter(x, y, label=codon, color=colors_depr[codons_to_depr[codon][0]], s=40)
                        if condition == 'ILE' or condition == 'LEU_ILE':
                            if codon in ['ATC', 'ATT']:
                                texts.append(axs[i].text(x, y, codon, fontsize=20))
                        elif condition == 'VAL' or condition == 'LEU_ILE_VAL':
                            if codon in ['GTC', 'GTT', 'GTA', 'GTG']:
                                texts.append(axs[i].text(x, y, codon, fontsize=20))
                    elif codon in ctrl_tagged_codons:
                        axs[i].scatter(x, y, label=codon, color=colors_depr['CTRL'], alpha=1, s=40)
                        if condition == 'CTRL':
                            texts.append(axs[i].text(x, y, codon, fontsize=20))
                    else:
                        axs[i].scatter(x, y, label=codon, color='black', alpha=0.5, s=10)

                # adjust the text
                adjust_text(texts, ax = axs[i], arrowprops=dict(arrowstyle='->', color='red'), expand_points=(1.2, 1.2), expand_text=(1.2, 1.2), force_text=(0.5, 0.5))

                # get full x and y lists
                x = [condition_codon_stall_mean_sorted[condition][codon] for codon in condition_codon_stall_mean_sorted[condition]]
                y = [condition_codon_attr_peaks_mean_sorted[condition][codon] for codon in condition_codon_stall_mean_sorted[condition]]
                # calculate pearson correlation
                corr, _ = pearsonr(x, y)

                # fit a line to the data
                z = np.polyfit(x, y, 1)
                p = np.poly1d(z)
                axs[i].plot(x, p(x), "r--", color='black', alpha=0.5)

                axs[i].set_xlabel('Mean Ribosome Counts', fontsize=15)
                axs[i].set_ylabel('Mean Attribution Value', fontsize=15)
                if condition == 'LEU_ILE':
                    c_text = 'LEU + ILE'
                elif condition == 'LEU_ILE_VAL':
                    c_text = 'LEU + ILE + VAL'
                else:
                    c_text = condition
                axs[i].set_title(c_text + " (PCC: {:.2f})".format(corr), fontsize=20)

                # ticks size
                axs[i].tick_params(axis='both', which='major', labelsize=15)

            legend_elements = [matplotlib.lines.Line2D([0], [0], marker='o', color='w', label='CTRL Tagged Codons', markerfacecolor=colors_depr['CTRL'], markersize=10),
                            matplotlib.lines.Line2D([0], [0], marker='o', color='w', label='ILE Codons', markerfacecolor=colors_depr['ILE'], markersize=10),
                            matplotlib.lines.Line2D([0], [0], marker='o', color='w', label='LEU Codons', markerfacecolor=colors_depr['LEU'], markersize=10),
                            matplotlib.lines.Line2D([0], [0], marker='o', color='w', label='VAL Codons', markerfacecolor=colors_depr['VAL'], markersize=10),
                            matplotlib.lines.Line2D([0], [0], marker='o', color='w', label='Other Codons', markerfacecolor='black', markersize=10)]
            
            # put legend outside of plot
            plt.legend(handles=legend_elements, loc='upper left', bbox_to_anchor=(1.1, 1.1), fontsize=15)

            # plt.savefig('plots/codon_stall-attr/Peaks_' + str(window_size) + '.png', dpi=300)
            plt.show()
    elif mode == 'full':
        # plot codon_stall and codon_attr values for each codon and see if they correlate
        # for each codon, plot the stall value on the x-axis and the attribute value on the y-axis

        with plt.style.context(['science','nature','grid','bright','no-latex']):

            for i, condition in enumerate(depr_codons.keys()):
                texts = []
                for codon in condition_codon_stall_mean_sorted[condition]:
                    x = condition_codon_stall_mean_sorted[condition][codon]
                    y = condition_codon_attr_full_mean_sorted[condition][codon]
                    # get the deprivation condition
                    if len(codons_to_depr[codon]) != 0:
                        axs[i].scatter(x, y, label=codon, color=colors_depr[codons_to_depr[codon][0]], s=40)
                        if condition == 'ILE' or condition == 'LEU_ILE':
                            if codon in ['ATC', 'ATT']:
                                texts.append(axs[i].text(x, y, codon, fontsize=20))
                        elif condition == 'LEU_ILE_VAL' or condition == 'VAL':
                            if codon in ['GTC', 'GTT', 'GTA', 'GTG']:
                                texts.append(axs[i].text(x, y, codon, fontsize=20))
                    elif codon in ctrl_tagged_codons:
                        axs[i].scatter(x, y, label=codon, color=colors_depr['CTRL'], alpha=1, s=40)
                        if condition == 'CTRL':
                            texts.append(axs[i].text(x, y, codon, fontsize=20))
                    else:
                        axs[i].scatter(x-0.025, y, label=codon, color='black', alpha=0.5, s=10)

                # adjust the text
                adjust_text(texts, ax = axs[i], arrowprops=dict(arrowstyle='->', color='red'), expand_points=(1.2, 1.2), expand_text=(1.2, 1.2), force_text=(0.5, 0.5))
                # get full x and y lists
                x = [condition_codon_stall_mean_sorted[condition][codon] for codon in condition_codon_stall_mean_sorted[condition]]
                y = [condition_codon_attr_full_mean_sorted[condition][codon] for codon in condition_codon_stall_mean_sorted[condition]]
                # calculate pearson correlation
                corr, _ = pearsonr(x, y)

                # fit a line to the data
                z = np.polyfit(x, y, 1)
                p = np.poly1d(z)
                axs[i].plot(x, p(x), "r--", color='black', alpha=0.5)

                axs[i].set_xlabel('Mean Ribosome Counts', fontsize=15)
                axs[i].set_ylabel('Mean Attribution Value', fontsize=15)
                if condition == 'LEU_ILE':
                    c_text = 'LEU + ILE'
                elif condition == 'LEU_ILE_VAL':
                    c_text = 'LEU + ILE + VAL'
                else:
                    c_text = condition
                axs[i].set_title(c_text + " (PCC: {:.2f})".format(corr), fontsize=20)

                # ticks size
                axs[i].tick_params(axis='both', which='major', labelsize=15)

            legend_elements = [matplotlib.lines.Line2D([0], [0], marker='o', color='w', label='CTRL Tagged Codons', markerfacecolor=colors_depr['CTRL'], markersize=10),
                            matplotlib.lines.Line2D([0], [0], marker='o', color='w', label='ILE Codons', markerfacecolor=colors_depr['ILE'], markersize=10),
                            matplotlib.lines.Line2D([0], [0], marker='o', color='w', label='LEU Codons', markerfacecolor=colors_depr['LEU'], markersize=10),
                            matplotlib.lines.Line2D([0], [0], marker='o', color='w', label='VAL Codons', markerfacecolor=colors_depr['VAL'], markersize=10),
                            matplotlib.lines.Line2D([0], [0], marker='o', color='w', label='Other Codons', markerfacecolor='black', markersize=10)]
            
            # put legend outside of plot
            plt.legend(handles=legend_elements, loc='upper left', bbox_to_anchor=(1.1, 1.1), fontsize=12)

            # plt.savefig('plots/codon_stall-attr/Full_' + str(window_size) + '.png', dpi=300)
            plt.show()
        

In [None]:
with plt.style.context(['science','nature','grid','bright','no-latex']):
    fig, ax = plt.subplots(1, 1, figsize=(20, 20))
    gs = matplotlib.gridspec.GridSpec(2, 3, width_ratios=[1, 1, 1], height_ratios=[1, 1])
    axs = [plt.subplot(gs[i]) for i in range(6)]
    globStalling(axs, mode='full')