# End-to-End Tutorial: Training a Neural Network with PyTorch and Xbatcher

This tutorial demonstrates how to use xarray, xbatcher, and PyTorch to train a simple neural network on the FashionMNIST dataset.

## Step 1: Setup 

Import the necessary libraries and load the dataset

In [None]:
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data
import xarray as xr

import xbatcher as xb
import xbatcher.loaders.torch

In [None]:
ds = xr.open_dataset(
    's3://carbonplan-share/xbatcher/fashion-mnist-train.zarr',
    engine='zarr',
    chunks={},
    backend_kwargs={'storage_options': {'anon': True}},
)
ds

In [None]:
ds.sel(sample=1).images.plot(cmap='gray');

## Step 2: Create batch generator and data loader

We use `xbatcher` to create batch generators for the images (`X_bgen`) and labels (`y_gen`)

In [None]:
# Define batch generators
X_bgen = xb.BatchGenerator(
    ds['images'],
    input_dims={'sample': 2000, 'channel': 1, 'height': 28, 'width': 28},
    preload_batch=False,
)
y_bgen = xb.BatchGenerator(
    ds['labels'], input_dims={'sample': 2000}, preload_batch=False
)
X_bgen[0]

In [None]:
# Map batches to a PyTorch-compatible dataset
dataset = xbatcher.loaders.torch.MapDataset(X_bgen, y_bgen)

In [None]:
# Create a DataLoader
train_dataloader = torch.utils.data.DataLoader(
    dataset,
    batch_size=None,  # Using batches defined by the dataset itself (via xbatcher)
    prefetch_factor=3,  # Prefetch up to 3 batches in advance to reduce data loading latency
    num_workers=4,  # Use 4 parallel worker processes to load data concurrently
    persistent_workers=True,  # Keep workers alive between epochs for faster subsequent epochs
    multiprocessing_context='forkserver',  # Use "forkserver" to spawn subprocesses, ensuring stability in multiprocessing
)

In [None]:
train_features, train_labels = next(iter(train_dataloader))

In [None]:
print(f'Feature batch shape: {train_features.size()}')
print(f'Labels batch shape: {train_labels.size()}')
img = train_features[0].squeeze()
label = train_labels[0]
plt.imshow(img, cmap='gray')
plt.show()
print(f'Label: {label}')

## Step 3: Define the Neural Network

We define a simple feedforward neural network for classification.

In [None]:
class SimpleNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(28 * 28, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.flatten(x)
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

In [None]:
# Instantiate the model
model = SimpleNN()
model

## Step 4: Define Loss Function and Optimizer
We use Cross-Entropy Loss and the Adam optimizer.

In [None]:
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

##  Step 5: Train the Model
We train the model using the data loader.

In [None]:
%%time

epochs = 5

for epoch in range(epochs):
    print(f'Epoch {epoch+1}/{epochs}')
    for batch, (X, y) in enumerate(train_dataloader):
        # Forward pass
        predictions = model(X)
        loss = loss_fn(predictions, y)

        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if batch % 10 == 0:
            print(f'Batch {batch}: Loss = {loss.item():.4f}')

print('Training completed!')

##  Step 6: Evaluate the Model
You can evaluate the model on the test set or visualize some predictions.

In [None]:
# Visualize a sample prediction
img = train_features[0].squeeze()
label = train_labels[0]
predicted_label = torch.argmax(model(train_features[0:1]), dim=1).item()

plt.imshow(img, cmap='gray')
plt.title(f'True Label: {label}, Predicted: {predicted_label}')
plt.show()

## Key Highlights

- **Data Handling**: We use Xbatcher to create efficient, chunked data pipelines from Xarray datasets.
- **Integration**: The `xbatcher.loaders.torch.MapDatase`t enables direct compatibility with PyTorch's DataLoader.
- **Training**: PyTorch simplifies the model training loop while leveraging the custom data pipeline.
