# PyKale Tutorial: Domain Adaptation on Digits with Lightning

| [Open in Colab](https://colab.research.google.com/github/pykale/pykale/blob/main/examples/digits_dann_lightn/tutorial.ipynb) (click `Runtime` → `Run all (Ctrl+F9)` | [Launch Binder](https://mybinder.org/v2/gh/pykale/pykale/HEAD?filepath=examples%2Fdigits_dann_lightn%2Ftutorial.ipynb) (click `Run` → `Run All Cells`) |

If using [Google Colab](https://colab.research.google.com), a free GPU can be enabled to save time via setting `Runtime` → `Change runtime type` → `Hardware accelerator: GPU`

## Introduction

[Domain Adaptation](https://en.wikipedia.org/wiki/Domain_adaptation) takes a model trained and evaluated on one set of data (the source) and adapts it to another (the target). In this tutorial, a model is trained on one digits dataset (source) and adapted to another (target). This tutorial is constructed based on the `digits_dann_lightn` example `main.py`, which is in turn refactored from the [ADA: (Yet) Another Domain Adaptation library](https://github.com/criteo-research/pytorch-ada). It has been put together to run interactively on online hosting platforms including [Google Colab](https://colab.research.google.com) or [myBinder](https://mybinder.org), but can also be downloaded and run locally. Follow the [PyKale installation instructions](https://pykale.readthedocs.io/en/latest/installation.html) for this.

## Setup

The first few blocks of code are necessary to set up the notebook execution environment and import the required modules, including PyKale.

This checks if the notebook is running on Google Colab and installs required packages.

In [None]:
if 'google.colab' in str(get_ipython()):
    print('Running on CoLab')
    !pip uninstall --yes imgaug && pip uninstall --yes albumentations && pip install git+https://github.com/aleju/imgaug.git
    !pip install git+https://github.com/pykale/pykale.git
    
    !git clone https://github.com/pykale/pykale.git
    %cd pykale/examples/digits_dann_lightn
else:
    print('Not running on CoLab')

This imports required modules.

In [None]:
import logging
import os

from config import get_cfg_defaults
import numpy as np
import pytorch_lightning as pl
import torch
from torch.utils.data import DataLoader
from torch.utils.data import SequentialSampler
import torchvision

from model import get_model
from pytorch_lightning.callbacks import ModelCheckpoint, TQDMProgressBar
from pytorch_lightning.loggers import TensorBoardLogger

from kale.loaddata.image_access import DigitDataset
from kale.loaddata.multi_domain import MultiDomainDatasets
from kale.utils.seed import set_seed

## Configuration

In this tutorial we modify the [default configuration for domain adaptation problems](https://github.com/pykale/pykale/blob/main/examples/digits_dann_lightn/config.py) with a customized [`.yaml` file for the specific application in this tutorial](https://github.com/pykale/pykale/blob/main/examples/digits_dann_lightn/configs/TUTORIAL.yaml). The configuration is summarized below the following cell.

In [None]:
cfg_path = "./configs/tutorial.yaml" # Path to `.yaml` config file

cfg = get_cfg_defaults()
cfg.merge_from_file(cfg_path)
cfg.freeze()
print(cfg)

## Check if a GPU is available

If a CUDA GPU is available, this should be used to accelerate the training process. The code below checks and reports on this.

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using: " + device)
gpus = 1 if device == "cuda" else None

## Select Datasets

Source and target datasets are specified using `DigitDataset.get_source_target` from values in the configuration (`cfg`) above. In this tutorial, we specify a subset of classes (1, 3 and 8) to make training and testing quicker.

In [None]:
source, target, num_channels = DigitDataset.get_source_target(
    DigitDataset(cfg.DATASET.SOURCE.upper()), DigitDataset(cfg.DATASET.TARGET.upper()), cfg.DATASET.ROOT
)

class_subset = [1, 3, 8]

dataset = MultiDomainDatasets(
    source,
    target,
    config_weight_type=cfg.DATASET.WEIGHT_TYPE,
    config_size_type=cfg.DATASET.SIZE_TYPE,
    valid_split_ratio=cfg.DATASET.VALID_SPLIT_RATIO,
    class_ids=class_subset,
)

## Set Seed

Some algorithms used in model training require generation of pseudo-random numbers. Setting the seed from which these are generated ensures reproducibility.

In [None]:
seed = cfg.SOLVER.SEED
# seed_everything in pytorch_lightning did not set torch.backends.cudnn
set_seed(seed)

## Setup Model

Here, we use the previously defined configuration and dataset to set up the model we will subsequently train.

In [None]:
%time model, train_params = get_model(cfg, dataset, num_channels)

Output reports on data file use.

## Setup Logger

A Tensorboard logger is used to store output generated during model training. This information can be used to assess the effectiveness of the training and to identify problems. The output model is stored at `cfg.OUTPUT.TB_DIR`.

In [None]:
tb_logger = TensorBoardLogger(cfg.OUTPUT.TB_DIR, name="seed{}".format(seed))

## Setup Checkpoint

A `ModelCheckpoint` is used to save the model and some quantitative measure(s) periodically.

In [None]:
checkpoint_callback = ModelCheckpoint(filename="{epoch}-{step}-{valid_loss:.4f}", monitor="valid_loss", mode="min",)

A `TQDMProgressBar` is used to set the progress bar. `PB_FRESH` determines at which rate (in number of batches) the progress bars get updated. Set it to ``0`` to disable the display.

In [None]:
progress_bar = TQDMProgressBar(cfg.OUTPUT.PB_FRESH)

## Setup Trainer

A trainer object is used to determine and store model parameters. Here, one is configured with information on how a model should be trained, and what hardware will be used.

In [None]:
trainer = pl.Trainer(
    min_epochs=cfg.SOLVER.MIN_EPOCHS,
    max_epochs=cfg.SOLVER.MAX_EPOCHS,
    callbacks=[checkpoint_callback, progress_bar],
    logger=tb_logger,
    gpus=gpus)

Output reports on available GPU and TPU resources.

## Train Model

Optimize model parameters using the trainer.

In [None]:
%time trainer.fit(model)

## Test Optimized Model

Check performance of model optimized with training data against test data which was not used in training.

In [None]:
# test scores
%time trainer.test()

Outputs are defined as:

* 'test_domain_acc': Accuracy on classifying the domain (source or target) from which data came.
* 'test_source_acc': Accuracy on test data drawn from the source dataset.
* 'test_target_acc': Accuracy on test data drawn from the target dataset.
* 'test_loss': Loss function value on the test data.