In [None]:
import scvi

import sys
import warnings
warnings.filterwarnings('ignore')

from itertools import cycle, product
from scipy import stats

import multiprocessing

In [None]:
sys.path.append("../")

import data_processing

import numpy as np
import pandas as pd

from scipy import stats
import scanpy as sc
import copy
import matplotlib.pyplot as plt
import time
import os
from anndata import AnnData
import anndata as ad

In [None]:
# load the data after filtering our low quality genes and cells/spots
data_id = "ctx_hipp_hvg"

data_root = os.path.join("data", data_id)
seq_data = sc.read_h5ad(f"{data_root}_sc.h5ad")
spatial_data = sc.read_h5ad(f"{data_root}_st.h5ad")

spatial_data.var_names = [x.lower() for x in spatial_data.var_names]
seq_data.var_names = [x.lower() for x in seq_data.var_names]

spatial_data.var_names_make_unique()
seq_data.var_names_make_unique()

# subset spatial data into shared genes
gene_names = np.intersect1d(spatial_data.var_names, seq_data.var_names)

# only use genes in both datasets
seq_data = seq_data[:, gene_names].copy()
spatial_data = spatial_data[:, gene_names].copy()

In [None]:
def get_kendall_tau(tuple):
    x,y = tuple
    res = stats.kendalltau(x,y)
    return res.statistic

def calculate_kendall_tau(adata1, adata2, gene_names):
    mat1 = adata1[:,gene_names].X
    mat2 = adata2[:, gene_names].X

    # Get all combinations of two matrix columns
    combinations = list(product(mat1, mat2))
    pool = multiprocessing.Pool() #the cores equal to the number of cores of the machine

    res = pool.map(get_kendall_tau, combinations)

    return res
    
from tqdm import tqdm
from sklearn.model_selection import LeaveOneOut
from sklearn.model_selection import KFold

def cv_data_gen(genelist, cv_mode="CV", kfold=10):
    """ Generates pair of training/test gene indexes cross validation datasets

    Args:
        genelist (list): list of all shared genes by adata_sc and adata_sp
        mode (str): Optional. support 'loo' and '10fold'. Default is 'loo'.

    Yields:
        tuple: list of train_genes, list of test_genes
    """

    #genes_array = np.array(adata_sp.uns["training_genes"])
    genes_array = np.array(genelist)

    if cv_mode == "loo":
        cv = LeaveOneOut()
    elif cv_mode == "CV":
        cv = KFold(n_splits=kfold)

    for train_idx, test_idx in cv.split(genes_array):
        train_genes = list(genes_array[train_idx])
        test_genes = list(genes_array[test_idx])
        yield train_genes, test_genes

In [None]:
seq_data.X.shape

In [None]:
spatial_data.X.shape

In [None]:
def process_data(adata):
    # get the total counts per cell/spot
    sc.pp.calculate_qc_metrics(adata, percent_top=None, log1p=False, inplace=True)
    # the library size factor is defined as the total counts per cell/spot divided by the median total counts of all the cells. in order to keep all the cells/spots having the same number of counts
    adata.obs["size_factor"] = adata.obs["total_counts"] / np.median(adata.obs["total_counts"])
    #adata.obs["size_factor"] = adata.obs["total_counts"] / 10000
    adata.layers["raw_counts"] = adata.X
    adata.X = np.matmul(np.linalg.inv(np.diag(adata.obs["size_factor"])), adata.X)

    # log and calculate the z-score of the counts
    sc.pp.log1p(adata)
    sc.pp.scale(adata)

In [None]:
process_data(seq_data)
process_data(spatial_data)

In [None]:
# only use genes in both datasets
seq_data = seq_data[:, gene_names].copy()
spatial_data = spatial_data[:, gene_names].copy()

seq_gene_names = seq_data.var_names
n_genes = seq_data.n_vars

# randomly permute all the shared genes
np.random.seed(seed=0)
rand_gene_idx = np.random.choice(range(n_genes), n_genes, replace=False)

fold=5
topK = 50

test_gene_list = []
ST_imputed = []

for train_genes, test_genes in tqdm(
    cv_data_gen(rand_gene_idx, kfold=fold), total=fold
):

    corr_sc_st = np.corrcoef(seq_data[:, train_genes].X, spatial_data[:, train_genes].X, rowvar=True)
    corr_cross = corr_sc_st[:seq_data.shape[0], seq_data.shape[0]:]
    # kendall_tau = calculate_kendall_tau(seq_data_ori, spatial_data_ori, train_genes)
    # corr_cross = np.array(kendall_tau).reshape((seq_data_ori.shape[0], spatial_data_ori.shape[0]))

    ind_map = np.argsort(corr_cross, axis=0)

    # create our "AVERAGE" model: for each spot, we selected topK cells according to correlation matrix
    topK_ind_map = ind_map[-topK:,:]
    impute_list = []
    
    for ind in topK_ind_map.T:
        pred_spot = np.mean(seq_data[:,test_genes].X[ind,:], axis=0)
        impute_list.append(pred_spot)
    
    impute_st = np.squeeze(impute_list)
    impute_st_raw = np.exp(impute_st) * spatial_data.obs["size_factor"].to_numpy().reshape(-1,1)

    test_gene_list.append(test_genes)

    ST_imputed.append(impute_st_raw)

In [None]:
ST_imputed = np.hstack(ST_imputed)

In [None]:
test_gene_ind = np.concatenate(test_gene_list)
spatial_copy = spatial_data[:,test_gene_ind].copy()

In [None]:
from sklearn.metrics import mean_squared_error

pearson = []
spearman = []
kendalltau = []
RMSE = []

for v1, v2 in zip(spatial_copy.layers["raw_counts"].T, ST_imputed.T):
    personR = stats.pearsonr(v1.reshape(-1), v2.reshape(-1))
    spearmanR = stats.spearmanr(v1.reshape(-1), v2.reshape(-1))
    kentou = stats.kendalltau(v1.reshape(-1), v2.reshape(-1))
    rmse = mean_squared_error(v1, v2, squared=False)

    pearson.append(personR[0])
    spearman.append(spearmanR[0])
    kendalltau.append(kentou.statistic)
    RMSE.append(rmse)

norm_raw = stats.zscore(spatial_copy.layers["raw_counts"], axis=0)
norm_imputed = stats.zscore(ST_imputed, axis=0)

norm_rmse = []
for v1, v2 in zip(norm_raw.T, norm_imputed.T):
    rmse = mean_squared_error(v1, v2, squared=False)
    norm_rmse.append(rmse)

df_sc = pd.DataFrame({"Pearson": pearson, "Spearman": spearman, "Kendalltou":kendalltau, "norm_RMSE": norm_rmse,"RMSE":RMSE})
df_sc.mean()

In [None]:
spatial_copy.obsm["imputed"] = ST_imputed

spatial_copy.var["Pearson"] = pearson
spatial_copy.var["Spearman"] = spearman
spatial_copy.var["Kendall_tau"] = kendalltau
spatial_copy.var["norm_RMSE"] = norm_rmse
spatial_copy.var["RMSE"] = RMSE

In [None]:
spatial_copy.write_h5ad(f"./results/naive_base{topK}_ctx_5fold.h5ad")