<a href="https://colab.research.google.com/github/sourcecode369/Kaggle-Notebooks/blob/master/Tutorials/tpu/pytorch/PyTorch_on_Cloud_TPUs_MultiCore_Training_AlexNet_on_Fashion_MNIST.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import os
assert os.environ['COLAB_TPU_ADDR'], 'Make sure to select TPU from Edit > Notebook settings > Hardware accelerator'

VERSION = "20200325"  #@param ["1.5" , "20200325", "nightly"]
!curl https://raw.githubusercontent.com/pytorch/xla/master/contrib/scripts/env-setup.py -o pytorch-xla-env-setup.py
!python pytorch-xla-env-setup.py --version $VERSION

  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100  4139  100  4139    0     0  54460      0 --:--:-- --:--:-- --:--:-- 54460
Updating TPU and VM. This may take around 2 minutes.
Updating TPU runtime to pytorch-dev20200325 ...
Collecting cloud-tpu-client
  Downloading https://files.pythonhosted.org/packages/56/9f/7b1958c2886db06feb5de5b2c191096f9e619914b6c31fdf93999fdbbd8b/cloud_tpu_client-0.10-py3-none-any.whl
Collecting google-api-python-client==1.8.0
[?25l  Downloading https://files.pythonhosted.org/packages/9a/b4/a955f393b838bc47cbb6ae4643b9d0f90333d3b4db4dc1e819f36aad18cc/google_api_python_client-1.8.0-py3-none-any.whl (57kB)
[K     |████████████████████████████████| 61kB 2.6MB/s 
Uninstalling torch-1.5.1+cu101:
Installing collected packages: google-api-python-client, cloud-tpu-client
  Found existing installation: google-api-python-client 1.7.12
    Uninstalling google

In [2]:
import torch
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.core.xla_env_vars as xla_env_vars
import torch_xla.distributed.xla_multiprocessing as xmp
from torch_xla import os as xla_os

In [3]:
print(torch_xla.__version__)

1.6+e788e5b


In [4]:
def simple_map_fn(index, flags):
  torch.manual_seed(1234)
  device = xm.xla_device()  
  t = torch.randn((2, 2), device=device)
  print("Process", index ,"is using", xm.xla_real_devices([str(device)])[0])
  xm.rendezvous('init')

flags = {}
xmp.spawn(simple_map_fn, args=(flags,), nprocs=8, start_method='fork')

Process 0 is using TPU:0
Process 6 is using TPU:6
Process 2 is using TPU:2
Process 5 is using TPU:5
Process 7 is using TPU:7
Process 4 is using TPU:4
Process 1 is using TPU:1
Process 3 is using TPU:3


In [5]:
def simple_map_fn(index, flags):
  torch.manual_seed(1234)
  device = xm.xla_device()  

  t = torch.randn((2, 2), device=device)  # Common Cloud TPU computation
  out = str(t)  # Each process uses the XLA tensors the same way

  if xm.is_master_ordinal():  # Divergent CPU-only computation (no XLA tensors beyond this point!)
    print(out)

  # Barrier to prevent master from exiting before workers connect.
  xm.rendezvous('init')


xmp.spawn(simple_map_fn, args=(flags,), nprocs=8, start_method='fork')

tensor([[-0.3916,  0.4222],
        [ 1.0496, -0.4849]], device='xla:1')


In [6]:
import torchvision 
from torchvision import datasets
import torchvision.transforms as transforms
import torch_xla.distributed.data_parallel as data_parallel
import torch_xla.distributed.parallel_loader as parallel_loader
import time

In [12]:
def map_fn(index, flags):
    torch.manual_seed(flags['seed'])
    device = xm.xla_device()
    
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                  std=[0.229, 0.224, 0.225])
    to_rgb = transforms.Lambda(lambda image: image.convert('RGB'))
    resize = transforms.Resize((224, 224))
    my_transform = transforms.Compose([resize, to_rgb, transforms.ToTensor(), normalize])

    if not xm.is_master_ordinal():
        xm.rendezvous('download_only_once')
    
    train_dataset = datasets.FashionMNIST('/tmp/fashionmnist',
                                          train=True,
                                          download=True,
                                          transform = my_transform)
    test_dataset = datasets.FashionMNIST('/tmp/fashionmnist',
                                         train=False,
                                         download=True,
                                         transform=my_transform)
    
    if xm.is_master_ordinal():
        xm.rendezvous('download_only_once')
    
    train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset,
                                                                    num_replicas=xm.xrt_world_size(),
                                                                    rank=xm.get_ordinal(),
                                                                    shuffle=True
                                                                    )
    
    test_sampler = torch.utils.data.distributed.DistributedSampler(test_dataset,
                                                                    num_replicas=xm.xrt_world_size(),
                                                                    rank=xm.get_ordinal(),
                                                                    shuffle=False
                                                                    )

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=flags['batch_size'],
                                               sampler=train_sampler,
                                               num_workers=flags['num_workers'],
                                               drop_last=True
                                               )
    
    test_loader = torch.utils.data.DataLoader(test_dataset,
                                               batch_size=flags['batch_size'],
                                               sampler=test_sampler,
                                               num_workers=flags['num_workers'],
                                               drop_last=True
                                               )
    
    net = torchvision.models.alexnet(num_classes=10).to(device).train()
    loss_fn = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(net.parameters(), lr=1e-3)

    train_start = time.time()
    for epoch in range(flags['num_epochs']):
        para_train_loader = parallel_loader.ParallelLoader(train_loader, devices=[device]).per_device_loader(device)
        for batch_num, batch in enumerate(para_train_loader):
            data, targets = batch

            output = net(data)
            loss = loss_fn(output, targets)
            optimizer.zero_grad()
            loss.backward()
            xm.optimizer_step(optimizer=optimizer)
    elapsed_train_time = time.time() - train_start
    print("Process", index, "finished training. Train time was:", elapsed_train_time) 

    net.eval()
    eval_start = time.time()
    with torch.no_grad():
        num_correct = 0
        total_guesses = 0
        para_train_loader = pl.ParallelLoader(test_loader, [device]).per_device_loader(device)
        for batch_num, batch in enumerate(para_train_loader):
            data, targets = batch

            output = net(data)
            best_guesses = torch.argmax(output, 1)

            num_correct += torch.eq(targets, best_guesses).sum().item()
            total_guesses += flags['batch_size']

    elapsed_eval_time = time.time() - eval_start
    print("Process", index, "finished evaluation. Evaluation time was:", elapsed_eval_time)
    print("Process", index, "guessed", num_correct, "of", total_guesses, "correctly for", num_correct/total_guesses * 100, "% accuracy.")

In [13]:
flags['batch_size'] = 32
flags['num_workers'] = 8
flags['num_epochs'] = 1
flags['seed'] = 1234

xmp.spawn(map_fn, args=(flags,), nprocs=8, start_method='fork')

Process 1 finished training. Train time was: 328.7308475971222
Process 0 finished training. Train time was: 327.9923770427704
Process 3 finished training. Train time was: 328.0882017612457


Exception in device=TPU:0: name 'pl' is not defined


Process 5 finished training. Train time was: 327.4206635951996


Exception in device=TPU:3: name 'pl' is not defined
Exception in device=TPU:5: name 'pl' is not defined
Exception in device=TPU:1: name 'pl' is not defined
Traceback (most recent call last):


Process 4 finished training. Train time was: 327.65316939353943


Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):


Process 7 finished training. Train time was: 333.198184967041
Process 6 finished training. Train time was: 332.0990159511566


  File "/usr/local/lib/python3.6/dist-packages/torch_xla/distributed/xla_multiprocessing.py", line 119, in _start_fn
    fn(gindex, *args)


Process 2 finished training. Train time was: 333.13902592658997


  File "<ipython-input-12-4a3924915d56>", line 75, in map_fn
    para_train_loader = pl.ParallelLoader(test_loader, [device]).per_device_loader(device)
  File "/usr/local/lib/python3.6/dist-packages/torch_xla/distributed/xla_multiprocessing.py", line 119, in _start_fn
    fn(gindex, *args)
Exception in device=TPU:4: name 'pl' is not defined
  File "/usr/local/lib/python3.6/dist-packages/torch_xla/distributed/xla_multiprocessing.py", line 119, in _start_fn
    fn(gindex, *args)
Exception in device=TPU:6: name 'pl' is not defined
  File "/usr/local/lib/python3.6/dist-packages/torch_xla/distributed/xla_multiprocessing.py", line 119, in _start_fn
    fn(gindex, *args)
Exception in device=TPU:7: name 'pl' is not defined
NameError: name 'pl' is not defined
  File "<ipython-input-12-4a3924915d56>", line 75, in map_fn
    para_train_loader = pl.ParallelLoader(test_loader, [device]).per_device_loader(device)
  File "<ipython-input-12-4a3924915d56>", line 75, in map_fn
    para_train_loader = pl

Exception: ignored