In [2]:
%load_ext autoreload
%autoreload 2

import tensorflow as tf
import torchvision
import torch
import lightning.pytorch as pl

from sklearn.decomposition import PCA
from src.model.full_model import SubCellProtModel
from src.utils.data_handling_utils import initialize_datasets, Retrieval_Data
from src.utils.batch_run_utils import batch_call, get_cell_lines_of_interest, get_isoforms_of_interest, get_proteoform_data
from src.analysis.shapely_analysis import shapley_analysis_sliding_kernel_explainer
import numpy as np
from matplotlib import pyplot as plt
from src.dataset.dataset import CLASSES
from tqdm import tqdm
import pandas as pd
import os
from enum import Enum
import pdb
import pickle as pk
import math
from numpy import savetxt



The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


0it [00:00, ?it/s]

## Get Image IDs & Splice_isoform ids for a joint embedding investigation

In [4]:
COLLECTION_NAME = "splice_isoform_dataset_cell_line_and_gene_split_full"
RANDOM_COLLECTION_NAME="random_splice_isoform_dataset"

MODEL_CHECKPOINT = 'checkpoints/splice_isoform_dataset_cell_line_and_gene_split_full-epoch=01-val_combined_loss=0.18.ckpt'

SLIDING_KERNEL_SIZE = 1  
NSAMPLES = 10000  # 1000 # Number of samples to take per baseline protein
NUM_BASELINE_PROTEINS = 300   # Number of baseline proteins to compare against
TOTAL_SAMPLED_PROTEOFORMS = 10000

In [5]:
train_dataset, val_dataset, test_dataset, get_data = initialize_datasets(COLLECTION_NAME, if_alphabetical=True)
random_train_dataset, random_val_dataset, random_test_dataset, random_get_data = (
    initialize_datasets(RANDOM_COLLECTION_NAME, if_alphabetical=False)
)

In [6]:
loaded_model = SubCellProtModel().load_from_checkpoint(
    MODEL_CHECKPOINT,
    collection_name=COLLECTION_NAME,
    batch_size=32,
)


## How to look up specific genes:
1. First find the gene you want from the HPA website.
2. Copy the ID from the URL e.g. https://www.proteinatlas.org/ENSG00000124608-AARS2
3. Figure out whether the gene is in one of our two datasets. You can try:

For genes in the training or holdout 1 dataset:
* use_old_hpa_client=True 
* and use get_data()    

Then for genes in the holdout 2 dataset:
* use_old_hpa_client=False 
* and use random_get_data()

In [None]:
proteoforms_4_shap = [
    # """ THESE ARE THE MITOCHONDRIA DATASET SAMPLES"""

    # AARS2 ENSG00000124608
    (get_proteoform_data(COLLECTION_NAME, get_data, gene_id="ENSG00000124608", use_old_hpa_client=True), get_data),
    # N4BP2
    (get_proteoform_data(RANDOM_COLLECTION_NAME, random_get_data, gene_id='ENSG00000078177', use_old_hpa_client=False), random_get_data),
    # DDIT3
    (get_proteoform_data(COLLECTION_NAME, get_data, gene_id='ENSG00000175197', use_old_hpa_client=True), get_data),
]
proteoforms_4_shap_updated = []
for metadata, data_source in proteoforms_4_shap:
    _X_investigation, x_len_investigation = data_source(
        metadata['_id'], retrieval_data=Retrieval_Data.PROTEIN_SEQ
    )
    proteoforms_4_shap_updated.append(
        (
            metadata['_id'], 
            x_len_investigation,
            data_source
        )
    )
proteoforms_4_shap_updated.sort(key = lambda x: x[1])
proteoforms_4_shap = [element[0] for element in proteoforms_4_shap_updated ]
data_sources = [element[2] for element in proteoforms_4_shap_updated ]
proteoforms_4_shap


Comment out below if you just want results

In [None]:
print("gather baseline isoforms")
baseline_isoforms = get_isoforms_of_interest_new(
    collection_name=COLLECTION_NAME,
    total_investigated_isoforms=TOTAL_SAMPLED_PROTEOFORMS,
    get_data=get_data,
    seed=0,
    use_old_hpa_client=True,
)

Global seed set to 0


gather baseline isoforms


 17%|██████████▊                                                      | 1570/9472 [00:53<04:36, 28.57it/s]

In [45]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from scipy.signal import savgol_filter
import matplotlib

# Ensure SciencePlots is installed
try:
    import scienceplots
except ImportError:
    print("scienceplots is not installed. Please install it using `pip install SciencePlots`.")

# Set SciencePlots style
plt.style.use(['science', 'no-latex'])

def viz_shapely_helper(
    TARGET_ISOFORM_FOR_SHAP,
    compartment_name,
    x_len_investigation,
    ax,
    sigma=100,
    x_axis_title="Residue Index",
    aggregate_type="savgol_filter",
):
    fixed_compartment_idx = sorted(CLASSES[0]).index(compartment_name)
    bad_compartment_name = CLASSES[0][fixed_compartment_idx]
    filename = f"high_fidelity_joint_background_{TARGET_ISOFORM_FOR_SHAP}_{bad_compartment_name}_({fixed_compartment_idx}).csv"
    print(filename)
    shap_values = pd.read_csv(filename, header=None).to_numpy()[0]

    if aggregate_type == "sliding_gaussian":
        def gaussian_kernel_1d(sigma):
            kernel_radius = np.ceil(sigma) * 3
            kernel_size = kernel_radius * 2 + 1
            ax = np.arange(-kernel_radius, kernel_radius + 1.0, dtype=np.float32)
            kernel = np.exp(-(ax**2) / (2.0 * sigma**2))
            return (kernel / np.sum(kernel)).reshape(1, kernel.shape[0])

        kernel = gaussian_kernel_1d(sigma)[0]
        sliding_window = np.convolve(shap_values, kernel, mode="full")
    elif aggregate_type == "sliding_ones":
        kernel = np.ones(20) / 20
        sliding_window = np.convolve(shap_values, kernel, mode="full")
    elif aggregate_type == "savgol_filter":
        sliding_window = savgol_filter(shap_values, window_length=51, polyorder=2)

    target_x_len = int(x_len_investigation / SLIDING_KERNEL_SIZE)
    x_vals = [SLIDING_KERNEL_SIZE * x for x in list(range(target_x_len))]
    
    # Set colors based on compartment name
    if compartment_name.lower() == "nucleoplasm":
        background_color = 'pink'
        foreground_color = 'red'
    elif compartment_name.lower() == "cytosol":
        background_color = 'peachpuff'
        foreground_color = 'orange'
    else:
        background_color = 'lightblue'
        foreground_color = 'blue'
    
    # Plot histogram
    ax.bar(x_vals, shap_values[:target_x_len], width=SLIDING_KERNEL_SIZE, color=background_color, alpha=0.7)
    
    # Plot smoothed line
    ax.plot(
        x_vals,
        sliding_window[:target_x_len],
        color=foreground_color,
        linewidth=2.5,
    )
    
    ax.set_title(compartment_name)
    ax.set_xlabel(x_axis_title)

def viz_shapely(
    TARGET_ISOFORM_FOR_SHAP,
    X_investigation,
    x_len_investigation,
    X_landmark_stains,
    compartments=CLASSES[0],
    num_cols=5,
    get_data=get_data
):
    fig, ax = plt.subplots(
        figsize=(4, 1.5),
    )
    fig.tight_layout()  # Adjust spacing between subplots

    # Iterate over compartments and corresponding subplot
    for compartment in compartments:
        viz_shapely_helper(
            TARGET_ISOFORM_FOR_SHAP=TARGET_ISOFORM_FOR_SHAP,
            compartment_name=compartment,
            x_len_investigation=x_len_investigation,  # Adjust as needed
            ax=ax,
            sigma=0.5,
            x_axis_title="Residue Index",
        )

    (
        _y_pred_antibody_stain,
        y_pred_multilabel,
        _y_pred_multilabel_raw,
    ) = loaded_model.predict_step(
        (
            X_investigation.unsqueeze(0),
            torch.Tensor([x_len_investigation]),
            torch.Tensor(X_landmark_stains).unsqueeze(0),
            None,
            None,
        ),
        batch_idx=0,
    )
    assert len(sorted(CLASSES[0])) == len(y_pred_multilabel[0])
    pred_location_labels = [
        compartment
        for compartment, y_pred in zip(sorted(CLASSES[0]), y_pred_multilabel[0])
        if y_pred
    ]
    metadata = get_data(TARGET_ISOFORM_FOR_SHAP, retrieval_data=Retrieval_Data.METADATA)

    plt.suptitle(
        f"Shapely Analysis [Averaged] for {metadata['splice_isoform_id']} \n(True: {metadata['location_labels']}, Pred: {pred_location_labels})",
        y=1.98,
    )
    plt.subplots_adjust(top=1.5, left=0.10)  # Increase the top spacing

#     plt.show()
    plt.savefig(
        f"Shapely Analysis [Averaged] for {metadata['splice_isoform_id']} (True: {metadata['location_labels']}, Pred: {pred_location_labels})"+'.pdf'
    )


In [13]:
def perform_shap_analysis(
    TARGET_ISOFORM_FOR_SHAP,
    model,
    baseline_Xs,
    target_X,
    target_x_lens,
    target_X_landmark_stains,
    nsamples=200,
    kernel_size=1,
):
    assert len(target_X) == 1, (
        "Can only compute single protein target"
    )
    model.eval()

    nucleoplasm_savename = f"high_fidelity_joint_background_{TARGET_ISOFORM_FOR_SHAP}_{CLASSES[0][0]}_({0}).csv"
    if os.path.exists(nucleoplasm_savename):
        return None

    res = shapley_analysis_sliding_kernel_explainer(
        model=model,
        baseline_Xs=baseline_Xs,
        target_X=target_X,
        target_x_lens=target_x_lens,
        target_X_landmark_stains=target_X_landmark_stains,
        kernel_size=kernel_size,
        nsamples=nsamples,
    )

    expected_vals = res[1]
    print("expected vals for the different compartments: ", expected_vals)
    compartment_shap_vals = [shap_vals for shap_vals in res[0]]
    [
        savetxt(
            f"high_fidelity_joint_background_{TARGET_ISOFORM_FOR_SHAP}_{CLASSES[0][compartment_idx]}_({compartment_idx}).csv",
            comp_shap,
            delimiter=",",
        )
        for compartment_idx, comp_shap in enumerate(compartment_shap_vals)
    ]
    return expected_vals


In [14]:
def average_shap_analysis(
    TARGET_ISOFORM_FOR_SHAP, 
    TARGET_CELL_IMAGE_FOR_SHAP, 
    baseline_isoforms, 
    data_source=None,
    run_just_one_average_background=True,
):
    if data_source is None:
        data_source = get_data

    X_investigation, x_len_investigation = data_source(
        TARGET_ISOFORM_FOR_SHAP, retrieval_data=Retrieval_Data.PROTEIN_SEQ
    )
    X_landmark_stains = data_source(
        TARGET_CELL_IMAGE_FOR_SHAP, retrieval_data=Retrieval_Data.CELL_IMAGE
    )

    filtered_baseline_isoforms = []
    full_baseline_isoforms = []
    averaged_baseline_X = torch.zeros(X_investigation.shape)

    for isoform in tqdm(baseline_isoforms):
        if len(isoform.split(" ")) < 2:
            continue
        X, x_len = get_data(
            isoform.split(" ")[1], retrieval_data=Retrieval_Data.PROTEIN_SEQ
        )
        if x_len < x_len_investigation:
            # Skipping all isoforms which are smaller than our target isoform in our baseline
            # Because otherwise we can't "replace" the amino acid at end of our target sequence
            # with a residue from the baseline isoform.
            continue
        averaged_baseline_X += X
        filtered_baseline_isoforms.append(isoform.split(" ")[1])
        if not run_just_one_average_background:
            full_baseline_isoforms.append(X)
    if not run_just_one_average_background:
        full_baseline_isoforms = np.stack(full_baseline_isoforms)
        full_baseline_isoforms = full_baseline_isoforms[
            : min(len(full_baseline_isoforms), NUM_BASELINE_PROTEINS)
        ]
        print(f"total number of isoforms: {len(full_baseline_isoforms)}")

    print(
        f"total samples in our baseline (that pass the length check): {len(filtered_baseline_isoforms)}"
    )
    print(filtered_baseline_isoforms)

    if run_just_one_average_background:
        averaged_baseline_X /= len(filtered_baseline_isoforms)
        full_baseline_isoforms = np.stack([averaged_baseline_X])

    expected_vals = perform_shap_analysis(
        TARGET_ISOFORM_FOR_SHAP=TARGET_ISOFORM_FOR_SHAP,
        model=loaded_model,
        baseline_Xs=full_baseline_isoforms,
        target_X=X_investigation,
        target_x_lens=[x_len_investigation],
        target_X_landmark_stains=X_landmark_stains,
        nsamples=NSAMPLES,
        kernel_size=SLIDING_KERNEL_SIZE,
    )


    viz_shapely(
        TARGET_ISOFORM_FOR_SHAP=TARGET_ISOFORM_FOR_SHAP,
        X_investigation=X_investigation,
        x_len_investigation=x_len_investigation,
        X_landmark_stains=X_landmark_stains,
        compartments=["Nucleoplasm", "Cytosol", "Mitochondria"],
    )



In [None]:
for proteoform_idx in proteoforms_4_shap:
    average_shap_analysis(proteoform_idx, proteoform_idx, baseline_isoforms)

In [None]:
for proteoform_idx in proteoforms_4_shap:
    X_investigation, x_len_investigation = get_data(
        proteoform_idx, retrieval_data=Retrieval_Data.PROTEIN_SEQ
    )
    X_landmark_stains = get_data(
        proteoform_idx, retrieval_data=Retrieval_Data.CELL_IMAGE
    )
    
    viz_shapely(
        TARGET_ISOFORM_FOR_SHAP=proteoform_idx,
        X_investigation=X_investigation,
        x_len_investigation=x_len_investigation,
        X_landmark_stains=X_landmark_stains,
        compartments=["Nucleoplasm", "Cytosol", "Mitochondria"],
    )

In [None]:
for proteoform_idx in proteoforms_4_shap:
    X_investigation, x_len_investigation = random_get_data(
        proteoform_idx, retrieval_data=Retrieval_Data.PROTEIN_SEQ
    )
    X_landmark_stains = random_get_data(
        proteoform_idx, retrieval_data=Retrieval_Data.CELL_IMAGE
    )
    
    viz_shapely(
        TARGET_ISOFORM_FOR_SHAP=proteoform_idx,
        X_investigation=X_investigation,
        x_len_investigation=x_len_investigation,
        X_landmark_stains=X_landmark_stains,
        compartments=["Nucleoplasm", "Cytosol", "Mitochondria"],get_data=random_get_data
    )

In [47]:
x_len_investigation

1772.0