## Setup

In [1]:
! pip install --quiet "torchvision" "torch>=1.6, <1.9" "torchmetrics>=0.3" "ipython[notebook]" "pytorch-lightning>=1.3"

[K     |████████████████████████████████| 804.1 MB 2.4 kB/s 
[K     |████████████████████████████████| 397 kB 51.3 MB/s 
[K     |████████████████████████████████| 527 kB 48.0 MB/s 
[K     |████████████████████████████████| 134 kB 55.9 MB/s 
[K     |████████████████████████████████| 952 kB 52.4 MB/s 
[K     |████████████████████████████████| 596 kB 51.1 MB/s 
[K     |████████████████████████████████| 829 kB 58.2 MB/s 
[K     |████████████████████████████████| 1.1 MB 38.7 MB/s 
[K     |████████████████████████████████| 21.0 MB 8.8 MB/s 
[K     |████████████████████████████████| 23.2 MB 1.5 MB/s 
[K     |████████████████████████████████| 23.3 MB 1.5 MB/s 
[K     |████████████████████████████████| 23.3 MB 1.5 MB/s 
[K     |████████████████████████████████| 22.1 MB 1.4 MB/s 
[K     |████████████████████████████████| 22.1 MB 12.3 MB/s 
[K     |████████████████████████████████| 17.4 MB 546 kB/s 
[K     |████████████████████████████████| 271 kB 39.8 MB/s 
[K     |█████████████

In [None]:
! pip install tf-estimator-nightly==2.8.0.dev2021122109


## Install Colab TPU compatible PyTorch/TPU wheels and dependencies

In [2]:
! pip install cloud-tpu-client==0.10 https://storage.googleapis.com/tpu-pytorch/wheels/torch_xla-1.8-cp37-cp37m-linux_x86_64.whl

Collecting torch-xla==1.8
  Downloading https://storage.googleapis.com/tpu-pytorch/wheels/torch_xla-1.8-cp37-cp37m-linux_x86_64.whl (144.6 MB)
[K     |████████████████████████████████| 144.6 MB 88 kB/s 
[?25hCollecting cloud-tpu-client==0.10
  Downloading cloud_tpu_client-0.10-py3-none-any.whl (7.4 kB)
Collecting google-api-python-client==1.8.0
  Downloading google_api_python_client-1.8.0-py3-none-any.whl (57 kB)
[K     |████████████████████████████████| 57 kB 2.3 MB/s 
Installing collected packages: google-api-python-client, torch-xla, cloud-tpu-client
  Attempting uninstall: google-api-python-client
    Found existing installation: google-api-python-client 1.12.10
    Uninstalling google-api-python-client-1.12.10:
      Successfully uninstalled google-api-python-client-1.12.10
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
earthengine-api 0.1.301 req

In [3]:
import torch
import torch.nn.functional as F
from pytorch_lightning import LightningDataModule, LightningModule, Trainer
from torch import nn
from torch.utils.data import DataLoader, random_split
from torchmetrics.functional import accuracy
from torchvision import transforms

# Note - you must have torchvision installed for this example
from torchvision.datasets import MNIST

BATCH_SIZE = 1024

ImportError: ignored

### Define dataloader

In [None]:
class MNISTDataModule(LightningDataModule):
    def __init__(self, data_dir: str = "./"):
        super().__init__()
        self.data_dir = data_dir
        self.transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])

        # self.dims is returned when you call dm.size()
        # Setting default dims here because we know them.
        # Could optionally be assigned dynamically in dm.setup()
        self.dims = (1, 28, 28)
        self.num_classes = 10

    def prepare_data(self):
        # download
        MNIST(self.data_dir, train=True, download=True)
        MNIST(self.data_dir, train=False, download=True)

    def setup(self, stage=None):

        # Assign train/val datasets for use in dataloaders
        if stage == "fit" or stage is None:
            mnist_full = MNIST(self.data_dir, train=True, transform=self.transform)
            self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])

        # Assign test dataset for use in dataloader(s)
        if stage == "test" or stage is None:
            self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform)

    def train_dataloader(self):
        return DataLoader(self.mnist_train, batch_size=BATCH_SIZE)

    def val_dataloader(self):
        return DataLoader(self.mnist_val, batch_size=BATCH_SIZE)

    def test_dataloader(self):
        return DataLoader(self.mnist_test, batch_size=BATCH_SIZE)

### Define model

In [None]:
class LitModel(LightningModule):
    def __init__(self, channels, width, height, num_classes, hidden_size=64, learning_rate=2e-4):

        super().__init__()

        self.save_hyperparameters()

        self.model = nn.Sequential(
            nn.Flatten(),
            nn.Linear(channels * width * height, hidden_size),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_size, num_classes),
        )

    def forward(self, x):
        x = self.model(x)
        return F.log_softmax(x, dim=1)

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)
        self.log("train_loss", loss)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)
        preds = torch.argmax(logits, dim=1)
        acc = accuracy(preds, y)
        self.log("val_loss", loss, prog_bar=True)
        self.log("val_acc", acc, prog_bar=True)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate)
        return optimizer

###TPU Training

In [None]:
# Init DataModule
dm = MNISTDataModule()
# Init model from datamodule's attributes
model = LitModel(*dm.size(), dm.num_classes)
# Init trainer
trainer = Trainer(max_epochs=3, progress_bar_refresh_rate=20, tpu_cores=[5])
# Train
trainer.fit(model, dm)