<a href="https://colab.research.google.com/github/probml/pyprobml/blob/master/notebooks/logreg_tpu_pytorch_lightning_bolts.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Logistic regression on MNIST using TPUs and PyTorch Lightning

Code is from 
https://lightning-bolts.readthedocs.io/en/latest/introduction_guide.html#logistic-regression




# Setup TPU

Be sure to select Runtime=TPU in the drop-down menu!

See
https://colab.sandbox.google.com/github/pytorch/xla/blob/master/contrib/colab/getting-started.ipynb#scrollTo=3P6b3uqfzpDI


See also 
https://pytorch-lightning.readthedocs.io/en/latest/notebooks/lightning_examples/cifar10-baseline.html#


In [1]:
import matplotlib.pyplot as plt
import numpy as np

In [3]:
import os
assert os.environ['COLAB_TPU_ADDR'], 'Make sure to select TPU from Edit > Notebook settings > Hardware accelerator'

In [4]:

#!pip install -q cloud-tpu-client==0.10 https://storage.googleapis.com/tpu-pytorch/wheels/torch_xla-1.8-cp37-cp37m-linux_x86_64.whl

!pip install -q cloud-tpu-client==0.10 https://storage.googleapis.com/tpu-pytorch/wheels/torch_xla-1.9-cp37-cp37m-linux_x86_64.whl

Collecting cloud-tpu-client==0.10
  Downloading https://files.pythonhosted.org/packages/56/9f/7b1958c2886db06feb5de5b2c191096f9e619914b6c31fdf93999fdbbd8b/cloud_tpu_client-0.10-py3-none-any.whl
Collecting torch-xla==1.9
[?25l  Downloading https://storage.googleapis.com/tpu-pytorch/wheels/torch_xla-1.9-cp37-cp37m-linux_x86_64.whl (149.9MB)
[K     |████████████████████████████████| 149.9MB 84kB/s 
Collecting google-api-python-client==1.8.0
[?25l  Downloading https://files.pythonhosted.org/packages/9a/b4/a955f393b838bc47cbb6ae4643b9d0f90333d3b4db4dc1e819f36aad18cc/google_api_python_client-1.8.0-py3-none-any.whl (57kB)
[K     |████████████████████████████████| 61kB 2.7MB/s 
[31mERROR: earthengine-api 0.1.269 has requirement google-api-python-client<2,>=1.12.1, but you'll have google-api-python-client 1.8.0 which is incompatible.[0m
Installing collected packages: google-api-python-client, cloud-tpu-client, torch-xla
  Found existing installation: google-api-python-client 1.12.8
    Un

In [None]:
# VERSION = "20200325"  #@param ["1.5" , "20200325", "nightly"]
# !curl https://raw.githubusercontent.com/pytorch/xla/master/contrib/scripts/env-setup.py -o pytorch-xla-env-setup.py
# !python pytorch-xla-env-setup.py --version $VERSION

In [5]:
# imports pytorch
import torch

# imports the torch_xla package
import torch_xla
import torch_xla.core.xla_model as xm



In [6]:
# Creates a random tensor on xla:1 (a Cloud TPU core)
dev = xm.xla_device()
t1 = torch.ones(3, 3, device = dev)
print(t1)

tensor([[1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.]], device='xla:1')


In [7]:
# Creating a tensor on the second Cloud TPU core
second_dev = xm.xla_device(n=2, devkind='TPU')
t2 = torch.zeros(3, 3, device = second_dev)
print(t2)

tensor([[0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.]], device='xla:2')


# Setup lightning

In [8]:
#!pip install -q lightning-bolts
!pip install --quiet torchmetrics lightning-bolts torchvision torch pytorch-lightning
#!pip install --quiet torchmetrics>=0.3 lightning-bolts torchvision torch>=1.6, <1.9 pytorch-lightning>=1.3
#!pip install --quiet torchmetrics>=0.3 lightning-bolts torchvision torch>=1.6 pytorch-lightning>=1.3

[K     |████████████████████████████████| 276kB 4.9MB/s 
[K     |████████████████████████████████| 256kB 33.4MB/s 
[K     |████████████████████████████████| 819kB 37.4MB/s 
[K     |████████████████████████████████| 645kB 41.8MB/s 
[K     |████████████████████████████████| 122kB 42.9MB/s 
[K     |████████████████████████████████| 10.6MB 35.4MB/s 
[K     |████████████████████████████████| 829kB 25.7MB/s 
[K     |████████████████████████████████| 1.3MB 35.5MB/s 
[K     |████████████████████████████████| 143kB 37.3MB/s 
[K     |████████████████████████████████| 296kB 39.2MB/s 
[?25h  Building wheel for future (setup.py) ... [?25l[?25hdone
[31mERROR: tensorflow 2.5.0 has requirement tensorboard~=2.5, but you'll have tensorboard 2.4.1 which is incompatible.[0m
[31mERROR: earthengine-api 0.1.269 has requirement google-api-python-client<2,>=1.12.1, but you'll have google-api-python-client 1.8.0 which is incompatible.[0m


In [9]:
from pl_bolts.models.regression import LogisticRegression
import pytorch_lightning as pl

from pl_bolts.datamodules import MNISTDataModule, FashionMNISTDataModule, CIFAR10DataModule, ImagenetDataModule

  "`pytorch_lightning.metrics.*` module has been renamed to `torchmetrics.*` and split off to its own package"


# Iris

In [10]:
from sklearn.datasets import load_iris
from pl_bolts.datamodules import SklearnDataModule
import pytorch_lightning as pl

# use any numpy or sklearn dataset
X, y = load_iris(return_X_y=True)
dm = SklearnDataModule(X, y, batch_size=12)

# build model
model = LogisticRegression(input_dim=4, num_classes=3)



In [12]:

# fit
trainer = pl.Trainer(tpu_cores=8)
trainer.fit(model, train_dataloader=dm.train_dataloader(), val_dataloaders=dm.val_dataloader())


INFO:pytorch_lightning.utilities.distributed:GPU available: False, used: False
INFO:pytorch_lightning.utilities.distributed:TPU available: True, using: 8 TPU cores
Exception in device=TPU:0: Cannot replicate if number of devices (1) is different from 8
Traceback (most recent call last):
  File "/usr/local/lib/python3.7/dist-packages/torch_xla/distributed/xla_multiprocessing.py", line 329, in _mp_start_fn
    _start_fn(index, pf_cfg, fn, args)
Exception in device=TPU:1: Cannot replicate if number of devices (1) is different from 8
Traceback (most recent call last):
Exception in device=TPU:2: Cannot replicate if number of devices (1) is different from 8
  File "/usr/local/lib/python3.7/dist-packages/torch_xla/distributed/xla_multiprocessing.py", line 322, in _start_fn
    _setup_replication()
  File "/usr/local/lib/python3.7/dist-packages/torch_xla/distributed/xla_multiprocessing.py", line 329, in _mp_start_fn
    _start_fn(index, pf_cfg, fn, args)
  File "/usr/local/lib/python3.7/dist-p

ProcessExitedException: ignored

In [None]:
trainer.test(test_dataloaders=dm.test_dataloader())

# CIFAR

In [7]:
# create dataset
#dm = MNISTDataModule(num_workers=0, data_dir='data')
dm = CIFAR10DataModule(num_workers=0, data_dir='data')
dm.prepare_data() # force download now


Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to data/cifar-10-python.tar.gz


HBox(children=(FloatProgress(value=0.0, max=170498071.0), HTML(value='')))


Extracting data/cifar-10-python.tar.gz to data
Files already downloaded and verified


In [8]:
print(dm.size())
print(dm.num_classes)
ndims = np.prod(dm.size())
nclasses = dm.num_classes
print([ndims, nclasses, ndims*nclasses])



(3, 32, 32)
10
[3072, 10, 30720]


In [9]:
model = LogisticRegression(input_dim=ndims, num_classes=nclasses, learning_rate=0.001)
print(model)

LogisticRegression(
  (linear): Linear(in_features=3072, out_features=10, bias=True)
)


In [13]:
trainer = pl.Trainer(tpu_cores=8, max_epochs=2)
#trainer = pl.Trainer(max_epochs=2)
trainer.fit(model, datamodule=dm)


MisconfigurationException: ignored

In [12]:
trainer.test(model, test_dataloaders=dm.val_dataloader())


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Testing', layout=Layout(flex='2'), max=…


--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{}
--------------------------------------------------------------------------------


[{}]