# **Cell Type Annotation (scANVI)**

In [17]:
import scanpy as sc
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from scipy import stats
import warnings
warnings.filterwarnings('ignore')
# scVI/scANVI imports
import scvi
from scvi.model import SCVI, SCANVI
# Set random seeds for reproducibility
np.random.seed(42)

# **Load CellMarker Database and Process Marker Genes**

In [None]:
# Load the Excel file
cell_marker_df = pd.read_excel('Cell_marker_Seq.xlsx', sheet_name='seq')

# ===== USER-DEFINED FILTERING PARAMETERS =====
# Set these parameters before running

# 1. Species (required)
species = "Human"  # Options: "Human" or "Mouse"

# 2. Tissue classes (optional - can be string, list, or None for all)
tissue_class = "Intestine"  # Options:
# - Single tissue: "Blood"
# - Multiple tissues: ["Blood", "Brain", "Intestine"]
# - None: All tissue classes

# 3. Tissue types (optional - can be string, list, or None for all)
tissue_type = None  # Options:
# - Single type: "Peripheral blood"
# - Multiple types: ["Cortex", "Hippocampus"]
# - None: All tissue types within selected tissue_class(es)

# 4. Cancer types (optional - can be string, list, or None for all)
cancer_type = None  # Options:
# - "Normal" (for normal tissues)
# - Specific cancer type: "Bladder Cancer"
# - None: All (both normal and cancer)

# ===== APPLY FILTERS =====
print("Applying filters to CellMarker database...")

# Filter by species (required)
cell_marker_df = cell_marker_df[cell_marker_df['species'] == species]
print(f"After species filter: {len(cell_marker_df)} rows")

# Filter by tissue_class if specified
if tissue_class is not None:
    if isinstance(tissue_class, str):
        cell_marker_df = cell_marker_df[cell_marker_df['tissue_class'] == tissue_class]
    elif isinstance(tissue_class, list):
        cell_marker_df = cell_marker_df[cell_marker_df['tissue_class'].isin(tissue_class)]
    print(f"After tissue_class filter: {len(cell_marker_df)} rows")

# Filter by tissue_type if specified
if tissue_type is not None:
    if isinstance(tissue_type, str):
        cell_marker_df = cell_marker_df[cell_marker_df['tissue_type'] == tissue_type]
    elif isinstance(tissue_type, list):
        cell_marker_df = cell_marker_df[cell_marker_df['tissue_type'].isin(tissue_type)]
    print(f"After tissue_type filter: {len(cell_marker_df)} rows")

# Filter by cancer_type if specified
if cancer_type is not None:
    if isinstance(cancer_type, str):
        cell_marker_df = cell_marker_df[cell_marker_df['cancer_type'] == cancer_type]
    elif isinstance(cancer_type, list):
        cell_marker_df = cell_marker_df[cell_marker_df['cancer_type'].isin(cancer_type)]
    print(f"After cancer_type filter: {len(cell_marker_df)} rows")

# Display summary
print("\n=== FILTERING SUMMARY ===")
print(f"Species: {species}")
print(f"Tissue class(es): {tissue_class if tissue_class else 'All'}")
print(f"Tissue type(s): {tissue_type if tissue_type else 'All'}")
print(f"Cancer type(s): {cancer_type if cancer_type else 'All'}")
print(f"Total rows after filtering: {len(cell_marker_df)}")
print(f"Unique tissue classes: {cell_marker_df['tissue_class'].unique()}")
print(f"Unique tissue types: {cell_marker_df['tissue_type'].unique()[:10]}")
print(f"Unique cell types: {len(cell_marker_df['cell_name'].unique())}")

# Display first few rows for verification
print("\nFirst few rows of data:")
display(cell_marker_df.head())

# **Select Cell Types to Annotate**

In [None]:
# ===== USER-DEFINED CELL TYPE SELECTION =====
# Option 1: Specific cell types (list)
selected_cell_types = None  # Options:
# - None: Use all cell types from filtered data
# - List: ["CD4+ T cell", "CD8+ T cell", "B cell", "Monocyte"]

# Option 2: Minimum markers required per cell type
min_markers_per_cell_type = 4  # Only include cell types with at least this many markers

# Option 3: Maximum cell types to include (top N by number of markers)
max_cell_types = None  # Set to None to include all

# ===== PROCESS CELL TYPES =====
print("Processing cell types...")

# Group markers by cell type
cell_type_markers = {}
cell_type_counts = {}

for cell_type in cell_marker_df['cell_name'].unique():
    markers = cell_marker_df[cell_marker_df['cell_name'] == cell_type]['marker'].dropna().unique()

    # Only keep markers that are in our dataset
    markers_in_data = [str(m).upper() for m in markers if str(m) in adata.var_names]

    # Apply minimum markers filter
    if len(markers_in_data) >= min_markers_per_cell_type:
        cell_type_markers[cell_type] = markers_in_data
        cell_type_counts[cell_type] = len(markers_in_data)

print(f"\nFound {len(cell_type_markers)} cell types with at least {min_markers_per_cell_type} markers")

# Apply cell type selection
if selected_cell_types is None:
    # Use all cell types from filtered data
    selected_markers = cell_type_markers.copy()
    print(f"Using ALL {len(selected_markers)} cell types")
else:
    # Use only specified cell types
    selected_markers = {}
    for ct in selected_cell_types:
        if ct in cell_type_markers:
            selected_markers[ct] = cell_type_markers[ct]
        else:
            print(f"Warning: {ct} not found in filtered marker database")
    print(f"Using {len(selected_markers)} specified cell types")

# Limit to top N cell types by marker count if requested
if max_cell_types is not None and len(selected_markers) > max_cell_types:
    # Sort cell types by number of markers (descending)
    sorted_cell_types = sorted(selected_markers.items(),
                               key=lambda x: len(x[1]),
                               reverse=True)[:max_cell_types]
    selected_markers = dict(sorted_cell_types)
    print(f"Limited to top {max_cell_types} cell types by marker count")

# Display selected cell types
print("\n=== SELECTED CELL TYPES ===")
for i, (ct, markers) in enumerate(selected_markers.items()):
    print(f"{i+1}. {ct}: {len(markers)} markers")
    if i < 5:  # Show first 5 markers for first few cell types
        print(f"   Sample markers: {', '.join(markers[:5])}")
    elif i == 5:
        print("   ... (truncated)")

# Get all unique markers for downstream processing
all_markers = set()
for markers in selected_markers.values():
    all_markers.update(markers)
print(f"\nTotal unique markers for annotation: {len(all_markers)}")

# **Marker–Expression Alignment**

In [None]:
# First, let's check what we have
print("Checking available data...")
print(f"Total genes in adata: {adata.n_vars}")
print(f"Genes in adata.X: {adata.X.shape[1]}")
if 'layers' in dir(adata):
    print(f"Layers available: {list(adata.layers.keys())}")

# Get all marker genes from our curated list
all_markers = set()
for markers in selected_markers.values():
    all_markers.update(markers)

print(f"\nTotal curated marker genes: {len(all_markers)}")

# Check which marker genes are in our dataset
available_markers = [gene for gene in all_markers if gene in adata.var_names]
print(f"Marker genes found in adata.var_names: {len(available_markers)}")

if len(available_markers) == 0:
    print("ERROR: No marker genes found in adata.var_names")
    print("Check gene naming conventions (case sensitivity, symbols vs. IDs)")

    # Try case-insensitive matching
    print("\nTrying case-insensitive matching...")
    adata_genes_upper = [g.upper() for g in adata.var_names]
    available_markers_case_insensitive = []
    for gene in all_markers:
        if gene.upper() in adata_genes_upper:
            # Find the actual gene name in adata (preserving original case)
            idx = adata_genes_upper.index(gene.upper())
            original_gene_name = adata.var_names[idx]
            available_markers_case_insensitive.append(original_gene_name)

    print(f"Found {len(available_markers_case_insensitive)} marker genes with case-insensitive matching")
    available_markers = available_markers_case_insensitive

# Create a new AnnData object with only marker genes
print(f"\nCreating adata_subset with {len(available_markers)} marker genes...")

# Create subset with only marker genes
adata_subset = adata[:, available_markers].copy()

print("Using scaled expression matrix from adata.X...")
# Store the scaled expression matrix (already in adata.X) in a layer for reference
adata_subset.layers["scaled"] = adata_subset.X.copy()

print(f"Created adata_subset with shape: {adata_subset.shape}")
print(f"Scaled expression data stored in: adata_subset.X and adata_subset.layers['scaled']")

# Verify scaling is already done
print("\nVerifying scaling...")
# Check a few genes to confirm scaling (mean~0, std~1)
sample_genes = min(5, adata_subset.n_vars)
for i in range(sample_genes):
    gene_expression = adata_subset.X[:, i]
    if isinstance(gene_expression, np.ndarray):
        print(f"Gene {adata_subset.var_names[i]}: mean={gene_expression.mean():.2f}, std={gene_expression.std():.2f}")
    else:
        # If it's a sparse matrix, convert to array for calculation
        gene_array = gene_expression.toarray().flatten()
        print(f"Gene {adata_subset.var_names[i]}: mean={gene_array.mean():.2f}, std={gene_array.std():.2f}")

# Copy relevant metadata from original adata
print("\nCopying metadata...")
adata_subset.obs = adata.obs.copy()
adata_subset.obsm = {key: adata.obsm[key].copy() for key in adata.obsm.keys()}
adata_subset.uns = adata.uns.copy()

# Store marker gene information
adata_subset.uns['marker_genes'] = available_markers
adata_subset.uns['selected_markers'] = selected_markers

print(f"\n=== adata_subset SUMMARY ===")
print(f"Shape: {adata_subset.shape}")
print(f"Cells: {adata_subset.n_obs}")
print(f"Marker genes: {adata_subset.n_vars}")
print(f"Layers: {list(adata_subset.layers.keys())}")
print(f"Obs columns: {list(adata_subset.obs.columns[:5])}...")
print(f"\nFirst 10 marker genes:")
for i, gene in enumerate(adata_subset.var_names[:10]):
    print(f"  {i+1}. {gene}")

# Optional: Visualize marker gene expression distribution
fig, axes = plt.subplots(1, 2, figsize=(12, 4))

# Plot 1: Distribution of expression values for a few marker genes
sample_genes = min(5, len(available_markers))
for i in range(sample_genes):
    gene = adata_subset.var_names[i]
    expression = adata_subset.X[:, i]
    # Convert sparse to dense if needed
    if not isinstance(expression, np.ndarray):
        expression = expression.toarray().flatten()
    axes[0].hist(expression, bins=50, alpha=0.5, label=gene)
axes[0].set_xlabel('Expression (scaled)')
axes[0].set_ylabel('Frequency')
axes[0].set_title('Scaled Expression Distribution\n(sample marker genes)')
axes[0].legend()

# Plot 2: Heatmap of marker gene expression across clusters
if 'leiden' in adata_subset.obs.columns:
    # Calculate mean expression per cluster for top markers
    n_markers_to_show = min(20, len(available_markers))
    top_markers = adata_subset.var_names[:n_markers_to_show]

    # Group by cluster
    cluster_means = []
    clusters = sorted(adata_subset.obs['leiden'].unique())

    for cluster in clusters:
        cluster_mask = adata_subset.obs['leiden'] == cluster
        cluster_cells = adata_subset[cluster_mask, top_markers]
        cluster_data = cluster_cells.X.mean(axis=0)
        # Handle sparse or dense format
        if hasattr(cluster_data, 'toarray'):
            cluster_data = cluster_data.toarray().flatten()
        elif hasattr(cluster_data, 'A1'):
            cluster_data = cluster_data.A1
        cluster_means.append(cluster_data)

    cluster_means = np.array(cluster_means)

    # Plot heatmap
    im = axes[1].imshow(cluster_means.T, aspect='auto', cmap='viridis')
    axes[1].set_xlabel('Cluster')
    axes[1].set_ylabel('Marker Gene')
    axes[1].set_xticks(range(len(clusters)))
    axes[1].set_xticklabels(clusters)
    axes[1].set_yticks(range(len(top_markers)))
    axes[1].set_yticklabels(top_markers, fontsize=8)
    axes[1].set_title('Mean Scaled Expression\nper Cluster')
    plt.colorbar(im, ax=axes[1], label='Mean Expression')

plt.tight_layout()
plt.show()

print("\nadata_subset is ready for seed labeling!")

# **Define Helper Functions for Seed Labeling**

In [23]:
def get_score_per_cluster(adata_subset, cluster_id, gene_set):
    """Returns the score per cell for a given cluster and gene set"""
    # Get cells in this cluster
    cluster_mask = adata_subset.obs['leiden'] == cluster_id
    cluster_cells = adata_subset[cluster_mask]

    # Calculate score
    score = np.zeros(cluster_cells.n_obs)
    for gene in gene_set:
        if gene in cluster_cells.var_names:
            expression = np.array(cluster_cells[:, gene].X)
            score += expression.flatten()

    # Normalize by number of genes found
    genes_found = len([g for g in gene_set if g in cluster_cells.var_names])
    if genes_found > 0:
        score = score / genes_found

    return score

def assign_label_to_cluster(adata_subset, cluster_id, selected_markers, seed_cell_mask):
    """Assigns a label to seed cells in a cluster based on marker scores"""
    cluster_mask = adata_subset.obs['leiden'] == cluster_id
    seed_in_cluster = cluster_mask & seed_cell_mask

    if not np.any(seed_in_cluster):
        return None

    best_score = -np.inf
    best_label = "Unknown"

    # Calculate average score for each cell type in this cluster's seed cells
    for cell_type, markers in selected_markers.items():
        # Get scores for seed cells in this cluster
        scores = get_score_per_cluster(adata_subset, cluster_id, markers)
        cluster_indices = np.where(cluster_mask)[0]
        seed_indices = np.where(seed_in_cluster)[0]

        # Map seed indices to cluster indices
        seed_in_cluster_indices = np.isin(cluster_indices, seed_indices)

        if np.any(seed_in_cluster_indices):
            avg_score = np.mean(scores[seed_in_cluster_indices])

            if avg_score > best_score:
                best_score = avg_score
                best_label = cell_type

    return best_label if best_score > 0 else "Unknown"

def create_seed_labels_per_cluster(adata_subset, seed_frac=0.3):
    """Create seed labels for each cluster independently"""
    # Initialize seed labels as "Unknown" for all cells
    seed_labels = pd.Series(["Unknown"] * adata_subset.n_obs, index=adata_subset.obs_names)
    seed_cell_mask = pd.Series(False, index=adata_subset.obs_names)

    clusters = sorted(adata_subset.obs['leiden'].unique())

    for cluster in clusters:
        # Get cells in this cluster
        cluster_cells = adata_subset.obs_names[adata_subset.obs['leiden'] == cluster]
        n_cluster_cells = len(cluster_cells)
        n_seed_cells = max(1, int(n_cluster_cells * seed_frac))

        # Randomly select seed cells
        seed_indices = np.random.choice(cluster_cells, size=n_seed_cells, replace=False)

        # Mark these as seed cells
        seed_cell_mask.loc[seed_indices] = True

        print(f"Cluster {cluster}: {n_seed_cells} seed cells out of {n_cluster_cells} total")

    return seed_labels, seed_cell_mask

# **Cluster-wise Seed Label Selection**

In [None]:
# Create seed cell mask (40% of each cluster)
seed_frac = 0.4
seed_labels, seed_cell_mask = create_seed_labels_per_cluster(adata_subset, seed_frac=seed_frac)

print(f"\nTotal cells: {adata_subset.n_obs}")
print(f"Seed cells: {seed_cell_mask.sum()} ({seed_cell_mask.sum()/adata_subset.n_obs*100:.1f}%)")
print(f"Unknown cells: {(~seed_cell_mask).sum()} ({(~seed_cell_mask).sum()/adata_subset.n_obs*100:.1f}%)")

# Store in adata_subset
adata_subset.obs['seed_cell_mask'] = seed_cell_mask.values
adata_subset.obs['seed_labels'] = seed_labels.values

# Visualize seed cell distribution
fig, axes = plt.subplots(1, 2, figsize=(12, 5))

# Plot 1: Seed cells on UMAP
sc.pl.umap(adata_subset, color='leiden', show=False, ax=axes[0])
seed_cells_umap = adata_subset.obsm['X_umap'][adata_subset.obs['seed_cell_mask']]
axes[0].scatter(seed_cells_umap[:, 0], seed_cells_umap[:, 1],
                s=10, c='red', alpha=0.5, label='Seed cells')
axes[0].set_title('Seed Cells Distribution')
axes[0].legend()

# Plot 2: Percentage of seed cells per cluster
cluster_counts = adata_subset.obs.groupby('leiden').size()
seed_counts = adata_subset.obs[adata_subset.obs['seed_cell_mask']].groupby('leiden').size()
seed_percent = (seed_counts / cluster_counts * 100).fillna(0)

axes[1].bar(seed_percent.index, seed_percent.values)
axes[1].axhline(y=seed_frac*100, color='r', linestyle='--', alpha=0.5, label=f'Target ({seed_frac*100}%)')
axes[1].set_xlabel('Leiden Cluster')
axes[1].set_ylabel('Percentage of Seed Cells (%)')
axes[1].set_title('Seed Cells per Cluster')
axes[1].legend()

plt.tight_layout()
plt.show()

# **Marker Gene Scoring and Initial Label Assignment**

In [None]:
# Initialize label assignment
cluster_labels = {}

# Process each cluster
clusters = sorted(adata_subset.obs['leiden'].unique())
print(f"Processing {len(clusters)} clusters...\n")

for cluster in clusters:
    # Assign label based on seed cell marker scores
    assigned_label = assign_label_to_cluster(
        adata_subset,
        cluster,
        selected_markers,
        adata_subset.obs['seed_cell_mask']
    )

    cluster_labels[cluster] = assigned_label

    print(f"Cluster {cluster}: Assigned label = {assigned_label}")

# Apply cluster labels to seed cells
final_seed_labels = seed_labels.copy()
for cluster, label in cluster_labels.items():
    cluster_mask = (adata_subset.obs['leiden'] == cluster) & adata_subset.obs['seed_cell_mask']
    final_seed_labels[cluster_mask] = label

# Update in adata_subset
adata_subset.obs['seed_labels'] = final_seed_labels.values

# Also store in original adata for scVI training
adata.obs['seed_labels'] = "Unknown"
adata.obs.loc[adata_subset.obs_names, 'seed_labels'] = final_seed_labels.values

# Summary
print("\n=== LABEL ASSIGNMENT SUMMARY ===")
unique_labels = final_seed_labels[final_seed_labels != "Unknown"].unique()
print(f"Assigned {len(unique_labels)} unique labels to seed cells:")
for label in unique_labels:
    n_cells = (final_seed_labels == label).sum()
    print(f"  {label}: {n_cells} cells")

# Visualization
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Plot 1: Assigned labels on UMAP (seed cells only)
seed_adata = adata_subset[adata_subset.obs['seed_cell_mask']].copy()
if len(seed_adata) > 0:
    sc.pl.umap(seed_adata, color='seed_labels',
               palette='tab20', show=False, ax=axes[0])
    axes[0].set_title('Assigned Labels (Seed Cells Only)')

# Plot 2: Distribution of labels across clusters
label_cluster_table = pd.crosstab(
    adata_subset.obs[adata_subset.obs['seed_cell_mask']]['seed_labels'],
    adata_subset.obs[adata_subset.obs['seed_cell_mask']]['leiden']
)
label_cluster_table.plot(kind='bar', stacked=True, ax=axes[1], colormap='tab20')
axes[1].set_title('Label Distribution Across Clusters')
axes[1].set_xlabel('Assigned Label')
axes[1].set_ylabel('Number of Seed Cells')
axes[1].legend(title='Cluster', bbox_to_anchor=(1.05, 1), loc='upper left')

plt.tight_layout()
plt.show()

# **Prepare Full Data for scVI Training**

In [None]:
# Ensure seed_labels are in the original adata
print(f"Cells with seed labels: {(adata.obs['seed_labels'] != 'Unknown').sum()}")
print(f"Cells marked as Unknown: {(adata.obs['seed_labels'] == 'Unknown').sum()}")

# Check if we need to add scale.data to the original adata for all genes
if 'scale.data' not in adata.layers:
    print("Warning: scale.data not found in original adata")
    print("Using normalized data from layers['data']")
    # This is acceptable since scVI works with raw counts or normalized data
else:
    print("scale.data available for all genes")

# Verify data structure
print(f"\nOriginal adata shape: {adata.shape}")
print(f"Seed labels unique values: {adata.obs['seed_labels'].unique()}")

# **Train scVI Model**

In [None]:
# Setup the AnnData for scVI
# We'll use the full adata
# Note: scVI works best with raw counts, but can work with normalized data
# Since we have layers["counts"], we can use that

if "counts" in adata.layers:
    print("Using raw counts from layers['counts']")
    # Temporarily replace X with counts for scVI
    counts_matrix = adata.layers["counts"].copy()
    # Convert to CSC format for efficiency if needed
    if not scipy.sparse.issparse(counts_matrix):
        counts_matrix = scipy.sparse.csc_matrix(counts_matrix)
    adata.X = counts_matrix
else:
    print("Warning: counts layer not found. Using current adata.X")
    print("Note: scVI works best with raw UMI counts")

# Setup scVI
try:
    scvi.model.SCVI.setup_anndata(adata, batch_key=None, labels_key="seed_labels")
    print("Successfully setup AnnData for scVI")
except Exception as e:
    print(f"Error setting up AnnData: {e}")
    # Try without labels_key if there's an issue
    scvi.model.SCVI.setup_anndata(adata, batch_key=None)

# Train scVI model
print("\nTraining scVI model...")
scvi_model = scvi.model.SCVI(adata, n_latent=30, n_layers=2)
scvi_model.train(max_epochs=100, plan_kwargs={'lr': 1e-3})   # Add: accelerator="cpu", devices=1; if you dont have a gpu

# Get latent representation
adata.obsm["X_scVI"] = scvi_model.get_latent_representation()

print("scVI training completed!")
print(f"Latent representation shape: {adata.obsm['X_scVI'].shape}")

In [None]:
# save scvi_model
scvi_model.save("scvi_model/", overwrite=True)

# load scvi_model command
# scvi_model = scvi.model.SCVI.load("scvi_model/", adata=adata)

# **Train scANVI Model for Label Transfer**

In [None]:
# Train scANVI for label transfer
print("Initializing scANVI from scVI model...")
scanvi_model = scvi.model.SCANVI.from_scvi_model(
    scvi_model,
    unlabeled_category="Unknown",
    labels_key="seed_labels"
)

# Train scANVI
print("Training scANVI model...")
scanvi_model.train(max_epochs=25, plan_kwargs={'lr': 1e-3})     # Add: accelerator="cpu", devices=1; if you dont have a gpu

# Predict labels for all cells
print("Predicting labels for all cells...")
adata.obs["C_scANVI"] = scanvi_model.predict(adata)
adata.obsm["X_scANVI"] = scanvi_model.get_latent_representation(adata)

print("scANVI training and prediction completed!")
print(f"Predicted labels: {adata.obs['C_scANVI'].unique()}")

# **Visualize Results**

In [None]:
# Visualize the results
# Run UMAP on scANVI latent space
sc.pp.neighbors(adata, use_rep="X_scANVI")
sc.tl.umap(adata)

# Plot results
fig, axes = plt.subplots(1, 2, figsize=(14, 6))

# Plot seed labels
sc.pl.umap(adata, color="seed_labels", ax=axes[0],
           title="Seed Labels", show=False, frameon=False)

# Plot predicted labels
sc.pl.umap(adata, color="C_scANVI", ax=axes[1],
           title="scANVI Predicted Labels", show=False, frameon=False)

plt.tight_layout()
plt.show()

# Print label distribution
print("\nFinal label distribution:")
print(adata.obs.C_scANVI.value_counts())

In [None]:
sc.pl.umap(adata, color=["leiden", "C_scANVI"], frameon = False, wspace=0.3)

# **Save Results**

In [None]:
# Save the annotated dataset
adata.write('annotated_dataset.h5ad')

In [None]:
print("Step 12: Saving Results")

# Save the annotation results as CSV
annotation_results = adata.obs[['leiden', 'seed_labels', 'C_scANVI']].copy()
annotation_results.to_csv('cell_annotations.csv')
print("Annotation results saved to: cell_annotations.csv")

# Save marker gene information
marker_info = pd.DataFrame([
    {'cell_type': ct, 'marker': marker}
    for ct, markers in selected_markers.items()
    for marker in markers
])
marker_info.to_csv('used_marker_genes.csv', index=False)
print("Marker gene information saved to: used_marker_genes.csv")

print("\n=== PIPELINE COMPLETED SUCCESSFULLY ===")