In [None]:
PROJECT_NAME = "reverse-gene-finder"

In [None]:
import os
PROJECT_HOME = os.path.join("/content/drive/My Drive/Projects", PROJECT_NAME)

import sys
sys.path.append(PROJECT_HOME)

In [None]:
# Google Drive storage setup
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

In [None]:
%pip install -U tdigest anndata scanpy loompy > /dev/null 2> /dev/null
%pip install -U transformers[torch] datasets > /dev/null 2> /dev/null

In [None]:
import pickle
import warnings

import torch
import numpy as np
import pandas as pd

from tqdm.auto import trange
from collections import defaultdict

from transformers import BertForSequenceClassification
from datasets import load_from_disk

from libs.classifier import Classifier
from libs.causal_trace import trace_important_states

In [None]:
warnings.filterwarnings("ignore", category=RuntimeWarning)
warnings.filterwarnings("ignore", category=UserWarning)

In [None]:
# https://huggingface.co/ctheodoris/Geneformer (Apache License 2.0)

def pad_tensor(tensor, pad_token_id, max_len):
    tensor = torch.nn.functional.pad(
        tensor, pad=(0, max_len - tensor.numel()), mode="constant", value=pad_token_id
    )

    return tensor

def pad_2d_tensor(tensor, pad_token_id, max_len, dim):
    if dim == 0:
        pad = (0, 0, 0, max_len - tensor.size()[dim])
    elif dim == 1:
        pad = (0, max_len - tensor.size()[dim], 0, 0)
    tensor = torch.nn.functional.pad(
        tensor, pad=pad, mode="constant", value=pad_token_id
    )
    return tensor

def pad_3d_tensor(tensor, pad_token_id, max_len, dim):
    if dim == 0:
        raise Exception("dim 0 usually does not need to be padded.")
    if dim == 1:
        pad = (0, 0, 0, max_len - tensor.size()[dim])
    elif dim == 2:
        pad = (0, max_len - tensor.size()[dim], 0, 0)
    tensor = torch.nn.functional.pad(
        tensor, pad=pad, mode="constant", value=pad_token_id
    )
    return tensor

# pad list of tensors and convert to tensor
def pad_tensor_list(
    tensor_list,
    dynamic_or_constant,
    pad_token_id,
    model_input_size,
    dim=None,
    padding_func=None,
):
    # determine maximum tensor length
    if dynamic_or_constant == "dynamic":
        max_len = max([tensor.squeeze().numel() for tensor in tensor_list])
    elif isinstance(dynamic_or_constant, int):
        max_len = dynamic_or_constant
    else:
        max_len = model_input_size

    # pad all tensors to maximum length
    if dim is None:
        tensor_list = [
            pad_tensor(tensor, pad_token_id, max_len) for tensor in tensor_list
        ]
    else:
        tensor_list = [
            padding_func(tensor, pad_token_id, max_len, dim) for tensor in tensor_list
        ]
    # return stacked tensors
    if padding_func != pad_3d_tensor:
        return torch.stack(tensor_list)
    else:
        return torch.cat(tensor_list, 0)

In [None]:
# Load gene token dictionary

token_dictionary_file = os.path.join(PROJECT_HOME, "libs", "token_dictionary.pkl")
with open(token_dictionary_file, "rb") as f:
    gene_token_dict = pickle.load(f)

token_gene_dict = {v: k for k, v in gene_token_dict.items()}
pad_token_id = gene_token_dict.get("<pad>")

In [None]:
# Load gene information

gene_info = pd.read_csv(os.path.join(PROJECT_HOME, "data", "gene_info.csv"))
gene_id_to_name = {}
gene_name_to_id = {}
for idx, row in gene_info.iterrows():
    gene_id_to_name[row['gene_id']] = row['gene_name']
    gene_name_to_id[row['gene_name']] = row['gene_id']

In [None]:
# Known AD-related genes

# http://www.alzgene.org/TopResults.asp
selected_genes = ["APOE", "BIN1", "CLU", "ABCA7", "CR1", "PICALM", "MS4A6A", "CD33", "MS4A4E", "CD2AP"]

selected_token_ids = []
for gene_name in selected_genes:
    if gene_name in gene_name_to_id:
        gene_id = gene_name_to_id[gene_name]
        selected_token_ids.append(gene_token_dict.get(gene_id))
selected_token_ids = set(selected_token_ids)

print("# of known AD genes: %d" % len(selected_token_ids))

In [None]:
# Causal tracing settings

max_seq_len = 256 # maximum sequence length
n_samples = 10 # 1 original sample plus 9 corrupted samples
noise = 1.0 # noise level when perturbing the input sample

# Identify the effects of input perturbations on the model's output

indirect_effects_list = []
attention_weights_list = []
input_ids_list = []

for CV_FOLD in range(5):

    model_dir = os.path.join(PROJECT_HOME, "models", "finetuned_models", "cv_%d" % CV_FOLD)
    model_prefix = "ad_cell_classifier"
    model_directory = f"{model_dir}/geneformer_cellClassifier_{model_prefix}/ksplit1/"
    test_data_file = f"{model_dir}/{model_prefix}_labeled_test.dataset"

    input_data = load_from_disk(test_data_file)

    num_layers = 12
    model = BertForSequenceClassification.from_pretrained(
        model_directory,
        num_labels=len(['nonAD', 'earlyAD']),
        output_hidden_states=True,
        output_attentions=True,
        attn_implementation="eager"
    ).to("cuda")

    total_batch_length = len(input_data)
    forward_batch_size = 1

    for i in trange(0, total_batch_length, forward_batch_size, leave=True, desc="CV Fold %s" % CV_FOLD):

        max_range = min(i + forward_batch_size, total_batch_length)
        minibatch = input_data.select([i for i in range(i, max_range)])
        minibatch.set_format(type="torch")

        input_data_minibatch = minibatch["input_ids"]
        input_data_minibatch = input_data_minibatch.to("cuda")
        input_ids = input_data_minibatch[0]
        if len(input_ids) > max_seq_len:
            # keep the high-rank and low-rank genes
            num_samples_one_side = int(max_seq_len/2)
            input_ids = torch.cat((input_ids[:num_samples_one_side], input_ids[-num_samples_one_side:]))
            input_data_minibatch = torch.stack([input_ids])

        input_data_minibatch = pad_tensor_list(
            input_data_minibatch, max_seq_len, pad_token_id, forward_batch_size
        )

        input_ids = input_data_minibatch[0].detach().cpu()
        e_range = []
        for token_idx, token_id in enumerate(input_ids):
            token_id = int(token_id)
            if token_id in selected_token_ids:
                e_range.append(token_idx)
        if len(e_range) == 0:
            continue # Skip sample with no known AD genes
        input_ids_list.append(input_ids)

        attention_weights = model(input_data_minibatch).attentions
        average_attention_weights_list = []
        for layer in range(num_layers):
            # Average attention weights across all heads
            # Shape: (seq_length, seq_length)
            average_attention_weights_at_one_layer = attention_weights[layer].mean(dim=1).squeeze(0)
            average_attention_weights_list.append(average_attention_weights_at_one_layer)
        average_attention_weights = torch.stack(average_attention_weights_list)
        average_attention_weights = average_attention_weights.detach().cpu()

        # Shape: (seq_length, num_layers)
        inp = torch.cat([input_data_minibatch[:] for _ in range(n_samples)])
        indirect_effects = trace_important_states(model, num_layers, inp, e_range=e_range, noise=noise)
        indirect_effects = indirect_effects.detach().cpu()

        indirect_effects_list.append(indirect_effects)
        attention_weights_list.append(average_attention_weights)

all_indirect_effects = torch.stack(indirect_effects_list)
all_attention_weights = torch.stack(attention_weights_list)
all_input_ids = torch.stack(input_ids_list)

In [None]:
np.save(os.path.join(PROJECT_HOME, "results", "indirect_effects.npy"), all_indirect_effects.numpy())
np.save(os.path.join(PROJECT_HOME, "results", "attention_weights.npy"), all_attention_weights.numpy())
np.save(os.path.join(PROJECT_HOME, "results", "input_ids.npy"), all_input_ids.numpy())

In [None]:
from google.colab import runtime
runtime.unassign()