<a href="https://colab.research.google.com/github/viyx/arc-nn/blob/tpu/colab_tpu.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!rm -rf arc-nn
!git clone -q https://github.com/viyx/arc-nn -b tpu
%cd arc-nn
!git log --pretty=oneline
!wget -q https://storage.googleapis.com/viy_data/pickle/median/ds_test_median.pickle
!wget -q https://storage.googleapis.com/viy_data/pickle/median/ds_train_median.pickle

In [None]:
import pickle
from datasets import ARCDataset, ColorPermutation, GPTDataset
from mingpt.utils import set_seed

# flags = {'seed': 19}

# set_seed(flags['seed'])

In [None]:
#create minidataset for testing

# ds = ARCDataset(data_folder='./arc-nn/data')
# perm = ColorPermutation(max_colors=10, max_permutations=100000)
# ds_train = ARCDataset(tasks=ds.tasks[:10], augs=[perm])
# ds_test = ARCDataset(tasks=ds.tasks[10:39])

# with open('train_dataset.pickle', 'wb') as f:
#   pickle.dump(ds_train, f)
    
# with open('test_dataset.pickle', 'wb') as f:
#   pickle.dump(ds_test, f)

In [None]:
import os
assert os.environ['COLAB_TPU_ADDR'], 'Make sure to select TPU from Edit > Notebook settings > Hardware accelerator'

### Installing PyTorch/XLA

Run the following cell (or copy it into your own notebook!) to install PyTorch, Torchvision, and PyTorch/XLA. It will take a couple minutes to run.

In [None]:
# Installs PyTorch, PyTorch/XLA, and Torchvision
# Copy this cell into your own notebooks to use PyTorch on Cloud TPUs 
# Warning: this may take a couple minutes to run
!pip install -q cloud-tpu-client==0.10 https://storage.googleapis.com/tpu-pytorch/wheels/torch_xla-1.6-cp36-cp36m-linux_x86_64.whl

In [None]:
import torch
import torch_xla.distributed.parallel_loader as pl
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multiprocessing as xmp
import time
import tqdm
from mingpt.model import GPT, GPTConfig
from mingpt.trainer import TrainerConfig

def map_fn(index, flags):
  ## Setup 

  # Sets a common random seed - both for initialization and ensuring graph is the same
  set_seed(flags['seed'])

  # Acquires the (unique) Cloud TPU core corresponding to this process's index
  device = xm.xla_device()  

  if not xm.is_master_ordinal():
    xm.rendezvous('download_only_once')

  with open('ds_train_median.pickle', 'rb') as f:
    train_dataset = pickle.load(f)
        
  with open('ds_test_median.pickle', 'rb') as f:
    test_dataset = pickle.load(f)

  train_dataset = GPTDataset(train_dataset, n_colors=10, n_context=2048, padding=True)
  test_dataset = GPTDataset(test_dataset,  n_colors=10, n_context=2048, padding=True)

  print(len(train_dataset), len(test_dataset))
  mconf = GPTConfig(train_dataset.vocab_size, block_size=train_dataset.n_context,
                    masked_length = 30 ** 2 + 30 + 1, padding_idx=13,
                    embd_pdrop=0.0, resid_pdrop=0.1, attn_pdrop=0.1,
                    n_layer=2, n_head=8, n_embd=8)


#   train(model, train_dataset, test_dataset, tconf)
  
  if xm.is_master_ordinal():
    xm.rendezvous('download_only_once')
  
  # Creates the (distributed) train sampler, which let this process only access
  # its portion of the training dataset.
  train_sampler = torch.utils.data.distributed.DistributedSampler(
    train_dataset,
    num_replicas=xm.xrt_world_size(),
    rank=xm.get_ordinal(),
    shuffle=True)
  
  test_sampler = torch.utils.data.distributed.DistributedSampler(
    test_dataset,
    num_replicas=xm.xrt_world_size(),
    rank=xm.get_ordinal(),
    shuffle=False)
  
  # Creates dataloaders, which load data in batches
  # Note: test loader is not shuffled or sampled
  train_loader = torch.utils.data.DataLoader(
      train_dataset,
      batch_size=flags['batch_size'],
      sampler=train_sampler,
      num_workers=flags['num_workers'],
      drop_last=True)

  test_loader = torch.utils.data.DataLoader(
      test_dataset,
      batch_size=flags['batch_size'],
      sampler=test_sampler,
      shuffle=False,
      num_workers=flags['num_workers'],
      drop_last=True)
  

  model = GPT(mconf).to(device).train()

  tokens_per_epoch = len(train_dataset) * 576 #mean x length
  train_epochs = 1

  tconf = TrainerConfig(max_epochs=1, batch_size=128, learning_rate=3e-3,
                      betas = (0.9, 0.95), weight_decay=0,
                      lr_decay=True, warmup_tokens=tokens_per_epoch, 
                      final_tokens=train_epochs*tokens_per_epoch,
                      ckpt_path='model.pt',
                      num_workers=2, early_stopping=1000)

  optimizer = model.configure_optimizers(tconf)

  ## Trains
  train_start = time.time()
  for epoch in range(flags['num_epochs']):
    para_train_loader = pl.ParallelLoader(train_loader, [device]).per_device_loader(device)
    pbar = tqdm(enumerate(para_train_loader), total=len(para_train_loader))

    for it, batch in pbar:
      x, y = batch

      logits, loss = model(x, y)
      loss = loss.mean()

      # Updates model
      optimizer.zero_grad()
      loss.backward()

      xm.optimizer_step(optimizer)  # Note: barrier=True not needed when using ParallelLoader 
      pbar.set_description(f"epoch {epoch+1} iter {it}: train loss {loss.item():.5f}.")

  elapsed_train_time = time.time() - train_start
  print("Process", index, "finished training. Train time was:", elapsed_train_time) 


  ## Evaluation
  # Sets net to eval and no grad context 
#   net.eval()
#   eval_start = time.time()
#   with torch.no_grad():
#     num_correct = 0
#     total_guesses = 0

#     para_train_loader = pl.ParallelLoader(test_loader, [device]).per_device_loader(device)
#     for batch_num, batch in enumerate(para_train_loader):
#       data, targets = batch

#       # Acquires the network's best guesses at each class
#       output = net(data)
#       best_guesses = torch.argmax(output, 1)

#       # Updates running statistics
#       num_correct += torch.eq(targets, best_guesses).sum().item()
#       total_guesses += flags['batch_size']
  
  elapsed_eval_time = time.time() - eval_start
  print("Process", index, "finished evaluation. Evaluation time was:", elapsed_eval_time)
  print("Process", index, "guessed", num_correct, "of", total_guesses, "correctly for", num_correct/total_guesses * 100, "% accuracy.")

In [None]:
# Configures training (and evaluation) parameters
flags = {}
flags['batch_size'] = 8
flags['num_workers'] = 8
flags['num_epochs'] = 1
flags['seed'] = 1234

xmp.spawn(map_fn, args=(flags,), nprocs=8, start_method='fork')

##What's Next?

This notebook broke down training AlexNet on the Fashion MNIST dataset using an entire Cloud TPU. A [previous notebook](https://colab.research.google.com/github/pytorch/xla/blob/master/contrib/colab/single-core-alexnet-fashion-mnist.ipynb) showed how to train AlexNet on Fashion MNIST using only a single Cloud TPU core, and can be a helpful point of comparison. 

In particular, this notebook showed us how to:

- Define a "map function" that runs in parallel on one process per Cloud TPU core. 
- Run the map function using `spawn`.
- Understand the Cloud TPU context, its benefits, like automatic cross-process coordination, and its limits, like needing each process to perform the same Cloud TPU operations.
- Load and sample the datasets.
- Train and evaluate the network.

Additional notebooks demonstrating how to run PyTorch on Cloud TPUs can be found [here](https://github.com/pytorch/xla). While Colab provides a free Cloud TPU, training is even faster on [Google Cloud Platform](https://github.com/pytorch/xla#Cloud), especially when using multiple Cloud TPUs in a Cloud TPU pod. Scaling from a single Cloud TPU, like in this notebook, to many Cloud TPUs in a pod is easy, too. You use the same code as this notebook and just spawn more processes.

In [None]:
!ls -lh