In [None]:
!pip install cloud-tpu-client==0.10 https://storage.googleapis.com/tpu-pytorch/wheels/torch_xla-1.9-cp37-cp37m-linux_x86_64.whl

In [None]:
import torch_xla
import torch_xla.core.xla_model as xm

import torch
import numpy as np
from torch import nn
from tqdm import tqdm
from torchvision import transforms, datasets, models
from torch.utils.data import random_split, DataLoader

In [2]:
device = xm.xla_device() # to use a single core
n_epochs = 1
img_size = 224
batch_size = 64
num_classes = 10
lr = 3e-4
T = transforms.Compose(
    [
     transforms.Resize((img_size, img_size)),
     transforms.ToTensor()
    ]
)
print(device)

xla:1


In [3]:
data = datasets.CIFAR10("data/", train=True, download=True, transform=T)
test_data = datasets.CIFAR10("data/", train=False, download=True, transform=T)

val_len = int(0.3 * len(data))
train_data, val_data = random_split(data, [len(data) - val_len, val_len])

train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True)
val_loader = DataLoader(val_data, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True)
test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True)

x, y = next(iter(train_loader))
print(len(train_data), len(val_data), len(test_data), x.shape, y.shape)

Files already downloaded and verified
Files already downloaded and verified
35000 15000 10000 torch.Size([64, 3, 224, 224]) torch.Size([64])


In [4]:
net = models.resnet18(pretrained=False)
net.fc = nn.Linear(net.fc.in_features, num_classes)
net.to(device)

inp = torch.randn(1, 3, 224, 224).to(device)
out = net(inp)
print(out.shape)

torch.Size([1, 10])


In [5]:
optimizer = torch.optim.Adam(net.parameters(), lr=lr)
loss_fn = nn.CrossEntropyLoss()
def get_accuracy(preds, y):
    preds = preds.argmax(1)
    num_correct = (preds == y).sum().item()
    acc = num_correct / y.shape[0]
    return acc

In [6]:
def loop(net, loader, is_train, epoch=None):
    net.train(is_train)
    losses = []
    accs = []
    if is_train:
        split = 'train'
    else:
        split = ' val '

    pbar = tqdm(loader, total=len(loader))
    for x, y in pbar:
        x = x.to(device)
        y = y.to(device)

        with torch.set_grad_enabled(is_train):
            preds = net(x)
            loss = loss_fn(preds, y)
            acc = get_accuracy(preds, y)
            losses.append(loss.item())
            accs.append(acc)

        if is_train:
            optimizer.zero_grad()
            loss.backward()
            xm.optimizer_step(optimizer, barrier=True) # tpu-specific code
        
        if epoch != None:
            pbar.set_description(f'{split}: epoch={epoch}, loss={np.mean(losses):.4f}, acc={np.mean(accs):.4f}')
        else:
            pbar.set_description(f'loss={np.mean(losses):.4f}, acc={np.mean(accs):.4f}')

In [7]:
for epoch in range(n_epochs):
    loop(net, train_loader, True, epoch)
    loop(net, val_loader, False, epoch)

train: epoch=0, loss=1.3215, acc=0.5198: 100%|██████████| 547/547 [05:18<00:00,  1.72it/s]
 val : epoch=0, loss=1.1623, acc=0.5922: 100%|██████████| 235/235 [02:11<00:00,  1.79it/s]


In [8]:
loop(net, test_loader, False)

loss=1.1700, acc=0.5891: 100%|██████████| 157/157 [01:24<00:00,  1.86it/s]
