Copyright 2019 Google LLC.
SPDX-License-Identifier: Apache-2.0

In [1]:
!pip install \
  http://storage.googleapis.com/pytorch-tpu-releases/tf-1.13/torch-1.0.0a0+1d94a2b-cp36-cp36m-linux_x86_64.whl  \
  http://storage.googleapis.com/pytorch-tpu-releases/tf-1.13/torch_xla-0.1+5622d42-cp36-cp36m-linux_x86_64.whl



In [0]:
import torch
import torch.nn as nn
import torch_xla

class XlaMulAdd(nn.Module):                                                                                             
  def forward(self, x, y):                                                                                            
    return x * y + y                                                                                                

# Inputs and output to/from XLA models are always in replicated mode. The shapes
# are [NUM_REPLICAS][NUM_VALUES]. A non replicated, single core, execution will
# has NUM_REPLICAS == 1, but retain the same shape rank.                                                                                                                               
x = torch.rand(3, 5)                                                                                                    
y = torch.rand(3, 5)                                                                                                    
model = XlaMulAdd()                                                                                                     
traced_model = torch.jit.trace(model, (x, y))                                                                             
xla_model = torch_xla._XLAC.XlaModule(traced_model)                                                             
output_xla = xla_model((torch_xla._XLAC.XLATensor(x), torch_xla._XLAC.XLATensor(y)))                                               
expected = model(x, y)
print(output_xla[0][0].to_tensor().data)
print(expected.data)


In [3]:
import shutil
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
import torch_xla
import torch_xla_py.utils as xu
import torch_xla_py.xla_model as xm
import unittest
import sys
import os
import argparse

parser = argparse.ArgumentParser(add_help=False)
parser.add_argument('--datadir', type=str, default='/tmp/mnist-data')
parser.add_argument('--logdir', type=str, default='/tmp/logs')
parser.add_argument('--num_cores', type=int, default=1)
parser.add_argument('--batch_size', type=int, default=512)
parser.add_argument('--num_epochs', type=int, default=10)
parser.add_argument('--num_workers', type=int, default=4)
parser.add_argument('--target_accuracy', type=float, default=98.0)
parser.add_argument('--fake_data', action='store_true')
parser.add_argument('--tidy', action='store_true')
parser.add_argument('--metrics_debug', action='store_true')

FLAGS, leftovers = parser.parse_known_args()
sys.argv = [sys.argv[0]] + leftovers
# Setup import folders.
xla_folder = os.path.dirname(os.path.dirname(os.path.abspath(sys.argv[0])))
print('xla folder', xla_folder)
sys.path.append(os.path.join(os.path.dirname(xla_folder), 'test'))
sys.path.insert(0, xla_folder)

class MNIST(nn.Module):

  def __init__(self):
    super(MNIST, self).__init__()
    self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
    self.bn1 = nn.BatchNorm2d(10)
    self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
    self.bn2 = nn.BatchNorm2d(20)
    self.fc1 = nn.Linear(320, 50)
    self.fc2 = nn.Linear(50, 10)

  def forward(self, x):
    x = F.relu(F.max_pool2d(self.conv1(x), 2))
    x = self.bn1(x)
    x = F.relu(F.max_pool2d(self.conv2(x), 2))
    x = self.bn2(x)
    x = x.view(-1, 320)
    x = F.relu(self.fc1(x))
    x = self.fc2(x)
    return F.log_softmax(x, dim=1)

print ("mnist defined")
def train_mnist():
  assert FLAGS.num_cores == 1
  torch.manual_seed(1)
  # Training settings
  lr = 0.01
  momentum = 0.5
  log_interval = 5

  if FLAGS.fake_data:
    print('using fake data')
    train_loader = xu.SampleGenerator(
        data=torch.zeros(FLAGS.batch_size, 1, 28, 28),
        target=torch.zeros(FLAGS.batch_size, dtype=torch.int64),
        sample_count=60000 // FLAGS.batch_size)
    test_loader = xu.SampleGenerator(
        data=torch.zeros(FLAGS.batch_size, 1, 28, 28),
        target=torch.zeros(FLAGS.batch_size, dtype=torch.int64),
        sample_count=10000 // FLAGS.batch_size)
  else:
    print('using real data', FLAGS.datadir)
    train_loader = torch.utils.data.DataLoader(
        datasets.MNIST(
            FLAGS.datadir,
            train=True,
            download=True,
            transform=transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.1307,), (0.3081,))
            ])),
        batch_size=FLAGS.batch_size,
        shuffle=True,
        num_workers=FLAGS.num_workers)
    test_loader = torch.utils.data.DataLoader(
        datasets.MNIST(
            FLAGS.datadir,
            train=False,
            transform=transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.1307,), (0.3081,))
            ])),
        batch_size=FLAGS.batch_size,
        shuffle=True,
        num_workers=FLAGS.num_workers)

  model = MNIST()

  inputs = torch.zeros(FLAGS.batch_size, 1, 28, 28)
  xla_model = xm.XlaModel(model, [inputs])
  optimizer = optim.SGD(xla_model.parameters_list(), lr=lr, momentum=momentum)
  loss_fn = nn.NLLLoss()
  accuracy = None
  for epoch in range(1, FLAGS.num_epochs + 1):
    # Training loop for epoch.
    start_time = time.time()
    processed = 0
    for batch_idx, (data, target) in enumerate(train_loader):
      if data.size()[0] != FLAGS.batch_size:
        break
      optimizer.zero_grad()
      y = xla_model(data)
      y[0].requires_grad = True
      loss = loss_fn(y[0], target)
      loss.backward()
      xla_model.backward(y)
      optimizer.step()
      processed += FLAGS.batch_size
      if batch_idx % log_interval == 0:
        print('Train Epoch: {} [{}/{} ({:.0f}%)]\t'
              'Loss: {:.6f}\tSamples/sec: {:.1f}'.format(
                  epoch, processed,
                  len(train_loader) * FLAGS.batch_size,
                  100. * batch_idx / len(train_loader), loss,
                  processed / (time.time() - start_time)))

    # Eval loop for epoch.
    start_time = time.time()
    correct_count = 0
    test_loss = 0
    count = 0
    for batch_idx, (data, target) in enumerate(test_loader):
      if data.size()[0] != FLAGS.batch_size:
        break
      y = xla_model(data)
      test_loss += loss_fn(y[0], target).sum().item()
      pred = y[0].max(1, keepdim=True)[1]
      correct_count += pred.eq(target.view_as(pred)).sum().item()
      count += FLAGS.batch_size

    test_loss /= count
    accuracy = 100.0 * correct_count / count
    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%), '
          'Samples/sec: {:.1f}\n'.format(test_loss, correct_count, count,
                                         accuracy,
                                         count / (time.time() - start_time)))
    # Debug metric dumping.
    if FLAGS.metrics_debug:
      print(torch_xla._XLAC._xla_metrics_report())

  return accuracy

torch.set_default_tensor_type('torch.FloatTensor')
train_mnist()

xla folder /usr/local/lib/python3.6
mnist defined
using real data /tmp/mnist-data
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Processing...
Done!

Test set: Average loss: 0.0020, Accuracy: 7870/9728 (80.90%), Samples/sec: 8403.3


Test set: Average loss: 0.0008, Accuracy: 8958/9728 (92.08%), Samples/sec: 8365.1


Test set: Average loss: 0.0004, Accuracy: 9225/9728 (94.83%), Samples/sec: 8260.0


Test set: Average loss: 0.0003, Accuracy: 9331/9728 (95.92%), Samples/sec: 8331.2


Test set: Average loss: 0.0003, Accuracy: 9396/9728 (96.59%), Samples/sec: 8272.5


Test set: Average loss: 0.0002, Accuracy: 9429/9728 (96.93%), Samples/sec: 8472.5


Test set: Average loss: 0.0002, Accuracy: 9463/9728 (97.28%), Samples/sec: 8186.4


Test set: Average loss:

97.74876644736842