# Training a model with PyTorch Lightning

This tutorial demonstrates training a simple Logistic Regression model with [PyTorch Lightning], using the `tiledbsoma_ml.ExperimentDataset` class, on data from the [CZI CELLxGENE Census](https://chanzuckerberg.github.io/cellxgene-census/). It is intended for demonstration purposes only, not as an example of how to train a biologically useful model.

For more information on these APIs, please refer to the [`tutorial_pytorch` notebook](tutorial_pytorch.ipynb).

**Prerequisites**

Install [`tiledbsoma_ml`], [`scikit-learn`], and [`pytorch-lightning`]:

```bash
pip install tiledbsoma_ml scikit-learn pytorch-lightning
```

[PyTorch Lightning]: https://lightning.ai/docs/pytorch/stable/
[`tiledbsoma_ml`]: https://github.com/single-cell-data/TileDB-SOMA-ML/
[`scikit-learn`]: https://pypi.org/project/scikit-learn/
[`pytorch-lightning`]: https://pypi.org/project/pytorch-lightning/

## Initialize SOMA Experiment query as training data

In [None]:
import pytorch_lightning as pl
import tiledbsoma as soma
import torch
from sklearn.preprocessing import LabelEncoder

import tiledbsoma_ml as soma_ml

CZI_Census_Homo_Sapiens_URL = "s3://cellxgene-census-public-us-west-2/cell-census/2024-07-01/soma/census_data/homo_sapiens/"

experiment = soma.open(
    CZI_Census_Homo_Sapiens_URL,
    context=soma.SOMATileDBContext(tiledb_config={"vfs.s3.region": "us-west-2", "vfs.s3.no_sign_request": "true"}),
)
obs_value_filter = "tissue_general == 'tongue' and is_primary_data == True"

with experiment.axis_query(
    measurement_name="RNA", obs_query=soma.AxisQuery(value_filter=obs_value_filter)
) as query:
    obs_df = query.obs(column_names=["cell_type"]).concat().to_pandas()
    cell_type_encoder = LabelEncoder().fit(obs_df["cell_type"].unique())

    experiment_dataset = soma_ml.ExperimentDataset(
        query,
        layer_name="raw",
        obs_column_names=["cell_type"],
        batch_size=128,
        shuffle=True,
    )

## Define the Lightning module

In [None]:
class LogisticRegressionLightning(pl.LightningModule):
    def __init__(self, input_dim, output_dim, cell_type_encoder, learning_rate=1e-5):
        super(LogisticRegressionLightning, self).__init__()
        self.linear = torch.nn.Linear(input_dim, output_dim)
        self.cell_type_encoder = cell_type_encoder
        self.learning_rate = learning_rate
        self.loss_fn = torch.nn.CrossEntropyLoss()

    def forward(self, x):
        outputs = torch.sigmoid(self.linear(x))
        return outputs

    def training_step(self, batch, batch_idx):
        X_batch, y_batch = batch
        # X_batch = X_batch.float()
        X_batch = torch.from_numpy(X_batch).float().to(self.device)

        # Perform prediction
        outputs = self(X_batch)

        # Determine the predicted label
        probabilities = torch.nn.functional.softmax(outputs, 1)
        predictions = torch.argmax(probabilities, axis=1)

        # Compute loss
        y_batch = torch.from_numpy(
            self.cell_type_encoder.transform(y_batch["cell_type"])
        ).to(self.device)
        loss = self.loss_fn(outputs, y_batch.long())

        # Compute accuracy
        train_correct = (predictions == y_batch).sum().item()
        train_accuracy = train_correct / len(predictions)

        # Log loss and accuracy
        self.log("train_loss", loss, prog_bar=True)
        self.log("train_accuracy", train_accuracy, prog_bar=True)

        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
        return optimizer

## Train the model

In [None]:
dataloader = soma_ml.experiment_dataloader(experiment_dataset)

# The size of the input dimension is the number of genes
input_dim = experiment_dataset.shape[1]

# The size of the output dimension is the number of distinct cell_type values
output_dim = len(cell_type_encoder.classes_)

# Initialize the PyTorch Lightning model
model = LogisticRegressionLightning(
    input_dim, output_dim, cell_type_encoder=cell_type_encoder
)

# Define the PyTorch Lightning Trainer
trainer = pl.Trainer(max_epochs=20)

# set precision
torch.set_float32_matmul_precision("high")

In [None]:
%%time
# Train the model
trainer.fit(model, train_dataloaders=dataloader)