In [1]:
import numpy as np
import scanpy as sc
import torch

from scipy.stats import pearsonr
from sklearn.metrics import mean_squared_error

In [2]:
PATH = "../"

import sys 
sys.path.append(PATH)

from src.data import sc_Dataset
from src.model import CrossmodalNet, load_model, load_hparams
from src.utils import test_to_tensor

%load_ext autoreload
%autoreload 2

In [3]:
dataset = sc_Dataset("cite_train_x.h5ad",
                     "cite_train_y_norm.h5ad",
                      time_key="week",
                     )
print(dataset.day_dict)

X to use: (13043, 14708)
Transform counts by None
Y to use: (13043, 99)
{1: tensor([1., 0., 0.]), 2: tensor([0., 1., 0.]), 3: tensor([0., 0., 1.])}


In [4]:
# test_x, true_y and one-hot time_y 

ada_test_x = sc.read_h5ad("cite_test_x.h5ad")
ada_test_x = ada_test_x[:, dataset.var_names_X]
counts_test_x = test_to_tensor(ada_test_x)
print(counts_test_x.shape)

true_test_y = np.load("cite_test_y_norm.npz")["arr_0"] 
true_test_y = true_test_y.T
print(true_test_y.shape)

test_y_day = torch.stack([dataset.day_dict[t] for t in ada_test_x.obs["week"]])
print(test_y_day.shape)

torch.Size([2299, 14708])
(2299, 99)
torch.Size([2299, 3])


In [5]:
hparams_load = load_hparams("params_0322_data2.json")
model = load_model("CrossmodalNet_0322_data2.th", 
                   n_input=14708, 
                   n_output=99, 
                   time_p=dataset.unique_day,
                   hparams_dict=hparams_load
                   )
# model

In [6]:
model.eval()

pred_cite_y = model(counts_test_x, T=test_y_day)
pred_cite_y_np = pred_cite_y.detach().cpu().numpy()

print(pred_cite_y_np.shape)

(2299, 99)


In [8]:
corr = [pearsonr(true_test_y[i], pred_cite_y_np[i])[0] for i in range(len(true_test_y))]
mse = mean_squared_error(true_test_y, pred_cite_y_np, multioutput="raw_values") 

# print("corr:", np.mean(corr))
# print("mse:", np.mean(mse))