Single Core

In [None]:
!pip install cloud-tpu-client==0.10 https://storage.googleapis.com/tpu-pytorch/wheels/torch_xla-1.6-cp36-cp36m-linux_x86_64.whl

In [None]:
import os
import torch
import torchvision
import torchvision.datasets as datasets

In [None]:
class_map = {
0 : "t-shirt",
1 : "trouser",
2 : "pullover",
3 : "dress",
4 : "coat",
5 : "sandal",
6 : "shirt",
7 : "sneaker",
8 : "bag",
9 : "ankle boot"
}

raw_dataset = datasets.FashionMNIST(
    os.path.join("/tmp/fashionmnist"),
    train=True,
    download=True
)

In [None]:
img_index = 0
tup = raw_dataset[img_index]
display(tup[0].resize((224, 224)))
print(class_map[tup[1]])

In [None]:
import torch_xla
import torch_xla.core.xla_model as xm

In [None]:
net = torchvision.models.alexnet(num_classes=10)

device = xm.xla_device()
net = net.to(device)

In [None]:
import torchvision.transforms as transforms

# See https://pytorch.org/docs/stable/torchvision/models.html for normalization
# Pre-trained TorchVision models expect RGB (3 x H x W) images
# H and W should be >= 224
# Loaded into [0, 1] and normalized as follows:
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
to_rgb = transforms.Lambda(lambda image: image.convert('RGB'))
resize = transforms.Resize((224, 224))
my_transform = transforms.Compose([resize, to_rgb, transforms.ToTensor(), normalize])

In [None]:
train_dataset = datasets.FashionMNIST(
  os.path.join("/tmp/fashionmnist"),
  train=True,
  download=True,
  transform=my_transform)

test_dataset = datasets.FashionMNIST(
  os.path.join("/tmp/fashionmnist"),
  train=False,
  download=True,
  transform=my_transform)

In [None]:
train_sampler = torch.utils.data.RandomSampler(train_dataset)
test_sampler = torch.utils.data.RandomSampler(test_dataset)

In [None]:
batch_size = 8

train_loader = torch.utils.data.DataLoader(
  train_dataset,
  batch_size=batch_size,
  sampler=train_sampler)

test_loader = torch.utils.data.DataLoader(
  test_dataset,
  batch_size=batch_size,
  sampler=test_sampler)

In [None]:
import time
from google.colab import widgets

t_to_img = transforms.Compose([transforms.ToPILImage(), transforms.Grayscale()])

# Runs the given net on the batches provided by the test_loader
# Records the number of correct predictions (guesses) and 
# prints the percentage of correct guesses on the dataset, plus a 
# sample batch.
def eval_network(net, test_loader):
  start_time = time.time()
  num_correct = 0
  total_guesses = 0

  # Sets eval and no grad context for evaluation
  net.eval()
  with torch.no_grad():
    for data, targets in iter(test_loader):
      # Sends data and targets to device
      data = data.to(device)
      targets = targets.to(device)

      # Acquires the network's best guesses at each class
      results = net(data)
      best_guesses = torch.argmax(results, 1)

      # Updates number of correct and total guesses
      num_correct += torch.eq(targets, best_guesses).sum().item()
      total_guesses += batch_size
    
    # Prints output
    elapsed_time = time.time() - start_time
    print("Correctly guessed ", num_correct/total_guesses*100, "% of the dataset")
    print("Evaluated in ", elapsed_time, " seconds")
    print("Sample batch:")
    
    # Uses last batch as sample
    grid = widgets.Grid(2, 4)
    row = 0
    col = 0
    for ex in zip(data, targets, best_guesses):
      data = ex[0].cpu()
      target = class_map[ex[1].item()]
      guess = class_map[ex[2].item()]

      img = t_to_img(data)

      with grid.output_to(row, col):
        display(img)
        print("Target: ", target)
        print("Guess: ", guess)

        # Updates grid location
        if col == 3:
          row += 1
          col = 0
        else:
          col += 1

In [None]:
eval_network(net, test_loader)

In [None]:
# Note: this will take 5-10 minutes to run.
num_epochs = 1
loss_fn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(net.parameters())

# Ensures network is in train mode
net.train()

start_time = time.time()
for epoch in range(num_epochs):
  for data, targets in iter(train_loader):
    # Sends data and targets to device
    data = data.to(device)
    targets = targets.to(device)

    # Acquires the network's best guesses at each class
    results = net(data)

    # Computes loss
    loss = loss_fn(results, targets)

    # Updates model
    optimizer.zero_grad()
    loss.backward()
    xm.optimizer_step(optimizer, barrier=True)  # Note: Cloud TPU-specific code!

elapsed_time = time.time() - start_time
print ("Spent ", elapsed_time, " seconds training for ", num_epochs, " epoch(s) on a single core.")

In [None]:
eval_network(net, test_loader)