In [None]:
pwd

In [None]:
import os
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)
warnings.simplefilter(action='ignore', category=UserWarning)
import scanpy as sc
import torch
import scarches as sca
from scarches.dataset.trvae.data_handling import remove_sparsity
import matplotlib.pyplot as plt
import numpy as np
import gdown
sc.settings.set_figure_params(dpi=200, frameon=False)
sc.set_figure_params(dpi=200)
sc.set_figure_params(figsize=(4, 4))
torch.set_printoptions(precision=3, sci_mode=False, edgeitems=7)

In [None]:
import pandas as pd
pd.set_option('display.max_columns', None)

# Load the core atlas

In [None]:
os.chdir('../PDAC_Final/Downstream/')

In [None]:
adata = sc.read_h5ad('final_scanVI_2.0/final_object.h5ad')

In [None]:
adata

# Load the pretrained model

In [None]:
model = sca.models.SCANVI.load(dir_path="final_scanVI_2.0/pretrained_scanvi_mg_L4_binned_model/", adata=adata)

In [None]:
model

# Load the extension datasets

In [None]:
extension = sc.read_h5ad('Extension/Extension_Datasets_Combined.h5ad')

In [None]:
extension_adata = extension[:, adata.var_names].copy()

In [None]:
extension_adata

In [None]:
extension_adata.obs.groupby('Dataset').size()

# Binning

In [None]:
import numpy as np
from scipy.sparse import issparse
import logging

def bin_data(adata, binning, key_to_process = None, result_binned_key="binned_data"):
    """
    Bins numerical data into discrete categories based on quantiles.

    Parameters:
        adata (AnnData): The input data object.
        key_to_process (str): Key in `adata.layers` to process.
        binning (int): Number of bins (must be an integer).
        result_binned_key (str): Key to store the binned results.

    Raises:
        ValueError: If `binning` is not an integer or data contains negative values.
    """
    if not isinstance(binning, int):
        raise ValueError(f"Binning must be an integer, but got {binning}.")

    layer_data = adata.layers[key_to_process] if key_to_process is not None else adata.X
    layer_data = layer_data.toarray() if issparse(layer_data) else layer_data  # Convert sparse to dense if needed

    if layer_data.min() < 0:
        raise ValueError(f"Expecting non-negative data, but got min value {layer_data.min()}.")

    binned_rows = []
    bin_edges = []

    for row in layer_data:
        if row.max() == 0:
            logger.warning("Row contains all zeros. Consider filtering such rows.")
            binned_rows.append(np.zeros_like(row, dtype=np.int64))
            bin_edges.append(np.array([0] * binning))
            continue

        non_zero_ids = row.nonzero()
        non_zero_row = row[non_zero_ids]

        # Define bin thresholds based on quantiles
        bins = np.quantile(non_zero_row, np.linspace(0, 1, binning - 1))

        # Assign bin indices
        non_zero_digits = np.digitize(non_zero_row, bins)  # Converts values into bin indices
        binned_row = np.zeros_like(row, dtype=np.int64)
        binned_row[non_zero_ids] = non_zero_digits

        binned_rows.append(binned_row)
        bin_edges.append(np.concatenate([[0], bins]))

    # Store the binned data and bin edges
    adata.layers[result_binned_key] = np.stack(binned_rows)
    adata.obsm["bin_edges"] = np.stack(bin_edges)

## Check the raw counts 

In [None]:
# extension_adata.X = extension_adata.layers['raw'].copy()

In [None]:
for dataset in extension_adata.obs.Dataset.unique():
    print(f'{dataset}')
    adata_temp = extension_adata[extension_adata.obs.Dataset == dataset]
    subset = sc.pp.subsample(adata_temp, fraction=0.01, copy=True)
    X = subset.X.toarray() if hasattr(subset.X, "toarray") else subset.X
    
    print(f"Min: {X.min()}, Max: {X.max()}, Mean: {X.mean()}")
    
    # Check if all values are integers (raw count hint)
    is_integer = np.allclose(X, X.astype(int))
    print(f"All values are integers: {is_integer}")
    print('-'*50)

In [None]:
bin_data(extension_adata, binning=50)

In [None]:
extension_adata.X = extension_adata.layers['binned_data'].copy()

# Create the ID batch covariate column

In [None]:
extension_adata.obs = extension_adata.obs.astype(str)
extension_adata.obs.replace("nan", np.nan, inplace=True)

In [None]:
extension_adata.obs['ID_batch_covariate'] = extension_adata.obs['ID_batch_covariate'].fillna(extension_adata.obs['Donor_ID'])

In [None]:
extension_adata.obs['ID_batch_covariate'] = extension_adata.obs['ID_batch_covariate'].fillna(extension_adata.obs['ID'])

In [None]:
extension_adata.obs.groupby(['Dataset', 'ID_batch_covariate']).size().unstack().sum(axis=1)

In [None]:
extension_adata.obs['ID_batch_covariate'] = extension_adata.obs['ID_batch_covariate'].astype('category')

In [None]:
extension_adata.obs.groupby(['Dataset', 'ID_batch_covariate']).size().unstack()

In [None]:
extension_adata.obs.groupby('Dataset')['ID_batch_covariate'].unique()

In [None]:
extension_adata.obs = extension_adata.obs.astype(str)
extension_adata.obs.replace("nan", np.nan, inplace=True)

In [None]:
extension_adata.obs['ID_batch_covariate'] = extension_adata.obs['ID_batch_covariate'].fillna(extension_adata.obs['ID'])

In [None]:
extension_adata.obs.groupby('Dataset')['ID_batch_covariate'].unique()

In [None]:
extension_adata.obs['ID_batch_covariate'] = extension_adata.obs['ID_batch_covariate'].astype('category')

In [None]:
extension_adata

# Extend the atlas

In [None]:
model = sca.models.SCANVI.load_query_data(extension_adata, 'final_scanVI_2.0/pretrained_scanvi_mg_L4_binned_model', freeze_dropout = True)
model._unlabeled_indices = np.arange(extension_adata.n_obs)
model._labeled_indices = []
print("Labelled Indices: ", len(model._labeled_indices))
print("Unlabelled Indices: ", len(model._unlabeled_indices))

In [None]:
model.train(max_epochs=100, plan_kwargs=dict(weight_decay=0.0), check_val_every_n_epoch=10)

In [None]:
model.save('final_scanVI_2.0/query_model', overwrite=True)

In [None]:
extension_adata.obsm['scanvi_L4_emb'] = model.get_latent_representation(adata=extension_adata)
extension_adata.obs['Level_4_predictions'] = model.predict()
extension_adata.write('final_scanVI_2.0/Atlas_Extentsion.h5ad')

In [None]:
# extension_adata.obs['cell_type'] = extension_adata.obs['Level_4'].tolist()
# extension_adata.obs['batch'] = extension_adata.obs['ID_batch_covariate'].tolist()

In [None]:
sc.pp.neighbors(extension_adata, use_rep='scanvi_L4_emb', n_neighbors=100,  metric='cosine')
sc.tl.leiden(extension_adata, resolution=0.5, key_added='Global_Leiden')
sc.tl.umap(extension_adata, min_dist=0.75)

In [None]:
sc.pl.umap(extension_adata, color=['Dataset', 'Technology', 'Level_4_predictions'], frameon=False, legend_fontsize=5, ncols=1)

In [None]:
pwd

In [None]:
# model.save('final_scanVI_2.0/query_model', overwrite=True)
# extension_adata.write('final_scanVI_2.0/Atlas_Extentsion.h5ad')

In [None]:
extension_adata.obs.Level_4_predictions.value_counts()

# concat both

In [None]:
adata_full = adata.concatenate(extension_adata)

In [None]:
full_latent = model.get_latent_representation(adata=adata_full)
adata_full.obsm['scanvi_extended_atlas_emb'] = full_latent

In [None]:
sc.pp.neighbors(adata_full, use_rep='scanvi_extended_atlas_emb', n_neighbors=100,  metric='cosine')
sc.tl.umap(adata_full, min_dist=0.95)

In [None]:
pwd

In [None]:
adata_full.obs = adata_full.obs.astype(str)
adata_full.obs.replace("nan", np.nan, inplace=True)

In [None]:
adata_full.obs.replace("Unknown", np.nan, inplace=True)

In [None]:
adata_full.obs['Level_4_All'] = adata_full.obs['Level_4'].fillna(adata_full.obs['Level_4_predictions'])

In [None]:
sc.pl.umap(adata_full, color=['Dataset', 'Technology', 'Level_4_All'], frameon=False, legend_fontsize=5, ncols=1)

In [None]:
adata_full.obs = adata_full.obs.astype(str)

In [None]:
adata_full.write('final_scanVI_2.0/Core_Extension_MG.h5ad')

# Reload

In [None]:
adata_full = sc.read_h5ad('final_scanVI_2.0/Core_Extension_MG.h5ad')

In [None]:
adata_full

In [None]:
sc.pl.umap(adata_full, color=['Level_4_All'], frameon=False, legend_fontsize=5, ncols=1)

In [None]:
adata_full.obs.ID_batch_covariate.unique()

In [None]:
adata_full.obs.head()

In [None]:
adata_full.obs.groupby(['Dataset', 'Level_4_All']).size().unstack()

# Save Prediction Probabilities

In [None]:
prediction_df = model.predict(soft=True)

In [None]:
prediction_df.to_csv('final_scanVI_2.0/prediction_prob.csv')

In [None]:
prediction_df

In [None]:
sc.pl.umap(adata_full, color=['Location', 'Condition'], frameon=False, wspace=0.75)

# Fix Condition/Location

## Condition

In [None]:
adata_full.obs.Condition = adata_full.obs.Condition.replace('Tumour', 'Primary Tumour')
adata_full.obs.Condition = adata_full.obs.Condition.replace('Normal', 'Healthy')
adata_full.obs.Condition = adata_full.obs.Condition.replace('PDAC', 'Primary Tumour')
adata_full.obs.Condition = adata_full.obs.Condition.replace('Primary Tumor', 'Primary Tumour')

In [None]:
sc.pl.umap(adata_full, color=['Location', 'Condition'], frameon=False, wspace=0.75)

In [None]:
adata_full.obs.groupby(['Level_4_All', 'Condition']).size().unstack().T.style.set_sticky('index')

## Location

In [None]:
sc.pl.umap(adata_full, color=['Location', 'Tissue'], frameon=False, wspace=0.75)

In [None]:
adata_full.obs.groupby(['Dataset', 'Tissue']).size().unstack().style.set_sticky('index')

In [None]:
adata_full.obs.Location = np.where((adata_full.obs.Dataset == 'Lin_MET_GSE154778'), 'Liver', adata_full.obs.Location)
adata_full.obs.Location = np.where((adata_full.obs.Dataset == 'Simeone_MET_GSE205013'), 'Liver', adata_full.obs.Location)

In [None]:
print(adata_full.obs['Location'].unique())

In [None]:
adata_full.obs['Location'] = adata_full.obs['Location'].replace('nan', 'Pancreas')

In [None]:
adata_full.obs.groupby(['Dataset', 'Location']).size().unstack().style.set_sticky('index')

# Kick Out Misannotated Metstatic Cells

In [None]:
for cell_type in adata_full.obs.Level_4_All.unique():
    if 'Ductal' in cell_type or 'Acinar' in cell_type:
        print(cell_type)
        adata_temp = adata_full[adata_full.obs.Level_4_All == cell_type]
        print(adata_temp.obs.Condition.value_counts())
        print('_'*100)

In [None]:
mask = (adata_full.obs.Level_4_All.isin(['Ductal Cell (atypical)', 'Ductal Cell',  'Acinar (REG+) Cell', 'Acinar Idling Cell', 'Acinar Cell'])) & (adata_full.obs.Condition == 'Metastatic Lesion')
mask.value_counts()

In [None]:
mask.value_counts()

In [None]:
adata_full_filtered = adata_full[~mask]

In [None]:
adata_full_filtered.obs.Location.value_counts()

In [None]:
adata_full_filtered.obs.Condition.value_counts()

In [None]:
adata_full_filtered.obs.rename(columns={'Level_4_All':'Level_4_Final'}, inplace=True)

In [None]:
sc.pl.umap(adata_full_filtered, color=['Dataset', 'Technology', 'Level_4_Final', 'Location', 'Condition'], frameon=False, wspace=0.75, ncols=2, legend_fontsize=4)

# Map core and extension

In [None]:
extension_dataset = ['EGAS00001002543''GSE15835','GSE194247','GSE211644','GSE229413','Lee_MET_GSE156405','Lin_MET_GSE154778','Simeone_MET_GSE205013','Zhang_GSE197177','phs001840_v1_p1']

In [None]:
adata_full_filtered.obs['Is_Core'] = np.where(adata_full_filtered.obs.Dataset.isin(extension_dataset), 'Extension', 'Core')

In [None]:
sc.pl.umap(adata_full_filtered, color=['Is_Core'], frameon=False, wspace=0.75)

# Add Other Levels

In [None]:
df_map = pd.read_csv('Level_4_to_Level_1.csv', index_col=None, sep=';')

In [None]:
df_map

In [None]:
obs = adata_full_filtered.obs.copy()

In [None]:
level_keys = ['Level_1', 'Level_2', 'Level_3', 'Level_4_Final']

In [None]:
obs['Level_1'] = obs.Level_4_Final.map(dict(zip(df_map.Level_4,df_map.Level_1)))
obs['Level_2'] = obs.Level_4_Final.map(dict(zip(df_map.Level_4,df_map.Level_2)))
obs['Level_3'] = obs.Level_4_Final.map(dict(zip(df_map.Level_4,df_map.Level_3)))

In [None]:
obs[obs.Is_Core == 'Extension'].head()

In [None]:
obs_old = adata_full_filtered.obs.copy()
adata_full_filtered.obs = obs.copy()

In [None]:
sc.pl.umap(adata_full_filtered, color=level_keys, frameon=False, wspace=0.75, ncols=2, legend_fontsize=8)

In [None]:
adata_full_filtered

In [None]:
adata_full_filtered.write('final_scanVI_2.0/Core_Extension_MG.h5ad')

# Redo UMAP

In [None]:
adata_full_filtered.obsm['UMAP_0.95'] = adata_full_filtered.obsm['X_umap'].copy()
sc.tl.umap(adata_full_filtered, min_dist=0.85, key_added='UMAP_0.85')

In [None]:
adata_full_filtered.obsm['X_umap'] = adata_full_filtered.obsm['UMAP_0.85']
sc.pl.umap(adata_full_filtered, color=level_keys, frameon=False, wspace=0.75, ncols=2, legend_fontsize=8)

In [None]:
adata_full_filtered.obsm

In [None]:
sc.tl.umap(adata_full_filtered, min_dist=0.75, key_added='UMAP_0.75')

In [None]:
adata_full_filtered.obsm['X_umap'] = adata_full_filtered.obsm['UMAP_0.75']
sc.pl.umap(adata_full_filtered, color=level_keys, frameon=False, wspace=0.75, ncols=2, legend_fontsize=8)

# Compress

In [None]:
# import numpy as np

In [None]:
# import scanpy as sc
# adata_full_filtered = sc.read_h5ad('final_scanVI_2.0/Core_Extension_MG.h5ad')

In [None]:
adata_full_filtered.X

In [None]:
subset = sc.pp.subsample(adata_full_filtered, fraction=0.01, copy=True)
X = subset.X.toarray() if hasattr(subset.X, "toarray") else subset.X

print(f"Min: {X.min()}, Max: {X.max()}, Mean: {X.mean()}")

# Check if all values are integers (raw count hint)
is_integer = np.allclose(X, X.astype(int))
print(f"All values are integers: {is_integer}")

In [None]:
from scipy.sparse import csr_matrix
adata_full_filtered.X = csr_matrix(adata_full_filtered.X)

In [None]:
adata_full_filtered.X 

In [None]:
adata_full_filtered.layers['raw']

In [None]:
X = subset.layers['raw'].toarray() if hasattr(subset.layers['raw'], "toarray") else subset.X

print(f"Min: {X.min()}, Max: {X.max()}, Mean: {X.mean()}")

# Check if all values are integers (raw count hint)
is_integer = np.allclose(X, X.astype(int))
print(f"All values are integers: {is_integer}")

In [None]:
adata_full_filtered.write('final_scanVI_2.0/Core_Extension_MG.h5ad', compression='gzip')

In [None]:
pwd

In [None]:
adata_full_filtered

# Add to Anndata with All Genes

In [None]:
import scanpy as sc
import anndata as ad
import pandas as pd
import numpy as np
import os

In [None]:
os.chdir('/lustre/groups/ml01/workspace/shrey.parikh/PDAC_Work_Dir/PDAC_Final/Downstream/')
adata_core_all_genes = sc.read_h5ad('2025_05_20_refined_annotation.h5ad')
adata_extension_all_genes = sc.read_h5ad('Extension/Extension_Datasets_Combined.h5ad')
adata_mg = sc.read_h5ad('final_scanVI_2.0/Core_Extension_MG.h5ad')

In [None]:
adata_core_all_genes

In [None]:
adata_extension_all_genes

In [None]:
adata_core_all_genes.obs_names_make_unique()
adata_extension_all_genes.obs_names_make_unique()
adata_mg.obs_names_make_unique()

In [None]:
adata_combined = adata_core_all_genes.concatenate(adata_extension_all_genes, join='outer')

In [None]:
len(set(adata_combined.obs_names) & set(adata_mg.obs_names))

In [None]:
(set(adata_mg.obs_names)- set(adata_combined.obs_names))

In [None]:
adata_mg_filtered = adata_mg[~adata_mg.obs_names.isin(['CGGGTGTTCGTCGCTT-1-1-2', 'TGAGGGAGTAGATTAG-1-1-2'])]

In [None]:
adata_mg_filtered

In [None]:
adata_combined_subset = adata_combined[adata_combined.obs_names.isin(adata_mg_filtered.obs_names)]

In [None]:
adata_combined_subset

In [None]:
adata_combined_subset.obs = adata_mg_filtered.obs.copy()
adata_combined_subset.obsm = adata_mg_filtered.obsm.copy()
adata_combined_subset.obsp = adata_mg_filtered.obsp.copy()
adata_combined_subset.uns = adata_mg_filtered.uns.copy()

In [None]:
for dataset in adata_combined_subset.obs.Dataset.unique():
    print(f'{dataset}')
    adata_temp = adata_combined_subset[adata_combined_subset.obs.Dataset == dataset]
    subset = sc.pp.subsample(adata_temp, fraction=0.01, copy=True)
    X = subset.X.toarray() if hasattr(subset.X, "toarray") else subset.X
    print(f"Min: {X.min()}, Max: {X.max()}, Mean: {X.mean()}")
    # Check if all values are integers (raw count hint)
    is_integer = np.allclose(X, X.astype(int))
    print(f"All values are integers: {is_integer}")
    print('-'*50)

In [None]:
adata_combined_subset.var = adata_combined_subset.var.astype(str)

In [None]:
adata_combined_subset.obs.head()

In [None]:
adata_combined_subset.obs.Level_4_Final.unique()

# Fix NA

In [None]:
import scanpy as sc
import numpy as np
import pandas as pd

In [None]:
import os
os.chdir('/lustre/groups/ml01/workspace/shrey.parikh/PDAC_Work_Dir/PDAC_Final/Downstream/')

In [None]:
# adata_mg_filtered = sc.read_h5ad('final_scanVI_2.0/Core_Extension_MG.h5ad')

In [None]:
# adata_combined_subset = sc.read_h5ad('final_scanVI_2.0/Core_Extension_All_Genes.h5ad')

In [None]:
adata_combined_subset

In [None]:
df_map = pd.read_csv('Level_4_to_Level_1.csv', index_col=None, sep=';')

In [None]:
adata_mg_filtered.obs.Level_4_Final = adata_mg_filtered.obs.Level_4_Final.replace('Malignant Cell - Invasive', 'Malignant Cell - Highly Invasive')
adata_combined_subset.obs.Level_4_Final = adata_combined_subset.obs.Level_4_Final.replace('Malignant Cell - Invasive', 'Malignant Cell - Highly Invasive')

In [None]:
adata_mg_filtered.obs.Level_4_Final = adata_mg_filtered.obs.Level_4_Final.replace('CD4+ Central Memory T Cell', 'CD4+ Memory T Cell')
adata_mg_filtered.obs.Level_4_Final = adata_mg_filtered.obs.Level_4_Final.replace('CD4+ Naive Cell', 'CD4+ Naive T Cell')
adata_combined_subset.obs.Level_4_Final = adata_combined_subset.obs.Level_4_Final.replace('CD4+ Central Memory T Cell', 'CD4+ Memory T Cell')
adata_combined_subset.obs.Level_4_Final = adata_combined_subset.obs.Level_4_Final.replace('CD4+ Naive Cell', 'CD4+ Naive T Cell')

In [None]:
df_map.Level_4 = df_map.Level_4.replace('Adypocyte', 'Adipocyte')

In [None]:
set(adata_mg_filtered.obs.Level_4_Final.unique()) - set(df_map.Level_4.unique())

In [None]:
set(adata_combined_subset.obs.Level_4_Final.unique()) - set(df_map.Level_4.unique())

In [None]:
obs = adata_mg_filtered.obs.copy()

In [None]:
level_keys = ['Level_1', 'Level_2', 'Level_3', 'Level_4_Final']

In [None]:
obs['Level_1'] = obs.Level_4_Final.map(dict(zip(df_map.Level_4,df_map.Level_1)))
obs['Level_2'] = obs.Level_4_Final.map(dict(zip(df_map.Level_4,df_map.Level_2)))
obs['Level_3'] = obs.Level_4_Final.map(dict(zip(df_map.Level_4,df_map.Level_3)))

In [None]:
obs = obs[['Barcode', 'Dataset', 'ID_batch_covariate', 'Unique_ID', 'Technology',
       'n_genes', 'n_counts', 'log_counts', 'mt_frac', 'n_genes_by_counts',
       'log1p_n_genes_by_counts', 'total_counts', 'log1p_total_counts',
       'total_counts_mito', 'log1p_total_counts_mito', 'pct_counts_mito',
       'leiden', 'batch', 'leiden_0.2', 'leiden_0.2_annotation',
       'leiden_subcluster', 'level0_leiden_subcluster', 'leiden_0.5',
       'is_outlier_total_counts', 'outlier', 'infercnv_score_malignant',
       'infercnv_score_malignant_refined', 'cnv_score_abs', 'treatment_status',
       'Level_0', 'MALAT1_lognorm', 'empty_droplet', 'ID_harmonised',
       'Dataset_unique', 'Tissue', 'Age', 'Sex', 'Diabetes', 'Treatment',
       'Global_Leiden', 'Treatment_Harmonized', 'Treatment_Category',
       'Myeloid_leiden_0.75', 'Fibroblast_leiden_0.75', 'Lymphoid_leiden_0.75',
       'Endothelial_Cell_leiden_0.75', 'Malignant_leiden_0.75',
       'Ductal_Cell_leiden_0.75', 'Schwann_Cell_leiden_0.75',
       'Adipocyte_leiden_0.75', 'Endocrine_Cell_leiden_0.75',
       'Acinar_Cell_leiden_0.75', 'Pericyte_leiden_0.75',
       'Smooth_Muscle_Cell_leiden_0.75', 'NK_Cell_leiden_0.75', 'Condition',
       'combo', 'EMT category', 'EMT score', 'EMT_score_DL',
       'Suspicious_Normal', '_scvi_batch', '_scvi_labels', 'Donor_ID', 'Location', 'TreatmentType',
       'ID', 'Atlas_Extension_CellType', 'Level_4_predictions',
       'Level_1', 'Level_2', 'Level_3', 'Level_4','Level_4_Final', 'Is_Core']]

In [None]:
obs.head()

In [None]:
adata_mg_filtered.obs = obs.copy()
adata_combined_subset.obs = obs.copy()

In [None]:
level_keys = ['Level_1', 'Level_2', 'Level_3', 'Level_4_Final']
sc.pl.umap(adata_mg_filtered, color=level_keys, frameon=False, wspace=0.75, ncols=2, legend_fontsize=8)

In [None]:
level_keys = ['Level_1', 'Level_2', 'Level_3', 'Level_4_Final']
sc.pl.umap(adata_combined_subset, color=level_keys, frameon=False, wspace=0.75, ncols=2, legend_fontsize=8)

In [None]:
adata_combined_subset.write('final_scanVI_2.0/Core_Extension_All_Genes.h5ad',  compression='gzip')

In [None]:
adata_mg_filtered.write('final_scanVI_2.0/Core_Extension_MG.h5ad', compression='gzip')

In [None]:
pwd