In [65]:
import os
os.environ['XRT_TPU_CONFIG']='tpu_worker;0;10.77.227.146:8470'

In [1]:
%matplotlib inline
RESULT_IMG_PATH = '/tmp/test_result.png'
from matplotlib.pyplot import imshow
from matplotlib import pyplot as plt
def plot_results(images):
  #inv_norm = transforms.Normalize((-0.1307/0.3081,), (1/0.3081,))

  num_images = images.shape[0]
  fig, axes = plt.subplots(4, 6, figsize=(11, 9))

  for i, ax in enumerate(fig.axes):
    ax.axis('off')
    if i >= num_images:
      continue
    img = images[i]
    #img = inv_norm(img)
    img = img.squeeze() # [1,Y,X] -> [Y,X]    
    ax.imshow(img)
  plt.savefig(RESULT_IMG_PATH, transparent=True)

In [27]:
import torch
from torch import nn, optim
from torch.autograd.variable import Variable
from torchvision import transforms, datasets
from torch.optim import Adam

import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.debug.metrics as met
import torch_xla.distributed.parallel_loader as pl
import torch_xla.distributed.xla_multiprocessing as xmp
import torch_xla.utils.utils as xu
import torch.nn.functional as F


In [3]:
# Define Parameters
FLAGS = {}
FLAGS['datadir'] = "/tmp/mnist"
FLAGS['batch_size'] = 64
FLAGS['num_workers'] = 4
FLAGS['learning_rate'] = 0.001
FLAGS['momentum'] = 0.5
FLAGS['num_epochs'] = 100
FLAGS['num_cores'] = 8
FLAGS['log_steps'] = 20
FLAGS['metrics_debug'] = False

In [4]:
def mnist_data():
    compose = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5])])
    out_dir = '{}/dataset'.format(FLAGS['datadir'])
    return datasets.MNIST(root=out_dir, train=True, transform=compose, download=True)

In [55]:
class DiscriminativeNet(torch.nn.Module):
    
    def __init__(self):
        super(DiscriminativeNet, self).__init__()
        
        self.conv1 = nn.Sequential(
            nn.Conv2d(
                in_channels=1, out_channels=128, kernel_size=4, 
                stride=2, padding=1, bias=False
            ),
            nn.LeakyReLU(0.2, inplace=True)
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(
                in_channels=128, out_channels=256, kernel_size=4,
                stride=2, padding=1, bias=False
            ),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True)
        )
        self.conv3 = nn.Sequential(
            nn.Conv2d(
                in_channels=256, out_channels=512, kernel_size=4,
                stride=2, padding=1, bias=False
            ),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True)
        )
        self.conv4 = nn.Sequential(
            nn.Conv2d(
                in_channels=512, out_channels=1024, kernel_size=4,
                stride=2, padding=1, bias=False
            ),
            nn.BatchNorm2d(1024),
            nn.LeakyReLU(0.2, inplace=True)
        )
        self.out = nn.Sequential(
            nn.Linear(1024,1),
            nn.Sigmoid(),
        )

    def forward(self, x):
        # Convolutional layers
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.conv4(x)
        # Flatten and apply sigmoid
        x = x.view(-1, 1024)
        x = self.out(x)
        return x

In [59]:
class GenerativeNet(torch.nn.Module):
    
    def __init__(self):
        super(GenerativeNet, self).__init__()
        
        self.linear = torch.nn.Linear(100, 1024*4*4)
        
        self.conv1 = nn.Sequential(
            nn.ConvTranspose2d(
                in_channels=1024, out_channels=512, kernel_size=4,
                stride=2, padding=1, bias=False
            ),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True)
        )
        self.conv2 = nn.Sequential(
            nn.ConvTranspose2d(
                in_channels=512, out_channels=256, kernel_size=4,
                stride=2, padding=1, bias=False
            ),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True)
        )
        self.conv3 = nn.Sequential(
            nn.ConvTranspose2d(
                in_channels=256, out_channels=128, kernel_size=4,
                stride=2, padding=1, bias=False
            ),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True)
        )
        self.conv4 = nn.Sequential(
            nn.ConvTranspose2d(
                in_channels=128, out_channels=1, kernel_size=4,
                stride=2, padding=1, bias=False
            )
        )
        self.out = torch.nn.Tanh()

    def forward(self, x):
        # Project and reshape
        print("Input",x.size())

        x = self.linear(x)
        x = x.view(x.shape[0], 1024, 4, 4)
        # Convolutional layers
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.conv4(x)
        # Apply Tanh
        print("Output",x.size())
        return self.out(x)
    


In [60]:
def init_weights(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1 or classname.find('BatchNorm') != -1:
        m.weight.data.normal_(0.00, 0.02)

In [61]:
def real_data_target(size, device):
    '''
    Tensor containing ones, with shape = size
    '''
    data = Variable(torch.ones(size, 1))
    return data.to(device)

def fake_data_target(size, device):
    '''
    Tensor containing zeros, with shape = size
    '''
    data = Variable(torch.zeros(size, 1))
    return data.to(device)

In [62]:
# Noise
def noise(size, device):
    n = Variable(torch.randn(size, 100))
    return n.to(device)

In [63]:
def train_gan(rank):
    torch.manual_seed(1)
    
    if not xm.is_master_ordinal():
        # Barrier: Wait until master is done downloading
        xm.rendezvous('download_only_once')
    # Dataset
    data = mnist_data()
    if xm.is_master_ordinal():
        # Master is done, other workers can proceed now
        xm.rendezvous('download_only_once')
    train_sampler = torch.utils.data.distributed.DistributedSampler(
        data,
        num_replicas=xm.xrt_world_size(),
        rank=xm.get_ordinal(),
        shuffle=True)
    

    # Create loader with data, so that we can iterate over it
    #train_loader = torch.utils.data.DataLoader(data, batch_size=100, shuffle=True)
    train_loader = torch.utils.data.DataLoader(
      data,
      batch_size=FLAGS['batch_size'],
      sampler=train_sampler,
      num_workers=FLAGS['num_workers'],
      drop_last=True)

    # Num batches
    num_batches = len(train_loader)
    
    device = xm.xla_device()
    
    generator = GenerativeNet().to(device)
    #generator.apply(init_weights)

    discriminator = DiscriminativeNet().to(device)
    #discriminator.apply(init_weights)
    
    # Optimizers
    d_optimizer = Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))
    g_optimizer = Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))

    # Number of epochs
    num_epochs = FLAGS['num_epochs'] 
    # Loss function
    loss = nn.BCELoss()
    
    num_test_samples = 16
    test_noise = noise(num_test_samples, device)
    
    def train_step_discriminator(optimizer, real_data, fake_data, device):
        # Reset gradients
        optimizer.zero_grad()

        # 1. Train on Real Data
        prediction_real = discriminator(real_data)
        # Calculate error and backpropagate
        error_real = loss(prediction_real, real_data_target(real_data.size(0), device))
        error_real.backward()

        # 2. Train on Fake Data
        prediction_fake = discriminator(fake_data)
        # Calculate error and backpropagate
        error_fake = loss(prediction_fake, fake_data_target(real_data.size(0), device))
        error_fake.backward()

        # Update weights with gradients
        xm.optimizer_step(optimizer)

        return error_real + error_fake, prediction_real, prediction_fake
        #return (0, 0, 0)

    def train_step_generator(optimizer, fake_data, device):
        # Reset gradients
        optimizer.zero_grad()
        # Sample noise and generate fake data
        prediction = discriminator(fake_data)
        # Calculate error and backpropagate
        error = loss(prediction, real_data_target(prediction.size(0), device))
        error.backward()
        # Update weights with gradients
        xm.optimizer_step(optimizer)

        # Return error
        return error


    def train_loop_fn(loader):
        tracker = xm.RateTracker()
        for n_batch, (real_batch,_) in enumerate(loader):
            # Train Step Descriminator
            real_data = Variable(real_batch).to(device)
            fake_data = generator(noise(real_data.size(0), device)).detach()
            d_error, d_pred_real, d_pred_fake = train_step_discriminator(d_optimizer,
                                                                real_data, fake_data, device)
            #Train Step Generator
            fake_data = generator(noise(real_batch.size(0), device))
            g_error = train_step_generator(g_optimizer, fake_data, device)
        print(f'D_ERROR: {d_error}, G_ERROR: {g_error}')
        return d_error, g_error


            # Display Test Images
            # Save Model Checkpoints

    for epoch in range(1, FLAGS['num_epochs'] +1):
        para_loader = pl.ParallelLoader(train_loader, [device])
        d_error, g_error = train_loop_fn (para_loader.per_device_loader(device))
        xm.master_print("Finished training epoch {}: D_error:{}, G_error".format(epoch, d_error, g_error))
        if rank == 0 :
            # Retrieve tensors that are on TPU core 0 and plot.
            plot_results(vectors_to_images(generator(test_noise).detach()).cpu())

# Start training processes
def _mp_fn(rank, flags):
    global FLAGS
    FLAGS = flags
    torch.set_default_tensor_type('torch.FloatTensor')
    train_gan(rank)
    #if rank == 0:
      # Retrieve tensors that are on TPU core 0 and plot.
      # plot_results(images.cpu())

xmp.spawn(_mp_fn, args=(FLAGS,), nprocs=FLAGS['num_cores'],
          start_method='fork')



Input torch.Size([64, 100])
Output torch.Size([64, 1, 64, 64])
torch.Size([64, 1])
torch.Size([64, 1]) torch.Size([64, 1])
Input torch.Size([64, 100])
Output torch.Size([64, 1, 64, 64])
torch.Size([64, 1])
torch.Size([1024, 1])


  return F.binary_cross_entropy(input, target, weight=self.weight, reduction=self.reduction)


torch.Size([64, 1]) torch.Size([64, 1])
Input torch.Size([64, 100])


Exception in device=TPU:0: Target and input must have the same number of elements. target nelement (64) != input nelement (1024)


Output torch.Size([64, 1, 64, 64])
torch.Size([1024, 1])


Traceback (most recent call last):


torch.Size([64, 1])


  return F.binary_cross_entropy(input, target, weight=self.weight, reduction=self.reduction)
  File "/home/sivaibhav/anaconda3/envs/pytorch/lib/python3.6/site-packages/torch_xla/distributed/xla_multiprocessing.py", line 119, in _start_fn
    fn(gindex, *args)
Exception in device=TPU:7: Target and input must have the same number of elements. target nelement (64) != input nelement (1024)
  File "<ipython-input-63-67079464aba3>", line 120, in _mp_fn
    train_gan(rank)
Traceback (most recent call last):
  File "<ipython-input-63-67079464aba3>", line 109, in train_gan
    d_error, g_error = train_loop_fn (para_loader.per_device_loader(device))
  File "/home/sivaibhav/anaconda3/envs/pytorch/lib/python3.6/site-packages/torch_xla/distributed/xla_multiprocessing.py", line 119, in _start_fn
    fn(gindex, *args)


torch.Size([64, 1]) torch.Size([64, 1])


  File "<ipython-input-63-67079464aba3>", line 96, in train_loop_fn
    real_data, fake_data, device)
  File "<ipython-input-63-67079464aba3>", line 120, in _mp_fn
    train_gan(rank)


torch.Size([1024, 1])


  File "<ipython-input-63-67079464aba3>", line 65, in train_step_discriminator
    error_fake = loss(prediction_fake, fake_data_target(real_data.size(0), device))
  File "<ipython-input-63-67079464aba3>", line 109, in train_gan
    d_error, g_error = train_loop_fn (para_loader.per_device_loader(device))
  File "/home/sivaibhav/anaconda3/envs/pytorch/lib/python3.6/site-packages/torch/nn/modules/module.py", line 558, in __call__
    result = self.forward(*input, **kwargs)
  return F.binary_cross_entropy(input, target, weight=self.weight, reduction=self.reduction)
Exception in device=TPU:4: Target and input must have the same number of elements. target nelement (64) != input nelement (1024)
  File "/home/sivaibhav/anaconda3/envs/pytorch/lib/python3.6/site-packages/torch/nn/modules/loss.py", line 520, in forward
    return F.binary_cross_entropy(input, target, weight=self.weight, reduction=self.reduction)
  File "<ipython-input-63-67079464aba3>", line 96, in train_loop_fn
    real_data, fa

Input torch.Size([64, 100])
Output torch.Size([64, 1, 64, 64])
torch.Size([64, 1])
torch.Size([64, 1]) torch.Size([64, 1])


Exception: process 0 terminated with exit code 17

In [None]:
from IPython.display import Image
Image(filename=RESULT_IMG_PATH)