In [None]:
!pip install cloud-tpu-client==0.10 https://storage.googleapis.com/tpu-pytorch/wheels/torch_xla-1.9-cp37-cp37m-linux_x86_64.whl

In [1]:
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.distributed.parallel_loader as dpl
import torch_xla.distributed.xla_multiprocessing as xmp

import torch
import numpy as np
from torch import nn
from tqdm import tqdm
from torchvision import transforms, datasets, models
from torch.utils.data import random_split, DataLoader



In [2]:
# use brain float16 datatype
!export XLA_USE_BF16=1

In [3]:
serial_exec = xmp.MpSerialExecutor()
flags = {
    'n_epochs': 1,
    'batch_size': 64, # batch_size will be scaled by num_cores times
    'lr': 3e-4,
    'num_cores': 8
}

In [4]:
def get_data():
    T = transforms.Compose(
        [
        transforms.Resize((224, 224)),
        transforms.ToTensor()
        ]
    )
    data = datasets.CIFAR10("data/", train=True, download=False, transform=T)
    test_data = datasets.CIFAR10("data/", train=False, download=False, transform=T)
    val_len = int(0.3 * len(data))
    train_data, val_data = random_split(data, [len(data) - val_len, val_len])
    return train_data, val_data, test_data

In [5]:
model = models.resnet18(pretrained=False)
model.fc = nn.Linear(model.fc.in_features, 10)
# model = xmp.MpModelWrapper(model) # not necessary

In [6]:
loss_fn = nn.CrossEntropyLoss()
def get_accuracy(preds, y):
    preds = preds.argmax(1)
    num_correct = (preds == y).sum().item()
    acc = num_correct / y.shape[0]
    return acc

In [7]:
def loop(net, loader, is_train, optimizer=None):
    net.train(is_train)
    losses = []
    accs = []
    # if is_train:
    #     split = 'train'
    # else:
    #     split = ' val '

    # pbar = tqdm(loader, total=len(loader)) # tqdm bar is kindof glitchy because the data is split across multiple cores
    for x, y in loader:
        with torch.set_grad_enabled(is_train):
            preds = net(x)
            loss = loss_fn(preds, y)
            acc = get_accuracy(preds, y)
            losses.append(loss.item())
            accs.append(acc)

        if is_train:
            optimizer.zero_grad()
            loss.backward()
            xm.optimizer_step(optimizer) # tpu-specific code
        
        # if epoch != None:
        #     pbar.set_description(f'{split}: epoch={epoch}, loss={np.mean(losses):.4f}, acc={np.mean(accs):.4f}')
        # else:
        #     pbar.set_description(f'loss={np.mean(losses):.4f}, acc={np.mean(accs):.4f}')
    
    return np.mean(losses), np.mean(accs)

In [8]:
def run(flags):
    xm.master_print('grabbing the data...') # usual print prints individual outputs from all the cores, master_print prints only one output which is gathered from all the cores
    train_data, val_data, test_data = serial_exec.run(get_data)

    xm.master_print('creating the dataloaders...')
    # xm.get_ordinal(): current core, xm.xrt_world_size(): num cores
    train_sampler = torch.utils.data.distributed.DistributedSampler(train_data, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal(), shuffle=True)
    train_loader = DataLoader(train_data, batch_size=flags['batch_size'], sampler=train_sampler, drop_last=True)

    val_sampler = torch.utils.data.distributed.DistributedSampler(val_data, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal(), shuffle=False)
    val_loader = DataLoader(val_data, batch_size=flags['batch_size'], sampler=val_sampler)

    test_sampler = torch.utils.data.distributed.DistributedSampler(test_data, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal(), shuffle=False)
    test_loader = DataLoader(test_data, batch_size=flags['batch_size'], sampler=test_sampler)

    device = xm.xla_device()
    xm.master_print('creating the model...')
    curr_model = model.to(device)
    new_lr = flags['lr'] * xm.xrt_world_size() # coz the batch_size is scaled xm.xrt_world_size() times
    optimizer = torch.optim.Adam(curr_model.parameters(), lr=new_lr)

    xm.master_print('starting the training...')
    for epoch in range(flags['n_epochs']):
        train_para_loader = dpl.ParallelLoader(train_loader, [device])
        train_loss, train_acc = loop(curr_model, train_para_loader.per_device_loader(device), True, optimizer)

        val_para_loader = dpl.ParallelLoader(val_loader, [device])
        val_loss, val_acc = loop(curr_model, val_para_loader.per_device_loader(device), False)

        xm.master_print(f'epoch={epoch}, train_loss={train_loss:.4f}, train_acc={train_acc:.4f}, val_loss={val_loss:.4f}, val_acc={val_acc:.4f}')

    test_para_loader = dpl.ParallelLoader(test_loader, [device])
    test_loss, test_acc = loop(curr_model, test_para_loader.per_device_loader(device), False)
    xm.master_print(f'test_loss={test_loss:.4f}, test_acc={test_acc:.4f}')

    xm.master_print('saving the model weights...')
    xm.save(curr_model.state_dict(), 'weights.pth')

In [9]:
def map_fn(rank, flags):
    '''
    rank: current tpu core
    flags: training args
    '''
    torch.set_default_tensor_type('torch.FloatTensor')
    run(flags)

In [10]:
xmp.spawn(map_fn, args=(flags,), nprocs=flags['num_cores'], start_method='fork')

grabbing the data...
creating the dataloaders...
creating the model...
starting the training...
epoch=0, train_loss=1.8077, train_acc=0.3297, val_loss=1.9125, val_acc=0.3456
test_loss=1.8780, test_acc=0.3368
saving the model weights...


In [11]:
model.load_state_dict(torch.load('weights.pth'))

<All keys matched successfully>