## MultiCore Training AlexNet on Fashion MNIST 

一块Cloud TPU上面包含了8个核，只用其中一个极大限制了Cloud TPU的能力。我们看看如何发挥8核的威力吧。


* 参考 https://colab.research.google.com/github/pytorch/xla/blob/master/contrib/colab/multi-core-alexnet-fashion-mnist.ipynb




### 数据集 & 模型

数据集：Fashion MNIST

模型：AlexNet 

### 使用多个Cloud TPU核

使用多核训练模型和单核还是有区别的，比如必须使用多进程，每个Cloud TPU核对应一个进程。

In [1]:
import torch
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multiprocessing as xmp

In [2]:
# "Map function": acquires a corresponding Cloud TPU core, creates a tensor on it,
# and prints its core
def simple_map_fn(index, flags):
    """
    index: index of process
    """
    # Sets a common random seed - both for initialization and ensuring graph is the same
    torch.manual_seed(1234)

    # Acquires the (unique) Cloud TPU core corresponding to this process's index
    device = xm.xla_device()  # no explicitly specify TPU core, thanks to xmp.spawn()

    # Creates a tensor on this process's device
    t = torch.randn((2, 2), device=device)

    print("Process", index ,"is using", xm.xla_real_devices([str(device)])[0])

    # Barrier to prevent master from exiting before workers connect.
    xm.rendezvous('init')  # 防止主进程先exist

# Spawns eight of the map functions, one for each of the eight cores on
# the Cloud TPU
flags = {}
# Note: Colab only supports start_method='fork'
xmp.spawn(simple_map_fn, args=(flags,), nprocs=8, start_method='fork')

Process0  is using TPU:0
Process 6 is using TPU:6
Process 4 is using TPU:4
Process 3 is using TPU:3
Process 1 is using TPU:1
Process 7 is using TPU:7
Process 5 is using TPU:5
Process 2 is using TPU:2


[ `spawn()` 文档] [here](http://pytorch.org/xla/#torch_xla.distributed.xla_multiprocessing.spawn)， `spawn()` 接收一个（map）函数、函数参数列表（tuple类型）、要创建的进程数量（`nprocs`）以及创建进程的方式（`fork`或`spawn`）。

`xmp.spawn()` 创建了8个进程，每个进程对应一个Cloud TPU核，每个进程上都调用 `simple_map_fn()` 。




### An Aside on Context

上面每个进程是如何知道自己拿到的是哪个Cloud TPU核的？答案是context。

Cloud TPU通过一个隐式的stateful context来管理算子/计算操作， `xmp.spawn()` 函数创建了一个多进程context，每个子进程都可以访问这个context。

要注意：如果你使用了多进程的context，就不能再创建单进程的context了，二者不能混用，会冲突！

In [3]:
# Don't mix these!
# Only one type of context per Colab!
# Warning: uncommenting the below and running this cell will cause a runtime error!

# device = xm.xla_device()  # Requires a single process context

# xmp.spawn(simple_map_fn, args=(flags,), nprocs=8, start_method='fork')  # Requires a multiprocess context

第二点要注意的：每个进程的计算任务要相同。不能在`simple_map_fn`中为不同的进程设置不同的计算。



In [4]:
# Don't perform different computations on different processes!
# Warning: uncommenting the below and running this cell will likely hang your Colab!
# def simple_map_fn(index, flags):
#   torch.manual_seed(1234)
#   device = xm.xla_device()  

#   if xm.is_master_ordinal():
#     t = torch.randn((2, 2), device=device)  # Divergent Cloud TPU computation!


# xmp.spawn(simple_map_fn, args=(flags,), nprocs=8, start_method='fork')

只有每个Cloud TPU 核的计算任务完全一致，context才能正确管理它们。但是我们可以在每个进程中执行不同的CPU计算任务。


In [5]:
# Common Cloud TPU computation but different CPU computation is OK
def simple_map_fn(index, flags):
  torch.manual_seed(1234)
  device = xm.xla_device()  

  t = torch.randn((2, 2), device=device)  # Common Cloud TPU computation
  out = str(t)  # Each process uses the XLA tensors the same way

  if xm.is_master_ordinal():  # Divergent CPU-only computation (no XLA tensors beyond this point!)
    print(out)

  # Barrier to prevent master from exiting before workers connect.
  xm.rendezvous('init')


xmp.spawn(simple_map_fn, args=(flags,), nprocs=8, start_method='fork')

tensor([[-0.6989, -0.0987],
        [ 0.7337, -0.9071]], device='xla:1')


### 多核训练

定义一个可以在8个Cloud TPU核上训练AlexNet的函数:

- **Setup**: 每个进程的随机数种子都相同
- **Dataloading**: 每个进程都有一份数据集备份，但是数据集sampling的结果不重复
- **Network creation**: 每个进程都有一份模型备份，由于每个进程的速技术相同，所以模型权重的值也完全相同
- **Training** and **Evaluation**: Training and evaluation occur as usual but use a ParallelLoader.

实际上就是数据并行，类比`DistributedDataParallel`。

In [6]:
import torchvision
from torchvision import datasets
import torchvision.transforms as transforms
import torch_xla.distributed.parallel_loader as pl
import time

def map_fn(index, flags):
    ## Setup 

    # Sets a common random seed - both for initialization and ensuring graph is the same
    torch.manual_seed(flags['seed'])

    # Acquires the (unique) Cloud TPU core corresponding to this process's index
    device = xm.xla_device()  


    ## Dataloader construction

    # Creates the transform for the raw Torchvision data
    # See https://pytorch.org/docs/stable/torchvision/models.html for normalization
    # Pre-trained TorchVision models expect RGB (3 x H x W) images
    # H and W should be >= 224
    # Loaded into [0, 1] and normalized as follows:
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225]
                                    )
    to_rgb = transforms.Lambda(lambda image: image.convert('RGB'))
    resize = transforms.Resize((224, 224))
    my_transform = transforms.Compose([resize, to_rgb, transforms.ToTensor(), normalize])

    # Downloads train and test datasets
    # Note: master goes first and downloads the dataset only once (xm.rendezvous)
    #   all the other workers wait for the master to be done downloading.
    if not xm.is_master_ordinal():
        xm.rendezvous('download_only_once')
    
    # Only master process load FashionMNIST dataset
    train_dataset = datasets.FashionMNIST(
        "/tmp/fashionmnist",
        train=True,
        download=True,
        transform=my_transform
    )

    test_dataset = datasets.FashionMNIST(
        "/tmp/fashionmnist",
        train=False,
        download=True,
        transform=my_transform
    )

    if xm.is_master_ordinal():
        xm.rendezvous('download_only_once')

    # Creates the (distributed) train sampler, which let this process only access
    # its portion of the training dataset.
    train_sampler = torch.utils.data.distributed.DistributedSampler(
        train_dataset,
        num_replicas=xm.xrt_world_size(),
        rank=xm.get_ordinal(),
        shuffle=True
    )

    test_sampler = torch.utils.data.distributed.DistributedSampler(
        test_dataset,
        num_replicas=xm.xrt_world_size(),
        rank=xm.get_ordinal(),
        shuffle=False
    )

    # Creates dataloaders, which load data in batches
    # Note: test loader is not shuffled or sampled
    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=flags['batch_size'],
        sampler=train_sampler,
        num_workers=flags['num_workers'],
        drop_last=True
    )

    test_loader = torch.utils.data.DataLoader(
        test_dataset,
        batch_size=flags['batch_size'],
        sampler=test_sampler,
        shuffle=False,
        num_workers=flags['num_workers'],
        drop_last=True
    )

    ## Network, optimizer, and loss function creation

    # Creates AlexNet for 10 classes
    # Note: each process has its own identical copy of the model
    #  Even though each model is created independently, they're also
    #  created in the same way.
    net = torchvision.models.alexnet(num_classes=10).to(device).train()

    loss_fn = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(net.parameters())

    ## Trains
    train_start = time.time()
    for epoch in range(flags['num_epochs']):
        para_train_loader = pl.ParallelLoader(train_loader, [device]).per_device_loader(device)  # Note 
        for batch_num, batch in enumerate(para_train_loader):
            data, targets = batch   # no to.(device)?
#             print(data.size())
            # batch_size

            # Acquires the network's best guesses at each class
            output = net(data)

            # Computes loss
            loss = loss_fn(output, targets)

            # Updates model
            optimizer.zero_grad()
            loss.backward()

            # Note: optimizer_step uses the implicit Cloud TPU context to
            #  coordinate and synchronize gradient updates across processes.
            #  This means that each process's network has the same weights after
            #  this is called.
            # Warning: this coordination requires the actions performed in each 
            #  process are the same. In more technical terms, the graph that
            #  PyTorch/XLA generates must be the same across processes. 
            xm.optimizer_step(optimizer)  # Note: barrier=True not needed when using ParallelLoader 

    elapsed_train_time = time.time() - train_start
    print("Process", index, "finished training. Train time was:", elapsed_train_time) 


    ## Evaluation
    # Sets net to eval and no grad context 
    net.eval()
    eval_start = time.time()
    with torch.no_grad():
        num_correct = 0
        total_guesses = 0

        para_train_loader = pl.ParallelLoader(test_loader, [device]).per_device_loader(device)
        for batch_num, batch in enumerate(para_train_loader):
            data, targets = batch

            # Acquires the network's best guesses at each class
            output = net(data)
            best_guesses = torch.argmax(output, 1)

            # Updates running statistics
            num_correct += torch.eq(targets, best_guesses).sum().item()
            total_guesses += flags['batch_size']

    elapsed_eval_time = time.time() - eval_start
    print("Process", index, "finished evaluation. Evaluation time was:", elapsed_eval_time)
    print("Process", index, "guessed", num_correct, "of", total_guesses, "correctly for", num_correct/total_guesses * 100, "% accuracy.")

In [None]:
# Configures training (and evaluation) parameters
flags['batch_size'] = 32  # batch_size per device?
flags['num_workers'] = 8
flags['num_epochs'] = 1
flags['seed'] = 1234

xmp.spawn(map_fn, args=(flags,), nprocs=8, start_method='fork')

ProcessProcessProcessProcessProcessProcessProcessProcess       6 132745 0      finished training. Train time was: finished training. Train time was:finished training. Train time was:finished training. Train time was:finished training. Train time was:finished training. Train time was:finished training. Train time was: finished training. Train time was:       43.0919952392578142.82916164398193443.0883066654205343.0956840515136743.0569190979003943.1000208854675344.9737517833709743.09987211227417







Process 0 finished evaluation. Evaluation time was: 7.351853132247925
Process 0 guessed 1080 of 1248 correctly for 86.53846153846155 % accuracy.


2022-07-08 06:17:09.358847: W tensorflow/core/distributed_runtime/rpc/grpc_remote_master.cc:157] RPC failed with status = "UNAVAILABLE: Connection reset by peer" and grpc_error_string = "{"created":"@1657261029.358613665","description":"Error received from peer ipv4:127.0.0.1:51011","file":"external/com_github_grpc_grpc/src/core/lib/surface/call.cc","file_line":1056,"grpc_message":"Connection reset by peer","grpc_status":14}", maybe retrying the RPC
2022-07-08 06:17:09.358893: W tensorflow/core/distributed_runtime/rpc/grpc_remote_master.cc:157] RPC failed with status = "UNAVAILABLE: Connection reset by peer" and grpc_error_string = "{"created":"@1657261029.358635196","description":"Error received from peer ipv4:127.0.0.1:51011","file":"external/com_github_grpc_grpc/src/core/lib/surface/call.cc","file_line":1056,"grpc_message":"Connection reset by peer","grpc_status":14}", maybe retrying the RPC
2022-07-08 06:17:09.358921: W tensorflow/core/distributed_runtime/rpc/grpc_remote_master.cc:

使用多核比单核训练快多了，毕竟batch size是原来的8倍