Copyright 2019 Google LLC.
SPDX-License-Identifier: Apache-2.0

In [0]:
!pip install \
  http://storage.googleapis.com/pytorch-tpu-releases/tf-1.13/torch-1.0.0a0+1d94a2b-cp36-cp36m-linux_x86_64.whl  \
  http://storage.googleapis.com/pytorch-tpu-releases/tf-1.13/torch_xla-0.1+5622d42-cp36-cp36m-linux_x86_64.whl

Collecting torch==1.0.0a0+1d94a2b from http://storage.googleapis.com/pytorch-tpu-releases/tf-1.13/torch-1.0.0a0+1d94a2b-cp36-cp36m-linux_x86_64.whl
[?25l  Downloading http://storage.googleapis.com/pytorch-tpu-releases/tf-1.13/torch-1.0.0a0+1d94a2b-cp36-cp36m-linux_x86_64.whl (266.4MB)
[K    100% |████████████████████████████████| 266.4MB 120.8MB/s 
[?25hCollecting torch-xla==0.1+5622d42 from http://storage.googleapis.com/pytorch-tpu-releases/tf-1.13/torch_xla-0.1+5622d42-cp36-cp36m-linux_x86_64.whl
[?25l  Downloading http://storage.googleapis.com/pytorch-tpu-releases/tf-1.13/torch_xla-0.1+5622d42-cp36-cp36m-linux_x86_64.whl (57.9MB)
[K    100% |████████████████████████████████| 57.9MB 71.2MB/s 
[31mtorchvision 0.2.1 has requirement pillow>=4.1.1, but you'll have pillow 4.0.0 which is incompatible.[0m
[31mfastai 1.0.46 has requirement numpy>=1.15, but you'll have numpy 1.14.6 which is incompatible.[0m
[31mfastai 1.0.46 has requirement torch>=1.0.0, but you'll have torch 1.0.0a

In [0]:
import torch
import torch.nn as nn
import torch_xla

class XlaMulAdd(nn.Module):                                                                                             
  def forward(self, x, y):                                                                                            
    return x * y + y                                                                                                

# Inputs and output to/from XLA models are always in replicated mode. The shapes
# are [NUM_REPLICAS][NUM_VALUES]. A non replicated, single core, execution will
# has NUM_REPLICAS == 1, but retain the same shape rank.                                                                                                                               
x = torch.rand(3, 5)                                                                                                    
y = torch.rand(3, 5)                                                                                                    
model = XlaMulAdd()                                                                                                     
traced_model = torch.jit.trace(model, (x, y))                                                                             
xla_model = torch_xla._XLAC.XlaModule(traced_model)                                                             
output_xla = xla_model((torch_xla._XLAC.XLATensor(x), torch_xla._XLAC.XLATensor(y)))                                               
expected = model(x, y)
print(output_xla[0][0].to_tensor().data)
print(expected.data)


In [0]:
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
import torch_xla
import torch_xla_py.utils as xu
import torch_xla_py.xla_model as xm

datadir = '/tmp/mnist-data'
num_workers = 4

class MNIST(nn.Module):

  def __init__(self):
    super(MNIST, self).__init__()
    self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
    self.bn1 = nn.BatchNorm2d(10)
    self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
    self.bn2 = nn.BatchNorm2d(20)
    self.fc1 = nn.Linear(320, 50)
    self.fc2 = nn.Linear(50, 10)

  def forward(self, x):
    x = F.relu(F.max_pool2d(self.conv1(x), 2))
    x = self.bn1(x)
    x = F.relu(F.max_pool2d(self.conv2(x), 2))
    x = self.bn2(x)
    x = x.view(-1, 320)
    x = F.relu(self.fc1(x))
    x = self.fc2(x)
    return F.log_softmax(x, dim=1)

def train_mnist():
  torch.manual_seed(1)
  # Training settings
  lr = 0.01
  momentum = 0.5
  log_interval = 5
  batch_size = 512
  num_epochs = 10

  train_loader = torch.utils.data.DataLoader(
      datasets.MNIST(
          datadir,
          train=True,
          download=True,
          transform=transforms.Compose([
              transforms.ToTensor(),
              transforms.Normalize((0.1307,), (0.3081,))
          ])),
      batch_size=batch_size,
      shuffle=True,
      num_workers=num_workers)
  test_loader = torch.utils.data.DataLoader(
      datasets.MNIST(
          datadir,
          train=False,
          transform=transforms.Compose([
              transforms.ToTensor(),
              transforms.Normalize((0.1307,), (0.3081,))
          ])),
      batch_size=batch_size,
      shuffle=True,
      num_workers=num_workers)

  model = MNIST()

  inputs = torch.zeros(batch_size, 1, 28, 28)
  xla_model = xm.XlaModel(model, [inputs])
  optimizer = optim.SGD(xla_model.parameters_list(), lr=lr, momentum=momentum)
  loss_fn = nn.NLLLoss()
  accuracy = None
  for epoch in range(1, num_epochs + 1):
    # Training loop for epoch.
    start_time = time.time()
    processed = 0
    for batch_idx, (data, target) in enumerate(train_loader):
      if data.size()[0] != batch_size:
        break
      optimizer.zero_grad()
      y = xla_model(data)
      y[0].requires_grad = True
      loss = loss_fn(y[0], target)
      loss.backward()
      xla_model.backward(y)
      optimizer.step()
      processed += batch_size
      if batch_idx % log_interval == 0:
        print('Train Epoch: {} [{}/{} ({:.0f}%)]\t'
              'Loss: {:.6f}\tSamples/sec: {:.1f}'.format(
                  epoch, processed,
                  len(train_loader) * batch_size,
                  100. * batch_idx / len(train_loader), loss,
                  processed / (time.time() - start_time)))

    # Eval loop for epoch.
    start_time = time.time()
    correct_count = 0
    test_loss = 0
    count = 0
    for batch_idx, (data, target) in enumerate(test_loader):
      if data.size()[0] != batch_size:
        break
      y = xla_model(data)
      test_loss += loss_fn(y[0], target).sum().item()
      pred = y[0].max(1, keepdim=True)[1]
      correct_count += pred.eq(target.view_as(pred)).sum().item()
      count += batch_size

    test_loss /= count
    accuracy = 100.0 * correct_count / count
    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%), '
          'Samples/sec: {:.1f}\n'.format(test_loss, correct_count, count,
                                         accuracy,
                                         count / (time.time() - start_time)))
    
    print(torch_xla._XLAC._xla_metrics_report())

torch.set_default_tensor_type('torch.FloatTensor')
train_mnist()

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Processing...
Done!

Test set: Average loss: 0.0020, Accuracy: 7870/9728 (80.90%), Samples/sec: 8332.3

Metric: CompileTime
  TotalSamples: 119
  Counter: 04s112ms741.607us
  ValueRate: 403ms961.310us / second
  Rate: 11.6623 / second
  Percentiles: 1%=003ms946.099us; 5%=003ms116.911us; 10%=003ms232.231us; 20%=003ms426.525us; 50%=004ms370.243us; 80%=009ms633.349us; 90%=010ms854.089us; 95%=014ms166.692us; 99%=399ms373.376us
Metric: ExecuteParallelTime
  TotalSamples: 117
  Counter: 418ms643.801us
  ValueRate: 062ms445.492us / second
  Rate: 17.4937 / second
  Percentiles: 1%=003ms593.794us; 5%=003ms730.280us; 10%=003ms814.589us; 20%=003ms952.300us; 50%=003ms235.367us; 80%=003ms494.505us; 90%=004ms880.297us