In [2]:
%%writefile run.py

import torch
import torch.nn as nn
import torchvision.datasets
import torchvision.transforms as T
import torchvision.models as models
import torch.profiler

import torch.multiprocessing as mp
import torch.distributed.algorithms.ddp_comm_hooks.powerSGD_hook as powerSGD

import os


os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '29500'

wait = 1
warmup = 1
active = 3
repeat = 2

size = 2

transform = T.Compose([
  T.Resize(224),
  T.ToTensor(),
  T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
train_set = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train_set, batch_size=32, shuffle=True)

device = torch.device("cuda:0")
model = torchvision.models.resnet18().cuda(device)
criterion = torch.nn.CrossEntropyLoss().cuda(device)


def process_fn(rank, size):
  torch.distributed.init_process_group('nccl', rank=rank, world_size=size)
  ddp_model = nn.parallel.DistributedDataParallel(module=model)
  state = powerSGD.PowerSGDState(process_group=None)
  ddp_model.register_comm_hook(state, powerSGD.powerSGD_hook)

  optimizer = torch.optim.SGD(ddp_model.parameters(), lr=0.001, momentum=0.9)
  ddp_model.train()


  with torch.profiler.profile(
        schedule=torch.profiler.schedule(wait=wait, warmup=warmup, active=active, repeat=repeat),
        on_trace_ready=torch.profiler.tensorboard_trace_handler(f'./log/resnet18-{rank}'),
        record_shapes=True,
        profile_memory=True,
        with_stack=True
  ) as prof:
    for step, batch_data in enumerate(train_loader):
        if step >= (wait + warmup + active) * repeat:
            break

        inputs, labels = data[0].to(device=device), data[1].to(device=device)
        outputs = ddp_model(inputs)
        loss = criterion(outputs, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        prof.step() 


def run():
  processes = []
  mp.set_start_method("spawn")
  for rank in range(size):
      p = mp.Process(target=process_fn, args=(rank, size))
      p.start()
      processes.append(p)

  for p in processes:
      p.join()

if __name__ == '__main__':
  run()

Writing run.py


In [None]:
!python3 run.py

## Analyze

In [None]:
!pip install torch_tb_profiler

In [None]:
%load_ext tensorboard
%tensorboard --logdir=log