# PyKale example: Domain Adaptation on Digits with Lightning

This example is constructed by refactoring the [ADA: (Yet) Another Domain Adaptation library](https://github.com/criteo-research/pytorch-ada), with many domain adaptation algorithms included.

It has been put together to run interactively on online hosting platofrms including [Google Colab](https://colab.research.google.com) or [myBinder](https://mybinder.org), but can also be downloaded and run locally. Follow the [PyKale](https://github.com/pykale/pykale) installation instructions for this.

*ToDo: Description of what this example actually does.*

# Setup

The first few blocks of code are necessary to set up the notebook execution environment and import the required modules, including PyKale.

This checks if the notebook is running on Google Colab and installs required packages.

In [None]:
if 'google.colab' in str(get_ipython()):
    print('Running on CoLab')
    !pip install pykale[extras] 

    !git clone -b digits-notebook https://github.com/pykale/pykale.git
    %cd pykale/examples/digits_dann_lightn
else:
    print('Not running on CoLab')

This imports required modules.

In [None]:
import logging
import os

from config import get_cfg_defaults
import numpy as np
import pytorch_lightning as pl
from torch.utils.data import DataLoader
from torch.utils.data import SequentialSampler
import torchvision

from model import get_model

from kale.loaddata.digits_access import DigitDataset
from kale.loaddata.multi_domain import MultiDomainDatasets
from kale.utils.csv_logger import setup_logger
from kale.utils.seed import set_seed

## Configuration

In this example we provide a [default configuration for domain adaptation problems](https://github.com/pykale/pykale/blob/main/examples/digits_dann_lightn/config.py) which  which is tailored using a [`.yaml` file for the specific application in this example](https://github.com/pykale/pykale/blob/main/examples/digits_dann_lightn/configs/TUTORIAL.yaml).

If GPUs are to be used at runtime, this is specified using a seperate variable. If you are running this example on Google Colab, or on a machine with GPU support, you might [set this](https://pytorch-lightning.readthedocs.io/en/stable/common/trainer.html#trainer-class-api) to make use of GPU accleration.

The configuration is summarised below the following cell.

In [None]:
cfg_path = "./configs/TUTORIAL.yaml" # Path to `.yaml` config file
gpus = None # GPU settings

cfg = get_cfg_defaults()
cfg.merge_from_file(cfg_path)
cfg.freeze()
print(cfg)

## Setup Output

In [None]:
os.makedirs(cfg.OUTPUT.DIR, exist_ok=True)
format_str = "@%(asctime)s %(name)s [%(levelname)s] - (%(message)s)"
logging.basicConfig(format=format_str)

## Select Datasets

In [None]:
source, target, num_channels = DigitDataset.get_source_target(
    DigitDataset(cfg.DATASET.SOURCE.upper()), DigitDataset(cfg.DATASET.TARGET.upper()), cfg.DATASET.ROOT
)

dataset = MultiDomainDatasets(
    source,
    target,
    config_weight_type=cfg.DATASET.WEIGHT_TYPE,
    config_size_type=cfg.DATASET.SIZE_TYPE,
    val_split_ratio=cfg.DATASET.VAL_SPLIT_RATIO,
)

## Train Model

In [None]:
seed = cfg.SOLVER.SEED + i * 10
# seed_everything in pytorch_lightning did not set torch.backends.cudnn
set_seed(seed)
print(f"==> Building model for seed {seed} ......")
# ---- setup model and logger ----
model, train_params = get_model(cfg, dataset, num_channels)
logger, results, checkpoint_callback, test_csv_file = setup_logger(
    train_params, cfg.OUTPUT.DIR, cfg.DAN.METHOD, seed
)

if gpus is None:
    trainer = pl.Trainer(
        progress_bar_refresh_rate=cfg.OUTPUT.PB_FRESH,  # in steps
        min_epochs=cfg.SOLVER.MIN_EPOCHS,
        max_epochs=cfg.SOLVER.MAX_EPOCHS,
        callbacks=[checkpoint_callback],
        logger=False,
    )
else:
    trainer = pl.Trainer(
        progress_bar_refresh_rate=cfg.OUTPUT.PB_FRESH,  # in steps
        min_epochs=cfg.SOLVER.MIN_EPOCHS,
        max_epochs=cfg.SOLVER.MAX_EPOCHS,
        callbacks=[checkpoint_callback],
        logger=False,
        gpus=gpus,
    )

trainer.fit(model)
results.update(
    is_validation=True, method_name=cfg.DAN.METHOD, seed=seed, metric_values=trainer.callback_metrics,
)
# test scores
trainer.test()
results.update(
    is_validation=False, method_name=cfg.DAN.METHOD, seed=seed, metric_values=trainer.callback_metrics,
)
results.to_csv(test_csv_file)
results.print_scores(cfg.DAN.METHOD)