In [None]:
import h5py 
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import scienceplots
from pyhere import here
from tqdm import tqdm

In [None]:
def globCTRLPlot(ax):
    f_test_path = here('data', 'results', 'interpretability', '241001_RDHPLG_test_int.h5')
    f_train_path = here('data', 'results', 'interpretability', '241022_RDHPLG_train_int.h5')
    f_val_path = here('data', 'results', 'interpretability', '241029_RDHPLG_val_int.h5')
    f_test = h5py.File(f_test_path, 'r')
    f_train = h5py.File(f_train_path, 'r')
    f_val = h5py.File(f_val_path, 'r')

    attr_top5 = []

    for f in [f_train, f_val, f_test]:
        num_samples = len(f['x_input'])

        for i in tqdm(range(num_samples)):
            num_codons_sample = len(f['x_input'][i]) - 1
            lig_sample_ctrl = f['lig_ctrl'][i].reshape(num_codons_sample, num_codons_sample)

            for j in range(num_codons_sample):
                lig_sample_ctrl[j] = lig_sample_ctrl[j] / np.sum(np.abs(lig_sample_ctrl[j]))
                # take absolute value of lig_sample_ctrl[j]
                lig_sample_ctrl[j] = np.abs(lig_sample_ctrl[j])
                # get top 5 codons with highest lig_sample_ctrl[j]
                top5 = np.argsort(lig_sample_ctrl[j])[-5:]
                # get distance between top 5 codons and the codon of interest
                for k in range(5):
                    attr_top5.append(top5[k] - j)

    f_test.close()
    f_train.close()
    f_val.close()

    # # load attr_top5 from npz
    # attr_top5 = np.load('data/ctrl_full_attr_top5.npz')['attr_top5']

    ax.hist(attr_top5, bins=21, color='#2ecc71', edgecolor='#ffffff', linewidth=1, range=(-10, 10), density=True)
    ax.axvline(0, color='black', linestyle='--', linewidth=1)
    ax.set_xticks([-10, -4.75, -2, -1, 0, 4.75, 10], [-10, -5, 'E', 'P', 'A', 5, 10])
    # change font size
    ax.tick_params(axis='both', which='major', labelsize=12)
    ax.set_title('CTRL Head', fontsize=16)
    ax.set_xlabel('Codon Distance from A-site', fontsize=14)
    ax.set_ylabel('Frequency', fontsize=14)
    # ax.figure.savefig('plots/global/ctrl_attr_top5_xlim10.png', dpi=600)
    # ax.figure.show()

def globDDPlot(ax):
    f_test_path = here('data', 'results', 'interpretability', '241001_RDHPLG_test_int.h5')
    f_train_path = here('data', 'results', 'interpretability', '241022_RDHPLG_train_int.h5')
    f_val_path = here('data', 'results', 'interpretability', '241029_RDHPLG_val_int.h5')
    f_test = h5py.File(f_test_path, 'r')
    f_train = h5py.File(f_train_path, 'r')
    f_val = h5py.File(f_val_path, 'r')

    attr_top5 = []

    for f in [f_train, f_val, f_test]:
        num_samples = len(f['x_input'])

        for i in tqdm(range(num_samples)):
            num_codons_sample = len(f['x_input'][i]) - 1
            lig_sample_ctrl = f['lig_dd'][i].reshape(num_codons_sample, num_codons_sample)

            for j in range(num_codons_sample):
                lig_sample_ctrl[j] = lig_sample_ctrl[j] / np.sum(np.abs(lig_sample_ctrl[j]))
                # take absolute value of lig_sample_ctrl[j]
                lig_sample_ctrl[j] = np.abs(lig_sample_ctrl[j])
                # get top 5 codons with highest lig_sample_ctrl[j]
                top5 = np.argsort(lig_sample_ctrl[j])[-5:]
                # get distance between top 5 codons and the codon of interest
                for k in range(5):
                    attr_top5.append(top5[k] - j)

    f_test.close()
    f_train.close()
    f_val.close()

    # # load attr_top5 from npz
    # attr_top5 = np.load('data/dd_full_attr_top5.npz')['attr_top5']

    ax.hist(attr_top5, bins=21, color='#e74c3c', edgecolor='#ffffff', linewidth=1, range=(-10, 10), density=True)
    ax.axvline(0, color='black', linestyle='--', linewidth=1)
    ax.set_xticks([-10, -4.75, -2, -1, 0, 4.75, 10], [-10, -5, 'E', 'P', 'A', 5, 10])
    # change font size
    ax.tick_params(axis='both', which='major', labelsize=12)
    ax.set_title('Difference Head', fontsize=16)
    ax.set_xlabel('Codon Distance from A-site', fontsize=14)
    ax.set_ylabel('Frequency', fontsize=14)
    # ax.figure.savefig('plots/global/dd_attr_top5_xlim10.png', dpi=600)
    # ax.figure.show()