<a href="https://colab.research.google.com/github/probml/pyprobml/blob/master/notebooks/logreg_tpu_pytorch_lightning_bolts.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Logistic regression on MNIST using TPUs and PyTorch Lightning

Code is from 
https://lightning-bolts.readthedocs.io/en/latest/introduction_guide.html#logistic-regression




# Setup TPU

Be sure to select Runtime=TPU in the drop-down menu!

See
https://colab.sandbox.google.com/github/pytorch/xla/blob/master/contrib/colab/getting-started.ipynb#scrollTo=3P6b3uqfzpDI


See also 
https://pytorch-lightning.readthedocs.io/en/latest/notebooks/lightning_examples/cifar10-baseline.html#


In [2]:
import matplotlib.pyplot as plt
import numpy as np

In [3]:
import os
assert os.environ['COLAB_TPU_ADDR'], 'Make sure to select TPU from Edit > Notebook settings > Hardware accelerator'

In [4]:

#!pip install -q cloud-tpu-client==0.10 https://storage.googleapis.com/tpu-pytorch/wheels/torch_xla-1.8-cp37-cp37m-linux_x86_64.whl

!pip install -q cloud-tpu-client==0.10 https://storage.googleapis.com/tpu-pytorch/wheels/torch_xla-1.9-cp37-cp37m-linux_x86_64.whl

[K     |████████████████████████████████| 149.9MB 74kB/s 
[K     |████████████████████████████████| 61kB 3.0MB/s 
[31mERROR: earthengine-api 0.1.269 has requirement google-api-python-client<2,>=1.12.1, but you'll have google-api-python-client 1.8.0 which is incompatible.[0m
[?25h

# Setup lightning

In [5]:
#!pip install -q lightning-bolts
!pip install --quiet torchmetrics lightning-bolts torchvision torch pytorch-lightning

[K     |████████████████████████████████| 276kB 5.5MB/s 
[K     |████████████████████████████████| 256kB 9.5MB/s 
[K     |████████████████████████████████| 819kB 10.5MB/s 
[K     |████████████████████████████████| 10.6MB 20.1MB/s 
[K     |████████████████████████████████| 829kB 38.6MB/s 
[K     |████████████████████████████████| 122kB 41.8MB/s 
[K     |████████████████████████████████| 645kB 33.7MB/s 
[K     |████████████████████████████████| 1.3MB 38.2MB/s 
[K     |████████████████████████████████| 143kB 40.3MB/s 
[K     |████████████████████████████████| 296kB 41.2MB/s 
[?25h  Building wheel for future (setup.py) ... [?25l[?25hdone
[31mERROR: tensorflow 2.5.0 has requirement tensorboard~=2.5, but you'll have tensorboard 2.4.1 which is incompatible.[0m
[31mERROR: earthengine-api 0.1.269 has requirement google-api-python-client<2,>=1.12.1, but you'll have google-api-python-client 1.8.0 which is incompatible.[0m


In [6]:
from pl_bolts.models.regression import LogisticRegression
import pytorch_lightning as pl

from pl_bolts.datamodules import MNISTDataModule, FashionMNISTDataModule, CIFAR10DataModule, ImagenetDataModule

  "`pytorch_lightning.metrics.*` module has been renamed to `torchmetrics.*` and split off to its own package"


# Iris

In [7]:
from sklearn.datasets import load_iris
from pl_bolts.datamodules import SklearnDataModule
import pytorch_lightning as pl

# use any numpy or sklearn dataset
X, y = load_iris(return_X_y=True)
dm = SklearnDataModule(X, y, batch_size=12)

# build model
model = LogisticRegression(input_dim=4, num_classes=3)



In [9]:

# fit
trainer = pl.Trainer(tpu_cores=8)
trainer.fit(model, train_dataloader=dm.train_dataloader(), val_dataloaders=dm.val_dataloader())


GPU available: False, used: False
TPU available: True, using: 8 TPU cores

  | Name   | Type   | Params
----------------------------------
0 | linear | Linear | 15    
----------------------------------
15        Trainable params
0         Non-trainable params
15        Total params
0.000     Total estimated model params size (MB)


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Exception in device=TPU:4: 'mappingproxy' object does not support item assignment
Traceback (most recent call last):
Exception in device=TPU:2: 'mappingproxy' object does not support item assignment
  File "/usr/local/lib/python3.7/dist-packages/torch_xla/distributed/xla_multiprocessing.py", line 329, in _mp_start_fn
    _start_fn(index, pf_cfg, fn, args)
Exception in device=TPU:7: 'mappingproxy' object does not support item assignment
Exception in device=TPU:6: 'mappingproxy' object does not support item assignment
Traceback (most recent call last):
Exception in device=TPU:3: 'mappingproxy' object does not support item assignment
  File "/usr/local/lib/python3.7/dist-packages/torch_xla/distributed/xla_multiprocessing.py", line 323, in _start_fn
    fn(gindex, *args)
  File "/usr/local/lib/python3.7/dist-packages/torch_xla/distributed/xla_multiprocessing.py", line 329, in _mp_start_fn
    _start_fn(index, pf_cfg, fn, args)
Exception in device=TPU:1: 'mappingproxy' object does not suppo

ProcessExitedException: ignored

In [None]:
trainer.test(test_dataloaders=dm.test_dataloader())

# CIFAR

In [10]:
# create dataset
#dm = MNISTDataModule(num_workers=0, data_dir='data')
dm = CIFAR10DataModule(num_workers=0, data_dir='data')
dm.prepare_data() # force download now


Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to data/cifar-10-python.tar.gz


HBox(children=(FloatProgress(value=0.0, max=170498071.0), HTML(value='')))


Extracting data/cifar-10-python.tar.gz to data
Files already downloaded and verified


In [11]:
print(dm.size())
print(dm.num_classes)
ndims = np.prod(dm.size())
nclasses = dm.num_classes
print([ndims, nclasses, ndims*nclasses])



(3, 32, 32)
10
[3072, 10, 30720]


In [12]:
model = LogisticRegression(input_dim=ndims, num_classes=nclasses, learning_rate=0.001)
print(model)

LogisticRegression(
  (linear): Linear(in_features=3072, out_features=10, bias=True)
)


In [13]:
trainer = pl.Trainer(tpu_cores=8, max_epochs=2)
#trainer = pl.Trainer(max_epochs=2)
trainer.fit(model, datamodule=dm)


GPU available: False, used: False
TPU available: True, using: 8 TPU cores

  | Name   | Type   | Params
----------------------------------
0 | linear | Linear | 30.7 K
----------------------------------
30.7 K    Trainable params
0         Non-trainable params
30.7 K    Total params
0.123     Total estimated model params size (MB)


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Exception in device=TPU:1: 'mappingproxy' object does not support item assignment
Exception in device=TPU:3: 'mappingproxy' object does not support item assignment
Exception in device=TPU:5: 'mappingproxy' object does not support item assignment
Exception in device=TPU:4: 'mappingproxy' object does not support item assignment
Traceback (most recent call last):
Exception in device=TPU:6: 'mappingproxy' object does not support item assignment
Traceback (most recent call last):
Traceback (most recent call last):
  File "/usr/local/lib/python3.7/dist-packages/torch_xla/distributed/xla_multiprocessing.py", line 329, in _mp_start_fn
    _start_fn(index, pf_cfg, fn, args)
  File "/usr/local/lib/python3.7/dist-packages/torch_xla/distributed/xla_multiprocessing.py", line 329, in _mp_start_fn
    _start_fn(index, pf_cfg, fn, args)
Traceback (most recent call last):
Exception in device=TPU:7: 'mappingproxy' object does not support item assignment
  File "/usr/local/lib/python3.7/dist-packages/tor

ProcessExitedException: ignored

In [12]:
trainer.test(model, test_dataloaders=dm.val_dataloader())


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Testing', layout=Layout(flex='2'), max=…


--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{}
--------------------------------------------------------------------------------


[{}]