# Multi-process training

Multi-process usage of `tiledbsoma_ml.ExperimentAxisQueryIterDataset` includes both:
* using the [`torch.utils.data.DataLoader`] with 1 or more workers (i.e., with an argument of `n_workers=1` or greater)
* using a multi-process training configuration, such as [`DistributedDataParallel`]

In these configurations, `ExperimentAxisQueryIterDataset` will automatically partition data across workers. However, when using `shuffle=True`, there are several things to keep in mind:

1. All worker processes must share the same random number generator `seed`, ensuring that all workers shuffle and partition the data in the same way.
2. To ensure that each epoch returns a _different_ shuffle, the caller must set the epoch, using the `set_epoch` API. This is identical to the behavior of [`torch.utils.data.distributed.DistributedSampler`].

[DataLoader]: https://pytorch.org/docs/stable/data.html
[`torch.utils.data.DataLoader`]: https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader
[`torch.utils.data.distributed.DistributedSampler`]: https://pytorch.org/docs/stable/data.html#torch.utils.data.distributed.DistributedSampler
[`DistributedDataParallel`]: https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html

In [None]:
import tiledbsoma as soma
import torch
from sklearn.preprocessing import LabelEncoder

import tiledbsoma_ml as soma_ml

CZI_Census_Homo_Sapiens_URL = "s3://cellxgene-census-public-us-west-2/cell-census/2024-07-01/soma/census_data/homo_sapiens/"

experiment = soma.open(
    CZI_Census_Homo_Sapiens_URL,
    context=soma.SOMATileDBContext(tiledb_config={"vfs.s3.region": "us-west-2", "vfs.s3.no_sign_request": "true"}),
)
obs_value_filter = "tissue_general == 'tongue' and is_primary_data == True"

with experiment.axis_query(
    measurement_name="RNA", obs_query=soma.AxisQuery(value_filter=obs_value_filter)
) as query:
    obs_df = query.obs(column_names=["cell_type"]).concat().to_pandas()
    cell_type_encoder = LabelEncoder().fit(obs_df["cell_type"].unique())

    experiment_dataset = soma_ml.ExperimentDataset(
        query,
        layer_name="raw",
        obs_column_names=["cell_type"],
        batch_size=128,
        shuffle=True,
    )
  

################################################################################
The 'datapipes', 'dataloader2' modules are deprecated and will be removed in a
future torchdata release! Please see https://github.com/pytorch/data/issues/1196
to learn more and leave feedback.
################################################################################



In [None]:
class LogisticRegression(torch.nn.Module):
    def __init__(self, input_dim, output_dim):
        super(LogisticRegression, self).__init__()  # noqa: UP008
        self.linear = torch.nn.Linear(input_dim, output_dim)

    def forward(self, x):
        outputs = torch.sigmoid(self.linear(x))
        return outputs
    

def train_epoch(model, train_dataloader, loss_fn, optimizer, device):
    model.train()
    train_loss = 0
    train_correct = 0
    train_total = 0

    for X_batch, y_batch in train_dataloader:
        optimizer.zero_grad()

        X_batch = torch.from_numpy(X_batch).float().to(device)

        # Perform prediction
        outputs = model(X_batch)

        # Determine the predicted label
        probabilities = torch.nn.functional.softmax(outputs, 1)
        predictions = torch.argmax(probabilities, axis=1)

        # Compute the loss and perform back propagation
        y_batch = torch.from_numpy(cell_type_encoder.transform(y_batch['cell_type'])).to(device)
        train_correct += (predictions == y_batch).sum().item()
        train_total += len(predictions)

        loss = loss_fn(outputs, y_batch.long())
        train_loss += loss.item()
        loss.backward()
        optimizer.step()

    train_loss /= train_total
    train_accuracy = train_correct / train_total
    return train_loss, train_accuracy

## Multi-worker DataLoader

If you use a multi-worker data loader (i.e., `num_workers` with a value other than `0`), and `shuffle=True`, remember to call `set_epoch` at the start of each epoch, _before_ the iterator is created.

The same approach should be taken for parallel training, e.g., when using DDP or DP.

*Tip*: when running with `num_workers=0`, i.e., using the data loader in-process, the `ExperimentAxisQueryIterDataset` will automatically increment the epoch count each time the iterator completes.

In [3]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

# The size of the input dimension is the number of genes
input_dim = experiment_dataset.shape[1]

# The size of the output dimension is the number of distinct cell_type values
output_dim = len(cell_type_encoder.classes_)

model = LogisticRegression(input_dim, output_dim).to(device)
loss_fn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-05)

# define a two-worker data loader. The dataset is shuffled, so call `set_epoch` to ensure
# that a different shuffle is applied on each epoch.
experiment_dataloader = soma_ml.experiment_dataloader(
    experiment_dataset, num_workers=2, persistent_workers=True
)

switching torch multiprocessing start method from "fork" to "spawn"


In [4]:
%%time
for epoch in range(20):
    experiment_dataset.set_epoch(epoch)
    train_loss, train_accuracy = train_epoch(
        model, experiment_dataloader, loss_fn, optimizer, device
    )
    print(
        f"Epoch {epoch + 1}: Train Loss: {train_loss:.7f} Accuracy {train_accuracy:.4f}"
    )

################################################################################
The 'datapipes', 'dataloader2' modules are deprecated and will be removed in a
future torchdata release! Please see https://github.com/pytorch/data/issues/1196
to learn more and leave feedback.
################################################################################

################################################################################
The 'datapipes', 'dataloader2' modules are deprecated and will be removed in a
future torchdata release! Please see https://github.com/pytorch/data/issues/1196
to learn more and leave feedback.
################################################################################



Epoch 1: Train Loss: 0.0163469 Accuracy 0.3040
Epoch 2: Train Loss: 0.0147538 Accuracy 0.4680
Epoch 3: Train Loss: 0.0143231 Accuracy 0.5575
Epoch 4: Train Loss: 0.0140254 Accuracy 0.6792
Epoch 5: Train Loss: 0.0138131 Accuracy 0.7943
Epoch 6: Train Loss: 0.0136806 Accuracy 0.8565
Epoch 7: Train Loss: 0.0135832 Accuracy 0.8925
Epoch 8: Train Loss: 0.0134801 Accuracy 0.9130
Epoch 9: Train Loss: 0.0134058 Accuracy 0.9217
Epoch 10: Train Loss: 0.0133559 Accuracy 0.9272
Epoch 11: Train Loss: 0.0133088 Accuracy 0.9336
Epoch 12: Train Loss: 0.0132692 Accuracy 0.9406
Epoch 13: Train Loss: 0.0132350 Accuracy 0.9451
Epoch 14: Train Loss: 0.0132063 Accuracy 0.9495
Epoch 15: Train Loss: 0.0131819 Accuracy 0.9535
Epoch 16: Train Loss: 0.0131599 Accuracy 0.9556
Epoch 17: Train Loss: 0.0131419 Accuracy 0.9585
Epoch 18: Train Loss: 0.0131233 Accuracy 0.9605
Epoch 19: Train Loss: 0.0131053 Accuracy 0.9625
Epoch 20: Train Loss: 0.0130904 Accuracy 0.9638
CPU times: user 1min 13s, sys: 1min 11s, total: 2