<a href="https://colab.research.google.com/github/rlandingin/fastai_xla_extensions/blob/multi-core-impl/nbs/03_multi_core.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
#default_exp multi_core

# Multi Core XLA extensions

## Setup torch XLA


This is the official way to install Pytorch-XLA 1.7 [instructions here](https://colab.research.google.com/github/pytorch/xla/blob/master/contrib/colab/getting-started.ipynb#scrollTo=CHzziBW5AoZH)

In [2]:
#colab
!pip install -Uqq cloud-tpu-client==0.10 https://storage.googleapis.com/tpu-pytorch/wheels/torch_xla-1.7-cp36-cp36m-linux_x86_64.whl

[K     |████████████████████████████████| 133.6MB 31kB/s 
[K     |████████████████████████████████| 61kB 3.6MB/s 
[?25h

## Install fastai

Use latest fastai and fastcore versions

In [3]:
#colab
!pip install -Uqq git+https://github.com/fastai/fastai.git 

[?25l[K     |██████▏                         | 10kB 20.3MB/s eta 0:00:01[K     |████████████▍                   | 20kB 12.0MB/s eta 0:00:01[K     |██████████████████▌             | 30kB 9.9MB/s eta 0:00:01[K     |████████████████████████▊       | 40kB 8.4MB/s eta 0:00:01[K     |██████████████████████████████▉ | 51kB 5.1MB/s eta 0:00:01[K     |████████████████████████████████| 61kB 3.5MB/s 
[?25h  Building wheel for fastai (setup.py) ... [?25l[?25hdone


In [4]:
#hide
#colab
!curl -s https://course19.fast.ai/setup/colab | bash

Updating fastai...
Done.


In [5]:
#hide
!pip freeze | grep torch
!pip freeze | grep fast

torch==1.7.0+cu101
torch-xla==1.7
torchsummary==1.5.1
torchtext==0.3.1
torchvision==0.8.1+cu101
fastai==2.2.5
fastcore==1.3.18
fastdtw==0.3.4
fastprogress==1.0.0
fastrlock==0.5


## Patching BaseOptimizer to be Pickable
Patching Base Optimizer `__getstate__` and `__setstate__` whichi is used in pickling
the optimizer which should fix the bug in running the learner in multiple TPU cores
in XLA by which the  `def _fetch_gradients(optimizer)` in `for param_group in optimizer.__getstate__()['param_groups']:` fails, and this patch fixes the "copy constructor" to include the param_groups.

In [6]:
#export
from fastcore.basics import patch_to
from fastai.optimizer import _BaseOptimizer

@patch_to(_BaseOptimizer)
def __getstate__(self):
    d = {
            'state': self.state_dict(),
            'param_groups': self.param_groups,
        }
    if hasattr(self,'defaults'): 
        d['defaults'] = self.defaults
    return d

@patch_to(_BaseOptimizer)
def __setstate__(self, data):
    if 'defaults' in data:
        self.defaults = data['defaults']
    self.load_state_dict(data['state'])
    self.param_groups = data['param_groups']

In [7]:
#exporti
import numpy as np
import os
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.debug.metrics as met
import torch_xla.distributed.parallel_loader as pl
import torch_xla.distributed.xla_multiprocessing as xmp
import torch_xla.utils.utils as xu
import torchvision
from torchvision import datasets, transforms
import torch.utils.data as th_data
from fastcore.foundation import L
from pathlib import Path
from fastcore.xtras import *
from fastcore.transform import Pipeline
from fastai.data.core import DataLoaders
from functools import partial
import torch.utils.data.distributed as torch_distrib

from pathlib import Path
import fastcore.xtras
import math
from fastcore.basics import store_attr
from operator import attrgetter
from fastai.data.load import _FakeLoader
from fastai.data.core import TfmdDL
from fastai.torch_core import find_bs, TensorBase
import random
import torch
from fastai.data.load import _loaders
from fastai.torch_core import to_device
from fastcore.basics import first


In [18]:
#export
def _recast2tensor(o):
    if isinstance(o,TensorBase):
        # return plain tensor since pl.parallelloader doesn't
        # seem to work with tensor subclasses
        return torch.tensor(o.numpy())
    return o

def _round_to_multiple(number,multiple): 
    return int(math.ceil(number/multiple)*multiple)

class TPUDistributedDL(TfmdDL):
    "A `TfmdDL` which splits a batch into equal size pieces for each TPU core"
    _default = 'dl'
    def __init__(self,dl,rank,world_size, seed=42):
        store_attr()
        self.bs,self.device,self.num_workers,self.drop_last,self.dataset,self.offs,fake = \
            attrgetter('bs','device','num_workers','drop_last','dataset','offs','fake_l')(dl)
        self.fake_l = _FakeLoader(self, fake.pin_memory, fake.num_workers, fake.timeout, 
                                  persistent_workers=fake.persistent_workers)
        self.epoch = 0
        random.seed(self.seed)
        self.dl.rng = random.Random(random.randint(0,2**32-1))
        self.reset_rng()

    def reset_rng(self):
        random.seed(self.seed + self.epoch)
        self.rng = random.Random(random.randint(0,2**32-1))

    def __len__(self): 
        return _round_to_multiple(len(self.dl),self.world_size)//self.world_size

    def set_epoch(self, epoch):
        self.epoch = epoch

    def get_idxs(self):
        idxs = self.dl.get_idxs()
        # do your own shuffling which factors in self.epoch + self.seed in
        # generating a random sequence (underlying self.dl does not)
        if self.shuffle: 
            idxs = self.shuffle_fn(idxs)
        self.n = len(idxs)              
        # we assumed n was dl.n but we really care about number of idxs
        # add extra samples to make it evenly divisible
        self.n_padded = _round_to_multiple(self.n,self.world_size)
        idxs += (idxs * (self.n_padded//self.n))[:self.n_padded-self.n] 
        # idx needs to be repeated when n_padded>>n
        # slice padded idxs so that each rank gets self.n_padded//self.world_size tensors
        start_pos = self.rank*self.n_padded//self.world_size
        end_pos = (self.rank+1)*self.n_padded//self.world_size
        return idxs[start_pos:end_pos]

    def before_iter(self):
        self.dl.before_iter()

    def randomize(self): 
        self.reset_rng()
        self.dl.randomize()

    def after_batch(self,b):
        b = self.dl.after_batch(b)
        # recast tensor subclasses to plain tensors
        # undoing work of self.retain()
        tb = [_recast2tensor(o) for o in b]
        b = tuple(tb)
        return b

    def after_iter(self): 
        self.dl.after_iter()

    def create_batches(self,samps): 
        return self.dl.create_batches(samps)

    def to(self, device):
        self.dl.device = device
        self.device = device
        return self



In [19]:
#export
def make_distributed_dataloaders(dls, rank, world_size):
    new_loaders = []
    for i,dl in enumerate(dls.loaders):
        if i == 0:
            use_rank = rank
            use_size = world_size
        else: 
            # for now, in validation, use all samples since only rank 0 computes
            # the valid loss and metrics
            # TODO: figure out a way to consolidate valid loss and metrics across
            # ranks to make distribute batches across multi cores (and reduce number
            # of batches per rank -- which should speed up validation)
            use_rank = 0
            use_size = 1         
        dl = TPUDistributedDL(dl,
                            rank=use_rank, 
                            world_size=use_size)
        new_loaders += [dl]
    return DataLoaders(*new_loaders, path=dls.path, device=dls.device)

In [20]:
#exporti
from fastai.torch_core import default_device, apply
import torch 
from fastcore.xtras import is_listy
import torch
import torch.utils.hooks
from fastcore.basics import patch
from fastai.torch_core import TensorBase
from collections import OrderedDict

In [22]:
#export
def wrap_parallel_loader(loader, device):
    para_loader = pl.ParallelLoader(loader, [device])
    loop_loader = para_loader.per_device_loader(device)
    return loop_loader

In [23]:
#exporti
from fastai.callback.core import TrainEvalCallback
from fastai.learner import Recorder
from fastai.torch_core import one_param
import torch
from fastai.callback.core import Callback
from fastai.learner import CancelTrainException, CancelValidException, CancelStepException
from fastai.torch_core import tensor, TensorCategory

In [27]:
#export
class XLATrainingCallback(Callback):
    run_before = Recorder
    run_valid = False
    order = -10 # same as TrainEvalCallback (since this replaces TrainEvalCallback)
    def __init__(self, device, rank=0):
        self.pdevice = device
        self.rank = rank

    def after_create(self):
        self.learn.n_epoch = 1  

    def before_fit(self):
        "Set the iter and epoch counters to 0, put the model and the right device"
        self.learn.epoch,self.learn.loss = 0,tensor(0.)
        self.learn.train_iter,self.learn.pct_train = 0,0.
        if hasattr(self.dls, 'device'): self.model.to(self.dls.device)
        if hasattr(self.model, 'reset'): self.model.reset()
        xm.master_print(' ')

    def before_epoch(self):
        # set the epoch on train only to make sure shuffle produces same seq 
        # across all ranks
        if hasattr(self.learn.dls.train,'sampler'):
            if hasattr(self.learn.dls.train.sampler,'set_epoch'):
                self.learn.dls.train.sampler.set_epoch(self.learn.epoch) 
        elif hasattr(self.learn.dls.train,'set_epoch'):
            self.learn.dls.train.set_epoch(self.learn.epoch)

    def before_train(self):
        "Set the model in training mode"
        self.learn.pct_train=self.epoch/self.n_epoch
        self.model.train()
        self.learn.training=True
        self.learn.dl = wrap_parallel_loader(self.dls.train, self.pdevice)

    def before_validate(self):
        "Set the model in validation mode"
        if self.rank != 0: # no need to compute valid loss/ metric if not master
            raise CancelValidException()    
        self.model.eval()
        self.learn.training=False
        self.learn.dl = wrap_parallel_loader(self.dls.valid, self.pdevice)

    def before_step(self):
        raise CancelStepException()

    def after_cancel_step(self):
        xm.optimizer_step(self.learn.opt)

    def after_batch(self):
        "Update the iter counter (in training mode)"
        self.learn.pct_train += 1./(self.n_iter*self.n_epoch)
        self.learn.train_iter += 1  


In [28]:
#exporti
from fastcore.imports import noop
from fastcore.basics import patch
from fastai.learner import Learner
from fastai.callback.progress import ProgressCallback
from fastcore.xtras import join_path_file
from fastai.torch_core import get_model


In [None]:
#export

@patch
def save(self:Learner, file, **kwargs):
    file = join_path_file(file, self.path/self.model_dir, ext='.pth')
    with_opt = self.opt is not None  
    state = self.model.state_dict()
    if with_opt:
        opt_state = self.opt.state_dict() 
        state = {'model': state, 'opt':opt_state}
    xm.save(state, file) # use xm.save instead of torch.save
    return file


In [30]:
#export 

@patch
def to_xla(self:Learner,device, rank):
    if 'xla_training' not in self.cbs.attrgot('name'):
        self.dls.device = None
        self.add_cbs(XLATrainingCallback(device, rank))
    else:
        self.xla_training.pdevice = device
        self.xla_training.rank = rank

    self.remove_cbs(TrainEvalCallback) # replace TrainEval with XLATraining

    if rank != 0:
        self.remove_cbs(ProgressCallback)
    self.logger = xm.master_print

In [31]:
#export

# def DataBlock.dataloaders(self, source, path='.', verbose=False, **kwargs):
def build_dataloaders(datablock, source, rank, world_size, device=None, path='.', verbose=False,**kwargs):
    dls = datablock.dataloaders(source=source, path=path, device=device, **kwargs)
    distrib_dls = make_distributed_dataloaders(dls, rank, world_size)
    return distrib_dls


In [32]:
#exporti
from fastcore.basics import store_attr

In [33]:
#export
class ExtendedModel:
    def __init__(self, arch, normalize, n_out, pretrained):
        # store_attr()
        self.arch = arch
        self.normalize = normalize
        self.n_out = n_out
        self.pretrained = pretrained

In [34]:
#exporti
from fastai.data.transforms import get_c
from fastai.vision.learner import create_cnn_model

In [35]:
#export
def xla_cnn_model(arch,
                  n_out,
                  normalize=True,  
                  pretrained=True, 
                **kwargs):
    "Build a convnet style learner from `dls` and `arch`"
    assert n_out, "`n_out` is not defined, and could not be inferred from data, set `dls.c` or pass `n_out`"
    # set concat_pool to false because AdaptiveConcatPool not supported in XLA
    if 'concat_pool' in kwargs:
        kwargs.pop('concat_pool',None)
    model = create_cnn_model(arch, n_out, pretrained=pretrained, concat_pool=False, **kwargs)
    ext_model = ExtendedModel(arch, normalize, n_out, pretrained)
    return ext_model, model
    

In [36]:
#exporti
from fastai.optimizer import Adam
from fastai.learner import defaults
from fastai.vision.learner import model_meta, _add_norm, _default_meta
from fastcore.basics import ifnone

In [37]:
#export
def xla_cnn_learner(dls, 
                    ext_model, 
                    model,
                    loss_func=None, 
                    opt_func=Adam, 
                    lr=defaults.lr, 
                    splitter=None, 
                    cbs=None, 
                    metrics=None, 
                    path=None,
                    model_dir='models', 
                    wd=None, 
                    wd_bn_bias=False, 
                    train_bn=True, 
                    moms=(0.95,0.85,0.95),
                    # other model args
                    **kwargs):
    "Build a convnet style learner from `dls` and `ext_model`"

    meta = model_meta.get(ext_model.arch, _default_meta)
    if ext_model.normalize: _add_norm(dls, meta, ext_model.pretrained)

    assert ext_model.n_out is not None, "`n_out` is not defined please pass `n_out`"
    # device = dls.device if hasattr(dls,'device') and dls.device is not None else xm.xla_device()
    # device = xm.xla_device()
    # model = ext_model.model.to(device) # xmp wrapped model 
    splitter=ifnone(splitter, meta['split'])
    learn = Learner(dls=dls, model=model, loss_func=loss_func, opt_func=opt_func, lr=lr, splitter=splitter, cbs=cbs,
                   metrics=metrics, path=path, model_dir=model_dir, wd=wd, wd_bn_bias=wd_bn_bias, train_bn=train_bn,
                   moms=moms)
    if ext_model.pretrained: learn.freeze()
    # keep track of args for loggers
    # store_attr('arch,normalize,n_out,pretrained', self=learn, **kwargs)
    learn.arch = ext_model.arch
    learn.normalize = ext_model.normalize
    learn.n_out = ext_model.n_out
    learn.pretrained = ext_model.pretrained
    return learn


## Test out the code


In [38]:
#hide
from functools import partial
from fastai.metrics import accuracy
from fastai.optimizer import SGD, Adam

from fastcore.basics import first
from fastai.callback.schedule import *
from fastai.test_utils import VerboseCallback


In [42]:
def train_learner(rank):
    torch.manual_seed(1)

    # Scale learning rate to num cores
    learning_rate = FLAGS['learning_rate'] * xm.xrt_world_size()

    # Get loss function, optimizer, and model
    device = xm.xla_device()
    model = WRAPPED_MODEL.to(device)
    bs = FLAGS['batch_size']
    world_size = xm.xrt_world_size()
    dls = build_dataloaders(DATA, PATH, rank, world_size, bs=bs)
    # learner = Learner(dls, model, 
    #                   loss_func=LOSS_FUNC, 
    #                   opt_func=OPT_FUNC, 
    #                   metrics=accuracy, 
    #                   wd=5e-4,
    #                   moms=(FLAGS['momentum'],FLAGS['momentum'],FLAGS['momentum']))
    learner = xla_cnn_learner(dls, 
                              EXT_MODEL, 
                              model, 
                              loss_func=LOSS_FUNC, 
                              opt_func=OPT_FUNC, 
                              metrics=accuracy, 
                              moms=(FLAGS['momentum'],FLAGS['momentum'],FLAGS['momentum']),
                              wd=5e-4)
                      
    learner.to_xla(device, rank=xm.get_ordinal())
                           
    epochs = FLAGS['num_epochs']

    learner.unfreeze()
    learner.fit_one_cycle(epochs, lr_max=slice(learning_rate/10))

    # learner.fine_tune(4, base_lr=learning_rate/10)    
    
    learner.save('stage-1')  


In [43]:
# Start training processes
def _mp_fn(rank, flags):
    global FLAGS
    FLAGS = flags
    train_learner(rank)


In [44]:
#hide
import torch
from fastcore.transform import DisplayedTransform, Transform
from fastcore.basics import store_attr
from fastai.vision.core import PILImage, PILBase, image2tensor
from fastai.data.block import TransformBlock

In [45]:
#hide
from fastai.data.transforms import get_c
# from fastai.vision.all import *
from fastai.data.block import DataBlock, CategoryBlock
from fastai.vision.data import ImageBlock
from fastai.data.transforms import get_image_files, parent_label, GrandparentSplitter
from fastai.vision.augment import Resize, aug_transforms
from fastai.data.external import untar_data, URLs
from fastai.data.transforms import Normalize
from fastai.vision.core import imagenet_stats

In [46]:
LOSS_FUNC = nn.CrossEntropyLoss()

In [47]:
OPT_FUNC = Adam

In [48]:
DATA = DataBlock(
    blocks=(ImageBlock, CategoryBlock),
    get_items=get_image_files,
    get_y=parent_label,
    splitter=GrandparentSplitter(valid_name='testing', train_name='training'),
    item_tfms=[Resize(28),],
    batch_tfms=[]
)

In [49]:
#hide
# Use fastai resnet18
from fastai.vision.learner import create_cnn_model
from fastai.vision.models import resnet18

In [50]:
#hide
from pathlib import Path
from fastcore.xtras import *


In [105]:
import os
# Define Parameters
FLAGS = {}
# FLAGS['batch_size'] = 1024
FLAGS['batch_size'] = 128
FLAGS['num_workers'] = 4
FLAGS['learning_rate'] = 2e-3

FLAGS['momentum'] = 0.9
FLAGS['num_epochs'] = 5
FLAGS['num_cores'] = 8 if os.environ.get('TPU_NAME', None) else 1
# FLAGS['num_cores'] = 1 

In [106]:
PATH = untar_data(URLs.MNIST)
# PATH = untar_data(URLs.MNIST_TINY)
mdls = DATA.dataloaders(PATH, bs=FLAGS['batch_size'])

In [107]:
DATA.summary(PATH)

Setting-up type transforms pipelines
Collecting items from /root/.fastai/data/mnist_png
Found 70000 items
2 datasets of sizes 60000,10000
Setting up Pipeline: PILBase.create
Setting up Pipeline: parent_label -> Categorize -- {'vocab': None, 'sort': True, 'add_na': False}

Building one sample
  Pipeline: PILBase.create
    starting from
      /root/.fastai/data/mnist_png/training/3/12933.png
    applying PILBase.create gives
      PILImage mode=RGB size=28x28
  Pipeline: parent_label -> Categorize -- {'vocab': None, 'sort': True, 'add_na': False}
    starting from
      /root/.fastai/data/mnist_png/training/3/12933.png
    applying parent_label gives
      3
    applying Categorize -- {'vocab': None, 'sort': True, 'add_na': False} gives
      TensorCategory(3)

Final sample: (PILImage mode=RGB size=28x28, TensorCategory(3))


Collecting items from /root/.fastai/data/mnist_png
Found 70000 items
2 datasets of sizes 60000,10000
Setting up Pipeline: PILBase.create
Setting up Pipeline: paren

In [108]:
EXT_MODEL, custom_model = xla_cnn_model(resnet18,
                                        n_out=get_c(mdls), 
                                        pretrained=True, 
                                        normalize=False)

In [109]:

# custom_model = create_cnn_model(resnet18, get_c(mdls), 
#                                 pretrained=True,
#                                 concat_pool=False)


In [110]:
# Only instantiate model weights once in memory.
WRAPPED_MODEL = xmp.MpModelWrapper(custom_model)

In [111]:
SERIAL_EXEC = xmp.MpSerialExecutor()

In [112]:
#colab
%%time
!rm -f /content/models/stage-1.pth
xmp.spawn(_mp_fn, args=(FLAGS,), nprocs=FLAGS['num_cores'],
          start_method='fork')

 


epoch,train_loss,valid_loss,accuracy,time
0,1.341644,0.330447,0.8989,00:26
1,0.422134,0.05414,0.9822,00:17
2,0.171844,0.040194,0.9859,00:17
3,0.078365,0.036932,0.9875,00:17
4,0.044823,0.037975,0.9877,00:13


CPU times: user 468 ms, sys: 325 ms, total: 793 ms
Wall time: 1min 47s


In [113]:
# mlearner = Learner(mdls, custom_model, 
#                     loss_func=LOSS_FUNC, 
#                     opt_func=OPT_FUNC, 
#                     metrics=accuracy, 
#                     wd=5e-4,
#                     moms=(FLAGS['momentum'],FLAGS['momentum'],FLAGS['momentum']))
mlearner = xla_cnn_learner(mdls, 
                      EXT_MODEL, 
                      custom_model, 
                      loss_func=LOSS_FUNC, 
                      opt_func=OPT_FUNC, 
                      metrics=accuracy, 
                      moms=(FLAGS['momentum'],FLAGS['momentum'],FLAGS['momentum']),
                      wd=5e-4)


<fastai.learner.Learner at 0x7f1d793143c8>

In [None]:
#colab                    
mlearner.load('stage-1')

In [114]:
mlearner.dls.device

device(type='cpu')

In [115]:
from fastai.torch_core import one_param

In [116]:
one_param(mlearner.model).device

device(type='cpu')

In [117]:
#colab
%%time
valid_metrics = mlearner.validate();print(valid_metrics)

[0.03807949647307396, 0.9876999855041504]
CPU times: user 1min 38s, sys: 3.28 s, total: 1min 41s
Wall time: 5.69 s


In [118]:
# master_device = xm.xla_device()

In [119]:
# mlearner.dls.device = master_device
# mlearner.model.to(master_device)
# mlearner.opt = None
# mlearner.create_opt()

In [120]:
# %%time
# valid_metrics = mlearner.validate(); valid_metrics

In [121]:
# mlearner.dls.device

In [122]:
# one_param(mlearner.model).device