# Training a model with PyTorch Lightning

This tutorial demonstrates training a simple Logistic Regression model with [PyTorch Lightning], using the `tiledbsoma_ml.ExperimentDataset` class, on data from the [CZI CELLxGENE Census](https://chanzuckerberg.github.io/cellxgene-census/). It is intended for demonstration purposes only, not as an example of how to train a biologically useful model.

For more information on these APIs, please refer to the [`tutorial_pytorch` notebook](tutorial_pytorch.ipynb).

**Prerequisites**

Install [`tiledbsoma_ml`], [`scikit-learn`], and [`pytorch-lightning`]:

```bash
pip install tiledbsoma_ml scikit-learn pytorch-lightning
```

[PyTorch Lightning]: https://lightning.ai/docs/pytorch/stable/
[`tiledbsoma_ml`]: https://github.com/single-cell-data/TileDB-SOMA-ML/
[`scikit-learn`]: https://pypi.org/project/scikit-learn/
[`pytorch-lightning`]: https://pypi.org/project/pytorch-lightning/

[Papermill] parameters:

[Papermill]: https://papermill.readthedocs.io/

In [1]:
import os

tissue = "tongue"
n_epochs = 20
census_version = "2024-07-01"
batch_size = 128
learning_rate = 1e-5
progress_bar = not bool(os.environ.get('PAPERMILL'))  # Defaults to True, unless env var $PAPERMILL is set

## Initialize SOMA Experiment query as training data

In [2]:
from tiledbsoma import AxisQuery, Experiment, SOMATileDBContext
from sklearn.preprocessing import LabelEncoder

from tiledbsoma_ml import ExperimentDataset

CZI_Census_Homo_Sapiens_URL = f"s3://cellxgene-census-public-us-west-2/cell-census/{census_version}/soma/census_data/homo_sapiens/"

experiment = Experiment.open(
    CZI_Census_Homo_Sapiens_URL,
    context=SOMATileDBContext(tiledb_config={"vfs.s3.region": "us-west-2", "vfs.s3.no_sign_request": "true"}),
)
obs_value_filter = f"tissue_general == '{tissue}' and is_primary_data == True"

with experiment.axis_query(
    measurement_name="RNA", obs_query=AxisQuery(value_filter=obs_value_filter)
) as query:
    obs_df = query.obs(column_names=["cell_type"]).concat().to_pandas()
    cell_type_encoder = LabelEncoder().fit(obs_df["cell_type"].unique())

experiment_dataset = ExperimentDataset(
    query,
    layer_name="raw",
    obs_column_names=["cell_type"],
    batch_size=batch_size,
    shuffle=True,
)

## Define the Lightning module

In [3]:
import torch
import pytorch_lightning as pl

class LogisticRegressionLightning(pl.LightningModule):
    def __init__(self, input_dim, output_dim, cell_type_encoder, learning_rate=learning_rate):
        super(LogisticRegressionLightning, self).__init__()
        self.linear = torch.nn.Linear(input_dim, output_dim)
        self.cell_type_encoder = cell_type_encoder
        self.learning_rate = learning_rate
        self.loss_fn = torch.nn.CrossEntropyLoss()

    def forward(self, x):
        outputs = torch.sigmoid(self.linear(x))
        return outputs

    def training_step(self, batch, batch_idx):
        X_batch, y_batch = batch
        # X_batch = X_batch.float()
        X_batch = torch.from_numpy(X_batch).float().to(self.device)

        # Perform prediction
        outputs = self(X_batch)

        # Determine the predicted label
        probabilities = torch.nn.functional.softmax(outputs, 1)
        predictions = torch.argmax(probabilities, axis=1)

        # Compute loss
        y_batch = torch.from_numpy(
            self.cell_type_encoder.transform(y_batch["cell_type"])
        ).to(self.device)
        loss = self.loss_fn(outputs, y_batch.long())

        # Compute accuracy
        train_correct = (predictions == y_batch).sum().item()
        train_accuracy = train_correct / len(predictions)

        # Log loss and accuracy
        self.log("train_loss", loss, prog_bar=True)
        self.log("train_accuracy", train_accuracy, prog_bar=True)

        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
        return optimizer

## Train the model

In [4]:
from tiledbsoma_ml import experiment_dataloader

dataloader = experiment_dataloader(experiment_dataset)

# The size of the input dimension is the number of genes
input_dim = experiment_dataset.shape[1]

# The size of the output dimension is the number of distinct cell_type values
output_dim = len(cell_type_encoder.classes_)

# Initialize the PyTorch Lightning model
model = LogisticRegressionLightning(
    input_dim, output_dim, cell_type_encoder=cell_type_encoder
)

# Define the PyTorch Lightning Trainer
trainer = pl.Trainer(max_epochs=n_epochs, enable_progress_bar=progress_bar)

# set precision
torch.set_float32_matmul_precision("high")

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores


In [5]:
%%time
# Train the model
trainer.fit(model, train_dataloaders=dataloader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name    | Type             | Params | Mode 
-----------------------------------------------------
0 | linear  | Linear           | 726 K  | train
1 | loss_fn | CrossEntropyLoss | 0      | train
-----------------------------------------------------
726 K     Trainable params
0         Non-trainable params
726 K     Total params
2.905     Total estimated model params size (MB)
2         Modules in train mode
0         Modules in eval mode


CPU times: user 3min 30s, sys: 1min 25s, total: 4min 55s
Wall time: 2min 14s
