# Training a PyTorch Model

This tutorial shows how to train a Logistic Regression model in PyTorch using the `tiledbsoma.ml.ExperimentAxisQueryDataPipe` class, and the [CZI CELLxGENE Census](https://chanzuckerberg.github.io/cellxgene-census/) dataset. This is intended only to demonstrate the use of the `ExperimentAxisQueryDataPipe`, and not as an example of how to train a biologically useful model.

This tutorial assumes a basic familiarity with PyTorch and the Census API.

**Prerequisites**

Install `tiledbsoma` with the optional `ml` dependencies, for example:

> pip install tiledbsoma[ml]


**Contents**

* [Create a DataLoader](#Create-a-DataLoader)
* [Define the model](#Define-the-model)
* [Train the model](#Train-the-model)
* [Make predictions with the model](#Make-predictions-with-the-model)


## Create an ExperimentAxisQueryDataPipe

To train a model in PyTorch using this `census` data object, first instantiate open a SOMA Experiment, and create a `ExperimentAxisQueryDataPipe`. This example utilizes a recent CZI Census release, access directly from S3.

We are also going to create an encoder for the `obs` labels at the same time, and train it on the `cell_type` labels. In this example we use the LabelEncoder from `scikit-learn`.

In [1]:
from sklearn.preprocessing import LabelEncoder

import tiledbsoma as soma
import tiledbsoma.ml as soma_ml

CZI_Census_Homo_Sapiens_URL = (
    "s3://cellxgene-census-public-us-west-2/cell-census/2024-07-01/soma/census_data/homo_sapiens/"
)

experiment = soma.open(
    CZI_Census_Homo_Sapiens_URL, context=soma.SOMATileDBContext(tiledb_config={"vfs.s3.region": "us-west-2"})
)
obs_value_filter = "tissue_general == 'tongue' and is_primary_data == True"
obs_query=soma.AxisQuery(value_filter=obs_value_filter)

experiment_dataset = soma_ml.ExperimentAxisQueryDataPipe(
    experiment,
    measurement_name="RNA",
    X_name="raw",
    obs_query=obs_query,
    obs_column_names=["cell_type"],
    batch_size=128,
    shuffle=True,
)

with experiment.axis_query(measurement_name="RNA", obs_query=obs_query) as query:
    obs_df = query.obs(column_names=['cell_type']).concat().to_pandas()
    cell_type_encoder = LabelEncoder().fit(obs_df['cell_type'].unique())

################################################################################
The 'datapipes', 'dataloader2' modules are deprecated and will be removed in a
future torchdata release! Please see https://github.com/pytorch/data/issues/1196
to learn more and leave feedback.
################################################################################



### `ExperimentAxisQueryDataPipe` class explained

This class provides an implementation of PyTorch's `torchdata` [IterDataPipe interface](https://pytorch.org/data/main/torchdata.datapipes.iter.html), which defines a common mechanism for wrapping and accessing training data from any underlying source. The `ExperimentAxisQueryDataPipe` class encapsulates the details of querying and retrieving Census data from a single SOMA `Experiment` and returning it to the caller a NumPy `ndarray` and a Pandas `DataFrame`. Most importantly, it retrieves the data lazily from the Census in batches, avoiding having to load the entire training dataset into memory at once.

### `ExperimentAxisQueryDataPipe` parameters explained

The constructor only requires a single parameter, `experiment`, which is a `soma.Experiment` containing the data of the organism to be used for training.

To retrieve a subset of the Experiment's data, along either the `obs` or `var` axes, you may specify query filters via the `obs_query` and `var_query` parameters, which are both `soma.AxisQuery` objects.

The values for the prediction label(s) that you intend to use for training are specified via the `obs_column_names` array.

The `batch_size` allows you to specify the number of obs rows (cells) to be returned by each return PyTorch tensor. You may exclude this parameter if you want single rows (`batch_size=1`).

The `shuffle` flag allows you to randomize the ordering of the training data for each training epoch. Note:
* You should use this flag instead of the `DataLoader` `shuffle` flag, primarily for performance reasons.
* PyTorch's TorchData library provides a [Shuffler](https://pytorch.org/data/main/generated/torchdata.datapipes.iter.Shuffler.html) `DataPipe`, which is alternate mechanism one can use to perform shuffling of an `IterableDataset`. However, the `Shuffler` will not "globally" randomize the training data, as it only "locally" randomizes the ordering of the training data within fixed-size "windows". Due to the layout of Census data, a given "window" of Census data may be highly homogeneous in terms of its `obs` axis attribute values, and so this shuffling strategy may not provide sufficient randomization for certain types of models.

You can inspect the shape of the full dataset, without causing the full dataset to be loaded:

In [2]:
experiment_dataset.shape

(15020, 60530)

## Split the dataset

You may split the overall dataset into the typical training, validation, and test sets by using the PyTorch [RandomSplitter](https://pytorch.org/data/main/generated/torchdata.datapipes.iter.RandomSplitter.html#torchdata.datapipes.iter.RandomSplitter) `DataPipe`. Using PyTorch's functional form for chaining `DataPipe`s, this is done as follows:

In [3]:
train_dataset, test_dataset = experiment_dataset.random_split(weights={"train": 0.8, "test": 0.2}, seed=1)

## Create the DataLoader

With the full set of DataPipe operations chained together, we can now instantiate a PyTorch [DataLoader](https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader) on the training data. 

In [4]:
experiment_dataloader = soma_ml.experiment_dataloader(train_dataset)

Alternately, you can instantiate a `DataLoader` object directly via its constructor. However, many of the parameters are not usable with iterable-style Datasets, which is the case for `ExperimentAxisQueryDataPipe`. In particular, the `shuffle`, `batch_size`, `sampler`, `batch_sampler`, `collate_fn` parameters should not be specified. Using `experiment_dataloader` helps enforce correct usage.

## Define the model

With the training data retrieval code now in place, we can move on to defining a simple logistic regression model, using PyTorch's `torch.nn.Linear` class:

In [5]:
import torch


class LogisticRegression(torch.nn.Module):
    def __init__(self, input_dim, output_dim):
        super(LogisticRegression, self).__init__()  # noqa: UP008
        self.linear = torch.nn.Linear(input_dim, output_dim)

    def forward(self, x):
        outputs = torch.sigmoid(self.linear(x))
        return outputs

Next, we define a function to train the model for a single epoch:

In [6]:
def train_epoch(model, train_dataloader, loss_fn, optimizer, device):
    model.train()
    train_loss = 0
    train_correct = 0
    train_total = 0

    for X_batch, y_batch in train_dataloader:
        optimizer.zero_grad()

        X_batch = torch.from_numpy(X_batch).float().to(device)

        # Perform prediction
        outputs = model(X_batch)

        # Determine the predicted label
        probabilities = torch.nn.functional.softmax(outputs, 1)
        predictions = torch.argmax(probabilities, axis=1)

        # Compute the loss and perform back propagation

        y_batch = torch.from_numpy(cell_type_encoder.transform(y_batch['cell_type'])).to(device)

        train_correct += (predictions == y_batch).sum().item()
        train_total += len(predictions)

        loss = loss_fn(outputs, y_batch.long())
        train_loss += loss.item()
        loss.backward()
        optimizer.step()

    train_loss /= train_total
    train_accuracy = train_correct / train_total
    return train_loss, train_accuracy

Note the line, `X_batch, y_batch = batch`. Since the `train_dataloader` was configured with `batch_size=16`, these variables will hold tensors of rank 2. The `X_batch` tensor will appear, for example, as:

```
tensor([[0., 0., 0.,  ..., 1., 0., 0.],
        [0., 0., 2.,  ..., 0., 3., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 1., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 8.]])
      
```

For `batch_size=1`, the tensors will be of rank 1. The `X_batch` tensor will appear, for example, as:

```
tensor([0., 0., 0.,  ..., 1., 0., 0.])
```
    
For `y_batch`, this will contain the user-specified `obs` `cell_type` training labels. By default, these are encoded using a LabelEncoder and it will be a matrix where each column represents the encoded values of each column specified in `obs_column_names` when creating the datapipe (in this case, only the cell type). It will look like this:

```
tensor([1, 1, 3, ..., 2, 1, 4])

```
Note that cell type values are integer-encoded values, which can be decoded using `experiment_dataset.encoders` (more on this below).


## Train the model

Finally, we are ready to train the model. Here we instantiate the model, a loss function, and an optimization method and then iterate through the desired number of training epochs. Note how the `train_dataloader` is passed into `train_epoch`, where for each epoch it will provide a new iterator through the training dataset.

In [7]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

# The size of the input dimension is the number of genes
input_dim = experiment_dataset.shape[1]

# The size of the output dimension is the number of distinct cell_type values
output_dim = len(cell_type_encoder.classes_)

model = LogisticRegression(input_dim, output_dim).to(device)
loss_fn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-05)

for epoch in range(10):
    train_loss, train_accuracy = train_epoch(model, experiment_dataloader, loss_fn, optimizer, device)
    print(f"Epoch {epoch + 1}: Train Loss: {train_loss:.7f} Accuracy {train_accuracy:.4f}")

Epoch 1: Train Loss: 0.0162193 Accuracy 0.3910
Epoch 2: Train Loss: 0.0148544 Accuracy 0.5580
Epoch 3: Train Loss: 0.0143674 Accuracy 0.5863
Epoch 4: Train Loss: 0.0140639 Accuracy 0.7015
Epoch 5: Train Loss: 0.0138526 Accuracy 0.7417
Epoch 6: Train Loss: 0.0137076 Accuracy 0.7998
Epoch 7: Train Loss: 0.0136148 Accuracy 0.8783
Epoch 8: Train Loss: 0.0135245 Accuracy 0.8984
Epoch 9: Train Loss: 0.0134525 Accuracy 0.9051
Epoch 10: Train Loss: 0.0133846 Accuracy 0.9122


## Make predictions with the model

To make predictions with the model, we first create a new `DataLoader` using the `test_dataset`, which provides the "test" split of the original dataset. For this example, we will only make predictions on a single batch of data from the test split.

In [8]:
experiment_dataloader = soma_ml.experiment_dataloader(test_dataset)
X_batch, y_batch = next(iter(experiment_dataloader))
X_batch = torch.from_numpy(X_batch)
y_batch = torch.from_numpy(cell_type_encoder.transform(y_batch['cell_type']))

Next, we invoke the model on the `X_batch` input data and extract the predictions:

In [9]:
model.eval()

model.to(device)
outputs = model(X_batch.to(device))

probabilities = torch.nn.functional.softmax(outputs, 1)
predictions = torch.argmax(probabilities, axis=1)

display(predictions)

tensor([ 1,  1,  1,  8,  1,  7,  1,  5,  1,  1,  8,  1,  8,  8,  1,  8,  7,  5,
         1,  1,  1,  5,  1,  7,  7,  5,  1,  5,  7,  5,  8,  1,  5,  1,  7,  1,
         5,  7,  1,  1,  7,  8,  5,  8,  7,  8,  1,  7,  1,  8,  5,  1,  1,  5,
         1,  1,  7,  7,  1,  1,  1,  7,  1,  1,  1,  1,  1,  7,  8,  1,  7,  8,
         8,  1,  5,  1,  6,  1,  5,  1,  7,  1,  1,  1,  1,  7,  1,  7,  1,  1,
         1,  1,  1,  5,  1,  5, 11,  1,  1,  5,  5,  1,  1,  1,  1,  1,  1,  5,
         5,  1,  8,  8,  1,  9,  1,  1,  8,  8,  5,  5,  5,  5,  1,  7,  7,  1,
         1,  1])

The predictions are returned as the encoded values of `cell_type` label. To recover the original cell type labels as strings, we decode using the same `LabelEncoder` used for training.

At inference time, if the model inputs are not obtained via an `ExperimentAxisQueryDataPipe`, one could pickle the encoder at training time and save it along with the model. Then, at inference time it can be unpickled and used as shown below.

In [10]:
predicted_cell_types = cell_type_encoder.inverse_transform(predictions.cpu())

display(predicted_cell_types)

array(['basal cell', 'basal cell', 'basal cell', 'leukocyte',
       'basal cell', 'keratinocyte', 'basal cell', 'epithelial cell',
       'basal cell', 'basal cell', 'leukocyte', 'basal cell', 'leukocyte',
       'leukocyte', 'basal cell', 'leukocyte', 'keratinocyte',
       'epithelial cell', 'basal cell', 'basal cell', 'basal cell',
       'epithelial cell', 'basal cell', 'keratinocyte', 'keratinocyte',
       'epithelial cell', 'basal cell', 'epithelial cell', 'keratinocyte',
       'epithelial cell', 'leukocyte', 'basal cell', 'epithelial cell',
       'basal cell', 'keratinocyte', 'basal cell', 'epithelial cell',
       'keratinocyte', 'basal cell', 'basal cell', 'keratinocyte',
       'leukocyte', 'epithelial cell', 'leukocyte', 'keratinocyte',
       'leukocyte', 'basal cell', 'keratinocyte', 'basal cell',
       'leukocyte', 'epithelial cell', 'basal cell', 'basal cell',
       'epithelial cell', 'basal cell', 'basal cell', 'keratinocyte',
       'keratinocyte', 'basal cell', 

Finally, we create a Pandas DataFrame to examine the predictions:

In [11]:
import pandas as pd

display(
    pd.DataFrame(
        {
            "actual cell type": cell_type_encoder.inverse_transform(y_batch.ravel().numpy()),
            "predicted cell type": predicted_cell_types,
        }
    )
)

Unnamed: 0,actual cell type,predicted cell type
0,basal cell,basal cell
1,basal cell,basal cell
2,basal cell,basal cell
3,leukocyte,leukocyte
4,basal cell,basal cell
...,...,...
123,keratinocyte,keratinocyte
124,keratinocyte,keratinocyte
125,basal cell,basal cell
126,basal cell,basal cell
