# multi-gpu pytorch

In [40]:
!nvidia-smi

Sun Jun  7 18:09:23 2020       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 440.82       Driver Version: 440.82       CUDA Version: 10.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|   0  Tesla T4            Off  | 00000000:00:1B.0 Off |                    0 |
| N/A   33C    P8     9W /  70W |     11MiB / 15109MiB |      0%      Default |
+-------------------------------+----------------------+----------------------+
|   1  Tesla T4            Off  | 00000000:00:1C.0 Off |                    0 |
| N/A   34C    P8     9W /  70W |     11MiB / 15109MiB |      0%      Default |
+-------------------------------+----------------------+----------------------+
|   2  Tesla T4            Off  | 00000000:00:1D.0 Off |                    0 |
| N/A   

Interesting to note that workspaces running on a multi-GPU instance only record metrics for the *first* GPU device in the details graph. We could do a better job with that.

For reference here's the basic complete distributed example (copied over from the `scratch` workspace):

```python
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn
import torch.optim as optim
from torch.nn.parallel import DistributedDataParallel as DDP
import os

def init_process(rank, size, backend='gloo'):
    """ Initialize the distributed environment. """
    os.environ['MASTER_ADDR'] = '127.0.0.1'
    os.environ['MASTER_PORT'] = '29500'
    dist.init_process_group(backend, rank=rank, world_size=size)

def example(rank, world_size):
    init_process(rank, world_size)

    # TEMP: override rank with 0 because I'm running on a simple T40x1.
    # To get the full effect you need to run this code on an INSTANCE_TYPEx2 machine.
    rank = 0
    
    # create local model
    model = nn.Linear(10, 10).to(rank)
    # construct DDP model
    ddp_model = DDP(model, device_ids=[rank])
    # define loss function and optimizer
    loss_fn = nn.MSELoss()
    optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)

    # forward pass
    outputs = ddp_model(torch.randn(20, 10).to(rank))
    labels = torch.randn(20, 10).to(rank)
    # backward pass
    loss_fn(outputs, labels).backward()
    # update parameters
    optimizer.step()
    
    print(f"Finished process {rank}/{world_size}.")

def main():
    world_size = 2
    mp.spawn(example,
        args=(world_size,),
        nprocs=world_size,
        join=True)

if __name__=="__main__":
    main()
```

In [6]:
import torch.nn as nn
tmp = nn.Linear(10, 10)

In [15]:
class TmpModule(nn.Module):
    def __init__(self):
        super().__init__()
        self.ff1 = tmp

    def forward(X):
        return self.ff1(X)

In [49]:
%%writefile ../models/2_pytorch_distributed_model.py
import torch
import torchvision
import numpy as np
import os
from torch.utils.data import Dataset, DataLoader
import PIL
import torch.nn as nn
from torch.optim import Adam
from torch.utils.tensorboard import SummaryWriter

# NEW
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel
from torch.utils.data.distributed import DistributedSampler

# NEW
def init_process(rank, size, backend='gloo'):
    """ Initialize the distributed environment. """
    os.environ['MASTER_ADDR'] = '127.0.0.1'
    os.environ['MASTER_PORT'] = '29500'
    dist.init_process_group(backend, rank=rank, world_size=size)

# VOCSegmentation returns a raw dataset: images are non-resized and in the PIL format. To transform them
# something suitable for input to PyTorch, we need to wrap the output in our own dataset class.
class PascalVOCSegmentationDataset(Dataset):
    def __init__(self, raw):
        super().__init__()
        self._dataset = raw
        self.resize_img = torchvision.transforms.Resize((256, 256), interpolation=PIL.Image.BILINEAR)
        self.resize_segmap = torchvision.transforms.Resize((256, 256), interpolation=PIL.Image.NEAREST)
    
    def __len__(self):
        return len(self._dataset)
    
    def __getitem__(self, idx):
        img, segmap = self._dataset[idx]
        img, segmap = self.resize_img(img), self.resize_segmap(segmap)
        img, segmap = np.array(img), np.array(segmap)
        img, segmap = (img / 255).astype('float32'), segmap.astype('int32')
        img = np.transpose(img, (-1, 0, 1))

        # The PASCAL VOC dataset PyTorch provides labels the edges surrounding classes in 255-valued
        # pixels in the segmentation map. However, PyTorch requires class values to be contiguous
        # in range 0 through n_classes, so we must relabel these pixels to 21.
        segmap[segmap == 255] = 21
        
        return img, segmap

def get_dataloader(rank, world_size):
    _PascalVOCSegmentationDataset = torchvision.datasets.VOCSegmentation(
        '/mnt/pascal_voc_segmentation/', year='2012', image_set='train', download=True,
        transform=None, target_transform=None, transforms=None
    )
    dataset = PascalVOCSegmentationDataset(_PascalVOCSegmentationDataset)
    
    # NEW
    sampler = DistributedSampler(dataset, rank=rank, num_replicas=world_size)
    dataloader = DataLoader(dataset, batch_size=8, shuffle=False, sampler=sampler)
    
    return dataloader

# num_classes is 22. PASCAL VOC includes 20 classes of interest, 1 background class, and the 1
# special border class mentioned in the previous comment. 20 + 1 + 1 = 22.
def get_model():
    return torchvision.models.segmentation.deeplabv3_resnet101(
        pretrained=False, progress=True, num_classes=22, aux_loss=None
    )

def train(rank, num_epochs, world_size):
    # NEW
    init_process(rank, world_size)
    print(f"Rank {rank}/{world_size} training process initialized.\n")

    # NEW
    # Since this is a single-instance multi-GPU training script, it's important that only one
    # process handle downloading of the data, to:
    #
    # * Avoid race conditions implicit in having multiple processes attempt to write to the same
    #   file simultaneously.
    # * Avoid downloading the data in multiple processes simultaneously.
    #
    # Since the data is cached on disk, we can construct and discard the dataloader and model in
    # the master process only to get the data. The other processes are held back by the barrier.
    if rank == 0:
        get_dataloader(rank, world_size)
        get_model()
    dist.barrier()
    print(f"Rank {rank}/{world_size} training process passed data download barrier.\n")

    model = get_model()
        
    # NEW
    model = DistributedDataParallel(model, device_ids=[rank])

    model.cuda(rank)
    model.train()
    
    dataloader = get_dataloader(rank, world_size)
    
    # since the background class doesn't matter nearly as much as the classes of interest to the
    # overall task a more selective loss would be more appropriate, however this training script
    # is merely a benchmark so we'll just use simple cross-entropy loss
    criterion = nn.CrossEntropyLoss()

    # NEW
    # Since we are computing the average of several batches at once (an effective batch size of
    # world_size * batch_size) we scale the learning rate to match.
    optimizer = Adam(model.parameters(), lr=1e-3 * world_size)
    
    writer = SummaryWriter(f'/spell/tensorboards/model_2')
        
    for epoch in range(1, NUM_EPOCHS + 1):
        losses = []

        for i, (batch, segmap) in enumerate(dataloader):
            optimizer.zero_grad()

            batch = batch.cuda(rank)
            segmap = segmap.cuda(rank)

            output = model(batch)['out']
            loss = criterion(output, segmap.type(torch.int64))
            loss.backward()
            optimizer.step()

            curr_loss = loss.item()
            if i % 10 == 0:
                print(
                    f'Finished epoch {epoch}, rank {rank}/{world_size}, batch {i}. '
                    f'Loss: {curr_loss:.3f}.\n'
                )
            if rank == 0:
                writer.add_scalar('training loss', curr_loss)
            losses.append(curr_loss)

        print(
            f'Finished epoch {epoch}, rank {rank}/{world_size}. '
            f'Avg Loss: {np.mean(losses)}; Median Loss: {np.min(losses)}.\n'
        )
        
        if rank == 0:
            if not os.path.exists('/spell/checkpoints/'):
                os.mkdir('/spell/checkpoints/')
            torch.save(model.state_dict(), f'/spell/checkpoints/model_{epoch}.pth')

# NEW
NUM_EPOCHS = 20
WORLD_SIZE = torch.cuda.device_count()
def main():
    mp.spawn(train,
        args=(NUM_EPOCHS, WORLD_SIZE),
        nprocs=WORLD_SIZE,
        join=True)

if __name__=="__main__":
    main()


Overwriting ../models/2_pytorch_distributed_model.py


In [29]:
# import torch.multiprocessing as mp
# mp.spawn?

In [48]:
import torch
torch.cuda.device_count()

4

In [44]:
!python ../models/2_pytorch_distributed_model.py

Rank 3/4 training process initialized.

Rank 0/4 training process initialized.

Rank 1/4 training process initialized.

Rank 2/4 training process initialized.

Using downloaded and verified file: /mnt/pascal_voc_segmentation/VOCtrainval_11-May-2012.tar
Using downloaded and verified file: /mnt/pascal_voc_segmentation/VOCtrainval_11-May-2012.tarUsing downloaded and verified file: /mnt/pascal_voc_segmentation/VOCtrainval_11-May-2012.tar
Using downloaded and verified file: /mnt/pascal_voc_segmentation/VOCtrainval_11-May-2012.tar

Finished epoch 1, rank 0/4, batch 0. Loss: 3.202.

Finished epoch 1, rank 1/4, batch 0. Loss: 3.182.

Finished epoch 1, rank 3/4, batch 0. Loss: 3.134.

Finished epoch 1, rank 2/4, batch 0. Loss: 3.235.

Finished epoch 1, rank 0/4, batch 10. Loss: 1.991.

Finished epoch 1, rank 1/4, batch 10. Loss: 2.106.

Finished epoch 1, rank 3/4, batch 10. Loss: 1.620.

Finished epoch 1, rank 2/4, batch 10. Loss: 1.957.

Finished epoch 1, rank 0/4, batch 20. Loss: 1.036.

Fini

In [47]:
%ls ../checkpoints/

model_1.pth  model_2.pth  model_3.pth  model_4.pth  model_5.pth
