In [None]:
import jax
print(jax.devices())

In [None]:
import warnings
warnings.filterwarnings("ignore")

In [None]:
import scanpy as sc
import scarches as sca
from scarches.dataset.trvae.data_handling import remove_sparsity
import traceback
import os
import anndata as ad

In [None]:
import numpy as np
from scipy.sparse import issparse
import logging

def bin_data(adata, binning, key_to_process = None, result_binned_key="binned_data"):
    """
    Bins numerical data into discrete categories based on quantiles.

    Parameters:
        adata (AnnData): The input data object.
        key_to_process (str): Key in `adata.layers` to process.
        binning (int): Number of bins (must be an integer).
        result_binned_key (str): Key to store the binned results.

    Raises:
        ValueError: If `binning` is not an integer or data contains negative values.
    """
    if not isinstance(binning, int):
        raise ValueError(f"Binning must be an integer, but got {binning}.")

    layer_data = adata.layers[key_to_process] if key_to_process is not None else adata.X
    layer_data = layer_data.toarray() if issparse(layer_data) else layer_data  # Convert sparse to dense if needed

    if layer_data.min() < 0:
        raise ValueError(f"Expecting non-negative data, but got min value {layer_data.min()}.")

    binned_rows = []
    bin_edges = []

    for row in layer_data:
        if row.max() == 0:
            logger.warning("Row contains all zeros. Consider filtering such rows.")
            binned_rows.append(np.zeros_like(row, dtype=np.int64))
            bin_edges.append(np.array([0] * binning))
            continue

        non_zero_ids = row.nonzero()
        non_zero_row = row[non_zero_ids]

        # Define bin thresholds based on quantiles
        bins = np.quantile(non_zero_row, np.linspace(0, 1, binning - 1))

        # Assign bin indices
        non_zero_digits = np.digitize(non_zero_row, bins)  # Converts values into bin indices
        binned_row = np.zeros_like(row, dtype=np.int64)
        binned_row[non_zero_ids] = non_zero_digits

        binned_rows.append(binned_row)
        bin_edges.append(np.concatenate([[0], bins]))

    # Store the binned data and bin edges
    adata.layers[result_binned_key] = np.stack(binned_rows)
    adata.obsm["bin_edges"] = np.stack(bin_edges)

In [None]:
def train_scanvi_add_embedding(
    adata,
    batch_key='ID_batch_covariate',
    cell_type_key='Level_1_refined',
    n_latent=10,
    n_layers=2,
    max_epochs_vae=150,
    max_epochs_scanvae=80,
    ref_path='scanvae/scanvae_model',
    full_adata_path='scanvae/scanvae_adata_with_embedding.h5ad'
):
    """
    Train scANVI model and add latent representation to the input AnnData object.

    Parameters:
        adata (AnnData): Annotated data matrix with labels for supervised training.
        batch_key (str): Key for batch information in adata.obs.
        cell_type_key (str): Key for cell type labels in adata.obs.
        n_latent (int): Number of latent dimensions for scANVI.
        n_layers (int): Number of layers for the scANVI model.
        max_epochs_vae (int): Epochs for pretraining the scVI model.
        max_epochs_scanvae (int): Epochs for training the scANVI model.
        ref_path (str): Path to save the scANVI model.
        full_adata_path (str): Path to save the updated AnnData object.

    Returns:
        adata (AnnData): Updated AnnData object with latent representation in `.obsm`.
        scanvae_model (SCANVI): Trained scANVI model.
    """
    try:
        print("Setting up scVI/scANVI model...")
        # Remove sparsity for compatibility
        adata = remove_sparsity(adata.copy())        
        sca.models.SCVI.setup_anndata(adata, batch_key=batch_key, labels_key=cell_type_key)
        print("Training scVI model...")
        vae = sca.models.SCVI(
            adata,
            n_layers=n_layers,
            n_latent=n_latent,
            encode_covariates=True,
            deeply_inject_covariates=False,
            use_layer_norm="both",
            use_batch_norm="none"
        )
        vae.train(max_epochs=max_epochs_vae)

        print("Initializing and training scANVI model...")
        scanvae = sca.models.SCANVI.from_scvi_model(vae, unlabeled_category="Unknown")
        scanvae.train(max_epochs=max_epochs_scanvae)
        scanvae.save(ref_path, overwrite=True)

        print("Adding latent representation to AnnData...")
        adata.obsm['X_scanvi'] = scanvae.get_latent_representation(adata=adata)

        print("Performing clustering and UMAP visualization...")
        sc.pp.neighbors(adata, use_rep='X_scanvi')
        sc.tl.leiden(adata, resolution=0.25)
        sc.tl.umap(adata)
        #prepare for metrics
        adata.uns['output_type'] = 'embed'
        adata.obsm['X_emb'] = adata.obsm['X_scanvi']
        adata.obsm['X_umap_scanvi'] = adata.obsm['X_umap']
        # Save updated AnnData object
        adata.write(full_adata_path)
        print("scANVI training and embedding generation completed successfully.")
        return adata, scanvae

    except Exception as e:
        print(f"scANVI training failed: {e}")
        traceback.print_exc()
        return None, None

In [None]:
adata = sc.read_h5ad('2025_05_20_refined_annotation.h5ad')

In [None]:
adata = adata[:, adata.var.Manual_Genes].copy()

In [None]:
adata

In [None]:
sc.pl.umap(adata, color='Level_5', frameon=False, legend_fontsize=5)

In [None]:
bin_data(adata, binning=50)

In [None]:
adata.obs.head()

In [None]:
mouse_to_human = {
    'cDC2': 'Dendritic Cell - cDC2',
    'Malignant Cell - Mesenchymal': 'Malignant Cell - Mesenchymal',
    'Double Positive CD4+CD8+ T Cell': 'Double Positive CD4+CD8+ T Cell',
    'T-reg': 'T-reg',
    'CD8+ Tissue-Resident Memory T Cell': 'CD8+ Tissue-Resident Memory T Cell',
    'CD8+ Exhausted T Cell': 'CD8+ Exhausted T Cell',
    'M2-like TAM': 'Macrophage - M2-like TAM',
    'CD4+ Central Memory T Cell': 'CD4+ Memory T Cell',  # Closest match
    'NK cell': 'NK Cell',
    'angiogenic TAM': 'Macrophage - angiogenic TAM',
    'CD8+ Memory T Cell': 'CD8+ Memory T Cell',
    'CD8+ Effector T Cell': 'CD8+ Effector T Cell',
    'Mast': 'Mast Cell',  # Not in cell_types
    'CD8+ Terminal Effector T Cell': 'CD8+ Terminal Effector T Cell',
    'N1': 'Neutrophil - N1',
    'CD8+ Naive T Cell': 'CD8+ Naive T Cell',
    'Vascular Endothelial Cell': 'Vascular Endothelial Cell',  # Closest match 
    'CD4+ Th17 Cell': 'CD4+ Th17 Cell',
    'γδ T Cell (Vδ1)': 'γδ T Cell (Vδ1)',
    'CD4+ Th1 Cell': 'CD4+ Th1 Cell',
    'Malignant Cell - Hypoxia': 'Malignant Cell - Hypoxia',
    'CD4+ Th2 Cell': 'CD4+ Th2 Cell',
    'myCAF': 'myCAF',
    'CD4+ Th22 Cell': 'CD4+ Th22 Cell',
    'CD4+ Naive Cell': 'CD4+ Naive T Cell',
    'N2': 'Neutrophil - N2',
    'Plasmablast': 'Plasmablast',  # Not in cell_types
    'M1-like TAM': 'Macrophage - M1-like TAM',
    'lipid processing TAM': 'Macrophage - lipid processing TAM',
    'Macrophage - CD3+ TAM': 'Macrophage - CD3+ TAM',
    'B Cell - Activated': 'B Cell - Activated',  # Closest match
    'pDC': 'Dendritic Cell - pDC',
    'B Cell - Germinal Center': 'B Cell - Germinal Center',  # Closest match
    'Malignant Cell - Highly Proliferative': 'Malignant Cell - Highly Proliferative',
    'Malignant Cell - Epithelial': 'Malignant Cell - Epithelial',
    'Tumor-Associated Endothelial Cell': 'Tumor-Associated Endothelial Cell',  # Closest match
    'Malignant Cell - Apoptotic': 'Malignant Cell - Apoptotic',
    'B Cell - Memory': 'B Cell - Memory',  # Closest match
    'Malignant Cell - Senescence': 'Malignant Cell - Senescence',
    'iCAF': 'iCAF',
    'B-reg': 'B-reg',  # Closest match
    'cDC1': 'Dendritic Cell - cDC1',
    'Monocyte': 'Monocyte',
    'Malignant Cell - EMT': 'Malignant Cell - EMT',
    'Schwann Cell': 'Schwann Cell',  # Not in cell_types
    'Malignant Cell - Invasive': 'Malignant Cell - Highly Invasive',  # Probable typo match
    'Lymphatic Endothelial Cell': 'Lymphatic Endothelial Cell',  # Closest match
    'Plasma Cell': 'Plasma Cell',
    'Adipocytes': 'Adypocyte',  # Typo in cell_types
    'B Cell - Naive': 'B Cell - Naive',
    'Ductal (atypical)': 'Ductal Cell (atypical)',  # Closest match
    'Beta Cell': 'Beta Cell',
    'Malignant Cell - Acinar-like': 'Malignant Cell - Acinar-like',
    'Malignant Cell - Pit Like': 'Malignant Cell - Pit Like',
    'Other Endocrine': 'Other Endocrine',  # Not in cell_types
    'Ductal Cell': 'Ductal Cell',
    'N0': 'Neutrophil - N0',
    'Activated DC': 'Dendritic Cell - Activated',  # Closest match
    'Acinar (REG+)': 'Acinar (REG+) Cell',
    'Acinar idlling': 'Acinar Idling Cell',  # Typo
    'Alpha Cell': 'Alpha Cell',
    'Acinar Cell': 'Acinar Cell',
    'Pericyte': 'Pericyte',  # Not in cell_types
    'Smooth Muscle Cell': 'Smooth Muscle Cell'  # Not in cell_types
}

In [None]:
adata.obs['Level_5'] = adata.obs.Level_5.map(mouse_to_human)

In [None]:
sc.pl.umap(adata, color='Level_5', frameon=False, legend_fontsize=5)

In [None]:
adata.obs.Level_5.isna().sum()

In [None]:
subset = sc.pp.subsample(adata, fraction=0.01, copy=True)
print("Min X:", np.min(subset.X))
print("Max X:", np.max(subset.X))

In [None]:
adata.X = adata.layers['binned_data'].copy()
adata_updated, scanvae_model = train_scanvi_add_embedding(
    adata=adata,
    batch_key='ID_batch_covariate',
    cell_type_key='Level_5',
    n_latent=10,
    n_layers=2,
    max_epochs_vae=100,
    max_epochs_scanvae=80,
    ref_path='final_scanVI/pretrained_scanvi_mg_L5_binned_model',
    full_adata_path='final_scanVI/final_object.h5ad'
)

- Need to correct: 'Vascular Endothelial Cell': 'Endothelial Cell'


In [None]:
adata_updated.obsm['scanvi_L5_emb'] = adata_updated.obsm['X_scanvi']
adata_updated.obsm['scanvi_L1_emb'] = adata_updated.obsm['scanvi_emb']
adata_updated.obs['Level_5'] = adata_updated.obs['Level_5'].replace('Endothelial Cell', 'Vascular Endothelial Cell')

In [None]:
sc.pl.umap(adata_updated, color=['Dataset', 'Technology', 'Level_5'], frameon=False, legend_fontsize=5, ncols=1)

In [None]:
sc.pp.neighbors(adata_updated, use_rep='X_scanvi', metric='cosine', n_neighbors=100)
sc.tl.umap(adata_updated, min_dist=0.75)

In [None]:
sc.pl.umap(adata_updated, color=['Dataset', 'Technology', 'Level_5'], frameon=False, legend_fontsize=5, ncols=1, save='final_object')

In [None]:
adata_updated.write('final_scanVI/final_object.h5ad')

In [None]:
import math
import matplotlib.pyplot as plt

# Get unique datasets
datasets = adata_updated.obs['Dataset'].unique()
n = len(datasets)

# Define subplot grid size
ncols = 4
nrows = math.ceil(n / ncols)

# Create figure and axes
fig, axes = plt.subplots(nrows, ncols, figsize=(4*ncols, 4*nrows))
axes = axes.flatten()  # Flatten in case it's 2D

# Loop through each dataset and plot on its axis
for i, dataset in enumerate(datasets):
    ax = axes[i]
    sc.pl.umap(
        adata_updated[adata_updated.obs['Dataset'] == dataset],
        color='Dataset',
        frameon=False,
        legend_fontsize=5,
        size=3,
        show=False,
        ax=ax,
        title=str(dataset)
    )

# Turn off any extra axes
for j in range(i + 1, len(axes)):
    axes[j].axis('off')

plt.tight_layout()
plt.show()

In [None]:
# Get unique datasets
cell_types = adata_updated.obs['Level_5'].unique()
n = len(cell_types)
n

In [None]:
# Get unique datasets
cell_types = sorted(adata_updated.obs['Level_5'].unique())
n = len(cell_types)

# Define subplot grid size
ncols = 4
nrows = math.ceil(n / ncols)

# Create figure and axes
fig, axes = plt.subplots(nrows, ncols, figsize=(4*ncols, 4*nrows))
axes = axes.flatten()  

# Loop through each dataset and plot on its axis
for i, cell_type in enumerate(cell_types):
    ax = axes[i]
    sc.pl.umap(
        adata_updated,
        color='Level_5',
        groups=cell_type,
        frameon=False,
        legend_fontsize=5,
        size=3,
        show=False,
        ax=ax,
        title=str(cell_type)
    )

# Turn off any extra axes
for j in range(i + 1, len(axes)):
    axes[j].axis('off')

plt.tight_layout()
plt.show()

In [None]:
pwd