In [None]:
import torch
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import pearsonr,zscore
import hickle as hkl
from sklearn.metrics import r2_score
from numpy.linalg import solve, svd, norm
from tqdm import tqdm

# Load Datasets

In [None]:
cell_embedding = hkl.load("embeddings/final_X_tcga_processed.hkl")
cell_embedding /= norm(cell_embedding, axis=1).reshape(-1, 1)
gene_effects_df = hkl.load("datasets/2023/CRISPRGeneEffect_processed.hkl")

# Train Models

In [None]:
def train_individual_rfm_cell():
    device = "cuda:3"
    bandwidth = 1
    reg = 1e-5

    X = torch.tensor(cell_embedding.values).to(device).float()

    num_cells = X.shape[0]
    knockouts = gene_effects_df.columns
    num_knockouts = len(knockouts)

    cell_distances = kernel.euclidean_distances(X,X).to(device)
    dist_ko = cell_distances.fill_diagonal_(0)
    y = torch.tensor(gene_effects_df.values).to(device).float()
    sol = torch.linalg.solve(torch.exp(-bandwidth*(dist_ko)**0.5).to(device) + reg*torch.eye(dist_ko.shape[0],device=device),y)

    return sol

In [None]:
sol = train_individual_rfm_cell()

# Get Grads

In [None]:
def euclidean_distances(samples, centers, M=None,squared=True,diag_only=False):
    '''Calculate the pointwise distance.
    Args:
        samples: of shape (n_sample, n_feature).
        centers: of shape (n_center, n_feature).
        squared: boolean.
    Returns:
        pointwise distances (n_sample, n_center).
    '''
    if M is None:
        samples_norm = torch.sum(samples**2, dim=1, keepdim=True)
    else:
        if diag_only:
            samples_norm = (samples * M) * samples
            # samples_cpu = samples.detach().cpu()
            # M_cpu = M.detach().cpu()
            # samples_norm = samples_cpu**2 * M_cpu
            # samples_norm = samples_norm.to("cuda:1")
        else:
            samples_norm = (samples @ M) * samples
        samples_norm = torch.sum(samples_norm, dim=1, keepdims=True)

    if samples is centers:
        centers_norm = samples_norm
    else:
        if M is None:
            centers_norm = torch.sum(centers**2, dim=1, keepdims=True)
        else:
            # centers_norm = (centers.diag() * M).diag() * centers
            if diag_only:
                centers_norm = (centers * M) * centers
            else:
                centers_norm = (centers @ M) * centers
            centers_norm = torch.sum(centers_norm, dim=1, keepdims=True)
    centers_norm = torch.reshape(centers_norm, (1, -1))

    distances = samples.mm(torch.t(centers))
    distances.mul_(-2)
    distances.add_(samples_norm)
    distances.add_(centers_norm)
    if not squared:
        distances.clamp_(min=0)
        distances.sqrt_()

    return distances

def laplace_kernel(samples, centers, bandwidth,M=None, diag_only=False):
    '''Laplacian kernel.
    Args:
        samples: of shape (n_sample, n_feature).
        centers: of shape (n_center, n_feature).
        bandwidth: kernel bandwidth.
    Returns:
        kernel matrix of shape (n_sample, n_center).
    '''
    assert bandwidth > 0
    kernel_mat = euclidean_distances(samples, centers, M=M, squared=False, diag_only=diag_only)
    kernel_mat.clamp_(min=0)
    gamma = 1. / bandwidth
    kernel_mat.mul_(-gamma)
    kernel_mat.exp_()
    return kernel_mat

In [None]:
def get_grads(X, sol, P, L=1,centering=False,diag_only=True):
    
    K = laplace_kernel(X, X, bandwidth=1, M=P, diag_only=diag_only)

    dist = euclidean_distances(X, X, M=P, squared=False, diag_only=diag_only)
    dist.clamp_(min=0)
    dist[dist < 1e-10] = 0

    with np.errstate(divide='ignore'):
        K = K/dist

    K[K == float("Inf")] = 0.
    n,d = X.shape
    num_kos,n = sol.shape

    grads = torch.zeros((d,num_kos)).to(X.device)
    for i in tqdm(range(num_kos)):
        weight = sol[i,:].reshape((-1,1))

        step2 = K @ (weight * X)
        step3 = (weight.T @ K).T * X
        G = (step2 - step3) * -1/L
        G = torch.sum(G**2,axis=0)
        grads[:,i] = G/n

    return grads

In [None]:
device = "cuda:3"
X = torch.tensor(cell_embedding.values).to(device).float()
n,d = X.shape
P = torch.ones(d).double().to(device)
grads = get_grads(X,sol.T,L=1,P=P,centering=False,diag_only=True)

# Add PCC Weights

In [None]:
def get_pcc():
    cell_embedding = hkl.load("embeddings/final_X_tcga_processed.hkl")
    cell_embedding /= norm(cell_embedding, axis=1).reshape(-1, 1)   

    exp_cols = [e for e in cell_embedding.columns if e.split("_")[-1] == "exp"]

    std_val = cell_embedding[exp_cols].std(axis=0).replace(0,1)
    zscore = (cell_embedding[exp_cols] - cell_embedding[exp_cols].mean(axis=0))/std_val
    cell_embedding[exp_cols] *= (np.abs(zscore) < 3).fillna(0).astype(int)

    normalized_cell_embedding = cell_embedding - cell_embedding.mean(axis=0)
    normalized_gene_effects_df = gene_effects_df - gene_effects_df.mean(axis=0)

    cell_norms = (normalized_cell_embedding**2).sum(axis=0).values
    gene_norms = (normalized_gene_effects_df**2).sum(axis=0).values

    pcc = (normalized_cell_embedding.T @ normalized_gene_effects_df)/(cell_norms.reshape((-1,1)) @ gene_norms.reshape((1,-1)))**0.5
    
    pcc = pd.DataFrame(pcc,columns=knockouts,index=features)
    hkl.dump(pcc,"datasets/pcc.hkl")

In [None]:
pcc = get_pcc().fillna(0)
pcc = pcc.loc[grads.index]
mut = [x for x in pcc.index if x.split("_")[-1] != "exp"]
pcc.loc[mut] = -(pcc.loc[mut].clip(upper=0))

exp = [x for x in pcc.index if x.split("_")[-1] == "exp"]
pcc.loc[exp] = abs(pcc.loc[exp])

feature_importance_df = grads * pcc
hkl.dump(feature_importance_df,"datasets/feature_importances.hkl")