#### Importing all the required **Python** and **R** libraries 

In [None]:
import pandas as pd
import scanpy as sc
import warnings
import scarches as sca
warnings.filterwarnings("ignore")

import decoupler as dc

import sys
sys.path.append('../scripts')
%load_ext autoreload
%autoreload 2
#%load_ext lab_black

In [None]:
sc.set_figure_params(frameon=False)
sc.settings.figdir = '/home/daniele/Code/scmouse_atlas/reports/figures/'

#### Read and bin

In [None]:
import numpy as np
from scipy.sparse import issparse, csr_matrix

def bin_data(adata, binning, key_to_process=None, result_binned_key="binned_data"):
    """
    Bins numerical data into discrete categories based on quantiles.

    Parameters:
        adata (AnnData): The input data object.
        key_to_process (str): Key in `adata.layers` to process.
        binning (int): Number of bins (must be an integer).
        result_binned_key (str): Key to store the binned results.

    Raises:
        ValueError: If `binning` is not an integer or data contains negative values.
    """
    if not isinstance(binning, int):
        raise ValueError(f"Binning must be an integer, but got {binning}.")

    layer_data = adata.layers[key_to_process] if key_to_process is not None else adata.X
    layer_data = layer_data.A if issparse(layer_data) else layer_data  # Convert sparse to dense if needed

    if layer_data.min() < 0:
        raise ValueError(f"Expecting non-negative data, but got min value {layer_data.min()}.")

    binned_rows = []
    bin_edges = []

    for row in layer_data:
        if row.max() == 0:
            binned_rows.append(np.zeros_like(row, dtype=np.int64))
            bin_edges.append(np.array([0] * binning))
            continue

        non_zero_ids = row.nonzero()
        non_zero_row = row[non_zero_ids]

        # Define bin thresholds based on quantiles
        bins = np.quantile(non_zero_row, np.linspace(0, 1, binning - 1))

        # Assign bin indices
        non_zero_digits = np.digitize(non_zero_row, bins)  # Converts values into bin indices
        binned_row = np.zeros_like(row, dtype=np.int64)
        binned_row[non_zero_ids] = non_zero_digits

        binned_rows.append(binned_row)
        bin_edges.append(np.concatenate([[0], bins]))

    # Convert binned data back to sparse format
    adata.layers[result_binned_key] = csr_matrix(np.stack(binned_rows))
    adata.obsm["bin_edges"] = np.stack(bin_edges)


In [None]:
adata_source = sc.read_h5ad('/mnt/storage/Daniele/atlases/mouse/03_mouse_larry_barcoded_annotated.h5ad')
adata_target = sc.read_h5ad('/mnt/storage/Daniele/atlases/mouse/02_mouse_no_larry_qced.h5ad')

In [None]:
gene_common = list(set(adata_source.var_names).intersection(adata_target.var_names))

In [None]:
manual_genes_human = pd.read_csv('../../../supplementary_data/human/human_manual_genes.csv')

In [None]:
manual_genes_human.columns = ['genesymbol','manual']
manual_genes_human['pathway'] = '_' #dummy for decoupler
manual_genes_human = manual_genes_human[manual_genes_human['manual']]

In [None]:
mouse_manual_genes = dc.translate_net(manual_genes_human, target_organism='mouse')

In [None]:
man_genes = list(set(mouse_manual_genes['genesymbol'].values).intersection(gene_common))

#### Reference

In [None]:
adata_source = adata_source[:, gene_common].copy()
bin_data(adata_source, 50, key_to_process = None, result_binned_key="binned_data")
source_manual = adata_source[:, man_genes].copy()

In [None]:
batch_key = 'donor_id'
celltype_key = 'Level_1'

In [None]:
sca.models.SCVI.setup_anndata(source_manual, layer='binned_data', batch_key=batch_key, labels_key=celltype_key)


In [None]:
vae = sca.models.SCVI(
    source_manual,
    n_layers=2,
    encode_covariates=True,
    deeply_inject_covariates=False,
    use_layer_norm="both",
    use_batch_norm="none",
)

In [None]:
vae.train()

In [None]:
scanvae = sca.models.SCANVI.from_scvi_model(vae, unlabeled_category = "Unknown")
scanvae.train()

In [None]:
source_manual.obs['predictions'] = scanvae.predict()
print("Acc: {}".format(np.mean(source_manual.obs.predictions == source_manual.obs.Level_1)))

#### Target

In [None]:
adata_target = adata_target[:, gene_common].copy()
bin_data(adata_target, 50, key_to_process = None, result_binned_key="binned_data")
target_manual = adata_target[:, man_genes].copy()



In [None]:
model_surgery = sca.models.SCANVI.load_query_data(
    target_manual,
    scanvae,
    freeze_dropout = True,
)

In [None]:
model_surgery._unlabeled_indices = np.arange(target_manual.n_obs)
model_surgery._labeled_indices = []
print("Labelled Indices: ", len(model_surgery._labeled_indices))
print("Unlabelled Indices: ", len(model_surgery._unlabeled_indices))

In [None]:
model_surgery.train(
    max_epochs=20,
    plan_kwargs=dict(weight_decay=0.0),
    check_val_every_n_epoch=2,
)

In [None]:
adata_full = source_manual.concatenate(target_manual)
adata_full.obs['Level_1_label_transfer'] = model_surgery.predict(adata_full)
adata_full.obsm['X_scANVI'] = model_surgery.get_latent_representation(adata_full)

In [None]:
adata_full

In [None]:
del source_manual, target_manual
import gc
gc.collect()


In [None]:
adata_full_all_genes = adata_source.concatenate(adata_target)
adata_full_all_genes = adata_full_all_genes[adata_full.obs_names].copy()

In [None]:
adata_full_all_genes

In [None]:
adata_full_all_genes.obs['Level_1_label_transfer'] = adata_full.obs['Level_1_label_transfer']
adata_full_all_genes.obsm['X_scANVI'] = adata_full.obsm['X_scANVI']

In [None]:
adata_full_all_genes.write_h5ad('/mnt/storage/Daniele/atlases/mouse/06_mouse_inhouse_integrated_scanvi.h5ad')