# PyKale Tutorial: Domain Adaptation on Digits with Lightning

[![Binder](https://mybinder.org/badge_logo.svg)](https://mybinder.org/v2/gh/pykale/pykale/HEAD?filepath=examples%2Fdigits_dann_lightn%2Ftutorial.ipynb) 

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/pykale/pykale/blob/main/examples/digits_dann_lightn/tutorial.ipynb)

This tutorial is constructed based on the `digits_dann_lightn` example `main.py`, which is in turn refactored from the [ADA: (Yet) Another Domain Adaptation library](https://github.com/criteo-research/pytorch-ada).

It has been put together to run interactively on online hosting platofrms including [Google Colab](https://colab.research.google.com) or [myBinder](https://mybinder.org), but can also be downloaded and run locally. Follow the [PyKale installation instructions](https://pykale.readthedocs.io/en/latest/installation.html) for this.

[Domain Adaptation](https://en.wikipedia.org/wiki/Domain_adaptation) takes a model trained and evaluated on one set of data (the source) and adapts it to another (the target). In this tutorial, a model is trained on one Digits Dataset (source) and adapted to another (target).

## Setup

The first few blocks of code are necessary to set up the notebook execution environment and import the required modules, including PyKale.

This checks if the notebook is running on Google Colab and installs required packages.

In [None]:
if 'google.colab' in str(get_ipython()):
    print('Running on CoLab')
    !pip install git+https://github.com/pykale/pykale.git#egg=pykale[extras] 

    !git clone https://github.com/pykale/pykale.git
    %cd pykale/examples/digits_dann_lightn
else:
    print('Not running on CoLab')

This imports required modules.

In [None]:
import logging
import os

from config import get_cfg_defaults
import numpy as np
import pytorch_lightning as pl
from torch.utils.data import DataLoader
from torch.utils.data import SequentialSampler
import torchvision

from model import get_model

from kale.loaddata.digits_access import DigitDataset
from kale.loaddata.multi_domain import MultiDomainDatasets
from kale.utils.csv_logger import setup_logger
from kale.utils.seed import set_seed

## Configuration

In this tutorial we provide a [default configuration for domain adaptation problems](https://github.com/pykale/pykale/blob/main/examples/digits_dann_lightn/config.py), which is tailored using a [`.yaml` file for the specific application in this tutorial](https://github.com/pykale/pykale/blob/main/examples/digits_dann_lightn/configs/TUTORIAL.yaml).

If GPUs are to be used at runtime, this is specified using a separate variable. If you are running this tutorial on Google Colab, or on a machine with GPU support, you might [set this](https://pytorch-lightning.readthedocs.io/en/stable/common/trainer.html#trainer-class-api) to make use of GPU acceleration. (On Google Colab click Runtime->Manage Sessions and select GPU, then change to `gpus = 1`).

The configuration is summarized below the following cell.

In [None]:
cfg_path = "./configs/tutorial.yaml" # Path to `.yaml` config file
gpus = None # GPU settings

cfg = get_cfg_defaults()
cfg.merge_from_file(cfg_path)
cfg.freeze()
print(cfg)

## Setup Output

If you are running online in myBinder ot Google Colab, you will not have easy access to files output by this tutorial. However, if you are running locally, a folder will be created to store model training output logs, which are configured below.

In [None]:
os.makedirs(cfg.OUTPUT.DIR, exist_ok=True)
format_str = "@%(asctime)s %(name)s [%(levelname)s] - (%(message)s)"
logging.basicConfig(format=format_str)

## Select Datasets

Source and target datasets are specified using `DigitDataset.get_source_target` from values in the configuration (`cfg`) above. In this tutorial, we specify a subset of classes (1, 3 and 8) to make training and testing quicker.

In [None]:
source, target, num_channels = DigitDataset.get_source_target(
    DigitDataset(cfg.DATASET.SOURCE.upper()), DigitDataset(cfg.DATASET.TARGET.upper()), cfg.DATASET.ROOT
)

class_subset = [1, 3, 8]

dataset = MultiDomainDatasets(
    source,
    target,
    config_weight_type=cfg.DATASET.WEIGHT_TYPE,
    config_size_type=cfg.DATASET.SIZE_TYPE,
    val_split_ratio=cfg.DATASET.VAL_SPLIT_RATIO,
    class_ids=class_subset,
)

## Set Seed

Some algorithms used in model training require generation of pseudo-random numbers. Setting the seed from which these are generated ensures reproducibility.

In [None]:
seed = cfg.SOLVER.SEED
# seed_everything in pytorch_lightning did not set torch.backends.cudnn
set_seed(seed)

## Setup Model

Here, we use the previously defined configuration and dataset to set up the model we will subsequently train.

In [None]:
%time model, train_params = get_model(cfg, dataset, num_channels)

Output reports on data file use.

## Setup Logger

A logger is used to store output generated during and after model training. This information can be used to assess the effectiveness of the training and to identify problems.

In [None]:
logger, results, checkpoint_callback, test_csv_file = setup_logger(
    train_params, cfg.OUTPUT.DIR, cfg.DAN.METHOD, seed
)

## Setup Trainer

A trainer object is used to determine and store model parameters. Here, one is configured with information on how a model should be trained, and what hardware will be used.

In [None]:
trainer = pl.Trainer(
    progress_bar_refresh_rate=cfg.OUTPUT.PB_FRESH,  # in steps
    min_epochs=cfg.SOLVER.MIN_EPOCHS,
    max_epochs=cfg.SOLVER.MAX_EPOCHS,
    callbacks=[checkpoint_callback],
    logger=False,
    gpus=gpus)

Output reports on available GPU and TPU resources.

## Train Model

Optimize model parameters using the trainer.

In [None]:
%time trainer.fit(model)
results.update(
    is_validation=True, method_name=cfg.DAN.METHOD, seed=seed, metric_values=trainer.callback_metrics,
)

## Test Optimized Model

Check performance of model optmized with training data against test data which was not used in training.

In [None]:
# test scores
%time trainer.test()
results.update(
    is_validation=False, method_name=cfg.DAN.METHOD, seed=seed, metric_values=trainer.callback_metrics,
)
results.print_scores(cfg.DAN.METHOD)

Outputs are defined as:

* 'Te_domain_acc': Accuracy on classifying the domain (source or target) from which data came.
* 'Te_source_acc': Accuracy on test data drawn from the source dataset.
* 'Te_target_acc': Accuracy on test data drawn from the target dataset.
* 'test_loss': Loss function value on the test data.

## Store Log

In [None]:
results.to_csv(test_csv_file)