# Distributed Mixed-precision Training with PyTorch and NVIDIA `Apex`

This tutorial goes over the most important parts of distributed training in PyTorch.

For more details, please refer to: https://github.com/richardkxu/distributed-pytorch

For the full ImageNet training script please refer to `imagenet_ddp_apex.py` in the above Git repo.

## What is `Apex`?
A Pytorch extension with NVIDIA-maintained utilities to streamline mixed precision and distributed training. It contains the full features of the built-in PyTorch Distributed Data Parallel (DDP) package. Additionally, it integrates better with NVIDIA GPUs and provides mixed-precision training acceleration.

`Apex` uses NVIDIA NVIDIA Collective Communications Library (NCCL) as the distributed backend. NCCL handles the communication for GPUs within and across multiple nodes. It curruntly has the best performance and integration with NVIDIA GPUs. 

Most deep learning frameworks, including PyTorch, train using 32-bit floating point (FP32) arithmetic by default. However, using FP32 for all operations is not essential to achieve full accuracy for many state-of-the-art deep neural networks (DNNs). In mixed precision training, majority of the network uses FP16 arithmetic, while automatically casting potentially unstable operations to FP32.

Key points:
- Ensuring that weight updates are carried out in FP32.
- Loss scaling to prevent underflowing gradients.
- A few operations (e.g. large reductions) left in FP32.
- Everything else (the majority of the network) executed in FP16.

## Why `Apex`?

- comes with all the distributed training features of the built-in PyTorch DDP
- better performance than built-in DDP
- reducing memory storage/bandwidth demands by 2x
- use larger batch sizes
- take advantage of NVIDIA Tensor Cores for matrix multiplications and convolutions
- don't need to explicitly convert your model, or the input data, to half().

## How to use `Apex`?

Let's say we are using 2 computer nodes, each with 4 GPUs. Then:
* world size = 8
* on each node, local rank of each GPU will be 0-3
* the global rank of each GPU will be 0-7
* we will use pytorch's `torch.distributed.launch` module to spawn 8 processes, one for each GPU. More on this later.

In [None]:
# this will be 0-3 if you have 4 GPUs on curr node
args.gpu = args.local_rank
torch.cuda.set_device(args.gpu)
torch.distributed.init_process_group(backend='nccl',
                                     init_method='env://')
# this is the total # of GPUs across all nodes
# if using 2 nodes with 4 GPUs each, world size is 8
args.world_size = torch.distributed.get_world_size()

You load your model, and initialize your optimizer like how you usually do in pytorch. However, you need to perform learning rate scaling to combat the difficulties of learning in a distributed context. If your global batch size is 4 times your normal learning rate, then you need to mulitple your learning rate by 4 to "speed up" training at the early stage.

In [None]:
model = models.resnet50(pretrained=True)
model.cuda()
# Scale init learning rate based on global batch size
args.lr = args.lr * float(args.batch_size*args.world_size)/256.
optimizer = torch.optim.SGD(model.parameters(), args.lr,
                            momentum=args.momentum,
                            weight_decay=args.weight_decay)

`apex.amp` is a tool to enable mixed precision training when using `apex`. 

Optimization level ranges from: "O0" to "O3". Typically we use either "O1" or "O2". More details: https://nvidia.github.io/apex/amp.html#opt-levels
* O0: pure FP32 training
* O1: GEMMs and convolutions are in FP16, model weights, softmax are in FP32
* O2: “Almost FP16” Mixed Precision. O2 casts the model weights to FP16, patches the model’s forward method to cast input data to FP16, keeps batchnorms in FP32, maintains FP32 master weights, updates the optimizer’s param_groups so that the optimizer.step() acts directly on the FP32 weights
* O3: pure FP16 training, but not stable. Only use as a speedy baseline.

In [None]:
# Initialize Amp.  Amp accepts either values or strings for the optional override arguments,
# for convenient interoperation with argparse.
model, optimizer = amp.initialize(model, optimizer,
                                  opt_level=args.opt_level,
                                  keep_batchnorm_fp32=args.keep_batchnorm_fp32,)

Wrap your model with `apex.parallel.DistributedDataParallel`. It is similar to `torch.nn.parallel.DistributedDataParallel`. It enables convenient multiprocess distributed training, optimized for NVIDIA's NCCL communication library.

Define your loss function like before.

In [None]:
model = DDP(model, delay_allreduce=True)
criterion = nn.CrossEntropyLoss().cuda()

Define your training and test dataset like before. However, you need to call `torch.utils.data.distributed.DistributedSampler` to makes sure that each process gets a different slice of the training data during distributed training.

In [None]:
train_dataset = datasets.ImageFolder(
    traindir,
    transforms.Compose([
        transforms.RandomResizedCrop(crop_size),
        transforms.RandomHorizontalFlip(),
        # transforms.ToTensor(), Too slow
        # normalize,
    ]))
val_dataset = datasets.ImageFolder(valdir, transforms.Compose([
        transforms.Resize(val_size),
        transforms.CenterCrop(crop_size),
    ]))
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset)

For dataloader, `args.batch_size` is the per GPU batch size. Notice that we turn off shuffling and use distributed data sampler. `args.workers` is the number of subprocesses per GPU you want for dataloading. **`args.workers` * number of GPU per node** should be <= **the number of CPU threads capable of your CPU**.

In [None]:
train_loader = torch.utils.data.DataLoader(
    train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None),
    num_workers=args.workers, sampler=train_sampler)

val_loader = torch.utils.data.DataLoader(
    val_dataset,
    batch_size=args.batch_size, shuffle=False,
    num_workers=args.workers,
    sampler=val_sampler)

We want to perform "learning rate warmup" to stablize training at early stages. The following is a regular step learning rate schedule with "warmup"

In [None]:
def adjust_learning_rate(optimizer, epoch, step, len_epoch):
    """LR schedule that should yield 76% converged accuracy with batch size 256"""
    factor = epoch // 30

    if epoch >= 80:
        factor = factor + 1

    lr = args.lr*(0.1**factor)

    """Warmup"""
    if epoch < 5:
        lr = lr*float(1 + step + epoch*len_epoch)/(5.*len_epoch)

    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

    return lr

For pytorch, usually in your train function, you do:

In [None]:
output = model(input)
loss = criterion(output, target)
loss.backward()
optimizer.step()

Now you should do:

In [None]:
output = model(input)
loss = criterion(output, target)
# Mixed-precision training requires that the loss is scaled in order
# to prevent the gradients from underflow
with amp.scale_loss(loss, optimizer) as scaled_loss:
    scaled_loss.backward()
optimizer.step()

## Bells and Whistles

### How to prevent race condition when mutiple devices try to do `Tensorboard` logging or print to output file?
If you are training distributedly on 2 nodes, each with 4 GPUs, then we will spawn 8 processes to run the same .py file, one process for each GPU. If two or more processes try to log or write on disk at the same time, race condition could happen. To prevent race condition, we usually only allow the process with global rank 0 to do all the logging and printing. Simple check `torch.distributed.get_rank() == 0` before you do the logging or write to `Tensorboard`. 

In [None]:
# only allow GPU0 to print training states to prevent double logging
if torch.distributed.get_rank() == 0:
    writer.add_scalar('Loss/train', train_losses, epoch + 1)
    writer.add_scalar('Loss/val', val_losses, epoch + 1)
    writer.add_scalar('Top1/train', train_top1, epoch + 1)
    writer.add_scalar('Top1/val', val_top1, epoch + 1)

### How to save model checkpoints?
Following the same idea above, we only allow the process with global rank 0 to save model checkpoints.

In [None]:
if torch.distributed.get_rank() == 0:
    is_best = val_top1 > best_prec1
    best_prec1 = max(val_top1, best_prec1)
    save_checkpoint({
        'epoch': epoch + 1,
        'arch': args.arch,
        'state_dict': model.state_dict(),
        'best_prec1': best_prec1,
        'optimizer': optimizer.state_dict(),
    }, is_best, writer.log_dir)

## How to run the program?
To run your programe on 2 nodes with 4 GPU each, you will need to open 2 terminals and run slightly different command on each node.

Node 0:

In [None]:
python -m torch.distributed.launch --nproc_per_node=4 --nnodes=2 --node_rank=0 --master_addr="192.168.100.11" --master_port=8888 imagenet_ddp_apex.py -a resnet50 --b 208 --workers 20 --opt-level O2 /home/shared/imagenet/raw/

- torch distributed launch module: https://github.com/pytorch/pytorch/blob/master/torch/distributed/launch.py
- --nproc_per_node: number of GPUs on the current node, each process is bound to a single GPU
- ----node_rank: rank of the current node, should be an int between 0 and --world-size - 1
- --master_addr: IP address for the master node of your choice. type str
- --master_port: open port number on the master node. type int. if you don't know, use 8888
- --workers: # of data loading workers for the current node. this is different from the processes that run the programe on each GPU. the total # of processes = # of data loading workers + # of GPUs (one process to run each GPU)
- -b: per GPU batch size, for a 16 GB GPU, 224 is the max batch size. Need to be a multiple of 8 to make use of Tensor Cores. If you are using tensorboard logging, you need to assign a slightly smaller batch size!

Node 1:

In [None]:
python -m torch.distributed.launch --nproc_per_node=4 --nnodes=2 --node_rank=1 --master_addr="192.168.100.11" --master_port=8888 imagenet_ddp_apex.py -a resnet50 --b 208 --workers 20 --opt-level O2 /home/shared/imagenet/raw/