# Tune Model on MNIST dataset using PyTorchJob and FSDP

This Notebook will tune a small model on the MNIST dataset using FSDP.

This Notebook will use **4** GPUs to train the model on 2 Nodes. This example is based on [the official PyTorch FSDP tutorial](https://pytorch.org/tutorials/intermediate/FSDP_tutorial.html).

## FSDP with multi-node multi-worker training

This Notebook demonstrates multi-node, multi-worker distributed training with Fully Sharded Data Parallel (FSDP) and PyTorchJob.

When a model is trained with FSDP, the GPU memory footprint is smaller compare to Distributed Data Parallel (DDP),
as the model parameters are sharded across GPU devices.

This enables training of very large models that would otherwise be impossible to fit on a single GPU device.

Check this guide to learn more about PyTorch FSDP: https://pytorch.org/tutorials/intermediate/FSDP_tutorial.html


## Install the required packages

Install the Kubeflow Training Python SDK.

In [None]:
# TODO (andreyvelich): Use the release version of SDK.
!pip install git+https://github.com/kubeflow/training-operator.git#subdirectory=sdk/python

## Create script to train using MNIST using FSDP

We need to wrap our fine-tuning script in a function to create Kubeflow PyTorchJob.

In [12]:
def train_function(parameters):
    import os
    import time
    import functools

    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    import torch.optim as optim
    from torchvision import datasets, transforms

    from torch.optim.lr_scheduler import StepLR

    import torch.distributed as dist
    import torch.distributed as dist
    import torch.multiprocessing as mp
    from torch.nn.parallel import DistributedDataParallel as DDP
    from torch.utils.data.distributed import DistributedSampler
    from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
    from torch.distributed.fsdp.fully_sharded_data_parallel import (
        CPUOffload,
        BackwardPrefetch,
    )
    from torch.distributed.fsdp.wrap import (
        size_based_auto_wrap_policy,
        enable_wrap,
        wrap,
    )

    class Net(nn.Module):
        def __init__(self):
            super(Net, self).__init__()
            self.conv1 = nn.Conv2d(1, 32, 3, 1)
            self.conv2 = nn.Conv2d(32, 64, 3, 1)
            self.dropout1 = nn.Dropout(0.25)
            self.dropout2 = nn.Dropout(0.5)
            self.fc1 = nn.Linear(9216, 128)
            self.fc2 = nn.Linear(128, 10)

        def forward(self, x):

            x = self.conv1(x)
            x = F.relu(x)
            x = self.conv2(x)
            x = F.relu(x)
            x = F.max_pool2d(x, 2)
            x = self.dropout1(x)
            x = torch.flatten(x, 1)
            x = self.fc1(x)
            x = F.relu(x)
            x = self.dropout2(x)
            x = self.fc2(x)
            output = F.log_softmax(x, dim=1)
            return output
        

    def train(args, model, rank, world_size, train_loader, optimizer, epoch, sampler=None):
        model.train()
        ddp_loss = torch.zeros(2).to(rank)
        if sampler:
            sampler.set_epoch(epoch)
        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.to(rank), target.to(rank)
            optimizer.zero_grad()
            output = model(data)
            loss = F.nll_loss(output, target, reduction='sum')
            loss.backward()
            optimizer.step()
            ddp_loss[0] += loss.item()
            ddp_loss[1] += len(data)

        dist.all_reduce(ddp_loss, op=dist.ReduceOp.SUM)
        if rank == 0:
            print('Train Epoch: {} \tLoss: {:.6f}'.format(epoch, ddp_loss[0] / ddp_loss[1]))
            
    def test(model, rank, world_size, test_loader):
        model.eval()
        correct = 0
        ddp_loss = torch.zeros(3).to(rank)
        with torch.no_grad():
            for data, target in test_loader:
                data, target = data.to(rank), target.to(rank)
                output = model(data)
                ddp_loss[0] += F.nll_loss(output, target, reduction='sum').item()  # sum up batch loss
                pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
                ddp_loss[1] += pred.eq(target.view_as(pred)).sum().item()
                ddp_loss[2] += len(data)

        dist.all_reduce(ddp_loss, op=dist.ReduceOp.SUM)

        if rank == 0:
            test_loss = ddp_loss[0] / ddp_loss[2]
            print('Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\n'.format(
                test_loss, int(ddp_loss[1]), int(ddp_loss[2]),
                100. * ddp_loss[1] / ddp_loss[2]))


    # [1] Setup PyTorch distributed and get the distributed parameters.
    torch.manual_seed(parameters["seed"])
    dist.init_process_group("nccl")
    local_rank = int(os.environ["LOCAL_RANK"])
    rank = dist.get_rank()
    world_size = dist.get_world_size()

    # Local rank identifies the GPU number inside the pod.
    torch.cuda.set_device(local_rank)

    print(
        f"FSDP Training for WORLD_SIZE: {world_size}, RANK: {rank}, LOCAL_RANK: {local_rank}"
    )

    transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])

    dataset1 = datasets.MNIST('../data', train=True, download=True,
                        transform=transform)
    dataset2 = datasets.MNIST('../data', train=False,
                        transform=transform)

    sampler1 = DistributedSampler(dataset1, rank=rank, num_replicas=world_size, shuffle=True)
    sampler2 = DistributedSampler(dataset2, rank=rank, num_replicas=world_size)

    train_kwargs = {'batch_size': parameters["batch-size"], 'sampler': sampler1}
    test_kwargs = {'batch_size': parameters["test-batch-size"], 'sampler': sampler2}
    cuda_kwargs = {'num_workers': 2,
                    'pin_memory': True,
                    'shuffle': False}
    train_kwargs.update(cuda_kwargs)
    test_kwargs.update(cuda_kwargs)

    train_loader = torch.utils.data.DataLoader(dataset1,**train_kwargs)
    test_loader = torch.utils.data.DataLoader(dataset2, **test_kwargs)
    my_auto_wrap_policy = functools.partial(
        size_based_auto_wrap_policy, min_num_params=100
    )

    init_start_event = torch.cuda.Event(enable_timing=True)
    init_end_event = torch.cuda.Event(enable_timing=True)

    model = Net().to(local_rank)

    model = FSDP(model)

    optimizer = optim.Adadelta(model.parameters(), lr=parameters["lr"])

    scheduler = StepLR(optimizer, step_size=1, gamma=parameters["gamma"])
    init_start_event.record()
    for epoch in range(1, parameters["epochs"] + 1):
        train(parameters, model, local_rank, world_size, train_loader, optimizer, epoch, sampler=sampler1)
        test(model, local_rank, world_size, test_loader)
        scheduler.step()

    init_end_event.record()

    if rank == 0:
        init_end_event.synchronize()
        print(f"CUDA event elapsed time: {init_start_event.elapsed_time(init_end_event) / 1000}sec")
        print(f"{model}")

    if parameters["save-model"]:
        # use a barrier to make sure training is done on all ranks
        dist.barrier()
        states = model.state_dict()
        if rank == 0:
            torch.save(states, "mnist_cnn.pt")


## Create Kubeflow PyTorchJob to train on MNIST with FSDP

Use `TrainingClient()` to create PyTorchJob which will train on **2 workers** using **2 GPU** for each worker.

If you don't have enough GPU resources, you can decrease number of workers or number of GPUs per worker.

In [19]:
from kubeflow.training import TrainingClient

job_name = "mnist-training"

parameters = {
    "batch-size": 64,
    "test-batch-size": 1000,
    "epochs": 10,
    "lr": 1.0,
    "gamma": 0.7,
    "seed": 1,
    "save-model": False,
}


In [20]:
# Create the PyTorchJob.
TrainingClient().create_job(
    name=job_name,
    train_func=train_function,
    parameters=parameters,
    num_workers=2, # You can modify number of workers or number of GPUs.
    num_procs_per_worker=2,
    resources_per_worker={"gpu": 2},
)

### Check the PyTorchJob conditions

Use `TrainingClient()` APIs to get information about created PyTorchJob.

In [None]:
print("PyTorchJob Conditions")
print(TrainingClient().get_job_conditions(job_name))
print("-" * 40)

# Wait until PyTorchJob has the Running condition.
job = TrainingClient().wait_for_job_conditions(
    job_name,
    expected_conditions={"Running"},
)
print("PyTorchJob is running")

### Get the PyTorchJob pod names

Since we define 2 workers, PyTorchJob will create 1 master pod and 1 worker pod to run FSDP fine-tuning.

In [21]:
TrainingClient().get_job_pod_names(job_name)


['mnist-training-master-0', 'mnist-training-worker-0']

### Get the PyTorchJob training logs

Model parameters are sharded across all workers and GPU devices.

In [None]:
logs, _ = TrainingClient().get_job_logs(job_name, follow=True)


## Delete the PyTorchJob

You can delete the created PyTorchJob.

In [14]:
TrainingClient().delete_job(name=job_name)