# Tutorial: Using the `MILDataset` and PyTorch `DataLoader` in `cellink`

This tutorial shows how to use the `MILDataset` together with PyTorch's `DataLoader` to prepare multimodal input for model training in the `cellink` framework. This pipeline is designed to support **donor-level learning** using single-cell transcriptomic measurements.

We will use the `OneK1K` dataset and demonstrate how to:

- Filter for CD8 Naive T cells,
- Package the data using `DonorData`,
- Wrap it into a `MILDataset`,
- Load it with a PyTorch `DataLoader`,
- Train a `DonorMILModel`.


## Setup and Configuration

We start by importing relevant modules and creating local directories to store input/output files. This ensures that any annotation tools have a consistent file structure to work with.

In [1]:
from cellink.datasets import get_onek1k
from cellink.ml.dataset import MILDataset
from torch.utils.data import DataLoader
from cellink.ml.dataset import mil_collate_fn
from cellink.ml.model import DonorMILModel
import pytorch_lightning as pl
from cellink._core import DAnn, GAnn
import numpy as np

## Load Genotype Data (`gdata`)

We load the example dataset using `get_onek1k()`. The `DonorData` object contains a `.G` attribute (`gdata`) that stores genotype information at the variant level. These variants will be the target of our annotations. We filter for chromosome 22 and CD8 Naive for fast execution of the notebook.

In [2]:
dd = get_onek1k(config_path='../../src/cellink/resources/config/onek1k.yaml')
dd

/Users/larnoldt/cellink_sample_data/onek1k/onek1k_cellxgene.h5ad already exists. Verifying checksum.
/Users/larnoldt/cellink_sample_data/onek1k/OneK1K.noGP.vcf.gz already exists. Verifying checksum.
/Users/larnoldt/cellink_sample_data/onek1k/OneK1K.noGP.vcf.gz.csi already exists. Verifying checksum.
/Users/larnoldt/cellink_sample_data/onek1k/gene_counts_Ensembl_105_phenotype_metadata.tsv.gz already exists. Verifying checksum.


  return self.values.astype(_dtype_obj)




In [3]:
def _get_ensembl_gene_id_start_end_chr():
    from pybiomart import Server
    server = Server(host='http://www.ensembl.org')
    dataset = (server.marts['ENSEMBL_MART_ENSEMBL'].datasets['hsapiens_gene_ensembl'])
    ensembl_gene_id_start_end_chr = dataset.query(attributes=['ensembl_gene_id', 'start_position', 'end_position', 'chromosome_name'])
    ensembl_gene_id_start_end_chr = ensembl_gene_id_start_end_chr.set_index("Gene stable ID")
    ensembl_gene_id_start_end_chr = ensembl_gene_id_start_end_chr.rename(columns={
        "Gene start (bp)": GAnn.start,
        "Gene end (bp)": GAnn.end,
        "Chromosome/scaffold name": GAnn.chrom,
    })
    return ensembl_gene_id_start_end_chr

ensembl_gene_id_start_end_chr = _get_ensembl_gene_id_start_end_chr()
ensembl_gene_id_start_end_chr

dd.C.var = dd.C.var.join(ensembl_gene_id_start_end_chr)

In [4]:
chrom = 22
dd = dd.sel(G_var=dd.G.var.chrom == str(chrom), C_var=dd.C.var.chrom == str(chrom)).copy()
dd



In [5]:
cell_type = "CD8 Naive"
celltype_key = "predicted.celltype.l2"
dd = dd[..., dd.C.obs[celltype_key] == cell_type, :].copy()
dd



## Wrap the Data with `MILDataset`
We use the `MILDataset` to create a wrapper that returns bags of cells for each donor. This enables multiple instance learning (MIL) with donor labels. The `MILDataset` automatically packages all labels, categorical and continuous covariates and data matrices, when available. You may adjust the keys. Please note, that for demosntration purposes we are randomly generating labels now.

In [6]:
dd.G.obs["donor_id"] = dd.G.obs.index
dd.G.obs["donor_labels"] = np.random.randint(2, size=len(dd.G.obs))
dd.C.obs["pool_number"] = dd.C.obs["pool_number"].astype("float")

In [7]:
dataset = MILDataset(
    dd,
    donor_labels_key="donor_labels",
    cell_batch_key="pool_number",  
    #split_donors=["OneK1K_1", "OneK1K_10", "OneK1K_1000"],
    split_indices=list(range(10))
)

## Create a PyTorch `DataLoader`
We now build a PyTorch `DataLoader` that will feed batches of donors (each with a variable number of cells) into the model. We use the `mil_collate_fn` to handle the collation of variable-length inputs.

In [8]:
dataloader = DataLoader(
    dataset,
    batch_size=2,
    shuffle=True,
    collate_fn=mil_collate_fn
)

## Initialize and Train a Model
We now create a `DonorMILModel`, specifying the dimensionality of the donor and cell input features. We then use PyTorch Lightning to train for a single epoch.

In [9]:
model = DonorMILModel(
    n_input_donor=dd.G.n_vars,
    n_input_cell=dd.C.n_vars
)

In [10]:
trainer = pl.Trainer(max_epochs=2)
trainer.fit(model, dataloader)

INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (mps), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
/opt/miniconda3/envs/single_cell_base3/lib/python3.12/site-packages/pytorch_lightning/trainer/configuration_validator.py:70: You defined a `validation_step` but have no `val_dataloader`. Skipping val loop.
INFO:pytorch_lightning.callbacks.model_summary:
  | Name          | Type       | Params | Mode 
-----------------------------------------------------
0 | donor_encoder | Linear     | 17.5 M | train
1 | cell_encoder  | Sequential | 112 K  | train
2 | attention     | Sequential | 8.3 K  | train
3 | classifier    | Sequential | 257    | train
-----------------------------------------------------
17.6 M    Trainable params
0         Non-trainable params
17.6 M    Total params
70.515    Total estimated model params size (MB)
11        Modules in tra

Training: |          | 0/? [00:00<?, ?it/s]

  loss = F.mse_loss(y_hat, y)
INFO:pytorch_lightning.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=2` reached.
