# Predicting Stephenson Status using MixMIL

# Install and set up

In [15]:
"""
conda create -n mixmil
conda init mixmil
pip install mixmil
"""
#Should install all required packages

'\nconda create -n mixmil\nconda init mixmil\npip install mixmil\n'

In [16]:
import json
from pathlib import Path

from mixmil.paths import DATA
import pandas as pd
from tqdm import tqdm
from sklearn.preprocessing import StandardScaler
import numpy as np
import anndata as ad
from mixmil import MixMIL
import torch
import matplotlib.pyplot as plt
import scipy.stats as st
import numpy as np
from sklearn.metrics import roc_auc_score, accuracy_score, f1_score, log_loss, classification_report
from scipy.special import softmax  # only used if your model returns logits
import scanpy as sc
from sklearn.model_selection import train_test_split

## Utility functions 

In [17]:
def to_device(el, device):
    """
    Move a nested structure of elements (dict, list, tuple, torch.Tensor, torch.nn.Module) to the specified device.

    Parameters:
    - el: Element or nested structure of elements to be moved to the device.
    - device (torch.device): The target device, such as 'cuda' for GPU or 'cpu' for CPU.

    Returns:
    - Transferred element(s) in the same structure: Elements moved to the specified device.
    """
    if isinstance(el, dict):
        return {k: to_device(v, device) for k, v in el.items()}
    elif isinstance(el, (list, tuple)):
        return [to_device(x, device) for x in el]
    elif isinstance(el, (torch.Tensor, torch.nn.Module)):
        return el.to(device)
    else:
        return el
    

def adata_from_mixmil(adata_full, split_adata, H, *, bag_col="bag_name", cols_to_keep=None):
    if cols_to_keep is None:
        cols_to_keep = ["bag", "label"]
    else:
        cols_to_keep = ["bag", "label"] + cols_to_keep

    # bag order from split adata
    bag_ids = split_adata.obs[bag_col].drop_duplicates().tolist()

    # per-bag aggregation from full adata
    g = adata_full.obs.groupby(bag_col, sort=False)
    mode_or_first = lambda s: (s.mode().iat[0] if not s.mode().empty else s.iloc[0])

    meta = pd.DataFrame({c: g[c].agg(mode_or_first) for c in cols_to_keep}).reindex(bag_ids)
    meta["n_cells"] = g.size().reindex(bag_ids)

    # H -> float32 numpy
    H = (H.detach().cpu().numpy() if torch.is_tensor(H) else np.asarray(H)).astype("float32")
    assert H.shape[0] == len(meta), f"H rows ({H.shape[0]}) must equal #bags ({len(meta)})"

    var = pd.DataFrame(index=[f"mixmil_q{i}" for i in range(H.shape[1])])
    return ad.AnnData(X=H, obs=meta, var=var)


def make_bag_splits(
    adata,
    layer="X_pca",
    *,
    sample_id_col="sample_id",  # bag column
    label_col="Status",
    covars=("Site",),           # tuple/list of covariate columns to copy over
    dtype="float32"
):
    """
    Creates bag-wise split from adata.obsm[layer], scales train->test,
    maps bag & label to ints, and copies selected covariates onto split objects.

    Returns:
        adata        (modified in-place: obs['bag'], ['label','label_name'], uns maps)
        train_adata
        test_adata
        bags_tr      (np.ndarray of bag names, in order)
        bags_te      (np.ndarray of bag names, in order)
        scaler       (StandardScaler fitted on train X)
    """
    X = np.asarray(adata.obsm[layer], dtype=dtype)
    obs = adata.obs.copy()

    # ---- bag ids ----
    bags = pd.Series(obs[sample_id_col].astype(str), index=obs.index, name="bag_name")
    uniq_bags = bags.unique() 
    bag_name_to_id = {name: i for i, name in enumerate(uniq_bags)}
    adata.obs["bag"] = bags.map(bag_name_to_id).astype("int32")
    adata.uns["bag_index_to_name"] = {int(i): name for name, i in bag_name_to_id.items()}

    # ---- labels (optional) ----
    label_name_to_id = {}
    if label_col is not None and label_col in obs.columns:
        cats = pd.Categorical(obs[label_col])
        label_name_to_id = {name: i for i, name in enumerate(cats.categories)}
        obs["label_name"] = cats.astype(object)
        obs["label"] = cats.codes.astype("int32")     # -1 for NaN
        adata.obs["label"] = obs["label"] # Make sure labels are mapped to original adata as well
        adata.obs["label_name"] = obs["label_name"]
        adata.uns["label_index_to_name"] = {int(i): str(name) for name, i in label_name_to_id.items()}

    # ---- stratify by bag-level label if possible ----
    def _mode_ignore_neg1(s):
        s = pd.Series(s)
        s = s[s != -1]
        if len(s) == 0:
            return -1
        m = s.mode()
        return int(m.iat[0]) if not m.empty else int(s.iloc[0])


    if "label" in obs.columns:
        bag_labels_id = obs.groupby(sample_id_col)["label"].agg(_mode_ignore_neg1).reindex(uniq_bags)
        strat = bag_labels_id if (bag_labels_id.nunique() > 1 and (bag_labels_id >= 0).any()) else None
    else:
        strat = None

    # As the model cannot be trained on missing values in task, put all samples with NA in the test set
    adata.obs.loc[adata.obs[task].isna(), "split"] = "val"

    bags_tr = adata.obs.loc[adata.obs["split"] == "train", sample_id_col].unique()
    bags_te = adata.obs.loc[adata.obs["split"] == "val", sample_id_col].unique()

    print("Bags in train:", len(bags_tr))
    print("Bags in test:", len(bags_te))
    print("Ratio of test:", len(bags_te) / (len(bags_tr) + len(bags_te)))

    # ---- masks ----
    mask_tr = bags.isin(bags_tr).to_numpy()
    mask_te = bags.isin(bags_te).to_numpy()

    # ---- scale ----
    scaler = StandardScaler()
    X_tr = scaler.fit_transform(X[mask_tr])
    X_te = scaler.transform(X[mask_te])

    # ---- build obs for outputs ----
    def _make_obs(_mask, split_name):
        o = obs.loc[_mask, [sample_id_col] + (["label", "label_name"] if "label" in obs.columns else [])].copy()
        o.rename(columns={sample_id_col: "bag_name"}, inplace=True)
        o["bag"] = o["bag_name"].map(bag_name_to_id).astype("int32")
        o["split"] = split_name
        cols = ["bag", "bag_name"] + (["label", "label_name"] if "label" in o.columns else []) + ["split"]
        return o[cols]

    train_obs = _make_obs(mask_tr, "train")
    test_obs  = _make_obs(mask_te,  "test")

    train_adata = ad.AnnData(X_tr, obs=train_obs)
    test_adata  = ad.AnnData(X_te, obs=test_obs)

    # ---- copy covariates by cell index ----
    if covars:
        for c in covars:
            train_adata.obs[c] = adata.obs.loc[train_adata.obs.index, c].values
            test_adata.obs[c]  = adata.obs.loc[test_adata.obs.index,  c].values

    return adata, train_adata, test_adata, bags_tr, bags_te, scaler


def get_supervised_task(dataset_config):
    """Extract the .obs column name for the supervised task from the dataset config"""
    task_dict = dataset_config["supervised_task"]
    if "classification" in task_dict:
        return task_dict["classification"][0]
    elif "regression" in task_dict:
        return task_dict["regression"][0]
    elif "ordinal_regression" in task_dict:
        return task_dict["ordinal_regression"][0]
    else:
        raise ValueError("No supervised task found in dataset config")

## Data Loading

In [19]:
adata

AnnData object with n_obs × n_vars = 1687127 × 3000
    obs: 'suspension_type', 'donor_id', 'is_primary_data', 'assay_ontology_term_id', 'cell_type_ontology_term_id', 'development_stage_ontology_term_id', 'disease_ontology_term_id', 'self_reported_ethnicity_ontology_term_id', 'tissue_ontology_term_id', 'organism_ontology_term_id', 'sex_ontology_term_id', "3'_or_5'", 'BMI', 'age_or_mean_of_age_range', 'age_range', 'anatomical_region_ccf_score', 'ann_coarse_for_GWAS_and_modeling', 'ann_finest_level', 'ann_level_1', 'ann_level_2', 'ann_level_3', 'ann_level_4', 'ann_level_5', 'cause_of_death', 'core_or_extension', 'dataset', 'fresh_or_frozen', 'log10_total_counts', 'lung_condition', 'mixed_ancestry', 'original_ann_level_1', 'original_ann_level_2', 'original_ann_level_3', 'original_ann_level_4', 'original_ann_level_5', 'original_ann_nonharmonized', 'reannotation_type', 'sample', 'scanvi_label', 'sequencing_platform', 'smoking_status', 'study', 'subject_type', 'tissue_coarse_unharmonized', '

In [20]:
metadata

Unnamed: 0_level_0,suspension_type,BMI,age_or_mean_of_age_range,age_range,anatomical_region_ccf_score,cause_of_death,core_or_extension,fresh_or_frozen,lung_condition,sequencing_platform,...,cell_type_lung neuroendocrine cell,cell_type_bronchial goblet cell,cell_type_pulmonary artery endothelial cell,cell_type_lung macrophage,cell_type_bronchus fibroblast of lung,cell_type_alveolar type 1 fibroblast cell,cell_type_alveolar adventitial fibroblast,cell_type_respiratory hillock cell,cell_type_unknown,split
donor_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
homosapiens_None_2023_None_sikkemalisa_002_d10_1101_2022_03_10_483747Donor_02,cell,,55.0,,0.97,intracranial hemorrhage,core,fresh,Healthy,Illumina HiSeq 4000,...,0.000000,0.000000,1.096033,0.991649,0.078288,0.652401,0.052192,0.000000,0.000000,val
homosapiens_None_2023_None_sikkemalisa_002_d10_1101_2022_03_10_483747cc05p,cell,,,,,,extension,,Healthy,,...,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.355661,39.063426,val
homosapiens_None_2023_None_sikkemalisa_002_d10_1101_2022_03_10_483747VUHD68,cell,23.5,41.0,,0.97,,core,fresh,Healthy,Illumina NovaSeq 6000 S1,...,0.000000,0.012739,3.006369,0.840764,0.025478,0.764331,0.101911,0.000000,0.000000,val
homosapiens_None_2023_None_sikkemalisa_002_d10_1101_2022_03_10_483747D062,nucleus,,0.0,,,,extension,,Healthy,,...,0.020610,0.000000,0.412201,1.277824,1.483924,13.025556,0.103050,0.000000,38.582028,val
homosapiens_None_2023_None_sikkemalisa_002_d10_1101_2022_03_10_483747donor 1,cell,24.6,75.0,,0.97,,core,fresh,Healthy (tumor adjacent),Illumina NovaSeq 6000,...,0.026585,0.000000,3.588994,0.837432,0.026585,0.717799,1.821082,0.199389,0.000000,val
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
homosapiens_None_2023_None_sikkemalisa_002_d10_1101_2022_03_10_483747244C,cell,,,,,,extension,,Healthy,,...,0.000000,0.000000,0.000000,0.000000,0.606061,0.606061,1.212121,0.000000,16.969697,train
homosapiens_None_2023_None_sikkemalisa_002_d10_1101_2022_03_10_483747LuT_580,cell,,76.0,,,,extension,,Healthy (tumor adjacent),,...,0.000000,0.000000,0.000000,0.250627,0.000000,0.000000,0.000000,0.000000,37.593985,train
homosapiens_None_2023_None_sikkemalisa_002_d10_1101_2022_03_10_483747VUHD105,cell,,58.0,,0.97,,core,fresh,Healthy,Illumina NovaSeq 6000 S4,...,0.000000,1.428571,7.857143,0.000000,0.000000,0.714286,3.571429,0.000000,0.000000,train
homosapiens_None_2023_None_sikkemalisa_002_d10_1101_2022_03_10_483747BAL036,cell,,,,,,extension,,Pneumonia non-COVID,,...,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,13.043478,train


In [None]:
datasets = ["hlca", "stephenson", "combat"]  # onek1k is not used as it has contunuous labels, which are not supported by MixMIL

for dataset in datasets:
    print("Working on dataset:", dataset)
    
    with open(f"../configs/datasets/{dataset}.json", "r") as f:
        dataset_config = json.load(f)

    task = get_supervised_task(dataset_config)
    batch_variable = dataset_config["batch_variable"]

    adata = sc.read_h5ad(dataset_config["data_path"])
    metadata = pd.read_csv(dataset_config["metadata_path"], index_col=0)
    metadata = metadata[~metadata.index.duplicated(keep="first")]

    print("Loaded adata:", adata.shape)
    print("Loaded metadata:", metadata.shape)

    # Add split column to adata
    adata.obs = adata.obs.merge(metadata[["split"]], left_on=dataset_config["sample_id_col"], right_index=True)

    ## Create training split and prepare data for training, important to create bags (i.e. indexing of samples for MIL)

    adata, train_adata, test_adata, bags_tr, bags_te, scaler = make_bag_splits(
        adata,
        sample_id_col=dataset_config["sample_id_col"],
        label_col=task,
        covars=[dataset_config["batch_variable"]],          
    )

    print("Train label distribution")
    print(train_adata.obs[["bag", "label"]].drop_duplicates().groupby("label").count())

    print("Test label distribution")
    print(test_adata.obs[["bag", "label"]].drop_duplicates().groupby("label").count())

    ## Initialize bags of observations as tensors. Bags = samples

    # prepare train data
    train_bags = train_adata.obs["bag"].unique().tolist()
    Xs = [torch.Tensor(train_adata[train_adata.obs["bag"] == bag].X) for bag in train_bags]
    Y = train_adata.obs[["bag", "label"]].drop_duplicates().set_index("bag")
    Y = Y[~Y.index.duplicated(keep="first")]
    Y = Y.loc[train_bags].values
    Y = torch.Tensor(Y)
    # prepare test data, following official train-test split
    test_bags = test_adata.obs["bag"].unique().tolist()
    test_Xs = [torch.Tensor(test_adata[test_adata.obs["bag"] == bag].X) for bag in test_bags]

    test_Y = torch.Tensor(test_adata.obs[["bag", "label"]].drop_duplicates().set_index("bag").loc[test_bags].values)

    ### Define covariate matrix (set to F = torch.ones((len(train_bags), 1)) if no covariates needed)

    # Use site as covariate in modeling
    F_tr_df = (train_adata.obs[["bag", batch_variable]]
            .drop_duplicates("bag")
            .set_index("bag"))
    F_tr_df = pd.get_dummies(F_tr_df, columns=[batch_variable])  # numeric, no dummy trap, not one hot encoded covariates
    F_tr = torch.tensor(F_tr_df.loc[train_bags].values, dtype=torch.float32)


    F_te_df = (test_adata.obs[["bag", batch_variable]]
            .drop_duplicates("bag")
            .set_index("bag"))
    F_te_df = pd.get_dummies(F_te_df, columns=[batch_variable])
    F_te_df = F_te_df.reindex(columns=F_tr_df.columns, fill_value=0)      # match train cols
    F_te = torch.tensor(F_te_df.loc[test_bags].to_numpy(), dtype=torch.float32)

    ## Training
    # Initialize MixMIL with a categorical GLMM and use it for prediction (changed likelihood to categorical due to multi-class problem!
    #  For binary use binomial with n_trials as described in paper) 

    # initialize model with mean model
    model = MixMIL.init_with_mean_model(Xs, F_tr, Y, likelihood="categorical", n_trials=None)

    logits = model.predict(test_Xs)               # (n_bags, P)
    proba  = torch.softmax(logits, dim=1)         # multiclass
    y_pred = proba.argmax(1)

    not_na = (test_Y != -1).flatten().cpu().numpy()

    device = "cuda" if torch.cuda.is_available() else "cpu"
    model, Xs, F, Y, test_Xs, test_Y = to_device((model, Xs, F_tr, Y, test_Xs, test_Y), device)
    model.train(Xs, F, Y, n_epochs=1000)
    model.to(device)
    test_Xs = [x.to(device) for x in test_Xs]


    with torch.no_grad():
        # Prefer to keep logits as a Tensor
        logits = model.predict(test_Xs)                  
        proba = torch.softmax(logits, dim=1).cpu().numpy()

    y_pred = proba.argmax(axis=1)
    y_true = test_Y.cpu().numpy().astype(int)

    ## Generate embeddings
    ### H_q refers to embedding of attention-pooled instance averages across classes by class probas (if needed this can be adjusted in code, see comments!)
    # Alternatively, use logits or probas to generate embeddings. These embeddings H_q here are the middle part in Fig. 1 b in the paper.

    @torch.inference_mode()
    def mixmil_bag_embed_q(model, Xs, F=None):
        """
        Returns a Q-dim bag embedding per bag (Q being the input dim.).
        Xs: list of [n_i, Q] tensors or a single [N, I, Q] tensor
        F:  [N, K] bag covariates (pass if you used them; else None)
        """
        # 1) Class probabilities per bag
        u = model.predict(Xs)                    # (N, P) bag effects
        if F is not None:
            logits = u + F @ model.alpha         # add fixed effects if used during train
        else:
            logits = u

        # IF NOT WANTED WEIGHTED, REMOVE COMMENTS HERE AND USE UNIFORMLY DISTR. PROBAS
        #P = logits.shape[1]
        #proba = torch.full_like(logits, 1.0 / P)

        proba = torch.softmax(logits, dim=1)     # (N, P)

        # 2) Attention weights per instance & class
        w, _ = model.get_weights(Xs, ravel=False)

        # 3) Attention-pooled feature means per class: H ∈ (N, P, Q)
        if torch.is_tensor(Xs):                  
            H = torch.einsum("nip,niq->npq", w, Xs)
        else:                                    # list-of-bags
            H = torch.stack([torch.einsum("ip,iq->pq", wb, Xb) for wb, Xb in zip(w, Xs)])

        # 4) Probability-weight across classes → single Q-dim embedding
        H_bar = (proba.unsqueeze(-1) * H).sum(dim=1)   # (N, Q)
        return H_bar, logits

    # Generate embeddings
    F_tr, F_te = F_tr.to(device), F_te.to(device)
    H_q_train, logits_train = mixmil_bag_embed_q(model, Xs, F_tr) # Set F to None if not used/ wanted
    H_q_test, logits_test = mixmil_bag_embed_q(model, test_Xs, F_te)

    cols_to_keep = [batch_variable, task, dataset_config["sample_id_col"]]

    adata_bag_train = adata_from_mixmil(adata, train_adata, H_q_train, bag_col="bag", cols_to_keep=cols_to_keep)
    adata_bag_test  = adata_from_mixmil(adata, test_adata, H_q_test,  bag_col="bag", cols_to_keep=cols_to_keep)
    adata_mix_mil = ad.concat(
        {"train": adata_bag_train, "test": adata_bag_test},
        label="split",         
        join="outer"
    )

    print("Saving the outputs")
    sample_reps_df = pd.DataFrame(adata_mix_mil.X, index=adata_mix_mil.obs[dataset_config["sample_id_col"]])

    results_dir = Path(f"../results/{dataset_config['dataset_name']}/mixmil/")
    results_dir.mkdir(parents=True, exist_ok=True)
    sample_reps_df.to_csv(results_dir / "sample_reps.csv")

    val_set_predictions = pd.DataFrame(
        y_pred,
        index=adata_bag_test.obs[dataset_config["sample_id_col"]],
        columns=[f"predicted_{task}"]
    )
    val_set_predictions.to_csv(results_dir / "val_set_predictions.csv")