In [1]:
import torch
import numpy as np
from PIL import Image
from copy import deepcopy
from time import time
import os
import pandas as pd
import matplotlib.pyplot as plt
import tqdm
from tqdm import tqdm
import anndata

In [None]:
data_path = "./hest1k_datasets/PRAD/processed_data/"  # processed data path
ori_st_path = "./hest1k_datasets/PRAD/st/" # original ST data path

slice_out = "MEND145"

# load training slide IDs
slicename_lst = list(np.genfromtxt(data_path + "all_slide_lst.txt", dtype=str))
slicename_lst.remove(slice_out)
print(slice_out, " is held out for testing ", (slice_out not in slicename_lst))

# load selected gene list
selected_genes = list(np.genfromtxt(data_path + "selected_gene_list.txt", dtype=str))
print("Selected gene list loaded. len of selected genes: ", len(selected_genes))

# load test slide
test_adata = anndata.read_h5ad(ori_st_path + slice_out + ".h5ad")
test_count_mtx_df = pd.DataFrame(test_adata.X.toarray(), columns=test_adata.var_names, index=test_adata.obs_names).loc[:, selected_genes]
# transform count
test_count_mtx_selected_genes = np.log2(test_count_mtx_df + 1).copy()
print("Test count mtx shape: ", test_count_mtx_selected_genes.shape)

# load generated samples for test slide
sample_path = "./PRAD_results/runs/000/samples/generated_samples_0300000_20sample.pt" # (example)
pred = torch.load(sample_path)
pred = pred.squeeze(1)
print("Generated samples shape: ", pred.shape)

In [12]:
num_rep = 20       # number of samples generated for one image patch
num_selected = 20  # take average of (a subset / all) samples and use it as prediction
random_selected_index = np.random.choice(np.arange(num_rep), num_selected, replace=False)

pred_avg = torch.zeros(size=(test_count_mtx_selected_genes.shape[0], 
                             test_count_mtx_selected_genes.shape[1]))
for i in range(test_count_mtx_selected_genes.shape[0]):
    pred_avg[i] = torch.mean(pred[i*num_rep + random_selected_index, :], dim=0)

pred_avg = pred_avg.cpu().detach().numpy()

In [None]:
# calculate correlation

all_corr = []
for i in range(test_count_mtx_selected_genes.shape[1]):
    x = test_count_mtx_selected_genes.iloc[:, i].values
    y = pred_avg[:, i]
    cor = np.corrcoef(x, y)[0][1]
    all_corr.append(cor)
plt.hist(all_corr, bins=50)

Evaluation metrics

In [None]:
# evaluation metrics

# PCC
print("PCC-10: ", np.mean(sorted(all_corr)[::-1][:10]))
print("PCC-50: ", np.mean(sorted(all_corr)[::-1][:50]))
print("PCC-200: ", np.mean(sorted(all_corr)[::-1][:200]))
# MSE, MAE
print("MSE: ", np.mean((test_count_mtx_selected_genes.values - pred_avg)**2))
print("MAE: ", np.mean(np.abs(test_count_mtx_selected_genes.values - pred_avg)))
# RVD
pred_var = np.var(pred_avg, axis=0)
gt_var = np.var(test_count_mtx_selected_genes.values, axis=0)
print("RVD: ", np.mean((pred_var - gt_var)**2 / gt_var**2))

In [None]:
# gene variation curve
fig, axs = plt.subplots(2, 2, figsize=(8, 8))

pred_mean = np.mean(pred_avg, axis=0)
pred_mean = pred_mean / np.sum(pred_mean)
gt_mean = np.mean(test_count_mtx_selected_genes, axis=0)
gt_mean = gt_mean / np.sum(gt_mean)
gt_mean_sorted = np.sort(gt_mean)
pred_mean_sorted = pred_mean[np.argsort(gt_mean)]
axs[0, 0].plot(np.arange(len(gt_mean_sorted)), gt_mean_sorted, label="Ground Truth", c="b")
axs[0, 0].scatter(np.arange(len(pred_mean_sorted)), pred_mean_sorted, s=5, label="Predicted", c="orange")
axs[0, 0].set_title("Normalized Mean")
axs[0, 0].set_xlabel("gene index ordered by mean")
axs[0, 0].set_ylabel("normalized mean")
# axs[0, 0].set_ylim()

pred_mean = np.mean(pred_avg, axis=0)
gt_mean = np.mean(test_count_mtx_selected_genes, axis=0)
gt_mean_sorted = np.sort(gt_mean)
pred_mean_sorted = pred_mean[np.argsort(gt_mean)]
axs[1, 0].plot(np.arange(len(gt_mean_sorted)), gt_mean_sorted, label="Ground Truth", c="b")
axs[1, 0].scatter(np.arange(len(pred_mean_sorted)), pred_mean_sorted, s=5, label="Predicted", c="orange")
axs[1, 0].set_title("Absolute Mean")
axs[1, 0].set_xlabel("gene index ordered by mean")
axs[1, 0].set_ylabel("absolute mean")
# axs[1, 0].set_ylim()


pred_var = np.var(pred_avg, axis=0)
pred_var = pred_var / np.sum(pred_var)
gt_var = np.var(test_count_mtx_selected_genes, axis=0)
gt_var = gt_var / np.sum(gt_var)
gt_var_sorted = np.sort(gt_var)
pred_var_sorted = pred_var[np.argsort(gt_var)]
axs[0, 1].plot(np.arange(len(gt_var_sorted)), gt_var_sorted, label="Ground Truth", c="b")
axs[0, 1].scatter(np.arange(len(pred_var_sorted)), pred_var_sorted, s=5, label="Predicted", c="orange")
axs[0, 1].set_title("Normalized Variance")
axs[0, 1].set_xlabel("gene index ordered by var")
axs[0, 1].set_ylabel("normalized variance")
# axs[0, 1].set_ylim()



pred_var = np.var(pred_avg, axis=0)
gt_var = np.var(test_count_mtx_selected_genes, axis=0)
gt_var_sorted = np.sort(gt_var)
pred_var_sorted = pred_var[np.argsort(gt_var)]
axs[1, 1].plot(np.arange(len(gt_var_sorted)), gt_var_sorted, label="Ground Truth", c="b")
axs[1, 1].scatter(np.arange(len(pred_var_sorted)), pred_var_sorted, s=5, label="Predicted", c="orange")
axs[1, 1].set_title("Absolute Variance")
axs[1, 1].set_xlabel("gene index ordered by var")
axs[1, 1].set_ylabel("absolute variance")
# axs[1, 1].set_ylim()

plt.tight_layout()
plt.show()