<a href="https://colab.research.google.com/github/secutron/TesTime/blob/main/FixMatch_ignite.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

### Install requirements:

In [1]:
import os
assert os.environ['COLAB_TPU_ADDR'], 'Make sure to select TPU from Edit > Notebook settings > Hardware accelerator'


try:
  import torch_xla
  import ignite
except ImportError:
  # VERSION = "nightly"
  VERSION = "20200607"
  # VERSION = "2020060f"

  !pip install gsutil

  !curl https://raw.githubusercontent.com/pytorch/xla/master/contrib/scripts/env-setup.py -o pytorch-xla-env-setup.py
  !python pytorch-xla-env-setup.py --version $VERSION
  !pip install --upgrade git+https://github.com/pytorch/ignite
  !pip install --upgrade --pre hydra-core
  !pip install wandb
  !pip uninstall -y pillow && CC="cc -mavx2" pip install --no-cache-dir --force-reinstall pillow-simd


Collecting gsutil
[?25l  Downloading https://files.pythonhosted.org/packages/6a/fc/b0200efe90fa46eddf65df967ffe9bdfd2d42fc7777c88f2a8b50cd67361/gsutil-4.63.tar.gz (2.9MB)
[K     |████████████████████████████████| 2.9MB 7.0MB/s 
[?25hCollecting argcomplete>=1.9.4
  Downloading https://files.pythonhosted.org/packages/b7/9e/9dc74d330c07866d72f62d553fe8bdbe32786ff247a14e68b5659963e6bd/argcomplete-1.12.3-py2.py3-none-any.whl
Collecting fasteners>=0.14.1
  Downloading https://files.pythonhosted.org/packages/31/91/6630ebd169ca170634ca8a10dfcc5f5c11b0621672d4c2c9e40381c6d81a/fasteners-0.16.3-py2.py3-none-any.whl
Collecting gcs-oauth2-boto-plugin>=2.7
  Downloading https://files.pythonhosted.org/packages/f7/ab/3cc16742de84b76aa328c4b9e09fbf88447027827c12fb3913c5907be23b/gcs-oauth2-boto-plugin-2.7.tar.gz
Collecting google-apitools>=0.5.32
[?25l  Downloading https://files.pythonhosted.org/packages/5e/cb/cb0311f2ec371c83d6510847476c665edc9cc97564a51923557bc8f0b680/google_apitools-0.5.32-py3-no

Collecting wandb
[?25l  Downloading https://files.pythonhosted.org/packages/e0/b4/9d92953d8cddc8450c859be12e3dbdd4c7754fb8def94c28b3b351c6ee4e/wandb-0.10.32-py2.py3-none-any.whl (1.8MB)
[K     |████████████████████████████████| 1.8MB 6.7MB/s 
[?25hCollecting subprocess32>=3.5.3
[?25l  Downloading https://files.pythonhosted.org/packages/32/c8/564be4d12629b912ea431f1a50eb8b3b9d00f1a0b1ceff17f266be190007/subprocess32-3.5.4.tar.gz (97kB)
[K     |████████████████████████████████| 102kB 9.0MB/s 
Collecting pathtools
  Downloading https://files.pythonhosted.org/packages/e7/7f/470d6fcdf23f9f3518f6b0b76be9df16dcc8630ad409947f8be2eb0ed13a/pathtools-0.1.2.tar.gz
Collecting sentry-sdk>=0.4.0
[?25l  Downloading https://files.pythonhosted.org/packages/1c/4a/a54b254f67d8f4052338d54ebe90126f200693440a93ef76d254d581e3ec/sentry_sdk-1.1.0-py2.py3-none-any.whl (131kB)
[K     |████████████████████████████████| 133kB 14.2MB/s 
[?25hCollecting shortuuid>=0.5.0
  Downloading https://files.pythonhosted

### Download dataset:

In [2]:
# Download dataset:
from torchvision.datasets import CIFAR10

CIFAR10("./cifar10", train=True, download=True);

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./cifar10/cifar-10-python.tar.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ./cifar10/cifar-10-python.tar.gz to ./cifar10


### Get the code:

In [3]:
!git clone https://github.com/secutron/FixMatch-pytorch/
#!git clone https://github.com/vfdev-5/FixMatch-pytorch/


Cloning into 'FixMatch-pytorch'...
remote: Enumerating objects: 248, done.[K
remote: Counting objects: 100% (248/248), done.[K
remote: Compressing objects: 100% (150/150), done.[K
remote: Total 248 (delta 139), reused 164 (delta 76), pack-reused 0[K
Receiving objects: 100% (248/248), 78.18 KiB | 1.78 MiB/s, done.
Resolving deltas: 100% (139/139), done.


#### Optionally, login to `W&B`

To skip logging to `W&B`, please set `online_exp_tracking.wandb=false` 

In [4]:
# !wandb login <token>

In [4]:
%cd FixMatch-pytorch/

/content/FixMatch-pytorch


In [5]:
import torch

import ignite.distributed as idist
from ignite.engine import Events
from ignite.utils import manual_seed, setup_logger

import hydra
from hydra.utils import instantiate
from omegaconf import DictConfig

import utils
import trainers
from ctaugment import get_default_cta, OPS, interleave, deinterleave


sorted_op_names = sorted(list(OPS.keys()))

In [7]:
from hydra import initialize, initialize_config_module, initialize_config_dir, compose
from omegaconf import OmegaConf

with initialize(config_path="config"):
    cfg=compose(config_name="fixmatch.yaml")
    print(OmegaConf.to_yaml(cfg))

dataflow:
  name: cifar10
  data_path: /tmp/cifar10
  batch_size: 64
  num_workers: 12
solver:
  num_epochs: 1024
  epoch_length: 128
  checkpoint_every: 500
  validate_every: 1
  resume_from: null
  optimizer:
    cls: torch.optim.SGD
    params:
      lr: 0.03
      momentum: 0.9
      weight_decay: 0.0001
      nesterov: false
  supervised_criterion:
    cls: torch.nn.CrossEntropyLoss
  lr_scheduler:
    cls: torch.optim.lr_scheduler.CosineAnnealingLR
    params:
      eta_min: 0.0
      T_max: null
  unsupervised_criterion:
    cls: torch.nn.CrossEntropyLoss
    params:
      reduction: none
ssl:
  num_train_samples_per_class: 25
  confidence_threshold: 0.95
  lambda_u: 1.0
  mu_ratio: 7
  cta_update_every: 1
name: fixmatch
seed: 543
debug: false
model: resnet18
num_classes: 10
ema_decay: 0.999
distributed:
  backend: null
  nproc_per_node: null
online_exp_tracking:
  wandb: false



See https://hydra.cc/docs/next/upgrades/1.0_to_1.1/changes_to_package_header for more information
  See {url} for more information"""
See https://hydra.cc/docs/next/upgrades/1.0_to_1.1/changes_to_package_header for more information
  See {url} for more information"""
See https://hydra.cc/docs/next/upgrades/1.0_to_1.1/changes_to_package_header for more information
  See {url} for more information"""


In [8]:
cfg.distributed.backend = "xla-tpu"
cfg.distributed.nproc_per_node = 8
cfg.online_exp_tracking.wandb = "true"
cfg.solver.num_epochs = 50
cfg.ssl.confidence_threshold = 0.7
cfg.ema_decay = 0.9
cfg.ssl.cta_update_every = 15
cfg.solver.optimizer.params.lr = 0.1
#cfg.solver.optimizer.param_groups.lr = 0.1


print(OmegaConf.to_yaml(cfg))

dataflow:
  name: cifar10
  data_path: /tmp/cifar10
  batch_size: 64
  num_workers: 12
solver:
  num_epochs: 50
  epoch_length: 128
  checkpoint_every: 500
  validate_every: 1
  resume_from: null
  optimizer:
    cls: torch.optim.SGD
    params:
      lr: 0.1
      momentum: 0.9
      weight_decay: 0.0001
      nesterov: false
  supervised_criterion:
    cls: torch.nn.CrossEntropyLoss
  lr_scheduler:
    cls: torch.optim.lr_scheduler.CosineAnnealingLR
    params:
      eta_min: 0.0
      T_max: null
  unsupervised_criterion:
    cls: torch.nn.CrossEntropyLoss
    params:
      reduction: none
ssl:
  num_train_samples_per_class: 25
  confidence_threshold: 0.7
  lambda_u: 1.0
  mu_ratio: 7
  cta_update_every: 15
name: fixmatch
seed: 543
debug: false
model: resnet18
num_classes: 10
ema_decay: 0.9
distributed:
  backend: xla-tpu
  nproc_per_node: 8
online_exp_tracking:
  wandb: 'true'



In [9]:
model, ema_model, optimizer, sup_criterion, lr_scheduler = utils.initialize(cfg)

In [10]:
lr_scheduler

{'cls': 'torch.optim.lr_scheduler.CosineAnnealingLR', 'params': {'eta_min': 0.0, 'T_max': None}, 'T_max': 6400}

In [11]:
optimizer

{'cls': 'torch.optim.SGD', 'params': {'lr': 0.1, 'momentum': 0.9, 'weight_decay': 0.0001, 'nesterov': False}}

In [6]:
%cd /content

/content


### Let's train ResNet18 model in a faster mode

In [4]:
!cd FixMatch-pytorch && export PYTHONPATH=$PWD:$PYTHONPATH && \
  python main_fixmatch.py distributed.backend=xla-tpu distributed.nproc_per_node=8 online_exp_tracking.wandb=true solver.num_epochs=50 \
    ssl.confidence_threshold=0.7 ema_decay=0.9 ssl.cta_update_every=15 solver.optimizer.params.lr=0.1

See https://hydra.cc/docs/next/upgrades/1.0_to_1.1/changes_to_package_header for more information
  See {url} for more information"""
See https://hydra.cc/docs/next/upgrades/1.0_to_1.1/changes_to_package_header for more information
  See {url} for more information"""
See https://hydra.cc/docs/next/upgrades/1.0_to_1.1/changes_to_package_header for more information
  See {url} for more information"""
2021-06-17 04:31:00,037 ignite.distributed.launcher.Parallel INFO: Initialized distributed launcher with backend: 'xla-tpu'
2021-06-17 04:31:00,038 ignite.distributed.launcher.Parallel INFO: - Parameters to spawn processes: 
	nproc_per_node: 8
	nnodes: 1
	node_rank: 0
2021-06-17 04:31:00,038 ignite.distributed.launcher.Parallel INFO: Spawn function '<function training at 0x7f4d574a4ef0>' in 8 processes
2021-06-17 04:31:47,676 FixMatch Training INFO: {'dataflow': {'name': 'cifar10', 'data_path': '/content/cifar10', 'batch_size': 64, 'num_workers': 12}, 'solver': {'num_epochs': 50, 'epoch_leng

### Let's train WRN-28-2 model

In [None]:
!cd FixMatch-pytorch && export PYTHONPATH=$PWD:$PYTHONPATH && \
  python -u main_fixmatch.py model=WRN-28-2 distributed.backend=xla-tpu distributed.nproc_per_node=8 online_exp_tracking.wandb=true