In [None]:
!pip install pytorch-lightning cloud-tpu-client==0.10 https://storage.googleapis.com/tpu-pytorch/wheels/torch_xla-1.9-cp37-cp37m-linux_x86_64.whl

In [2]:
import torch
from torch import nn
import pytorch_lightning as pl
import torch_xla.core.xla_model as xm
from torchmetrics.functional import accuracy
from torchvision import transforms, datasets, models
from torch.utils.data import random_split, DataLoader



In [3]:
pl.seed_everything(42)
max_epochs = 2
img_size = 224
val_pct = 0.2
batch_size = 64
lr = 3e-4

Global seed set to 42


In [4]:
class CIFAR10DataModule(pl.LightningDataModule):
    def __init__(self, data_dir, img_size, val_pct, batch_size):
        super().__init__()
        self.data_dir = data_dir
        self.T = transforms.Compose(
                    [
                    transforms.Resize((img_size, img_size)),
                    transforms.ToTensor()
                    ]
                )
        self.val_pct = val_pct
        self.batch_size = batch_size

    def prepare_data(self):
        datasets.CIFAR10(self.data_dir, train=True, download=False)
        datasets.CIFAR10(self.data_dir, train=False, download=False)

    def setup(self, stage=None):
        if stage in ('fit', 'validate'):
            data = datasets.CIFAR10(self.data_dir, train=True, transform=self.T)
            val_len = int(self.val_pct * len(data))
            self.train_data, self.val_data = random_split(data, [len(data) - val_len, val_len])
        elif stage in ('test', 'predict'):    
            self.test_data = datasets.CIFAR10(self.data_dir, train=False, transform=self.T)

    def get_dataloader(self, data, split):
        if split == 'train':
            shuffle = True
        else:
            shuffle = False
        sampler = torch.utils.data.distributed.DistributedSampler(data, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal(), shuffle=shuffle)
        return DataLoader(data, batch_size=self.batch_size, sampler=sampler)

    def train_dataloader(self):
        return self.get_dataloader(self.train_data, 'train')

    def val_dataloader(self):
        return self.get_dataloader(self.val_data, 'val')

    def test_dataloader(self):
        return self.get_dataloader(self.test_data, 'test')

In [5]:
class Model(pl.LightningModule):
    def __init__(self, lr):
        super().__init__()
        self.resnet = models.resnet18(pretrained=False)
        self.resnet.fc = nn.Linear(self.resnet.fc.in_features, 10)

        self.lr = lr
        self.loss_fn = nn.CrossEntropyLoss()

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.lr)

    def forward(self, x):
        return self.resnet(x)
        
    def shared_step(self, batch, split):
        x, y = batch
        preds = self(x)
        loss = self.loss_fn(preds, y)
        acc = accuracy(preds, y)
        self.log(f'{split}_loss', loss, on_epoch=True, prog_bar=True)
        self.log(f'{split}_acc', acc, on_epoch=True, prog_bar=True)
        if split == 'train':
            return loss

    def training_step(self, batch, batch_idx):
        loss = self.shared_step(batch, 'train')
        return loss

    def validation_step(self, batch, batch_idx):
        self.shared_step(batch, 'val')

    def test_step(self, batch, batch_idx):
        self.shared_step(batch, 'test')

In [6]:
cifar10_dm = CIFAR10DataModule('data/', img_size, val_pct, batch_size)
model = Model(lr)

In [7]:
# to use a single core, set tpu_cores=1
trainer = pl.Trainer(max_epochs=max_epochs, tpu_cores=8, precision=16)
trainer.fit(model, datamodule=cifar10_dm)

GPU available: False, used: False
TPU available: True, using: 8 TPU cores
Global seed set to 42
Global seed set to 42
Global seed set to 42
Global seed set to 42
Global seed set to 42
Global seed set to 42
Global seed set to 42
Global seed set to 42

  | Name    | Type             | Params
---------------------------------------------
0 | resnet  | ResNet           | 11.2 M
1 | loss_fn | CrossEntropyLoss | 0     
---------------------------------------------
11.2 M    Trainable params
0         Non-trainable params
11.2 M    Total params
22.363    Total estimated model params size (MB)


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…

Global seed set to 42
Global seed set to 42
Global seed set to 42
Global seed set to 42
Global seed set to 42
Global seed set to 42
Global seed set to 42
Global seed set to 42




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

  "Relying on `self.log('val_loss', ...)` to set the ModelCheckpoint monitor is deprecated in v1.2"


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…




  rank_zero_warn("cleaning up ddp environment...")


In [8]:
trainer.test(model, datamodule=cifar10_dm)

Global seed set to 42
Global seed set to 42
Global seed set to 42
Global seed set to 42
Global seed set to 42
Global seed set to 42
Global seed set to 42
Global seed set to 42


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Testing', layout=Layout(flex='2'), max=…


--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'test_acc': 0.3768046796321869, 'test_loss': 1.7515374422073364}
--------------------------------------------------------------------------------


  rank_zero_warn("cleaning up ddp environment...")


[{'test_acc': 0.3768046796321869, 'test_loss': 1.7515374422073364}]

In [12]:
torch.save(model.resnet.state_dict(), 'weights.pth')

In [13]:
resnet = models.resnet18(pretrained=False)
resnet.fc = nn.Linear(resnet.fc.in_features, 10)
resnet.load_state_dict(torch.load('weights.pth'))

<All keys matched successfully>