# Annotation Voting

Each method used to predict cell type annotations is considered as part of the votation in this way:

Let $w$ be the weight for each method in $M$ and $k$ a cell type annotation in $K$

$$P(k) = \sum_{i=1}^M w_i * P(k|m_i) $$


Each annotation method $m_i \in M$ provides either a predicted label $k_i \in K \cup \{unknown\}$, and an associated confidence score $c_i \in [0,1]$

We define a normalized weight $w_i$ for each method, reflecting its global concordance.

Then, for a given cell, the ensemble assigns a probability to each possible cell type $k \in K$ as:

$$P(k)=\frac{1}{Z} \sum_{i=1}^M w_i P(k∣mi)$$

where $Z=\sum_{i=1}^M w_i$ is a normalization constant ensuring $\sum_k P(k)=1$.


## Final ensemble decision

The ensemble’s predicted label is:

$$\hat{k} = \text{arg max}_{k \in K}(P(k))$$

and uncertainty (entropy) can be quantified as:

$$H = -\sum_{k \in K} P(k) \log{P(k)}$$


In [16]:
path_to_predictions = "../data/ann_integration/all_predictions_with_unknown.csv"
path_to_adata = "/media/raqcoss/USB KGST/ratopin_counts/out_datasets/final_nmr_crossAnn9.h5ad"

In [17]:
import pandas as pd
import numpy as np
from itertools import combinations
from collections import Counter
from math import log2
from sklearn.metrics import adjusted_rand_score, normalized_mutual_info_score, silhouette_score
from statsmodels.stats.inter_rater import fleiss_kappa
from scipy.stats import entropy
from sklearn.preprocessing import LabelEncoder
#import scanpy as sc


## Metrics for Predictions

This notebook implements **Shannon entropy-based model congruence** to evaluate the agreement among multiple annotation methods (SCANVI, CellTypist, CellAnnotator, etc.).

### Core Concept: Shannon Entropy-Based Congruence

Model congruence measures how much the ensemble predictions **disagree** or **agree** across models at different hierarchical levels (specific cell type, broad group, supergroup).

**Per-cell congruence score**:
$${C = 1 - \frac{H_{\text{total}}}{H_{\text{max}}}}$$

where:
- $H_L = - \sum_{y} p_L(y) \log_2 p_L(y)$ is Shannon entropy at level $L$ (specific, group, super)
- $p_L(y)$ = proportion of models predicting label $y$ at level $L$
- $H_{\text{max}} = \log_2(N_{\text{unique}})$ = maximum possible entropy

**Interpretation**:
- $C = 1$: Perfect agreement (all models predict same label)
- $C ≈ 0.5$: Moderate disagreement
- $C ≈ 0$: Maximal disagreement (uniform distribution across predictions)

---

### Supporting Metrics

#### - Fleiss' Kappa
Measures the degree of agreement between various categorical predictors (comparing it to a random choice)

>Let $N$ be the number of cells, $M$ the number of models, and $K$ the number of possible labels.
For each cell $j$, let $n_{jk}$ be the number of models that assigned label $k$.
Then the agreement for that cell is:
>$$P_j = \frac{1}{M(M-1)} \sum_{k=1}^{K} n_{jk}(n_{jk} - 1)$$
>and the mean agreement across all cells:
>$$\bar{P} = \frac{1}{N} \sum_{j=1}^{N} P_j$$
>Let the expected agreement by chance be:
>$$\bar{P}_e = \sum_{k=1}^{K} p_k^2, \quad \text{where} \quad p_k = \frac{1}{N M} \sum_{j=1}^{N} n_{jk}$$
>Then Fleiss' Kappa is given by:
>$$\kappa = \frac{\bar{P} - \bar{P}_e}{1 - \bar{P}_e}$$
>A value of $\kappa = 1$ indicates perfect agreement, $\kappa = 0$ corresponds to random labeling, and negative values indicate systematic disagreement.


#### - Pairwise Agreement (ARI & NMI)

Quantifies agreement between model pairs using Adjusted Rand Index (ARI) and Normalized Mutual Information (NMI).

>For models $m_a, m_b \in M$:
>$$A_{ab} = \frac{1}{N} \sum_{j=1}^{N} [k_{a,j} = k_{b,j}]$$
>The overall pairwise agreement is the mean over all unique model pairs:
>$$\bar{A} = \frac{2}{M(M-1)} \sum_{a < b} A_{ab}$$
>This metric can be computed globally, per cluster, or per label to identify systematic disagreements between models.


#### - Per-Cell Entropy (Shannon Entropy)

Measures the uncertainty of the ensemble prediction for each cell, given the distribution of predicted labels across models.

>For a cell $j$, let $P_j(k)$ be the ensemble probability of label $k$:
>$$P_j(k) = \frac{1}{Z_j} \sum_{i=1}^{M} w_i \cdot c_{ij} \cdot [k_{ij} = k]$$
>where $w_i$ is the model's weight, $c_{ij}$ its confidence score for that cell, and $Z_j$ is the normalization constant.
>
>Then the entropy for cell $j$ is:
>$$H_j = - \sum_{k=1}^{K} P_j(k) \log_2 P_j(k)$$
>
>Averaging across all cells gives the global ensemble entropy:
>$$\bar{H} = \frac{1}{N} \sum_{j=1}^{N} H_j$$
>
>Low entropy indicates consistent predictions (high agreement), while high entropy reflects uncertainty or label conflict among models.

### Hierarchical Entropy Framework (Advanced)

For hierarchical cell-type annotations (specific → group → supergroup):

**Weighted hierarchical entropy**:
$$H_{\text{total}} = w_{\text{specific}} H_{\text{specific}} + w_{\text{group}} H_{\text{group}} + w_{\text{super}} H_{\text{super}} + \lambda H_{S,G}$$

where:
- $w_s, w_g, w_{sg}$ are level-specific weights (e.g., $w_s=1/9, w_g=3/9, w_{sg}=5/9$)
- $\lambda H_{S,G}$ is a coupling term capturing joint entropy between specific-group predictions
- Models predicting only partial hierarchies are handled gracefully (entropy computed only among comparable predictions)

**Advantages**:
- Continuous metric (unlike Kappa which is discrete)
- Handles incomplete predictions across hierarchical levels
- Per-cell granularity for identifying ambiguous annotations
- Information-theoretic interpretation: measures uncertainty reduction

In [18]:
def entropy_from_counts(counter):
    """Compute Shannon entropy (base-2) from a collections.Counter of items.
    Returns 0.0 when counter is empty or only one unique item.
    """
    total = sum(counter.values())
    if total == 0:
        return 0.0
    probs = np.array(list(counter.values()), dtype=float) / float(total)
    return float(entropy(probs, base=2))

def _ensemble_entropy(pred_df):
    """
    Compute per-cell entropy and consensus across model predictions (flat, non-hierarchical).
    pred_df: DataFrame with one column per model of categorical labels.
    Returns two arrays: entropy (bits) and consensus (max probability).
    """
    ents = []
    cons = []
    for idx, row in pred_df.iterrows():
        vals = [v for v in row.values if pd.notna(v)]
        if not vals:
            ents.append(np.nan)
            cons.append(np.nan)
            continue
        ctr = Counter(vals)
        H = entropy_from_counts(ctr)
        ents.append(H)
        total = sum(ctr.values())
        cons.append(max(ctr.values())/total if total else np.nan)
    return np.array(ents), np.array(cons)

def _pairwise_agreement(pred_df):
    """
    Compute mean ARI and NMI across all pairs of model prediction columns.
    Returns (mean_ARI, mean_NMI).
    """
    cols = list(pred_df.columns)
    if len(cols) < 2:
        return np.nan, np.nan
    aris = []
    nmis = []
    for a, b in combinations(cols, 2):
        sub = pred_df[[a, b]].dropna()
        if len(sub) < 5:
            continue
        aris.append(adjusted_rand_score(sub[a], sub[b]))
        nmis.append(normalized_mutual_info_score(sub[a], sub[b]))
    return (np.nanmean(aris) if aris else np.nan, np.nanmean(nmis) if nmis else np.nan)

def hierarchical_entropy(predictions, level_mapping=None, weights=None, lambda_joint=0.5):
    """
    Compute per-cell entropy and per-level congruence using Shannon entropy for hierarchical predictions.

    Returns a dict containing entropies per level and per-level congruence scores and overall congruence.
    """
    if weights is None:
        weights = {"specific": 1/9, "group": 3/9, "super": 5/9}

    # Initialize level predictions
    level_preds = {"specific": [], "group": [], "region": [], "super": []}
    joint_pairs = []  # (specific, group) tuples

    for pred in predictions:
        if pd.isna(pred) or pred is None:
            continue
        if level_mapping is not None and pred in level_mapping:
            lv = level_mapping[pred]
            if isinstance(lv, dict):
                level_preds["specific"].append(lv.get('specific', pred))
                if 'group' in lv: level_preds["group"].append(lv.get('group'))
                if 'region' in lv: level_preds["region"].append(lv.get('region'))
                if 'super' in lv: level_preds["super"].append(lv.get('super'))
                if 'group' in lv: joint_pairs.append((lv.get('specific', pred), lv.get('group')))
            else:
                level_preds["specific"].append(pred)
        else:
            # treat as specific only
            level_preds["specific"].append(pred)

    entropies = {}
    per_level_congruence = {}
    for L, preds_list in level_preds.items():
        uniq = set([p for p in preds_list if p is not None])
        if len(uniq) > 1:
            H_L = entropy_from_counts(Counter(preds_list))
            entropies[L] = H_L
            H_max_L = log2(max(len(uniq), 2))
            per_level_congruence[L] = 1 - (H_L / H_max_L) if H_max_L > 0 else 1.0
        elif len(uniq) == 1:
            entropies[L] = 0.0
            per_level_congruence[L] = 1.0
        else:
            entropies[L] = 0.0
            per_level_congruence[L] = np.nan

    # Joint entropy between specific and group
    if len(set(joint_pairs)) > 1:
        H_joint = entropy_from_counts(Counter(joint_pairs))
    else:
        H_joint = 0.0

    # Weighted total entropy (use only levels with predictions)
    H_total = 0.0
    for L in ['specific','group','region','super']:
        w = (weights or {}).get(L, 0.0)
        H_total += w * entropies.get(L, 0.0)
    H_total += lambda_joint * H_joint

    all_preds = sum(level_preds.values(), [])
    n_unique = len(set([p for p in all_preds if p is not None]))
    H_max = log2(max(n_unique, 2)) if n_unique > 0 else 1.0
    congruence = 1 - (H_total / H_max) if H_max > 0 else 1.0
    congruence = max(min(congruence, 1), 0)

    return {
        'entropies': entropies,
        'per_level_congruence': per_level_congruence,
        'joint_entropy': H_joint,
        'H_total': H_total,
        'H_max': H_max,
        'congruence': congruence
    }

def compute_per_cell_congruence_df(df, model_pred_cols=None, level_mapping=None, weights=None, lambda_joint=0.5):
    """
    Compute per-cell congruence and per-level congruence, returning a DataFrame with columns:
    ['congruence', 'congruence_specific', 'congruence_group', 'congruence_region', 'congruence_super']
    """
    if model_pred_cols is None:
        model_pred_cols = [c for c in df.columns if c.endswith('_pred') or c.endswith('_label')]

    records = []
    for idx, row in df.iterrows():
        preds = [row[col] for col in model_pred_cols if pd.notna(row[col])]
        res = hierarchical_entropy(preds, level_mapping=level_mapping, weights=weights, lambda_joint=lambda_joint)
        rec = {
            'congruence': res['congruence'],
            'congruence_specific': res['per_level_congruence'].get('specific', np.nan),
            'congruence_group': res['per_level_congruence'].get('group', np.nan),
            'congruence_region': res['per_level_congruence'].get('region', np.nan),
            'congruence_super': res['per_level_congruence'].get('super', np.nan),
            'entropies': res['entropies']
        }
        records.append(rec)
    out = pd.DataFrame(records, index=df.index)
    # expand entropies into columns if needed
    out['entropy_specific'] = out['entropies'].apply(lambda d: d.get('specific', 0) if isinstance(d, dict) else 0)
    out['entropy_group'] = out['entropies'].apply(lambda d: d.get('group', 0) if isinstance(d, dict) else 0)
    out['entropy_region'] = out['entropies'].apply(lambda d: d.get('region', 0) if isinstance(d, dict) else 0)
    out['entropy_super'] = out['entropies'].apply(lambda d: d.get('super', 0) if isinstance(d, dict) else 0)
    out = out.drop(columns=['entropies'])
    return out

### Model's Weight Calculation

Define Model's Weight for score by its average mean Predicting Score and average agreement. 

$$w_m = \frac{\text{mean(PredScore)}_m \times \text{mean(agreement)}_m}{\sum w}$$

In [19]:
def _compute_model_weights(df, model_names, pred_suffix = '_label'):
    """
    Compute per-model weights based on confidence and mutual agreement.
    
    Weight formula: w_m = (mean_confidence_m × mean_agreement_m) / sum(weights)
    
    Parameters
    ----------
    df : pd.DataFrame
        Must contain '{model}_pred_score' columns
    model_names : list
        List of model names (e.g., ['scvi', 'celltypist', 'CrossAnn'])
    pred_suffix: str
        Suffix used for prediction columns (default '_label')
    
    Returns
    -------
    dict
        Normalized weights for each model (sum to 1)
    """
    # Mean confidence per model
    confs = {
        m: df.get(f"{m}_pred_score", pd.Series(dtype=float)).mean(skipna=True)
        for m in model_names
    }
    
    # Agreement: average ARI vs others (on predicted labels)
    aris = {}
    for m1 in model_names:
        others = [m2 for m2 in model_names if m2 != m1 and f"{m2}{pred_suffix}" in df.columns]
        ari_vals = []
        for m2 in others:
            c1, c2 = f"{m1}{pred_suffix}", f"{m2}{pred_suffix}"
            if c1 in df.columns and c2 in df.columns:
                common = df[[c1, c2]].dropna()
                if len(common) > 5:
                    ari_vals.append(adjusted_rand_score(common[c1], common[c2]))
        aris[m1] = np.nanmean(ari_vals) if ari_vals else np.nan
    
    # Combine confidence and agreement
    weights = {}
    for m in model_names:
        conf = confs.get(m, 1.0)
        ari = aris.get(m, 1.0)
        if np.isnan(conf):
            conf = 1.0
        if np.isnan(ari):
            ari = 1.0
        w = conf * ari
        weights[m] = w if not np.isnan(w) else 1.0
    
    # Normalize to sum to 1
    total = np.sum(list(weights.values()))
    if total > 0:
        return {m: float(w) / float(total) for m, w in weights.items()}
    else:
        return {m: 1.0/len(weights) for m in weights}


def _weighted_majority_vote(df, model_names, weights, pred_suffix = '_label', score_suffix = '_pred_score', level_mapping=None, levels=['specific','group','region','super']):
    """
    Compute weighted ensemble label per hierarchical level based on model confidence and agreement.

    Parameters
    ----------
    df : pd.DataFrame
        DataFrame with predictions
    model_names : list
        List of model names
    weights : dict
        Model weights (from _compute_model_weights)
    pred_suffix : str
        Suffix for predicted label columns (default '_label')
    score_suffix : str
        Suffix for per-cell model score columns (default '_pred_score')
    level_mapping : dict or None
        Mapping from specific labels to hierarchical levels e.g. {spec: {'group':'G1','super':'SG1','region':'R1'}}
    levels : list
        List of levels to produce (order: most-specific first)

    Returns
    -------
    pd.DataFrame
        DataFrame with one column per level containing the ensemble label for that level
    """
    pred_cols = [f"{m}{pred_suffix}" for m in model_names if f"{m}{pred_suffix}" in df.columns]
    # For each level we will accumulate weighted votes
    level_votes = {L: [] for L in levels}

    for i, row in df.iterrows():
        # collect votes per model for this cell
        votes_per_level = {L: {} for L in levels}
        for m in model_names:
            col = f"{m}{pred_suffix}"
            score_col = f"{m}{score_suffix}"
            if col not in df.columns:
                continue
            val = row.get(col, None)
            if pd.isna(val) or val is None:
                continue
            conf = row.get(score_col, 1.0)
            w = weights.get(m, 1.0)
            # Use product of model weight and per-cell confidence
            vote_weight = float(w) * (float(conf) if not pd.isna(conf) else 1.0)
            # Determine labels at each level using level_mapping if provided
            if level_mapping and val in level_mapping:
                mapped = level_mapping[val]
                # mapped may contain keys like 'specific','group','region','super'
                for L in levels:
                    lab = mapped.get(L) if isinstance(mapped, dict) else None
                    # if L == 'specific' and no mapped value, we still accept the original val
                    if L == 'specific' and (lab is None):
                        lab = val
                    if lab is None:
                        continue
                    votes_per_level[L][lab] = votes_per_level[L].get(lab, 0.0) + vote_weight
            else:
                # no mapping provided — assign vote only to 'specific' level using the raw value
                votes_per_level['specific'][val] = votes_per_level['specific'].get(val, 0.0) + vote_weight
        # pick max for each level
        for L in levels:
            if votes_per_level[L]:
                chosen = max(votes_per_level[L].items(), key=lambda x: x[1])[0]
            else:
                chosen = None
            level_votes[L].append(chosen)
    # Return DataFrame with columns per level
    res = pd.DataFrame({L: pd.Series(level_votes[L], index=df.index) for L in levels})
    return res

In [20]:
def evaluate_ensemble(df, embedding=None, cluster_col="cluster", level_mapping=None, pred_suffix = '_label', levels=['specific','group','region','super'], congruence_thresholds=None):
    """
    Comprehensive evaluation that produces per-level ensemble predictions and per-level congruence.
    """
    # determine model names
    model_names = sorted({c.split(pred_suffix)[0] for c in df.columns if c.endswith(pred_suffix)})
    pred_cols = [f"{m}{pred_suffix}" for m in model_names if f"{m}{pred_suffix}" in df.columns]
    if len(pred_cols) == 0:
        raise ValueError('No prediction columns found. Ensure your DataFrame has columns ending with the specified pred_suffix.')

    # compute weights
    weights = _compute_model_weights(df, model_names, pred_suffix=pred_suffix)

    # per-level ensemble columns via weighted votes
    per_level_df = _weighted_majority_vote(df, model_names, weights, pred_suffix=pred_suffix, level_mapping=level_mapping, levels=levels)
    # attach per-level ensemble columns to df copy
    df_out = df.copy()
    for L in levels:
        colname = f'ensemble_{L}'
        df_out[colname] = per_level_df[L]

    # compute per-cell entropies, consensus
    ent, consensus = _ensemble_entropy(df[pred_cols])
    # per-cell per-level congruence df
    per_level_cong = compute_per_cell_congruence_df(df, model_pred_cols=pred_cols, level_mapping=level_mapping)

    # determine transfer policy for each level based on congruence thresholds
    if congruence_thresholds is None:
        congruence_thresholds = 0.7

    for L in levels:
        thr = congruence_thresholds[L] if isinstance(congruence_thresholds, dict) else float(congruence_thresholds)
        ensemble_col = f'ensemble_{L}'
        transfer_col = f'transferred_{L}'
        # transfer if per-level congruence >= threshold
        df_out[transfer_col] = df_out[ensemble_col].where(per_level_cong[f'congruence_{L}'] >= thr, None)
        # also add per-level congruence column
        df_out[f'congruence_{L}'] = per_level_cong[f'congruence_{L}']
        if f'entropy_{L}' in per_level_cong.columns:
            df_out[f'entropy_{L}'] = per_level_cong[f'entropy_{L}']

    # overall congruence
    df_out['congruence'] = per_level_cong['congruence']
    # add entropy and consensus arrays from _ensemble_entropy
    df_out['entropy'] = ent
    df_out['consensus'] = consensus

    # compute global metrics
    ari, nmi = _pairwise_agreement(df[pred_cols])
    global_metrics = pd.Series({
        'mean_ARI': ari,
        'mean_NMI': nmi,
        'mean_entropy': np.nanmean(ent),
        'entropy_std': np.nanstd(ent),
        'mean_congruence': np.nanmean(per_level_cong['congruence']),
        'n_cells': len(df),
        'n_models': len(model_names)
    })
    if embedding is not None:
        try:
            labels_int = LabelEncoder().fit_transform(df_out[f'ensemble_{levels[0]}'].fillna('unknown'))
            global_metrics['silhouette_score'] = silhouette_score(embedding, labels_int)
        except Exception:
            global_metrics['silhouette_score'] = np.nan

    # per-cell
    per_cell = pd.DataFrame({
        'entropy': ent,
        'consensus': consensus,
        'congruence': per_level_cong['congruence']
    }, index=df.index)
    for L in levels:
        per_cell[f'transferred_{L}'] = df_out[f'transferred_{L}'].values
        per_cell[f'congruence_{L}'] = df_out[f'congruence_{L}'].values

    # per-cluster
    per_cluster = []
    if cluster_col in df.columns:
        for cid, subset in df.groupby(cluster_col):
            preds_sub = subset[pred_cols]
            ent_sub, cons_sub = _ensemble_entropy(preds_sub)
            cong_sub = compute_per_cell_congruence_df(subset, model_pred_cols=pred_cols, level_mapping=level_mapping)['congruence']
            ari_sub, nmi_sub = _pairwise_agreement(preds_sub)
            per_cluster.append({
                cluster_col: cid,
                'mean_entropy': np.nanmean(ent_sub),
                'mean_congruence': np.nanmean(cong_sub),
                'mean_ARI': ari_sub,
                'mean_NMI': nmi_sub,
                'n_cells': len(subset)
            })
    per_cluster = pd.DataFrame(per_cluster)

    # per transferred label
    per_label = []
    for L in levels:
        for label, subset in df_out.groupby(f'transferred_{L}'):
            if pd.isna(label):
                continue
            preds_sub = subset[pred_cols]
            ent_sub, cons_sub = _ensemble_entropy(preds_sub)
            cong_sub = compute_per_cell_congruence_df(subset, model_pred_cols=pred_cols, level_mapping=level_mapping)['congruence']
            ari_sub, nmi_sub = _pairwise_agreement(preds_sub)
            per_label.append({
                'level': L,
                'transferred_label': label,
                'mean_entropy': np.nanmean(ent_sub),
                'mean_congruence': np.nanmean(cong_sub),
                'mean_ARI': ari_sub,
                'mean_NMI': nmi_sub,
                'n_cells': len(subset)
            })
    per_label = pd.DataFrame(per_label)

    return {
        'global_metrics': global_metrics,
        'model_weights': pd.Series(weights),
        'per_cell': per_cell,
        'per_cluster': per_cluster,
        'per_label': per_label,
        'df_out': df_out
    }

## Congruence Score Interpretation

**Congruence** is the primary metric for model agreement:

| Score Range | Interpretation | Action |
|---|---|---|
| 0.9–1.0 | Perfect/near-perfect agreement | ✓ High confidence, use ensemble label |
| 0.7–0.9 | Good agreement | ✓ Usable, but check outlier models |
| 0.5–0.7 | Moderate disagreement | ⚠ Review individual model predictions |
| 0.0–0.5 | Strong disagreement | ✗ Cell is ambiguous; flag for manual review |

**How to use**:
- Filter cells by congruence threshold: `df[df['congruence'] > 0.7]` for high-confidence labels
- Identify problematic cells: `df[df['congruence'] < 0.5]` for manual annotation
- Track model performance via weights: Models with high agreement get higher weight

## Data Import & Ensemble Evaluation

Load predictions from multiple annotation methods and evaluate their agreement using Shannon entropy-based congruence.

In [21]:
whole_mouse_hierarchy = pd.read_excel('../data/ann_integration/Glosario_CellTypist_Mouse_Whole_Brain_sin_acentos.xlsx', sheet_name='Glosario_CellTypist_Mouse_Whole')
whole_mouse_hierarchy = whole_mouse_hierarchy[['Etiqueta', 'Campo_mayor', 'Region', 'Grupo', 'Especifico']]
whole_mouse_hierarchy.set_index('Etiqueta', inplace=True)
whole_mouse_mapping = {}
for label, row in whole_mouse_hierarchy.iterrows():
    whole_mouse_mapping[label] = {
        "specific": row['Especifico'],
        "group": row['Grupo'],
        "region":  row['Region'],
        "super": row['Campo_mayor']
    }
print('Loaded hierarchy with', len(whole_mouse_mapping), 'entries')


Loaded hierarchy with 29 entries


In [None]:
df = pd.read_csv(path_to_predictions, index_col=0)
df


Unnamed: 0,Mouse_Whole_Brain_label,Mouse_Whole_Brain_majority_voting,Mouse_Whole_Brain_conf_score,Developing_Mouse_Hippocampus_label,Developing_Mouse_Hippocampus_voting,Developing_Mouse_Hippocampus_conf_score,Developing_Mouse_Brain_label,Developing_Mouse_Brain_voting,Developing_Mouse_Brain_conf_score,predicted_annotation
AAACCCAAGAAGCGAA-1-NMR1,054 STR Prox1 Lhx6 Gaba,264 PRNc Otp Gly-Gaba,0.178184,ExciteNeuron,ExciteNeuron,0.013145,Blood: Erythroid progenitor,Blood: Erythroid progenitor,0.999674,Amygdala excitatory
AAACCCAAGAGGCGGA-1-NMR4,062 STR D2 Gaba,062 STR D2 Gaba,0.123459,InhibNeuron,InhibNeuron,0.998999,unknown,,,Medium spiny neuron
AAACCCAAGAGGTTTA-1-NMR2,unknown,,,unknown,,,unknown,,,unknown
AAACCCAAGATTCGAA-1-NMR1,unknown,,,unknown,,,unknown,,,Bergmann glia
AAACCCAAGCATGAAT-1-NMR2,069 LSX Nkx2-1 Gaba,264 PRNc Otp Gly-Gaba,0.794896,InhibNeuron,InhibNeuron,0.998664,Blood: Erythrocyte,Blood: Erythroid progenitor,0.999711,Midbrain-derived inhibitory
...,...,...,...,...,...,...,...,...,...,...
TTTGTTGTCCTACGGG-1-NMR6,unknown,,,unknown,,,unknown,,,Splatter
TTTGTTGTCGTCACCT-1-NMR1,264 PRNc Otp Gly-Gaba,053 Sst Gaba,0.008710,InhibNeuron,InhibNeuron,0.999430,Blood: Erythroid progenitor,Blood: Erythrocyte,0.999863,MGE interneuron
TTTGTTGTCGTCCATC-1-NMR1,053 Sst Gaba,264 PRNc Otp Gly-Gaba,0.018367,ExciteNeuron,ExciteNeuron,0.022017,Blood: Erythroid progenitor,Blood: Erythroid progenitor,0.999999,unknown
TTTGTTGTCTATGCCC-1-NMR5,unknown,,,unknown,,,unknown,,,Splatter


In [25]:
df[['Mouse_Whole_Brain_label', 'Developing_Mouse_Hippocampus_label', 'Developing_Mouse_Brain_label', 'predicted_annotation']]



Unnamed: 0,Mouse_Whole_Brain_label,Developing_Mouse_Hippocampus_label,Developing_Mouse_Brain_label,predicted_annotation
AAACCCAAGAAGCGAA-1-NMR1,054 STR Prox1 Lhx6 Gaba,ExciteNeuron,Blood: Erythroid progenitor,Amygdala excitatory
AAACCCAAGAGGCGGA-1-NMR4,062 STR D2 Gaba,InhibNeuron,unknown,Medium spiny neuron
AAACCCAAGAGGTTTA-1-NMR2,unknown,unknown,unknown,unknown
AAACCCAAGATTCGAA-1-NMR1,unknown,unknown,unknown,Bergmann glia
AAACCCAAGCATGAAT-1-NMR2,069 LSX Nkx2-1 Gaba,InhibNeuron,Blood: Erythrocyte,Midbrain-derived inhibitory
...,...,...,...,...
TTTGTTGTCCTACGGG-1-NMR6,unknown,unknown,unknown,Splatter
TTTGTTGTCGTCACCT-1-NMR1,264 PRNc Otp Gly-Gaba,InhibNeuron,Blood: Erythroid progenitor,MGE interneuron
TTTGTTGTCGTCCATC-1-NMR1,053 Sst Gaba,ExciteNeuron,Blood: Erythroid progenitor,unknown
TTTGTTGTCTATGCCC-1-NMR5,unknown,unknown,unknown,Splatter


In [23]:
#import scanpy as sc
#adata = sc.read(path_to_adata)


In [24]:
results = evaluate_ensemble(df, embedding=None, cluster_col='cluster', level_mapping=whole_mouse_mapping, pred_suffix='_label', levels=['specific','group','region','super'], congruence_thresholds={'specific':0.85, 'group':0.75, 'region':0.7, 'super':0.7})


In [26]:
results

{'global_metrics': mean_ARI               0.376727
 mean_NMI               0.252141
 mean_entropy           1.035644
 entropy_std            0.754254
 mean_congruence        0.927398
 n_cells            57190.000000
 n_models               3.000000
 dtype: float64,
 'model_weights': Developing_Mouse_Brain          0.262018
 Developing_Mouse_Hippocampus    0.351302
 Mouse_Whole_Brain               0.386679
 dtype: float64,
 'per_cell':                           entropy  consensus  congruence transferred_specific  \
 AAACCCAAGAAGCGAA-1-NMR1  1.584963   0.333333    0.888889                 None   
 AAACCCAAGAGGCGGA-1-NMR4  1.584963   0.333333    0.888889                 None   
 AAACCCAAGAGGTTTA-1-NMR2  0.000000   1.000000    1.000000              unknown   
 AAACCCAAGATTCGAA-1-NMR1  0.000000   1.000000    1.000000              unknown   
 AAACCCAAGCATGAAT-1-NMR2  1.584963   0.333333    0.888889                 None   
 ...                           ...        ...         ...             

In [None]:
# GLOBAL METRICS & EXPORTS
print('='*60)
print('GLOBAL ENSEMBLE METRICS')
print('='*60)
print(results['global_metrics'])

print('
' + '='*60)
print('MODEL WEIGHTS (by confidence × agreement)')
print('='*60)
print(results['model_weights'])

print('
' + '='*60)
print('PER-CLUSTER METRICS')
print('='*60)
print(results['per_cluster'])
results['per_cluster'].to_csv('ensemble_per_cluster.csv', index=False)

print('
' + '='*60)
print('PER-LABEL METRICS (transferred labels per level)')
print('='*60)
print(results['per_label'].sort_values(['level','mean_congruence'], ascending=[True, False]))
results['per_label'].to_csv('ensemble_per_label.csv', index=False)

# Add per-cell metrics and transferred labels to adata.obs (if adata is available)
df_out = results['df_out']
if 'adata' in globals():
    for col in ['transferred_specific','transferred_group','transferred_region','transferred_super','congruence','congruence_specific','congruence_group','congruence_region','congruence_super','entropy_specific','entropy_group','entropy_region','entropy_super','congruence']:
        if col in df_out.columns:
            adata.obs[col] = df_out[col].astype(object).reindex(adata.obs_names)
    print('
✓ Added per-level transferred labels and congruence to adata.obs')
else:
    print('
No AnnData object found; saved output is available in results["df_out"].')

# Export per-cell summary
per_cell_summary = df_out[[c for c in df_out.columns if c.startswith('transferred_') or c.startswith('congruence_') or c.startswith('entropy_')]].copy()
per_cell_summary['cell_barcode'] = df_out.index
per_cell_summary.to_csv('ensemble_per_cell_transferred_levels.csv', index=False)
print('
✓ Saved: ensemble_per_cell_transferred_levels.csv')


In [None]:
## Visualization & Quality Assessment

Visualize model congruence, entropy, and ensemble confidence across cells and cell types.


In [None]:
import matplotlib.pyplot as plt
import seaborn as sns

# Set style
sns.set_style("whitegrid")
plt.rcParams['figure.figsize'] = (14, 10)

if 'adata' in globals() and 'congruence' in adata.obs.columns:
    # 1. Congruence distribution
    fig, axes = plt.subplots(2, 2, figsize=(14, 10))

    # Histogram of congruence scores
    axes[0, 0].hist(adata.obs['congruence'].dropna(), bins=50, edgecolor='black', alpha=0.7, color='steelblue')
    axes[0, 0].axvline(adata.obs['congruence'].mean(), color='red', linestyle='--', label=f'Mean: {adata.obs["congruence"].mean():.3f}')
    axes[0, 0].set_xlabel('Congruence Score')
    axes[0, 0].set_ylabel('Number of Cells')
    axes[0, 0].set_title('Distribution of Model Congruence')
    axes[0, 0].legend()

    # Entropy vs Consensus
    scatter = axes[0, 1].scatter(adata.obs['entropy'], adata.obs['consensus'], 
                                  c=adata.obs['congruence'], cmap='RdYlGn', s=30, alpha=0.6)
    axes[0, 1].set_xlabel('Shannon Entropy (bits)')
    axes[0, 1].set_ylabel('Max Consensus (probability)')
    axes[0, 1].set_title('Entropy vs Consensus (colored by congruence)')
    plt.colorbar(scatter, ax=axes[0, 1], label='Congruence')

    # Congruence by ensemble label (use specific level if present)
    label_col = 'ensemble_specific' if 'ensemble_specific' in adata.obs.columns else None
    if label_col is not None:
        per_label_cong = adata.obs.groupby(label_col)['congruence'].agg(['mean', 'std', 'count']).sort_values('mean', ascending=False)
        per_label_cong['mean'].plot(kind='barh', ax=axes[1, 0], xerr=per_label_cong['std'])
        axes[1, 0].set_xlabel('Mean Congruence ± Std')
        axes[1, 0].set_title('Congruence by Ensemble Label')
        axes[1, 0].set_xlim([0, 1])
    else:
        axes[1, 0].axis('off')

    # Cell counts per label
    if label_col is not None:
        label_counts = adata.obs[label_col].value_counts()
        axes[1, 1].barh(label_counts.index, label_counts.values, color='skyblue', edgecolor='black')
        axes[1, 1].set_xlabel('Number of Cells')
        axes[1, 1].set_title('Cell Counts per Ensemble Label')
    else:
        axes[1, 1].axis('off')

    plt.tight_layout()
    plt.savefig('ensemble_congruence_overview.png', dpi=300, bbox_inches='tight')
    plt.show()

    print('✓ Saved: ensemble_congruence_overview.png')
else:
    print('Skipping visualization: adata object with congruence not available.')


In [None]:
# UMAP/tSNE visualization of congruence
if 'adata' in globals() and hasattr(adata, 'obsm') and 'congruence' in adata.obs.columns:
    fig, axes = plt.subplots(1, 3, figsize=(18, 5))

    # Determine which reduction to use
    reduction = 'X_umap' if 'X_umap' in adata.obsm else 'X_tsne' if 'X_tsne' in adata.obsm else 'X_pca'
    if reduction not in adata.obsm:
        print(f"Warning: {reduction} not found. Available reductions:", list(adata.obsm.keys()))
        reduction = list(adata.obsm.keys())[0]

    coords = adata.obsm[reduction][:, :2]

    # Plot 1: Colored by congruence
    scatter1 = axes[0].scatter(coords[:, 0], coords[:, 1], c=adata.obs['congruence'], 
                               cmap='RdYlGn', s=20, alpha=0.6)
    axes[0].set_xlabel(f'{reduction.upper()} 1')
    axes[0].set_ylabel(f'{reduction.upper()} 2')
    axes[0].set_title('Cells colored by Congruence Score')
    plt.colorbar(scatter1, ax=axes[0], label='Congruence')

    # Plot 2: Colored by ensemble label (specific level if available)
    label_col = 'ensemble_specific' if 'ensemble_specific' in adata.obs.columns else None
    if label_col is not None:
        le = LabelEncoder().fit(adata.obs[label_col].astype(str))
        label_codes = le.transform(adata.obs[label_col].astype(str))
        scatter2 = axes[1].scatter(coords[:, 0], coords[:, 1], c=label_codes, 
                                   cmap='tab20', s=20, alpha=0.6)
        axes[1].set_xlabel(f'{reduction.upper()} 1')
        axes[1].set_ylabel(f'{reduction.upper()} 2')
        axes[1].set_title('Cells colored by Ensemble Label')
    else:
        axes[1].axis('off')

    # Plot 3: Colored by entropy (inverse of congruence)
    scatter3 = axes[2].scatter(coords[:, 0], coords[:, 1], c=adata.obs['entropy'], 
                               cmap='YlOrRd', s=20, alpha=0.6)
    axes[2].set_xlabel(f'{reduction.upper()} 1')
    axes[2].set_ylabel(f'{reduction.upper()} 2')
    axes[2].set_title('Cells colored by Shannon Entropy')
    plt.colorbar(scatter3, ax=axes[2], label='Entropy (bits)')

    plt.tight_layout()
    plt.savefig('ensemble_umap_congruence.png', dpi=300, bbox_inches='tight')
    plt.show()

    print('✓ Saved: ensemble_umap_congruence.png')
else:
    print('Skipping UMAP/TSNE visualization: adata with embeddings or congruence not available.')


In [None]:
## Quality Control & Filtering

# Define confidence threshold
CONGRUENCE_THRESHOLD = 0.7  # Adjust based on your analysis
HIGH_CONFIDENCE_THRESHOLD = 0.85
LOW_CONFIDENCE_THRESHOLD = 0.5

if 'adata' in globals() and 'congruence' in adata.obs.columns:
    # Categorize cells by confidence
    adata.obs['confidence_category'] = pd.cut(
        adata.obs['congruence'],
        bins=[0, LOW_CONFIDENCE_THRESHOLD, CONGRUENCE_THRESHOLD, HIGH_CONFIDENCE_THRESHOLD, 1.0],
        labels=['Ambiguous', 'Moderate', 'Good', 'High']
    )

    # Summary statistics
    print("="*60)
    print("CELL CLASSIFICATION BY CONGRUENCE")
    print("="*60)
    confidence_counts = adata.obs['confidence_category'].value_counts()
    for cat in ['High', 'Good', 'Moderate', 'Ambiguous']:
        count = confidence_counts.get(cat, 0)
        pct = 100 * count / len(adata)
        print(f"{cat:12s}: {count:6d} cells ({pct:5.1f}%)")

    # Export high-confidence cells
    high_conf_cells = adata[adata.obs['congruence'] >= CONGRUENCE_THRESHOLD]
    print(f"\n✓ High-confidence cells (≥{CONGRUENCE_THRESHOLD}): {len(high_conf_cells)} ({100*len(high_conf_cells)/len(adata):.1f}%)")

    # Export low-confidence cells for manual review
    low_conf_cells = adata[adata.obs['congruence'] < LOW_CONFIDENCE_THRESHOLD]
    print(f"⚠ Low-confidence cells (<{LOW_CONFIDENCE_THRESHOLD}): {len(low_conf_cells)} ({100*len(low_conf_cells)/len(adata):.1f}%)")

    # Save filtered datasets
    high_conf_cells.write_h5ad('adata_high_confidence.h5ad')
    low_conf_cells.write_h5ad('adata_low_confidence.h5ad')

    print("\n✓ Saved: adata_high_confidence.h5ad (for downstream analysis)")
    print("✓ Saved: adata_low_confidence.h5ad (for manual review)")

    # Export summary table for review
    # prefer specific level if present
    label_col = 'ensemble_specific' if 'ensemble_specific' in adata.obs.columns else None
    cols = ['congruence', 'entropy', 'consensus']
    if label_col is not None:
        cols = [label_col] + cols
    review_table = adata.obs[cols].copy()
    review_table['cell_barcode'] = adata.obs_names
    review_table.to_csv('ensemble_annotation_summary.csv', index=False)
    print("✓ Saved: ensemble_annotation_summary.csv")
else:
    print('Skipping QC/Filtering: adata object with congruence not available.')
