In [None]:
import torch
import numpy as np
import os
import pandas as pd

In [None]:
os.chdir("/your/project/root")

In [None]:
%load_ext autoreload
%autoreload 2
from baselines.scIDPMs.data_utils import SingleCellDatasetBaselines
from countdiff.utils.metrics import scFID, compute_resampled_eval

In [None]:
test_set = SingleCellDatasetBaselines("data/dnadiff/filtered_heart_data.hdf5", "test", ["age", "gender", "batch", "cell_type"], missing_ratio = None, imputation_mask = "data/dnadiff/random_masks/MCAR_masks/heart_dropout_0.5.npy")

In [None]:
imputed_data = np.load("baselines/scIDPMs/imputed_data/scid_heart_0.5.npy")

In [None]:
imputed_data

In [None]:
data_2 = pd.read_csv("baselines/scIDPMs/imputed_data/fetus_data_imputed.csv")

In [None]:
orig_test_data =  SingleCellDatasetBaselines("data/dnadiff/filtered_hca_data.hdf5", "test", None, missing_ratio = None, imputation_mask = "data/dnadiff/random_masks/MCAR_masks/fetus_dropout_0.5.npy")

In [None]:
imp_mask = np.load("data/dnadiff/random_masks/MCAR_masks/fetus_dropout_0.5.npy").astype(bool)

In [None]:
target_mask = imp_mask * orig_test_data.observed_mask

In [None]:
imputed_data = np.round(data_2.values*target_mask) + ~target_mask*orig_test_data.counts.numpy()

In [None]:
np.save("data/dnadiff/imputed_data/scidpm_fetus_results_MCAR_dropout_0.5.npy", imputed_data)

In [None]:
data_2 = pd.read_csv("baselines/scIDPMs/imputed_data/fetus_data_imputed.csv")

In [None]:
orig_test_data =  SingleCellDatasetBaselines("data/dnadiff/filtered_heart_data.hdf5", "test", None, missing_ratio = None, imputation_mask = "data/dnadiff/random_masks/MCAR_masks/heart_dropout_0.5.npy")

In [None]:
imp_mask = np.load("data/dnadiff/random_masks/MCAR_masks/heart_dropout_0.5.npy").astype(bool)

In [None]:
imputed_vals = np.round(imputed_data) * target_mask + ~target_mask*test_set.counts.numpy()

In [None]:
np.max(imputed_vals)

In [None]:
torch.max(test_set.counts)

In [None]:
torch.max(torch.Tensor(imputed_vals)-test_set.counts)

In [None]:
model_path = "data/dnadiff/2024-02-12-scvi-homo-sapiens/scvi.model"
scfid = scFID(test_set.gene_names, categorical_covariates= test_set.get_obs_dict(unique = True), feature_model_path = model_path)

In [None]:
impute_dict = {}
for key, value in test_set.get_obs_dict(unique = False).items():
    impute_dict[key] = test_set.get_obs_dict(unique = False)[key][:imputed_data.values.shape[0]]

In [None]:
total_size = test_set.counts.shape[0]
n_samples = 70430
sample_indices = torch.randperm(total_size)[:n_samples]

In [None]:
imputed_subset = imputed_data[sample_indices]
raw_subset = test_set.counts[sample_indices].numpy()
mask_subset = target_mask[sample_indices]
covariates_dict = test_set.get_obs_dict()
covariates_subset = None
if covariates_dict:
    covariates_subset = {key: np.array(val)[sample_indices.numpy()].tolist() for key, val in covariates_dict.items()}


In [None]:
scfid.reset()

In [None]:
scfid.update(raw_subset, covariates_subset, False)
scfid.update(test_set.counts.numpy(), test_set.get_obs_dict(), True)
print(scfid.compute())
scfid.reset()

In [None]:
scfid.update(imputed_subset, covariates_subset, False)
scfid.update(imputed_data, test_set.get_obs_dict(), True)
print(scfid.compute())
scfid.reset()

In [None]:
scfid.update(raw_subset, covariates_subset, True)
scfid.update(imputed_data, test_set.get_obs_dict(), False)
print(scfid.compute())
scfid.reset()

In [None]:
scfid.update(imputed_data, test_set.get_obs_dict(), False)
scfid.update(test_set.counts.numpy(), test_set.get_obs_dict(), True)
print(scfid.compute())
scfid.reset()

In [None]:
imputed_sites = torch.masked_select(torch.Tensor(imputed_vals), torch.Tensor(target_mask).bool())
actual_sites = torch.masked_select(test_set.counts, torch.Tensor(target_mask).bool())


In [None]:
raw_bias = torch.mean(imputed_sites - actual_sites).item()
raw_bias

In [None]:
raw_bias = torch.mean(imputed_sites - actual_sites).item()
raw_bias

In [None]:
mae = torch.mean(torch.abs(imputed_sites-actual_sites)).item()
mae

In [None]:
rmse = torch.sqrt(torch.mean((imputed_sites - actual_sites) ** 2)).item()
rmse

In [None]:
ss_res = torch.sum((actual_sites - imputed_sites) ** 2)
ss_tot = torch.sum((actual_sites - torch.mean(actual_sites)) ** 2)
r2 = (1 - ss_res / ss_tot).item()
r2