## PyTorch/TPU ResNet18/CIFAR10 Demo

This colab example is TF/XRT 1.15 compatible.

<h3>  &nbsp;&nbsp;Use Colab Cloud TPU&nbsp;&nbsp; <a href="https://cloud.google.com/tpu/"><img valign="middle" src="https://raw.githubusercontent.com/GoogleCloudPlatform/tensorflow-without-a-phd/master/tensorflow-rl-pong/images/tpu-hexagon.png" width="50"></a></h3>

* On the main menu, click Runtime and select **Change runtime type**. Set "TPU" as the hardware accelerator.
* The cell below makes sure you have access to a TPU on Colab.


In [0]:
import os
assert os.environ['COLAB_TPU_ADDR'], 'Make sure to select TPU from Edit > Notebook settings > Hardware accelerator'
import collections
from datetime import datetime, timedelta
import requests
import threading

In [0]:
from google.colab import drive
drive.mount('/content/drive')

Go to this URL in a browser: https://accounts.google.com/o/oauth2/auth?client_id=947318989803-6bn6qk8qdgf4n4g3pfee6491hc0brc4i.apps.googleusercontent.com&redirect_uri=urn%3aietf%3awg%3aoauth%3a2.0%3aoob&response_type=code&scope=email%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdocs.test%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive.photos.readonly%20https%3a%2f%2fwww.googleapis.com%2fauth%2fpeopleapi.readonly

Enter your authorization code:
··········
Mounted at /content/drive


### [RUNME] Install Colab TPU compatible PyTorch/TPU wheels and dependencies
This may take up to ~2 minutes



In [0]:
import collections
from datetime import datetime, timedelta
import os
import requests
import threading

_VersionConfig = collections.namedtuple('_VersionConfig', 'wheels,server')
VERSION = "xrt==1.15.0"  #@param ["xrt==1.15.0", "torch_xla==nightly"]
CONFIG = {
    'xrt==1.15.0': _VersionConfig('1.15', '1.15.0'),
    'torch_xla==nightly': _VersionConfig('nightly', 'XRT-dev{}'.format(
        (datetime.today() - timedelta(1)).strftime('%Y%m%d'))),
}[VERSION]
DIST_BUCKET = 'gs://tpu-pytorch/wheels'
TORCH_WHEEL = 'torch-{}-cp36-cp36m-linux_x86_64.whl'.format(CONFIG.wheels)
TORCH_XLA_WHEEL = 'torch_xla-{}-cp36-cp36m-linux_x86_64.whl'.format(CONFIG.wheels)
TORCHVISION_WHEEL = 'torchvision-{}-cp36-cp36m-linux_x86_64.whl'.format(CONFIG.wheels)

# Update TPU XRT version
def update_server_xrt():
  print('Updating server-side XRT to {} ...'.format(CONFIG.server))
  url = 'http://{TPU_ADDRESS}:8475/requestversion/{XRT_VERSION}'.format(
      TPU_ADDRESS=os.environ['COLAB_TPU_ADDR'].split(':')[0],
      XRT_VERSION=CONFIG.server,
  )
  print('Done updating server-side XRT: {}'.format(requests.post(url)))

update = threading.Thread(target=update_server_xrt)
update.start()

# Install Colab TPU compat PyTorch/TPU wheels and dependencies
!pip uninstall -y torch torchvision
!gsutil cp "$DIST_BUCKET/$TORCH_WHEEL" .
!gsutil cp "$DIST_BUCKET/$TORCH_XLA_WHEEL" .
!gsutil cp "$DIST_BUCKET/$TORCHVISION_WHEEL" .
!pip install "$TORCH_WHEEL"
!pip install "$TORCH_XLA_WHEEL"
!pip install "$TORCHVISION_WHEEL"
!sudo apt-get install libomp5
update.join()

Updating server-side XRT to 1.15.0 ...
Uninstalling torch-1.3.1:
Done updating server-side XRT: <Response [200]>
  Successfully uninstalled torch-1.3.1
Uninstalling torchvision-0.4.2:
  Successfully uninstalled torchvision-0.4.2
Copying gs://tpu-pytorch/wheels/torch-1.15-cp36-cp36m-linux_x86_64.whl...
\ [1 files][ 77.8 MiB/ 77.8 MiB]                                                
Operation completed over 1 objects/77.8 MiB.                                     
Copying gs://tpu-pytorch/wheels/torch_xla-1.15-cp36-cp36m-linux_x86_64.whl...
- [1 files][109.8 MiB/109.8 MiB]                                                
Operation completed over 1 objects/109.8 MiB.                                    
Copying gs://tpu-pytorch/wheels/torchvision-1.15-cp36-cp36m-linux_x86_64.whl...
/ [1 files][  2.1 MiB/  2.1 MiB]                                                
Operation completed over 1 objects/2.1 MiB.                                      
Processing ./torch-1.15-cp36-cp36m-linux_x86_64.wh

### Define Parameters



In [0]:
# Result Visualization Helper
from matplotlib import pyplot as plt
from torch.optim import lr_scheduler

M, N = 4, 6
RESULT_IMG_PATH = '/tmp/test_result.jpg'
CIFAR10_LABELS = ['airplane', 'automobile', 'bird', 'cat', 'deer',
                 'dog', 'frog', 'horse', 'ship', 'truck']

def plot_results(images, labels, preds):
  images, labels, preds = images[:M*N], labels[:M*N], preds[:M*N]
  inv_norm = transforms.Normalize(
      mean=(-0.4914/0.2023, -0.4822/0.1994, -0.4465/0.2010),
      std=(1/0.2023, 1/0.1994, 1/0.2010))

  num_images = images.shape[0]
  fig, axes = plt.subplots(M, N, figsize=(16, 9))
  fig.suptitle('Correct / Predicted Labels (Red text for incorrect ones)')

  for i, ax in enumerate(fig.axes):
    ax.axis('off')
    if i >= num_images:
      continue
    img, label, prediction = images[i], labels[i], preds[i]
    img = inv_norm(img)
    img = img.permute(1, 2, 0) # (C, M, N) -> (M, N, C)
    label, prediction = label.item(), prediction.item()
    if label == prediction:
      ax.set_title(u'\u2713', color='blue', fontsize=22)
    else:
      ax.set_title(
          'X {}/{}'.format(CIFAR10_LABELS[label],
                          CIFAR10_LABELS[prediction]), color='red')
    ax.imshow(img)
  plt.savefig(RESULT_IMG_PATH, transparent=True)

In [0]:
# Define Parameters
FLAGS = {}
FLAGS['data_dir'] = "/tmp/cifar"
FLAGS['batch_size'] = 16
FLAGS['num_workers'] = 8
FLAGS['learning_rate'] = 0.1  
FLAGS['momentum'] = 0.9
FLAGS['num_epochs'] = 200
FLAGS['num_cores'] = 8
FLAGS['log_steps'] = 20
FLAGS['metrics_debug'] = True

In [0]:
import numpy as np
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.debug.metrics as met
import torch_xla.distributed.parallel_loader as pl
import torch_xla.distributed.xla_multiprocessing as xmp
import torch_xla.utils.utils as xu
import torchvision
from torch.optim import lr_scheduler
from torchvision import datasets, models, transforms

In [0]:
def train_resnet50():
  torch.manual_seed(1)

  # Get and shard dataset into dataloaders
  # norm = transforms.Normalize(
  #     mean=(0.4914, 0.4822, 0.4465), std=(0.2023, 0.1994, 0.2010))
  transform_train = transforms.Compose([
        transforms.RandomResizedCrop(224),
        #transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
  ])
  
  # transforms.Compose([
  #     transforms.RandomCrop(32, padding=4),
  #     transforms.RandomHorizontalFlip(),
  #     transforms.ToTensor(),
  #     norm,
  # ])
  transform_test = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
  ])
  # transforms.Compose([
  #     transforms.ToTensor(),
  #     norm,
  # ])
  train_dataset = datasets.CIFAR10(
      root=os.path.join(FLAGS['data_dir'], str(xm.get_ordinal())),
      train=True,
      download=True,
      transform=transform_train)
  test_dataset = datasets.CIFAR10(
      root=os.path.join(FLAGS['data_dir'], str(xm.get_ordinal())),
      train=False,
      download=True,
      transform=transform_test)
  train_sampler = torch.utils.data.distributed.DistributedSampler(
      train_dataset,
      num_replicas=xm.xrt_world_size(),
      rank=xm.get_ordinal(),
      shuffle=True)
  train_loader = torch.utils.data.DataLoader(
      train_dataset,
      batch_size=FLAGS['batch_size'],
      sampler=train_sampler,
      num_workers=FLAGS['num_workers'],
      drop_last=True)
  test_loader = torch.utils.data.DataLoader(
      test_dataset,
      batch_size=FLAGS['batch_size'],
      shuffle=False,
      num_workers=FLAGS['num_workers'],
      drop_last=True)

  # Scale learning rate to num cores
  learning_rate = FLAGS['learning_rate'] * xm.xrt_world_size()

  # Get loss function, optimizer, and model
  device = xm.xla_device()
  model = models.resnet50().to(device)
  optimizer = optim.SGD(model.parameters(), lr=learning_rate,
                        momentum=FLAGS['momentum'], weight_decay=5e-4)
  scheduler = lr_scheduler.StepLR(optimizer, step_size=90, gamma=0.1)
  #loss_fn = nn.NLLLoss()
  loss_fn = nn.CrossEntropyLoss()

  def train_loop_fn(loader):
    tracker = xm.RateTracker()
    model.train()
    for x, (data, target) in enumerate(loader):
      optimizer.zero_grad()
      output = model(data)
      loss = loss_fn(output, target)
      loss.backward()
      xm.optimizer_step(optimizer)
      tracker.add(FLAGS['batch_size'])
      #scheduler.step()
      #if x % FLAGS['log_steps'] == 0:
      #  print('[xla:{}]({}) Loss={:.5f} Rate={:.2f} GlobalRate={:.2f} Time={}'.format(
      #      xm.get_ordinal(), x, loss.item(), tracker.rate(),
      #      tracker.global_rate(), time.asctime()), flush=True)

  def test_loop_fn(loader):
    total_samples = 0
    correct = 0
    model.eval()
    data, pred, target = None, None, None
    for data, target in loader:
      output = model(data)
      pred = output.max(1, keepdim=True)[1]
      correct += pred.eq(target.view_as(pred)).sum().item()
      total_samples += data.size()[0]

    accuracy = 100.0 * correct / total_samples
    #print('[xla:{}] Accuracy={:.2f}%'.format(
    #    xm.get_ordinal(), accuracy), flush=True)
    return accuracy, data, pred, target

  # Train and eval loops
  accuracy = 0.0
  data, pred, target = None, None, None
  for epoch in range(1, FLAGS['num_epochs'] + 1):
    para_loader = pl.ParallelLoader(train_loader, [device])
    train_loop_fn(para_loader.per_device_loader(device))
    scheduler.step()
    xm.master_print("Finished training epoch {}".format(epoch))
    print("Finishtime: ", datetime.now())
    if epoch % 10 == 0:
      torch.save(model.state_dict(), '/content/drive/My Drive/tpu-resnet50/model.pt')
    para_loader = pl.ParallelLoader(test_loader, [device])
    accuracy, data, pred, target  = test_loop_fn(para_loader.per_device_loader(device))
    #if FLAGS['metrics_debug']:
      #xm.master_print(met.metrics_report())

  return accuracy, data, pred, target, model

In [0]:
# Start training processes
print("Starttime: ", datetime.now())
def _mp_fn(rank, flags):
  global FLAGS
  FLAGS = flags
  torch.set_default_tensor_type('torch.FloatTensor')
  accuracy, data, pred, target, model = train_resnet50()
xmp.spawn(_mp_fn, args=(FLAGS,), nprocs=FLAGS['num_cores'], start_method='fork')

Starttime:  2019-12-12 02:16:22.884370
Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to /tmp/cifar/0/cifar-10-python.tar.gz


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to /tmp/cifar/3/cifar-10-python.tar.gz
Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to /tmp/cifar/6/cifar-10-python.tar.gz


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to /tmp/cifar/2/cifar-10-python.tar.gz


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to /tmp/cifar/1/cifar-10-python.tar.gz


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to /tmp/cifar/5/cifar-10-python.tar.gz


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to /tmp/cifar/7/cifar-10-python.tar.gz


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to /tmp/cifar/4/cifar-10-python.tar.gz


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

Extracting /tmp/cifar/3/cifar-10-python.tar.gz to /tmp/cifar/3
Extracting /tmp/cifar/1/cifar-10-python.tar.gz to /tmp/cifar/1
Extracting /tmp/cifar/2/cifar-10-python.tar.gz to /tmp/cifar/2
Extracting /tmp/cifar/0/cifar-10-python.tar.gz to /tmp/cifar/0
Extracting /tmp/cifar/6/cifar-10-python.tar.gz to /tmp/cifar/6
Extracting /tmp/cifar/5/cifar-10-python.tar.gz to /tmp/cifar/5
Extracting /tmp/cifar/7/cifar-10-python.tar.gz to /tmp/cifar/7
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Extracting /tmp/cifar/4/cifar-10-python.tar.gz to /tmp/cifar/4
Files already downloaded and verified
Finishtime:  2019-12-12 02:18:32.664870
Finishtime:  2019-12-12 02:18:32.666317
Finished training epoch 1
Finishtime:  2019-12-12 02:18:32.666974
Finishtime:  2019-12-12 02:18:32.667074
Finish

KeyboardInterrupt: ignored

In [0]:
device = xm.xla_device()
model = models.resnet50().to(device)
model.load_state_dict(torch.load('/content/model.pt'))
model.eval()

FileNotFoundError: ignored

In [0]:
from PIL import Image
from torchvision import transforms

def eval_image(filepath):
    input_image = Image.open(filepath)
    preprocess = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
    input_tensor = preprocess(input_image)
    input_batch = input_tensor.unsqueeze(0) # create a mini-batch as expected by the model

    # move the input and model to GPU for speed if available
    device = xm.xla_device()
    input_batch = input_batch.to(device)

    with torch.no_grad():
        output = model(input_batch)
    # The output has unnormalized scores. To get probabilities, you can run a softmax on it.
    #print(torch.nn.functional.softmax(output[0], dim=0))
    _, preds = torch.max(output, 1)
    return preds[0]

from os import listdir
classes_new = ('airplane', 'automobile', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
image_paths = [
    listdir("/content/DSF_HW5_wild_images/%d_%s" % ( num, c_name ))
    for num, c_name in enumerate(classes_new)
]

In [0]:
correct = 0.0
total = 0.0
for actual_class,files in enumerate(image_paths):
    for image_filepath in files:
        fpath = "/content/drive/My Drive/DSF_HW5_wild_images/%d_%s/" % (actual_class, classes_new[actual_class])
        pred_label = eval_image(fpath + image_filepath)
        if pred_label == actual_class:
            correct += 1.0
        total += 1.0
print( "Wild Accuracy: ", correct / total )

Wild Accuracy:  0.0


In [0]:
os.system("unzip /content/DSF_HW5_wild_images.zip")

0