# Training a PyTorch Model

This tutorial shows how to train a simple Logistic Regression model in PyTorch using the `tiledbsoma.ml.ExperimentAxisQueryIterDataPipe` class, and the [CZI CELLxGENE Census](https://chanzuckerberg.github.io/cellxgene-census/) dataset. This is intended only to demonstrate the use of the `ExperimentAxisQueryIterDataPipe`, and not as an example of how to train a biologically useful model.

This tutorial assumes a basic familiarity with PyTorch and the Census API.

**Prerequisites**

Install `tiledbsoma_ml` (and `scikit-learn` for convenience). For example:

```bash
pip install tiledbsoma_ml scikit-learn
```

**Contents**

* [Create an ExperimentAxisQueryIterDataPipe](#data-pipe)
* [Split the dataset](#split)
* [Create the DataLoader](#data-loader)
* [Define the model](#model)
* [Train the model](#train)
* [Make predictions with the model](#predict)

## Create an ExperimentAxisQueryIterDataPipe <a id="data-pipe"></a>

To train a PyTorch model on a SOMA [Experiment]:
1. Open the Experiment.
2. Select the desired `obs` rows and `var` columns with an [ExperimentAxisQuery].
3. Create an `ExperimentAxisQueryIterDataPipe`.

The example below utilizes a recent CZI Census release, accessed directly from S3. We also encode the `obs` `cell_type` labels, using a `scikit-learn` [LabelEncoder].

[Experiment]: https://tiledbsoma.readthedocs.io/en/stable/_autosummary/tiledbsoma.Experiment.html#tiledbsoma.Experiment
[ExperimentAxisQuery]: https://tiledbsoma.readthedocs.io/en/stable/_autosummary/tiledbsoma.ExperimentAxisQuery.html
[LabelEncoder]: https://scikit-learn.org/stable/modules/generated/sklearn.preprocessing.LabelEncoder.html

In [1]:
import tiledbsoma as soma
from sklearn.preprocessing import LabelEncoder

import tiledbsoma_ml as soma_ml

CZI_Census_Homo_Sapiens_URL = "s3://cellxgene-census-public-us-west-2/cell-census/2024-07-01/soma/census_data/homo_sapiens/"

experiment = soma.open(
    CZI_Census_Homo_Sapiens_URL,
    context=soma.SOMATileDBContext(tiledb_config={"vfs.s3.region": "us-west-2", "vfs.s3.no_sign_request": "true"}),
)
obs_value_filter = "tissue_general == 'tongue' and is_primary_data == True"

with experiment.axis_query(
    measurement_name="RNA", obs_query=soma.AxisQuery(value_filter=obs_value_filter)
) as query:
    obs_df = query.obs(column_names=["cell_type"]).concat().to_pandas()
    cell_type_encoder = LabelEncoder().fit(obs_df["cell_type"].unique())

    experiment_dataset = soma_ml.ExperimentAxisQueryIterDataPipe(
        query,
        layer_name="raw",
        obs_column_names=["cell_type"],
        batch_size=128,
        shuffle=True,
    )

  """
################################################################################
The 'datapipes', 'dataloader2' modules are deprecated and will be removed in a
future torchdata release! Please see https://github.com/pytorch/data/issues/1196
to learn more and leave feedback.
################################################################################



### `ExperimentAxisQueryIterDataPipe` class explained

This class provides an implementation of PyTorch's [`torchdata` IterDataPipe interface][IterDataPipe], which defines a common mechanism for wrapping and accessing training data from any underlying source. The `ExperimentAxisQueryIterDataPipe` class encapsulates the details of querying a SOMA `Experiment` and returning a series of "batches," each consisting of a NumPy `ndarray` and a Pandas `DataFrame`. Most importantly, it retrieves data lazily, avoiding loading the entire training dataset into memory at once.

### `ExperimentAxisQueryIterDataPipe` parameters explained

The constructor only requires a single parameter, `query`, which is an [`ExperimentAxisQuery`] containing the data to be used for training. This is obtained by querying an [`Experiment`], along the `obs` and/or `var` axes (see above, or [the TileDB-SOMA docs][tdbs docs], for examples).

The values for the prediction label(s) that you intend to use for training are specified via the `obs_column_names` array.

The `batch_size` parameter specifies the number of `obs` rows (i.e., cells) returned in each batch (default: `1`).

The `shuffle` flag supports randomizing the ordering of the training data for each training epoch (default: `True`). Note:
* You should use this flag instead of [`DataLoader`]'s `shuffle` flag, primarily for performance reasons.
* [TorchData] also provides a [Shuffler] `DataPipe`, which is another way to shuffle an `IterDataPipe`. However, `Shuffler` does not "globally" randomize training data, it only "locally" shuffles (within fixed-size "windows"). This is problematic for atlas-style datasets such as [CZI Census], where `obs` axis attributes tend to be homogeneous within contiguous "windows,", so this shuffling strategy may not provide sufficient randomization for certain types of models.

[IterDataPipe]: https://pytorch.org/data/main/torchdata.datapipes.iter.html
[`ExperimentAxisQuery`]: https://tiledbsoma.readthedocs.io/en/stable/_autosummary/tiledbsoma.ExperimentAxisQuery.html
[`Experiment`]: https://github.com/single-cell-data/TileDB-SOMA/blob/1.14.3/apis/python/src/tiledbsoma/_experiment.py#L80
[tdbs docs]: https://tiledbsoma.readthedocs.io/en/stable/
[`DataLoader`]: https://pytorch.org/docs/stable/data.html
[TorchData]: https://pytorch.org/data/beta/index.html
[Shuffler]: https://pytorch.org/data/main/generated/torchdata.datapipes.iter.Shuffler.html
[CZI Census]: https://chanzuckerberg.github.io/cellxgene-census/

You can inspect the shape of the full dataset, without causing the full dataset to be loaded. The `shape` property returns the number of batches on the first dimension:

In [2]:
experiment_dataset.shape

(118, 60530)

## Split the dataset <a id="split"></a>

You may split the overall dataset into the typical training, validation, and test sets by using the PyTorch [RandomSplitter](https://pytorch.org/data/main/generated/torchdata.datapipes.iter.RandomSplitter.html#torchdata.datapipes.iter.RandomSplitter) `DataPipe`. Using PyTorch's functional form for chaining `DataPipe`s, this is done as follows:

In [3]:
train_dataset, test_dataset = experiment_dataset.random_split(weights={"train": 0.8, "test": 0.2}, seed=1)

## Create the DataLoader <a id="data-loader"></a>

With the full set of DataPipe operations chained together, you can instantiate a PyTorch [DataLoader](https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader) on the training data.

In [4]:
train_dataloader = soma_ml.experiment_dataloader(train_dataset)

Instantiating a `DataLoader` object directly is not recommended, as several of its parameters interfere with iterable-style DataPipes like `ExperimentAxisQueryIterDataPipe`. Using `experiment_dataloader` helps enforce correct usage.

## Define the model <a id="model"></a>

With the training data retrieval code now in place, we can move on to defining a simple logistic regression model, using PyTorch's `torch.nn.Linear` class:

In [5]:
import torch

class LogisticRegression(torch.nn.Module):
    def __init__(self, input_dim, output_dim):
        super(LogisticRegression, self).__init__()
        self.linear = torch.nn.Linear(input_dim, output_dim)

    def forward(self, x):
        outputs = torch.sigmoid(self.linear(x))
        return outputs

Next, we define a function to train the model for a single epoch:

In [6]:
def train_epoch(model, train_dataloader, loss_fn, optimizer, device):
    model.train()
    train_loss = 0
    train_correct = 0
    train_total = 0

    for X_batch, y_batch in train_dataloader:
        optimizer.zero_grad()

        X_batch = torch.from_numpy(X_batch).float().to(device)

        # Perform prediction
        outputs = model(X_batch)

        # Determine the predicted label
        probabilities = torch.nn.functional.softmax(outputs, 1)
        predictions = torch.argmax(probabilities, axis=1)

        # Compute the loss and perform back propagation
        y_batch = torch.from_numpy(cell_type_encoder.transform(y_batch['cell_type'])).to(device)
        train_correct += (predictions == y_batch).sum().item()
        train_total += len(predictions)

        loss = loss_fn(outputs, y_batch.long())
        train_loss += loss.item()
        loss.backward()
        optimizer.step()

    train_loss /= train_total
    train_accuracy = train_correct / train_total
    return train_loss, train_accuracy

Note the line, `X_batch, y_batch = batch`. Since the `train_dataloader` was configured with `batch_size=16`, these variables will hold tensors of rank 2. The `X_batch` tensor will appear, for example, as:

```
tensor([[0., 0., 0.,  ..., 1., 0., 0.],
        [0., 0., 2.,  ..., 0., 3., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 1., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 8.]])
      
```

For `batch_size=1`, the tensors will be of rank 1. The `X_batch` tensor will appear, for example, as:

```
tensor([0., 0., 0.,  ..., 1., 0., 0.])
```
    
For `y_batch`, this will contain the user-specified `obs` `cell_type` training labels. By default, these are encoded using a LabelEncoder and it will be a matrix where each column represents the encoded values of each column specified in `obs_column_names` when creating the datapipe (in this case, only the cell type). It will look like this:

```
tensor([1, 1, 3, ..., 2, 1, 4])

```
Note that cell type values are integer-encoded values, which can be decoded using `train_dataset.encoders` (more on this below).

## Train the model <a id="train"></a>

Finally, we are ready to train the model. Here we instantiate the model, a loss function, and an optimization method, then iterate through the desired number of training epochs. Note how the `train_dataloader` is passed into `train_epoch`, where for each epoch it will provide a new iterator through the training dataset.

In [7]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

# The size of the input dimension is the number of genes
input_dim = experiment_dataset.shape[1]

# The size of the output dimension is the number of distinct cell_type values
output_dim = len(cell_type_encoder.classes_)

model = LogisticRegression(input_dim, output_dim).to(device)
loss_fn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-05)

In [8]:
%%time
n_epochs = 20
for epoch in range(n_epochs):
    train_loss, train_accuracy = train_epoch(model, train_dataloader, loss_fn, optimizer, device)
    print(f"Epoch {epoch + 1}: Train Loss: {train_loss:.7f} Accuracy {train_accuracy:.4f}")

Epoch 1: Train Loss: 0.0168204 Accuracy 0.3381
Epoch 2: Train Loss: 0.0147634 Accuracy 0.4705
Epoch 3: Train Loss: 0.0143201 Accuracy 0.5093
Epoch 4: Train Loss: 0.0140467 Accuracy 0.5263
Epoch 5: Train Loss: 0.0138597 Accuracy 0.5695
Epoch 6: Train Loss: 0.0137500 Accuracy 0.6357
Epoch 7: Train Loss: 0.0136695 Accuracy 0.6987
Epoch 8: Train Loss: 0.0135382 Accuracy 0.8216
Epoch 9: Train Loss: 0.0134330 Accuracy 0.9219
Epoch 10: Train Loss: 0.0133367 Accuracy 0.9339
Epoch 11: Train Loss: 0.0132720 Accuracy 0.9377
Epoch 12: Train Loss: 0.0132345 Accuracy 0.9423
Epoch 13: Train Loss: 0.0131933 Accuracy 0.9456
Epoch 14: Train Loss: 0.0131710 Accuracy 0.9481
Epoch 15: Train Loss: 0.0131489 Accuracy 0.9514
Epoch 16: Train Loss: 0.0131206 Accuracy 0.9534
Epoch 17: Train Loss: 0.0131053 Accuracy 0.9559
Epoch 18: Train Loss: 0.0130829 Accuracy 0.9584
Epoch 19: Train Loss: 0.0130765 Accuracy 0.9594
Epoch 20: Train Loss: 0.0130524 Accuracy 0.9614
CPU times: user 3min 34s, sys: 1min 52s, total: 5

## Make predictions with the model <a id="predict"></a>

To make predictions with the model, we first create a new `DataLoader` using the `test_dataset`, which provides the "test" split of the original dataset. For this example, we will only make predictions on a single batch of data from the test split.

In [9]:
test_dataloader = soma_ml.experiment_dataloader(test_dataset)
X_batch, y_batch = next(iter(test_dataloader))
X_batch = torch.from_numpy(X_batch)
y_batch = torch.from_numpy(cell_type_encoder.transform(y_batch['cell_type']))

Next, we invoke the model on the `X_batch` input data and extract the predictions:

In [10]:
model.eval()
model.to(device)
outputs = model(X_batch.to(device))

In [11]:
cell_type_encoder

In [12]:
X_batch.shape, y_batch.shape

(torch.Size([128, 60530]), torch.Size([128]))

In [13]:
outputs

tensor([[7.1457e-17, 1.0307e-03, 4.7891e-11,  ..., 1.6927e-12, 5.7567e-20,
         4.4607e-10],
        [0.0000e+00, 1.0000e+00, 0.0000e+00,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [9.6133e-05, 1.3468e-02, 3.3928e-04,  ..., 1.2045e-03, 9.4754e-06,
         1.6436e-03],
        ...,
        [6.4198e-16, 9.9984e-01, 1.2217e-09,  ..., 7.0942e-11, 1.5370e-19,
         3.9809e-10],
        [1.0738e-07, 5.7099e-01, 1.3891e-04,  ..., 7.0405e-05, 3.6205e-09,
         3.8026e-04],
        [1.6748e-07, 9.9937e-01, 1.1633e-05,  ..., 5.0808e-05, 1.5206e-13,
         1.4158e-04]], device='cuda:0', grad_fn=<SigmoidBackward0>)

In [14]:
probabilities = torch.nn.functional.softmax(outputs, 1)
predictions = torch.argmax(probabilities, axis=1)
predictions

tensor([ 5,  1,  8,  7,  5, 11,  5,  1,  5,  7,  7,  7,  1, 11,  8,  1,  1,  8,
         1,  7,  8,  1,  1,  7,  7,  8,  7,  5,  1,  5,  1,  1,  7,  1,  8,  7,
         7,  7,  8,  5,  8,  1,  7,  1,  1,  1,  1,  8,  1,  5,  5,  1,  1,  5,
         1,  8,  1,  5,  8,  6,  7,  5,  9,  6,  5,  8,  7,  1,  5,  1,  1,  1,
         5,  7,  1,  1,  7,  7,  8,  8,  5,  1,  5,  9,  1,  1,  1,  1,  6,  1,
        11,  1,  5,  7,  2,  7,  7,  8,  1,  7,  7,  1,  1,  1,  1,  1,  7,  1,
         1,  8,  2,  1,  7,  1,  1, 11,  1,  8,  1,  1,  1,  7,  8,  5,  6,  1,
         7,  1], device='cuda:0')

The predictions are returned as the encoded values of `cell_type` label. To recover the original cell type labels as strings, we decode using the same `LabelEncoder` used for training.

At inference time, if the model inputs are not obtained via an `ExperimentAxisQueryIterDataPipe`, one could pickle the encoder at training time and save it along with the model. Then, at inference time it can be unpickled and used as shown below.

In [15]:
predicted_cell_types = cell_type_encoder.inverse_transform(predictions.cpu())
predicted_cell_types

array(['epithelial cell', 'basal cell', 'leukocyte', 'keratinocyte',
       'epithelial cell', 'vein endothelial cell', 'epithelial cell',
       'basal cell', 'epithelial cell', 'keratinocyte', 'keratinocyte',
       'keratinocyte', 'basal cell', 'vein endothelial cell', 'leukocyte',
       'basal cell', 'basal cell', 'leukocyte', 'basal cell',
       'keratinocyte', 'leukocyte', 'basal cell', 'basal cell',
       'keratinocyte', 'keratinocyte', 'leukocyte', 'keratinocyte',
       'epithelial cell', 'basal cell', 'epithelial cell', 'basal cell',
       'basal cell', 'keratinocyte', 'basal cell', 'leukocyte',
       'keratinocyte', 'keratinocyte', 'keratinocyte', 'leukocyte',
       'epithelial cell', 'leukocyte', 'basal cell', 'keratinocyte',
       'basal cell', 'basal cell', 'basal cell', 'basal cell',
       'leukocyte', 'basal cell', 'epithelial cell', 'epithelial cell',
       'basal cell', 'basal cell', 'epithelial cell', 'basal cell',
       'leukocyte', 'basal cell', 'epitheli

Finally, we create a Pandas DataFrame to examine the predictions:

In [16]:
import pandas as pd

cmp_df = pd.DataFrame(
    {
        "actual cell type": cell_type_encoder.inverse_transform(y_batch.ravel().numpy()),
        "predicted cell type": predicted_cell_types,
    }
)
cmp_df

Unnamed: 0,actual cell type,predicted cell type
0,epithelial cell,epithelial cell
1,basal cell,basal cell
2,leukocyte,leukocyte
3,keratinocyte,keratinocyte
4,epithelial cell,epithelial cell
...,...,...
123,epithelial cell,epithelial cell
124,fibroblast,fibroblast
125,basal cell,basal cell
126,keratinocyte,keratinocyte


In [None]:
right, wrong = (cmp_df['actual cell type'] == cmp_df['predicted cell type']).value_counts().values
print('Accuracy: %.1f%% (%d correct, %d incorrect)' % (100 * right / len(cmp_df), right, wrong))

pd.crosstab(cmp_df['actual cell type'], cmp_df['predicted cell type']).replace(0, '')