<a href="https://colab.research.google.com/github/secutron/FixMatch-pytorch/blob/master/FixMatch_pytorch_on_TPUs.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# FixMatch with Pytorch on TPUs - experimental

Code: https://github.com/vfdev-5/FixMatch-pytorch/


### Install requirements:

In [None]:
import os
assert os.environ['COLAB_TPU_ADDR'], 'Make sure to select TPU from Edit > Notebook settings > Hardware accelerator'


try:
  import torch_xla
  import ignite
except ImportError:
  VERSION = "20210304"    
  # VERSION = "nightly"
  # VERSION = "20200607"
  !curl https://raw.githubusercontent.com/pytorch/xla/master/contrib/scripts/env-setup.py -o pytorch-xla-env-setup.py
  !python pytorch-xla-env-setup.py --version $VERSION
  !pip install --upgrade git+https://github.com/pytorch/ignite
  !pip install --upgrade --pre hydra-core
  !pip install wandb
  !pip uninstall -y pillow && CC="cc -mavx2" pip install --no-cache-dir --force-reinstall pillow-simd


  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
  0     0    0     0    0     0      0      0 --:--:-- --:--:-- --:--:--     0100  5116  100  5116    0     0  30819      0 --:--:-- --:--:-- --:--:-- 30819
Updating... This may take around 2 minutes.
Updating TPU runtime to pytorch-dev20210304 ...
Found existing installation: torch 1.9.0+cu102
Collecting cloud-tpu-client
  Downloading cloud_tpu_client-0.10-py3-none-any.whl (7.4 kB)
Collecting google-api-python-client==1.8.0
  Downloading google_api_python_client-1.8.0-py3-none-any.whl (57 kB)
[K     |████████████████████████████████| 57 kB 2.5 MB/s 
Uninstalling torch-1.9.0+cu102:
Installing collected packages: google-api-python-client, cloud-tpu-client
  Attempting uninstall: google-api-python-client
    Found existing installation: google-api-python-client 1.12.8
    Uninstalling google-api-python-client-1.12.8:
      Succes

Collecting wandb
  Downloading wandb-0.11.2-py2.py3-none-any.whl (1.8 MB)
[K     |████████████████████████████████| 1.8 MB 4.8 MB/s 
[?25hCollecting shortuuid>=0.5.0
  Downloading shortuuid-1.0.1-py3-none-any.whl (7.5 kB)
Collecting urllib3>=1.26.5
  Downloading urllib3-1.26.6-py2.py3-none-any.whl (138 kB)
[K     |████████████████████████████████| 138 kB 69.2 MB/s 
Collecting sentry-sdk>=1.0.0
  Downloading sentry_sdk-1.3.1-py2.py3-none-any.whl (133 kB)
[K     |████████████████████████████████| 133 kB 63.5 MB/s 
[?25hCollecting docker-pycreds>=0.4.0
  Downloading docker_pycreds-0.4.0-py2.py3-none-any.whl (9.0 kB)
Collecting configparser>=3.8.1
  Downloading configparser-5.0.2-py3-none-any.whl (19 kB)
Collecting GitPython>=1.0.0
  Downloading GitPython-3.1.18-py3-none-any.whl (170 kB)
[K     |████████████████████████████████| 170 kB 71.1 MB/s 
[?25hCollecting subprocess32>=3.5.3
  Downloading subprocess32-3.5.4.tar.gz (97 kB)
[K     |████████████████████████████████| 97 kB 6.3 M

### Download dataset:

In [None]:
# Download dataset:
from torchvision.datasets import CIFAR10

CIFAR10("/tmp/cifar10", train=True, download=True);

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to /tmp/cifar10/cifar-10-python.tar.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting /tmp/cifar10/cifar-10-python.tar.gz to /tmp/cifar10


### Get the code:

In [None]:
!git clone https://github.com//FixMatch-pytorch/

#### Optionally, login to `W&B`

To skip logging to `W&B`, please set `online_exp_tracking.wandb=false` 

In [None]:
# !wandb login <token>

[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[32mSuccessfully logged in to Weights & Biases![0m


### Let's train ResNet18 model in a faster mode

In [None]:
!cd FixMatch-pytorch && export PYTHONPATH=$PWD:$PYTHONPATH && \
  python main_fixmatch.py distributed.backend=xla-tpu distributed.nproc_per_node=8 online_exp_tracking.wandb=true solver.num_epochs=50 \
    ssl.confidence_threshold=0.7 ema_decay=0.9 ssl.cta_update_every=15 solver.optimizer.params.lr=0.1

2020-06-15 14:45:26,154 ignite.distributed.launcher.Parallel INFO: Initialized distributed launcher with backend: 'xla-tpu'
2020-06-15 14:45:26,154 ignite.distributed.launcher.Parallel INFO: - Parameters to spawn processes: 
	nproc_per_node: 8
	nnodes: 1
	node_rank: 0
2020-06-15 14:45:26,154 ignite.distributed.launcher.Parallel INFO: Spawn function '<function training at 0x7f7b29cdbd90>' in 8 processes
2020-06-15 14:46:42,829 FixMatch Training INFO: name: fixmatch
seed: 543
debug: false
model: resnet18
num_classes: 10
ema_decay: 0.9
solver:
  unsupervised_criterion:
    cls: torch.nn.CrossEntropyLoss
    params:
      reduction: none
  num_epochs: 50
  epoch_length: 128
  checkpoint_every: 500
  validate_every: 1
  resume_from: null
  optimizer:
    cls: torch.optim.SGD
    params:
      lr: 0.1
      momentum: 0.9
      weight_decay: 0.0001
      nesterov: false
  supervised_criterion:
    cls: torch.nn.CrossEntropyLoss
  lr_scheduler:
    cls: torch.optim.lr_scheduler.CosineAnnealing

### Let's train WRN-28-2 model

In [None]:
!cd FixMatch-pytorch && export PYTHONPATH=$PWD:$PYTHONPATH && \
  python -u main_fixmatch.py model=WRN-28-2 distributed.backend=xla-tpu distributed.nproc_per_node=8 online_exp_tracking.wandb=true