# Multi-process training

Multi-process usage of `tiledbsoma_ml.ExperimentDataset` includes both:
* using the [`torch.utils.data.DataLoader`] with 1 or more workers (i.e., with an argument of `n_workers=1` or greater)
* using a multi-process training configuration, such as [`DistributedDataParallel`]

In these configurations, `ExperimentDataset` will automatically partition data across workers. However, when using `shuffle=True`, there are several things to keep in mind:

1. All worker processes must share the same random number generator `seed`, ensuring that all workers shuffle and partition the data in the same way.
2. To ensure that each epoch returns a _different_ shuffle, the caller must set the epoch, using the `set_epoch` API. This is identical to the behavior of [`torch.utils.data.distributed.DistributedSampler`].

[DataLoader]: https://pytorch.org/docs/stable/data.html
[`torch.utils.data.DataLoader`]: https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader
[`torch.utils.data.distributed.DistributedSampler`]: https://pytorch.org/docs/stable/data.html#torch.utils.data.distributed.DistributedSampler
[`DistributedDataParallel`]: https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html

[Papermill] parameters:

[Papermill]: https://papermill.readthedocs.io/

In [1]:
import os

tissue = "tongue"
n_epochs = 20
census_version = "2024-07-01"
batch_size = 128
learning_rate = 1e-5
num_workers = 2

In [2]:
from tiledbsoma import AxisQuery, Experiment, SOMATileDBContext
from sklearn.preprocessing import LabelEncoder

from tiledbsoma_ml import ExperimentDataset

CZI_Census_Homo_Sapiens_URL = f"s3://cellxgene-census-public-us-west-2/cell-census/{census_version}/soma/census_data/homo_sapiens/"

experiment = Experiment.open(
    CZI_Census_Homo_Sapiens_URL,
    context=SOMATileDBContext(tiledb_config={"vfs.s3.region": "us-west-2", "vfs.s3.no_sign_request": "true"}),
)
obs_value_filter = f"tissue_general == '{tissue}' and is_primary_data == True"

with experiment.axis_query(
    measurement_name="RNA", obs_query=AxisQuery(value_filter=obs_value_filter)
) as query:
    obs_df = query.obs(column_names=["cell_type"]).concat().to_pandas()
    cell_type_encoder = LabelEncoder().fit(obs_df["cell_type"].unique())

experiment_dataset = ExperimentDataset(
    query,
    layer_name="raw",
    obs_column_names=["cell_type"],
    batch_size=batch_size,
    shuffle=True,
)

In [3]:
import torch

class LogisticRegression(torch.nn.Module):
    def __init__(self, input_dim, output_dim):
        super(LogisticRegression, self).__init__()
        self.linear = torch.nn.Linear(input_dim, output_dim)

    def forward(self, x):
        outputs = torch.sigmoid(self.linear(x))
        return outputs
    

def train_epoch(model, train_dataloader, loss_fn, optimizer, device):
    model.train()
    train_loss = 0
    train_correct = 0
    train_total = 0

    for X_batch, obs_batch in train_dataloader:
        optimizer.zero_grad()

        X_batch = torch.from_numpy(X_batch).float().to(device)

        # Perform prediction
        outputs = model(X_batch)

        # Determine the predicted label
        probabilities = torch.nn.functional.softmax(outputs, 1)
        predictions = torch.argmax(probabilities, axis=1)

        # Compute the loss and perform back propagation
        obs_batch = torch.from_numpy(cell_type_encoder.transform(obs_batch['cell_type'])).to(device)
        train_correct += (predictions == obs_batch).sum().item()
        train_total += len(predictions)

        loss = loss_fn(outputs, obs_batch.long())
        train_loss += loss.item()
        loss.backward()
        optimizer.step()

    train_loss /= train_total
    train_accuracy = train_correct / train_total
    return train_loss, train_accuracy

## Multi-worker DataLoader

If you use a multi-worker data loader (i.e., `num_workers` with a value other than `0`), and `shuffle=True`, remember to call `set_epoch` at the start of each epoch, _before_ the iterator is created.

The same approach should be taken for parallel training, e.g., when using DDP or DP.

*Tip*: when running with `num_workers=0`, i.e., using the data loader in-process, the `ExperimentDataset` will automatically increment the epoch count each time the iterator completes.

In [4]:
from tiledbsoma_ml import experiment_dataloader

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

# The size of the input dimension is the number of genes
input_dim = experiment_dataset.shape[1]

# The size of the output dimension is the number of distinct cell_type values
output_dim = len(cell_type_encoder.classes_)

model = LogisticRegression(input_dim, output_dim).to(device)
loss_fn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

# Define a two-worker data loader. The dataset is shuffled, so call `set_epoch` to ensure
# that a different shuffle is applied on each epoch.
dataloader = experiment_dataloader(
    experiment_dataset, num_workers=num_workers, persistent_workers=True
)

switching torch multiprocessing start method from "fork" to "spawn"


In [5]:
%%time
for epoch in range(n_epochs):
    experiment_dataset.set_epoch(epoch)
    train_loss, train_accuracy = train_epoch(
        model, dataloader, loss_fn, optimizer, device
    )
    print(
        f"Epoch {epoch + 1}: Train Loss: {train_loss:.7f} Accuracy {train_accuracy:.4f}"
    )

Epoch 1: Train Loss: 0.0165012 Accuracy 0.3866
Epoch 2: Train Loss: 0.0148111 Accuracy 0.4217
Epoch 3: Train Loss: 0.0144168 Accuracy 0.6109
Epoch 4: Train Loss: 0.0141248 Accuracy 0.8374
Epoch 5: Train Loss: 0.0138151 Accuracy 0.9001
Epoch 6: Train Loss: 0.0136300 Accuracy 0.9123
Epoch 7: Train Loss: 0.0135218 Accuracy 0.9234
Epoch 8: Train Loss: 0.0134472 Accuracy 0.9324
Epoch 9: Train Loss: 0.0133907 Accuracy 0.9375
Epoch 10: Train Loss: 0.0133443 Accuracy 0.9419
Epoch 11: Train Loss: 0.0132998 Accuracy 0.9456
Epoch 12: Train Loss: 0.0132594 Accuracy 0.9489
Epoch 13: Train Loss: 0.0132298 Accuracy 0.9524
Epoch 14: Train Loss: 0.0132037 Accuracy 0.9549
Epoch 15: Train Loss: 0.0131809 Accuracy 0.9568
Epoch 16: Train Loss: 0.0131603 Accuracy 0.9585
Epoch 17: Train Loss: 0.0131425 Accuracy 0.9601
Epoch 18: Train Loss: 0.0131270 Accuracy 0.9613
Epoch 19: Train Loss: 0.0131112 Accuracy 0.9630
Epoch 20: Train Loss: 0.0130966 Accuracy 0.9639
CPU times: user 1min 6s, sys: 1min 58s, total: 3m