# Tutorial 2: Understanding the Dataset and Model

In this notebook, we'll explore:
1. The FashionMNIST dataset
2. Data partitioning for federated learning
3. The CNN model architecture
4. Training and testing functions

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np
from datasets import load_dataset
from flwr_datasets import FederatedDataset
from flwr_datasets.partitioner import IidPartitioner
from torch.utils.data import DataLoader
from torchvision.transforms import Compose, ToTensor, Normalize

print("All imports successful!")

All imports successful!


## 1. Exploring FashionMNIST Dataset

FashionMNIST is a dataset of fashion product images:
- **70,000 grayscale images** (60,000 training + 10,000 test)
- **28x28 pixels** per image
- **10 classes**: T-shirt/top, Trouser, Pullover, Dress, Coat, Sandal, Shirt, Sneaker, Bag, Ankle boot

It's a drop-in replacement for MNIST, but more challenging!

Let's load and visualize some samples:

In [None]:
# Load FashionMNIST dataset
dataset = load_dataset("zalando-datasets/fashion_mnist", split="train")

print(f"Dataset size: {len(dataset)} images")
print(f"Dataset features: {dataset.features}")
print(f"\nFirst sample keys: {dataset[0].keys()}")

README.md: 0.00B [00:00, ?B/s]

fashion_mnist/train-00000-of-00001.parqu(…):   0%|          | 0.00/30.9M [00:00<?, ?B/s]

  [2m2025-10-19T17:44:00.547743Z[0m [31mERROR[0m  [31mPython exception updating progress:, error: PyErr { type: <class 'LookupError'>, value: LookupError(<ContextVar name='shell_parent' at 0xffffb1183ec0>), traceback: Some(<traceback object at 0xffff28b038c0>) }, [1;31mcaller[0m[31m: "src/progress_update.rs:313"[0m
    [2;3mat[0m /home/runner/work/xet-core/xet-core/error_printer/src/lib.rs:28



In [None]:
# Class names for FashionMNIST
class_names = [
    'T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',
    'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot'
]

# Visualize some images
fig, axes = plt.subplots(2, 5, figsize=(15, 6))
fig.suptitle('Sample Images from FashionMNIST', fontsize=16)

for idx, ax in enumerate(axes.flat):
    img = dataset[idx]['image']
    label = dataset[idx]['label']
    ax.imshow(img, cmap='gray')
    ax.set_title(f'{class_names[label]}')
    ax.axis('off')

plt.tight_layout()
plt.show()

## 2. Data Partitioning for Federated Learning

In federated learning, data is distributed across multiple clients. We'll explore two partitioning strategies:

### IID (Independent and Identically Distributed)
- Each client gets a random subset of data
- Class distribution is similar across clients
- **Easier scenario** for federated learning

### Non-IID (Non-Independent and Identically Distributed)
- Each client may have different class distributions
- **More realistic** but challenging scenario
- Example: One client has mostly t-shirts, another has mostly trousers

For this tutorial, we'll use **IID partitioning** to keep things simple.

In [None]:
# Create federated dataset with IID partitioning
num_partitions = 5  # Simulate 5 clients
partitioner = IidPartitioner(num_partitions=num_partitions)

fds = FederatedDataset(
    dataset="zalando-datasets/fashion_mnist",
    partitioners={"train": partitioner},
)

print(f"Created federated dataset with {num_partitions} partitions")
print(f"Each partition will have approximately {60000 // num_partitions} training images")

In [None]:
# Let's examine the class distribution for each partition
def get_class_distribution(partition):
    """Count the number of samples per class in a partition."""
    labels = [sample['label'] for sample in partition]
    return np.bincount(labels, minlength=10)

# Plot class distribution for each client
fig, axes = plt.subplots(1, num_partitions, figsize=(20, 4))
fig.suptitle('Class Distribution Across Clients (IID Partitioning)', fontsize=16)

for client_id in range(num_partitions):
    partition = fds.load_partition(client_id)
    distribution = get_class_distribution(partition)
    
    axes[client_id].bar(range(10), distribution, color='steelblue')
    axes[client_id].set_title(f'Client {client_id}')
    axes[client_id].set_xlabel('Class')
    axes[client_id].set_ylabel('Count')
    axes[client_id].set_xticks(range(10))
    axes[client_id].set_xticklabels(range(10))
    axes[client_id].set_ylim(0, 1500)

plt.tight_layout()
plt.show()

print("\nNotice how the distribution is roughly uniform across all clients.")
print("This is characteristic of IID partitioning.")

## 3. The CNN Model Architecture

We'll use a simple Convolutional Neural Network (CNN) for image classification:

```
Input (28x28x1) - Grayscale image
    ↓
Conv2D (6 filters, 5x5) → ReLU → MaxPool (2x2)
    ↓
Conv2D (16 filters, 5x5) → ReLU → MaxPool (2x2)
    ↓
Flatten
    ↓
FC (120) → ReLU
    ↓
FC (84) → ReLU
    ↓
FC (10) → Output
```

This is a classic architecture adapted for FashionMNIST.

In [None]:
class Net(nn.Module):
    def __init__(self, num_classes: int = 10) -> None:
        super(Net, self).__init__()
        # Convolutional layers (for 28x28 grayscale images)
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5, stride=1)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv2 = nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5, stride=1)
        
        # Fully connected layers
        self.fc1 = nn.Linear(in_features=16 * 4 * 4, out_features=120)
        self.fc2 = nn.Linear(in_features=120, out_features=84)
        self.fc3 = nn.Linear(in_features=84, out_features=num_classes)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # First conv block: 28x28x1 -> 24x24x6 -> 12x12x6
        x = self.pool(F.relu(self.conv1(x)))
        # Second conv block: 12x12x6 -> 8x8x16 -> 4x4x16
        x = self.pool(F.relu(self.conv2(x)))
        # Flatten: 4x4x16 = 256
        x = x.view(-1, 16 * 4 * 4)
        # Fully connected layers
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

# Create model and count parameters
model = Net()
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"Model Architecture:")
print(model)
print(f"\nTotal parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")

## 4. Data Preprocessing and Loading

Before training, we need to:
1. Convert images to tensors
2. Normalize pixel values
3. Create data loaders for batching

In [None]:
# Define transforms for grayscale images
pytorch_transforms = Compose([
    ToTensor(),  # Convert PIL Image to tensor (0-1 range)
    Normalize((0.5,), (0.5,))  # Normalize to (-1, 1) range for grayscale
])

def apply_transforms(batch):
    """Apply transforms to a batch of images."""
    batch["image"] = [pytorch_transforms(img) for img in batch["image"]]
    return batch

# Load data for one client
client_id = 0
partition = fds.load_partition(client_id)

# Split into train/test
partition_train_test = partition.train_test_split(test_size=0.2, seed=42)
partition_train_test = partition_train_test.with_transform(apply_transforms)

# Create dataloaders
batch_size = 32
trainloader = DataLoader(partition_train_test["train"], batch_size=batch_size, shuffle=True)
testloader = DataLoader(partition_train_test["test"], batch_size=batch_size)

print(f"Client {client_id} data:")
print(f"  Training samples: {len(partition_train_test['train'])}")
print(f"  Test samples: {len(partition_train_test['test'])}")
print(f"  Training batches: {len(trainloader)}")
print(f"  Test batches: {len(testloader)}")

In [None]:
# Visualize a batch after preprocessing
batch = next(iter(trainloader))
images = batch['image'][:8]
labels = batch['label'][:8]

# Denormalize for visualization
def denormalize(img):
    img = img * 0.5 + 0.5  # Reverse normalization
    return torch.clamp(img, 0, 1)

fig, axes = plt.subplots(2, 4, figsize=(12, 6))
fig.suptitle('Preprocessed Images (from a batch)', fontsize=16)

for idx, ax in enumerate(axes.flat):
    img = denormalize(images[idx]).squeeze().numpy()
    ax.imshow(img, cmap='gray')
    ax.set_title(f'{class_names[labels[idx]]}')
    ax.axis('off')

plt.tight_layout()
plt.show()

## 5. Training and Testing Functions

Let's implement the training and testing logic:

In [None]:
def train(net, trainloader, epochs, lr, device):
    """Train the model on the training set."""
    net.to(device)
    criterion = torch.nn.CrossEntropyLoss().to(device)
    optimizer = torch.optim.SGD(net.parameters(), lr=lr, momentum=0.9)
    net.train()
    
    running_loss = 0.0
    for epoch in range(epochs):
        epoch_loss = 0.0
        for batch in trainloader:
            images = batch["image"].to(device)
            labels = batch["label"].to(device)
            
            optimizer.zero_grad()
            outputs = net(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            epoch_loss += loss.item()
        
        running_loss += epoch_loss / len(trainloader)
        print(f"  Epoch {epoch+1}/{epochs}: Loss = {epoch_loss/len(trainloader):.4f}")
    
    avg_trainloss = running_loss / epochs
    return avg_trainloss

def test(net, testloader, device):
    """Validate the model on the test set."""
    net.to(device)
    criterion = torch.nn.CrossEntropyLoss()
    net.eval()
    
    correct, loss = 0, 0.0
    with torch.no_grad():
        for batch in testloader:
            images = batch["image"].to(device)
            labels = batch["label"].to(device)
            outputs = net(images)
            loss += criterion(outputs, labels).item()
            correct += (torch.max(outputs.data, 1)[1] == labels).sum().item()
    
    accuracy = correct / len(testloader.dataset)
    loss = loss / len(testloader)
    return loss, accuracy

print("Training and testing functions defined!")

## 6. Quick Training Test

Let's do a quick test to ensure everything works:

In [None]:
# Test training for 2 epochs
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}\n")

model = Net()
print("Training for 2 epochs...")
train_loss = train(model, trainloader, epochs=2, lr=0.01, device=device)

print(f"\nTesting model...")
test_loss, test_acc = test(model, testloader, device)

print(f"\nResults:")
print(f"  Average training loss: {train_loss:.4f}")
print(f"  Test loss: {test_loss:.4f}")
print(f"  Test accuracy: {test_acc*100:.2f}%")
print("\nEverything works! The model is learning.")

## 7. Visualize Model Predictions

Let's see what the model predicts on some test images:

In [None]:
# Get a batch of test images
model.eval()
batch = next(iter(testloader))
images = batch['image'][:8].to(device)
labels = batch['label'][:8]

# Get predictions
with torch.no_grad():
    outputs = model(images)
    _, predicted = torch.max(outputs, 1)

# Visualize
images = images.cpu()
predicted = predicted.cpu()

fig, axes = plt.subplots(2, 4, figsize=(12, 6))
fig.suptitle('Model Predictions', fontsize=16)

for idx, ax in enumerate(axes.flat):
    img = denormalize(images[idx]).squeeze().numpy()
    ax.imshow(img, cmap='gray')
    
    true_label = class_names[labels[idx]]
    pred_label = class_names[predicted[idx]]
    color = 'green' if labels[idx] == predicted[idx] else 'red'
    
    ax.set_title(f'True: {true_label}\nPred: {pred_label}', color=color, fontsize=9)
    ax.axis('off')

plt.tight_layout()
plt.show()

print("Green titles = correct predictions")
print("Red titles = incorrect predictions")

## Summary

In this notebook, we:
1. ✅ Explored the FashionMNIST dataset
2. ✅ Understood IID data partitioning
3. ✅ Built a CNN model architecture
4. ✅ Implemented training and testing functions
5. ✅ Verified everything works with a quick test
6. ✅ Visualized model predictions

**Next Steps**: In Notebook 3, we'll learn how to wrap this code into a Flower client for federated learning!

## Exercises for Students

**Exercise 1**: Modify the code to use Non-IID partitioning. (Hint: Look at `flwr_datasets.partitioner.DirichletPartitioner`)

**Exercise 2**: Count the number of parameters in each layer of the CNN. Which layer has the most parameters?

**Exercise 3**: What happens if you train for more epochs (e.g., 5)? Does accuracy improve?

**Exercise 4**: Modify the model to add one more convolutional layer. How does this affect the number of parameters?