# Domain Specific Batch Normalization (DSBN)

 > Domain-Specific Batch Normalization for Unsupervised Domain Adaptation (DSBN) (https://arxiv.org/abs/1906.03950)

## Overview

DSBN은 domain마다 고유의 batch normalization (BN) layer를 사용해, 각 BN layer가 맡은 domain의 정보를 BN parameter를 통해 학습합니다.

그 후, normalization을 하면서 domain-specific한 정보를 제거하고, 모델로 하여금 domain-invariant feature를 학습할 수 있게 한 method입니다.

DSBN은 2 stage로 학습이 되는데, 이번 실습은 그 중 첫 번째 stage에 대한 학습을 구현하는 실습입니다.

<img src = "https://github.com/wgchang/DSBN/raw/master/captions/dsbn.jpg">

In [None]:
import argparse
import logging
import pprint
import datetime
import sys
import random
from collections import defaultdict
import math

import torch.nn.functional as F
import torch.optim as optim
import os
import torch
import torch.nn as nn
from torch.nn import init
from torch.utils import data
from torchvision.datasets import MNIST, SVHN
from torchvision import transforms
import numpy as np

일관적인 학습 결과를 위해 random seed를 고정하겠습니다.

In [None]:
# basic random seed
import os
import random
import numpy as np

random_seed= 2023

def seedBasic(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)

# torch random seed
import torch
def seedTorch(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

# basic + torch
def seedEverything(seed):
    seedBasic(seed)
    seedTorch(seed)
    print(f"Ransom Seed: {seed}")


seedEverything(2023)

Ransom Seed: 2023


# 모델 구현하기

DSBN을 구현하고, 이를 이용해 DSBN을 포함한 LeNet을 구현해봅니다.

<br>
<br>


## DSBN(Domain Specific Batch Normalization) 구현
먼저, 이번 실습에서 가장 핵심이 되는 DSBN을 정의해줍니다.
torch에 이미 정의된 BatchNorm2d를 이용하여 Domain Specific한 Batch Normalization을 구현할 수 있습니다.

- nn.BatchNorm2d

<img src="https://ifh.cc/g/gdkh0f.png">

  - num_features : input size (N,C,H,W)에서 C를 의미합니다.
  - eps : Numerical stability를 위해 분모에 추가된 값으로서 default는 0.1입니다. (위의 식에서 ϵ에 해당되며 분모가 0이 되어 NaN이 되는것을 방지)
  - momentum : running_mean과 running_var의 계산을 위해 사용되는 값으로서 default는 0.1입니다. (Trainining을 수행시에는 batch단위의 평균과 분산으로 batch-norm을 수행하고, Test를 수행시에는 축적된 running_mean/variance를 사용합니다)
  - affine : Boolean value이며, True로 설정시에, 이 모듈에 학습가능한 affine parameter가 있습니다. Default는 True입니다. (False로 설정시에 위의식에서 γ=1, β=0)
  - track_running_stats : Boolean value이며, True로 설정시에 이 모듈은 running mean과 variance를 track하고 False로 설정시에는 이 모듈이 track하지 않고, running_mean과 running_var를 None으로 초기화하며, 이 값이 None인 경우에는 이 모듈은 항상 batch statistics를 사용합니다. Default는 True입니다.

In [None]:
class DomainSpecificBatchNorm2d(nn.Module):
    _version = 2

    def __init__(self, num_features, num_classes, eps=1e-5, momentum=0.1, affine=True,
                 track_running_stats=True):
        super(DomainSpecificBatchNorm2d, self).__init__()
        ############ TODO #############
        '''
        https://pytorch.org/docs/stable/generated/torch.nn.BatchNorm2d.html
        위 링크에서 BatchNorm2d를 참고하여 다음의 self.bns를 채우세요.
        '''
        self.bns = nn.ModuleList(
            [nn.BatchNorm2d(num_features, eps, momentum, affine, track_running_stats) for _ in range(num_classes)])

        ############ TODO #############

    def reset_running_stats(self):
        for bn in self.bns:
            bn.reset_running_stats()

    def reset_parameters(self):
        for bn in self.bns:
            bn.reset_parameters()

    def _check_input_dim(self, input):
        if input.dim() != 4:
            raise ValueError('expected 4D input (got {}D input)'
                             .format(input.dim()))

    def forward(self, x, domain_label):
        self._check_input_dim(x)
        bn = self.bns[domain_label[0]]
        return bn(x), domain_label


In [None]:
def init_weights(obj):
    for m in obj.modules():
        if isinstance(m, nn.Conv2d):
            # init.xavier_normal_(m.weight)
            m.weight.data.normal_(0, 0.01).clamp_(min=-0.02, max=0.02)
            try:
                m.bias.data.zero_()
            except AttributeError:
                # no bias
                pass
        if isinstance(m, nn.Linear):
            m.weight.data.normal_(0, 0.01).clamp_(min=-0.02, max=0.02)
            try:
                m.bias.data.zero_()
            except AttributeError:
                # no bias
                pass
        elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d):
            m.reset_parameters()
        elif isinstance(m, nn.Embedding):
            init.normal_(m.weight, 0, 0.01)

## LeNet 구현하기
다음으로 LeNet을 구현해봅니다. LeNet의 기본 구조는 다음과 같습니다.

<img src='https://miro.medium.com/v2/resize:fit:1204/format:webp/1*9MRcNBz9uHXplXzd3ii6VQ.png'>

이번 실습에서는 LeNet의 기본 구조를 따라가되, 몇 가지 파라미터를 수정하고 Batch Normalization 모듈을 추가하여 실험을 진행할 것입니다.

In [None]:
class LeNet(nn.Module):
    """"Network used for MNIST or USPS experiments."""

    def __init__(self, num_classes=10, weights_init_path=None):
        super(LeNet, self).__init__()
        self.num_classes = num_classes
        self.num_channels = 3
        self.image_size = 28
        self.name = 'LeNet'
        self.setup_net()

        if weights_init_path is not None:
            init_weights(self)
            self.load(weights_init_path)
        else:
            init_weights(self)

    def setup_net(self):
        self.conv1 = nn.Conv2d(self.num_channels, 20, kernel_size=5)
        self.bn1 = nn.BatchNorm2d(20)
        self.pool1 = nn.MaxPool2d(2)
        self.conv2 = nn.Conv2d(20, 50, kernel_size=5)
        self.bn2 = nn.BatchNorm2d(50)
        self.pool2 = nn.MaxPool2d(2)

        # 28 - 5 + 1 = 24
        # (24 - 2)/2 + 1 = 12
        # 12 - 5 + 1 = 8
        # (8 - 2)/2 + 1 = 4
        self.fc1 = nn.Linear(50 * 4 * 4, 500)
        self.fc2 = nn.Linear(500, self.num_classes)

    def forward(self, x, with_ft=False):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.pool1(F.relu(x))
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.pool2(F.relu(x))
        x = x.view(x.size(0), -1)
        x = self.fc1(x)
        x = self.fc2(F.relu(x))
        feat = x

        if with_ft:
            return x, feat
        else:
            return x

    def load(self, init_path):
        net_init_dict = torch.load(init_path)
        init_weights(self)
        updated_state_dict = self.state_dict()
        print('load {} params.'.format(init_path))
        for k, v in updated_state_dict.items():
            if k in net_init_dict:
                if v.shape == net_init_dict[k].shape:
                    updated_state_dict[k] = net_init_dict[k]
                else:
                    print(
                        "{0} params' shape not the same as pretrained params. Initialize with default settings.".format(
                            k))
            else:
                print("{0} params does not exist. Initialize with default settings.".format(k))
        self.load_state_dict(updated_state_dict)

## DSBN을 포함한 LeNet 구현하기
이제 DSBN을 LeNet을 구현해봅니다.

In [None]:
class DSBNLeNet(nn.Module):
    """"Network used for MNIST or USPS experiments. Conditional Batch Normalization is added."""

    def __init__(self, num_classes=10, weights_init_path=None, num_domains=2):
        super(DSBNLeNet, self).__init__()
        self.num_classes = num_classes
        self.num_channels = 3
        self.image_size = 28
        self.num_domains = num_domains
        self.name = 'DSBNLeNet'
        self.setup_net()

        if weights_init_path is not None:
            init_weights(self)
            self.load(weights_init_path)
        else:
            init_weights(self)

    def setup_net(self):
        ############ TODO #############
        '''
        위 LeNet을 참고하여 DSBN을 포함시킨 DSBNLeNet을 완성하세요.
        '''

        self.conv1 = nn.Conv2d(self.num_channels, 20, kernel_size=5)
        self.bn1 = DomainSpecificBatchNorm2d(20, self.num_domains)
        self.pool1 = nn.MaxPool2d(2)
        self.conv2 = nn.Conv2d(20, 50, kernel_size=5)
        self.bn2 = DomainSpecificBatchNorm2d(50, self.num_domains)
        self.pool2 = nn.MaxPool2d(2)

        self.fc1 = nn.Linear(50 * 4 * 4, 500)
        self.fc2 = nn.Linear(500, self.num_classes)


        ############ TODO #############

    def forward(self, x, y, with_ft=False):
        x = self.conv1(x)
        x, _ = self.bn1(x, y)
        x = self.pool1(F.relu(x))
        x = self.conv2(x)
        x, _ = self.bn2(x, y)
        x = self.pool2(F.relu(x))
        x = x.view(x.size(0), -1)
        x = self.fc1(x)
        x = self.fc2(F.relu(x))
        feat = x

        if with_ft:
            return x, feat
        else:
            return x

    def load(self, init_path):
        net_init_dict = torch.load(init_path)
        init_weights(self)
        updated_state_dict = self.state_dict()
        print('load {} params.'.format(init_path))
        for k, v in updated_state_dict.items():
            if k in net_init_dict:
                if v.shape == net_init_dict[k].shape:
                    updated_state_dict[k] = net_init_dict[k]
                else:
                    print(
                        "{0} params' shape not the same as pretrained params. Initialize with default settings.".format(
                            k))
            else:
                print("{0} params does not exist. Initialize with default settings.".format(k))
        self.load_state_dict(updated_state_dict)

학습에 필요한 추가적인 model들도 정의해줍니다.

In [None]:
class Discriminator(nn.Module):
    def __init__(self, in_features):
        super(Discriminator, self).__init__()
        self.in_features = in_features

        self.discriminator = nn.Sequential(
            nn.Linear(self.in_features, 500),
            nn.ReLU(),
            nn.Linear(500, 500),
            nn.ReLU(),
            nn.Linear(500, 1)
        )

        init_weights(self)

    def forward(self, x):
        return self.discriminator(x)

In [None]:
class Centroids(nn.Module):
  '''
  Learning Semantic Representations for Unsupervised Domain Adaptation
  http://proceedings.mlr.press/v80/xie18c/xie18c.pdf
  '''
  def __init__(self, feature_dim, num_classes, decay_const=0.3):
      super(Centroids, self).__init__()
      self.decay_const = decay_const
      self.num_classes = num_classes
      self.centroids = nn.Parameter(torch.randn(num_classes, feature_dim))
      self.centroids.requires_grad = False
      self.reset_parameters()

  def reset_parameters(self):
      self.centroids.data.zero_()

  def forward(self, x, y, y_mask=None):
      classes = torch.unique(y)
      current_centroids = []
      for c in range(self.num_classes):
          if c in classes:
              if y_mask is not None:
                  avg_c = torch.sum(x[(y == c) & y_mask, :], dim=0) / torch.sum((y == c) & y_mask).float()
              else:
                  avg_c = torch.sum(x[(y == c), :], dim=0) / torch.sum((y == c)).float()
              current_centroids.append(avg_c * self.decay_const + (1 - self.decay_const) * self.centroids[c:c + 1, :])
          else:
              current_centroids.append(self.centroids[c:c + 1, :])
      current_centroids = torch.cat(current_centroids, 0)
      return current_centroids


# DataLoader 구현하기

이번 실습을 위한 domain adaptation setting은 svhn --> mnist 입니다.
<br>
<br>

### MNIST
MNIST 데이터셋은 손으로 쓴 숫자 이미지로 이루어진 대형 데이터셋이며,

60,000개의 Training dataset과 10,000개의 Test dataset으로 이루어져 있습니다.

<img src='https://upload.wikimedia.org/wikipedia/commons/thumb/2/27/MnistExamples.png/440px-MnistExamples.png'>

<br>
<br>

### SVHN
SVHN은 Google Street View에서 수집된 숫자 데이터셋으로,

MNIST와 마찬가지로 0부터 9까지 숫자 10개의 이미지로 이루어진 데이터셋 입니다.

<img src='https://production-media.paperswithcode.com/datasets/SVHN-0000000424-c12734ed_mMXUnWD.jpg' width="500" height="300">


데이터에 대해서 더 자세하게 알고 싶다면, 다음의 링크를 참조해주세요.

https://kjhov195.github.io/2020-02-09-image_dataset_1/

### source dataset으로 svhn, target dataset으로 mnist를 구현해주시길 바랍니다.

In [None]:
MNIST_DIR = './data/mnist'
SVHN_DIR  = './data/svhn'
batch_size = 40
num_workers = 2

def get_dataloaders(src_data_path = './data/svhn', trg_data_path = './data/mnist'):
    ############ TODO #############
    '''
    source_transform:
        1. resize to 28 by 28
        2. convert to tensor
        3. normalize by mean (0.5, 0.5, 0.5) and std (0.5, 0.5, 0.5)

    target_transform:
        1. convert to RGB from Grayscale (Hint: Use transforms.Lambda)
        2. convert to tensor
        3. normalize by mean (0.5, 0.5, 0.5) and std (0.5, 0.5, 0.5)

    source_train_dataset:
        SVHN을 이용해 구현
        (참고: https://pytorch.org/vision/main/generated/torchvision.datasets.SVHN.html#torchvision.datasets.SVHN)

    target_train_dataset & target_val_dataset:
        MNIST로 구현
        (참고: https://pytorch.org/vision/main/generated/torchvision.datasets.MNIST.html)
    '''

    source_transform = transforms.Compose([
        transforms.Resize([28, 28]),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
    target_transform = transforms.Compose([
        transforms.Lambda(lambda x: x.convert("RGB")),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

    source_train_dataset = SVHN(root=src_data_path,
                                split='train',
                                download=True,
                                transform=source_transform)

    target_train_dataset = MNIST(root=trg_data_path,
                                 train=True,
                                 download=True,
                                 transform=target_transform)


    target_val_dataset = MNIST(root=trg_data_path,
                               train=False,
                               download=True,
                               transform=target_transform)

    ############ TODO #############

    source_train_dataloader = data.DataLoader(source_train_dataset,
                                              batch_size = batch_size,
                                              shuffle = True,
                                              num_workers=num_workers,
                                              drop_last=True,
                                              pin_memory=True)

    target_train_dataloader = data.DataLoader(target_train_dataset,
                                              batch_size = batch_size,
                                              shuffle = True,
                                              num_workers=num_workers,
                                              drop_last=True,
                                              pin_memory=True)

    target_val_dataloader = data.DataLoader(target_val_dataset,
                                            batch_size = batch_size,
                                            shuffle = False,
                                            num_workers=num_workers,
                                            drop_last=True,
                                            pin_memory=True)

    return source_train_dataloader, target_train_dataloader, target_val_dataloader

# Utils for training

학습에 필요한 다양한 기능들을 구현한 부분입니다.

In [None]:
def get_optimizer_params(modules, lr, weight_decay=0.0005, base_weight_factor=0.1):
    weights = []
    biases = []
    base_weights = []
    base_biases = []

    module = modules
    for key, value in dict(module.named_parameters()).items():
        if value.requires_grad:
            if 'fc' in key or 'score' in key:
                if 'bias' in key:
                    biases += [value]
                else:
                    weights += [value]
            else:
                if 'bias' in key:
                    base_biases += [value]
                else:
                    base_weights += [value]
    if base_weight_factor:
        params = [
            {'params': weights, 'lr': lr, 'weight_decay': weight_decay},
            {'params': biases, 'lr': lr },
            {'params': base_weights, 'lr': lr * base_weight_factor, 'weight_decay': weight_decay},
            {'params': base_biases, 'lr': lr * base_weight_factor},
        ]
    else:
        params = [
            {'params': base_weights + weights, 'lr': lr, 'weight_decay': weight_decay},
            {'params': base_biases + biases, 'lr': lr},
        ]
    return params

In [None]:
class Monitor:
    def __init__(self):
        self.reset()

    def reset(self):
        self._cummulated_losses = defaultdict(lambda: 0.0)
        self._total_counts = defaultdict(lambda: 0)

    def update(self, losses_dict):
        for key in losses_dict:
            self._cummulated_losses[key] += losses_dict[key]
            self._total_counts[key] += 1

    @property
    def cummulated_losses(self):
        return self._cummulated_losses

    @property
    def total_counts(self):
        return self._total_counts

    @property
    def losses(self):
        losses = {}
        for k, v in self._cummulated_losses.items():
            if self._total_counts[k] > 0:
                losses[k] = v / float(self._total_counts[k])
            else:
                losses[k] = 0.0
        return losses

    def __repr__(self):
        sorted_loss_keys = sorted([k for k in self._cummulated_losses.keys()])
        losses = self.losses
        repr_str = ''
        for key in sorted_loss_keys:
            repr_str += ', {0}={1:.4f}'.format(key, losses[key])
        return repr_str[2:]


def one_hot_encoding(y, n_classes):
    tensor_size = [y.size(i) for i in range(len(y.size()))]
    if tensor_size[-1] != 1:
        tensor_size += [1]
    tensor_size = tuple(tensor_size)
    y_one_hot = torch.zeros(tensor_size[:-1] + (n_classes,)).to(y.device).scatter_(len(tensor_size) - 1,
                                                                                   y.view(tensor_size), 1)
    return y_one_hot

In [None]:
def lr_poly(base_lr, i_iter, alpha=10, beta=0.75, num_steps=250000):
    if i_iter < 0:
        return base_lr
    return base_lr / ((1 + alpha * float(i_iter) / num_steps) ** (beta))

class LRScheduler:
    def __init__(self, learning_rate, num_steps=200000, alpha=10,
                 beta=0.75, base_weight_factor=False):
        self.learning_rate = learning_rate
        self.num_steps = num_steps
        self.alpha = alpha
        self.beta = beta
        self.base_weight_factor = base_weight_factor

    def __call__(self, optimizer, i_iter):
        lr_i_iter = i_iter
        lr = self.learning_rate

        if len(optimizer.param_groups) == 1:
            optimizer.param_groups[0]['lr'] = lr_poly(lr, lr_i_iter, alpha=self.alpha, beta=self.beta,
                                                      num_steps=self.num_steps)
        elif len(optimizer.param_groups) == 2:
            optimizer.param_groups[0]['lr'] = lr_poly(lr, lr_i_iter, alpha=self.alpha, beta=self.beta,
                                                      num_steps=self.num_steps)
            optimizer.param_groups[1]['lr'] = lr_poly(lr, lr_i_iter, alpha=self.alpha, beta=self.beta,
                                                      num_steps=self.num_steps)
        elif len(optimizer.param_groups) == 4:
            optimizer.param_groups[0]['lr'] = lr_poly(lr, lr_i_iter, alpha=self.alpha, beta=self.beta,
                                                      num_steps=self.num_steps)
            optimizer.param_groups[1]['lr'] = lr_poly(lr, lr_i_iter, alpha=self.alpha, beta=self.beta,
                                                      num_steps=self.num_steps)
            optimizer.param_groups[2]['lr'] = self.base_weight_factor * lr_poly(lr, lr_i_iter, alpha=self.alpha,
                                                                                beta=self.beta,
                                                                                num_steps=self.num_steps)
            optimizer.param_groups[3]['lr'] = self.base_weight_factor * lr_poly(lr, lr_i_iter, alpha=self.alpha,
                                                                                beta=self.beta,
                                                                                num_steps=self.num_steps)
        else:
            raise RuntimeError('Wrong optimizer param groups')

    def current_lr(self, i_iter):
        return lr_poly(self.learning_rate, i_iter, alpha=self.alpha, beta=self.beta, num_steps=self.num_steps)


# Train

In [None]:
save_dir = './logs'
num_classes = 10
num_domains = 2 # source + target
num_source_domains = 1
num_target_domains = 1

start_iter = 1
end_iter = 30000
adaptation_gamma = 10

learning_rate = 0.001
weight_decay = 0.0
base_weight_factor = 0.1

best_accuracy = 0.0
best_accuracy_each_c = 0.0
best_mean_val_accuracy = 0.0
best_total_val_accuracy = 0.0

disp_interval = 10
save_interval = 500
domain_loss_adjust_factor = 0.1


In [None]:
  def adaptation_factor(p, gamma=10):
    p = max(min(p, 1.0), 0.0)
    den = 1.0 + math.exp(-gamma * p)
    lamb = 2.0 / den - 1.0
    return min(lamb, 1.0)

def semantic_loss_calc(x, y, mean=True):
    loss = (x - y) ** 2
    if mean:
        return torch.mean(loss)
    else:
        return loss

def accuracy(output, target, topk=(1,)):
    """Computes the precision@k for the specified values of k"""
    maxk = max(topk)
    num_samples = target.size(0)

    _, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()
    correct = pred.eq(target.view(1, -1).expand_as(pred))

    res = []
    for k in topk:
        correct_k = correct[:k].view(-1).float().sum(0)
        res.append(correct_k.div_(num_samples))
    return res


def accuracy_of_c(output, target, class_idx, topk=(1,)):
    """Computes the precision@k for the specified values of k"""
    maxk = max(topk)
    # num_samples = target.size(0)
    selection = target == class_idx
    target_selected = target[selection]
    output_selected = output[selection]
    num_samples = torch.sum(selection).float()

    _, pred = output_selected.topk(maxk, 1, True, True)
    pred = pred.t().float()
    correct = pred.eq((target_selected.view(1, -1).expand_as(pred)).float())

    res = []
    for k in topk:
        correct_k = correct[:k].view(-1).float().sum(0)
        res.append(correct_k.div_(num_samples))
    return res

# Train with no adaptation

### Domain Adaptation 없이 학습 시켰을 때 결과를 보도록 하겠습니다.

In [None]:
def train_no_adapt(model, dataloaders, optimizer, lr_scheduler, ce_loss, start_time):
    source_train_loader, target_train_loader, target_val_loader = dataloaders
    source_train_loader_iter, target_train_loader_iter, target_val_loader_iter = map(iter, dataloaders)

    best_accuracy = 0.0
    monitor = Monitor()
    for i_iter in range(start_iter, 10001):
        try:
            x_s, y_s = next(source_train_loader_iter)
        except StopIteration:
            source_train_loader_iter = iter(source_train_loader)
            x_s, y_s = next(source_train_loader_iter)

        x_s, y_s = x_s.cuda(), y_s.cuda()
        current_lr = lr_scheduler.current_lr(i_iter)

        # init optimizer
        optimizer.zero_grad()
        lr_scheduler(optimizer, i_iter)

        ########################################################################################################
        #                                               Train                                                  #
        ########################################################################################################

        pred_s, f_s = model(x_s, with_ft=True)

        Closs_src = ce_loss(pred_s, y_s)
        monitor.update({"Loss/Closs_src": float(Closs_src)})

        Floss = Closs_src

        # Floss backward
        Floss.backward()
        optimizer.step()


        if i_iter % disp_interval == 0  and i_iter != 0:
            disp_msg = 'iter[{:8d}/{:8d}], '.format(i_iter, 10000)
            disp_msg += str(monitor)
            disp_msg += ', lr={:.6f}'.format(current_lr)
            print(disp_msg)

            monitor.reset()

        if i_iter % save_interval == 0 and i_iter != 0:
            print("Elapsed Time: {}".format(datetime.datetime.now() - start_time))
            print("Start Evaluation at {:d}".format(i_iter))

            model.eval()

            pred_vals = []
            y_vals = []
            x_val = None
            y_val = None
            pred_val = None

            with torch.no_grad():
                for i, (x_val, y_val) in enumerate(target_val_loader):
                    y_vals.append(y_val.cpu())
                    x_val = x_val.cuda()
                    y_val = y_val.cuda()

                    pred_val = model(x_val, with_ft=False)
                    pred_vals.append(pred_val.cpu())

            pred_vals = torch.cat(pred_vals, 0)
            y_vals = torch.cat(y_vals, 0)
            total_val_accuracy = float(accuracy(pred_vals, y_vals, topk=(1,))[0])

            val_accuracy_each_c = [(c_name, float(accuracy_of_c(pred_vals, y_vals,
                                                                class_idx=c, topk=(1,))[0]))
                                   for c, c_name in enumerate(range(num_classes))]
            print('\nMNIST Accuracy of Each class')
            print(''.join(["{:<25}: {:.2f}%\n".format(c_name, 100 * c_val_acc)
                            for c_name, c_val_acc in val_accuracy_each_c]))

            mean_val_accuracy = float(
                torch.mean(torch.FloatTensor([c_val_acc for _, c_val_acc in val_accuracy_each_c])))

            print('MNIST mean Accuracy: {:.2f}%'.format(100 * mean_val_accuracy))
            print("MNIST Accuracy: {:.2f}%".format(total_val_accuracy * 100))

            model.train()

            val_accuracy = total_val_accuracy

            del x_val, y_val, pred_val, pred_vals, y_vals

            if val_accuracy > best_accuracy:
                #save best model
                best_accuracy = val_accuracy
                best_accuracy_each_c = val_accuracy_each_c
                best_mean_val_accuracy = mean_val_accuracy
                best_total_val_accuracy = total_val_accuracy

                model = model.cuda()

            print('\nBest {MNIST} Accuracy of Each class')
            print(''.join(["{:<25}: {:.2f}%\n".format(c_name, 100 * c_val_acc)
                           for c_name, c_val_acc in best_accuracy_each_c]))
            print('Best Accs: ' + ''.join(["{:.2f}% ".format(100 * c_val_acc)
                                           for _, c_val_acc in best_accuracy_each_c]))
            print('Best mean Accuracy: {:.2f}%'.format(100 * best_mean_val_accuracy))
            print('Best Accuracy: {:.2f}%'.format(100 * best_total_val_accuracy))

In [None]:
def no_adapt():
    start_time = datetime.datetime.now()

    # make save_dir
    if not os.path.isdir(save_dir):
        os.makedirs(save_dir)

    dataloaders = get_dataloaders()

    ###################################################################################################################
    #                                               Model Loading                                                     #
    ###################################################################################################################
    model = LeNet(num_classes = num_classes)

    model.train(True)
    model = model.cuda()
    params = get_optimizer_params(model,
                                  lr = learning_rate,
                                  weight_decay=weight_decay,
                                  base_weight_factor=base_weight_factor)

    ###################################################################################################################
    #                                               Train Configurations                                              #
    ###################################################################################################################
    ce_loss = nn.CrossEntropyLoss()

    lr_scheduler = LRScheduler(learning_rate, end_iter, base_weight_factor=base_weight_factor)

    optimizer = optim.Adam(params, betas=(0.9, 0.999))

    train_no_adapt(model, dataloaders, optimizer, lr_scheduler, ce_loss, start_time)
    print('Total Time:  {}'.format((datetime.datetime.now() - start_time)))


In [None]:
no_adapt()

Downloading http://ufldl.stanford.edu/housenumbers/train_32x32.mat to ./data/svhn/train_32x32.mat


100%|██████████| 182040794/182040794 [00:05<00:00, 32604062.82it/s]


Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./data/mnist/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:00<00:00, 117036545.30it/s]


Extracting ./data/mnist/MNIST/raw/train-images-idx3-ubyte.gz to ./data/mnist/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./data/mnist/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 20863881.13it/s]


Extracting ./data/mnist/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/mnist/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./data/mnist/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:00<00:00, 29985654.69it/s]


Extracting ./data/mnist/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/mnist/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./data/mnist/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 26907526.51it/s]


Extracting ./data/mnist/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/mnist/MNIST/raw

iter[      10/   10000], Loss/Closs_src=2.3011, lr=0.000998
iter[      20/   10000], Loss/Closs_src=2.2565, lr=0.000995
iter[      30/   10000], Loss/Closs_src=2.2474, lr=0.000993
iter[      40/   10000], Loss/Closs_src=2.1834, lr=0.000990
iter[      50/   10000], Loss/Closs_src=2.0746, lr=0.000988
iter[      60/   10000], Loss/Closs_src=2.0082, lr=0.000985
iter[      70/   10000], Loss/Closs_src=1.8850, lr=0.000983
iter[      80/   10000], Loss/Closs_src=1.6392, lr=0.000980
iter[      90/   10000], Loss/Closs_src=1.5203, lr=0.000978
iter[     100/   10000], Loss/Closs_src=1.3351, lr=0.000976
iter[     110/   10000], Loss/Closs_src=1.2181, lr=0.000973
iter[     120/   10000], Loss/Closs_src=1.1553, lr=0.000971
iter[     130/   10000], Loss/Closs_src=0.9906, lr=0.000969
iter[     140/   10000], Loss/Closs_src=0.9666, lr=0.000966
iter[     150/   10000], Loss/Closs_src=0.9805, lr=0.000964
iter[     160

# Train with DSBN

DSBN 코드는 domain adaptation을 위해 다음의 방법들을 사용했습니다.

<img src="https://ifh.cc/g/Jrshtr.png">

    1. DSBN
    2. Adversarial loss
    3. Semantic Matching loss

Train code를 통해 추가적으로 Adversarial loss를 구현해보도록 하겠습니다.



In [None]:
def train(model, discriminator, centroids,
          dataloaders, optimizers, lr_scheduler, ce_loss, bce_loss, start_time):
    source_train_loader, target_train_loader, target_val_loader = dataloaders
    source_train_loader_iter, target_train_loader_iter, target_val_loader_iter = map(iter, dataloaders)
    src_centroid, trg_centroid = centroids
    optimizer, optimizer_D = optimizers

    best_accuracy = 0.0
    monitor = Monitor()
    for i_iter in range(start_iter, end_iter+1):
        try:
            x_s, y_s = next(source_train_loader_iter)
        except StopIteration:
            source_train_loader_iter = iter(source_train_loader)
            x_s, y_s = next(source_train_loader_iter)

        try:
            x_t, y_t = next(target_train_loader_iter)
        except StopIteration:
            target_train_loader_iter = iter(target_train_loader)
            x_t, y_t = next(target_train_loader_iter)

        x_s, y_s, x_t, y_t = x_s.cuda(), y_s.cuda(), x_t.cuda(), y_t.cuda()
        current_lr = lr_scheduler.current_lr(i_iter)
        adaptation_lambda = adaptation_factor(i_iter / float(end_iter),
                                              gamma=adaptation_gamma)

        # init optimizer
        optimizer.zero_grad()
        lr_scheduler(optimizer, i_iter)
        optimizer_D.zero_grad()
        lr_scheduler(optimizer_D, i_iter)

        ########################################################################################################
        #                                               Train G                                                #
        ########################################################################################################
        for param in discriminator.parameters():
            param.requires_grad = False

        src_domain_id = torch.zeros(x_s.shape[0], dtype=torch.long).cuda()
        trg_domain_id = torch.ones(x_t.shape[0], dtype=torch.long).cuda()

        pred_s, f_s = model(x_s, src_domain_id, with_ft=True)
        pred_t, f_t = model(x_t, trg_domain_id, with_ft=True)

        Closs_src = ce_loss(pred_s, y_s)
        monitor.update({"Loss/Closs_src": float(Closs_src)})

        Floss = Closs_src

        ############ TODO #############
        '''
        discriminator를 사용해 loss를 Gloss를 구하시오.
        discriminator:
            Input: f_s or f_t

        loss:
            domain_loss_adjust_factor * bce_loss( , )
        '''
        Dout_s = discriminator(f_s)
        source_label = torch.zeros_like(Dout_s).cuda()
        loss_adv_src = domain_loss_adjust_factor * bce_loss(Dout_s, source_label)

        Dout_t = discriminator(f_t)
        target_label = torch.ones_like(Dout_t).cuda()
        loss_adv_trg = domain_loss_adjust_factor * bce_loss(Dout_t, target_label)

        ############ TODO #############

        Gloss =  - (loss_adv_src + loss_adv_trg) / 2
        monitor.update({'Loss/Gloss': float(Gloss)})

        Floss = Floss + adaptation_lambda * Gloss

        # pseudo label generation
        pred_t_pseudo = []
        with torch.no_grad():
            model.eval()
            pred_t_pseudo = model(x_t, trg_domain_id, with_ft=False)
            model.train(True)

        # moving semantic loss
        current_src_centroid = src_centroid(f_s, y_s)
        current_trg_centroid = trg_centroid(f_t, torch.argmax(pred_t_pseudo, 1))

        semantic_loss = semantic_loss_calc(current_src_centroid, current_trg_centroid)
        monitor.update({'Loss/SMloss': float(semantic_loss)})

        Floss = Floss + adaptation_lambda * semantic_loss

        # Floss backward
        Floss.backward()
        optimizer.step()
        ########################################################################################################
        #                                               Train D                                                #
        ########################################################################################################
        for param in discriminator.parameters():
            param.requires_grad = True

        ############ TODO #############
        '''
        위에서 Gloss를 구하는 과정을 참고해 Dloss를 구하시오.

        Hint: discriminator input으로 들어가는 feature들은 detach 해줘야 합니다.
        '''
        Dout_s = discriminator(f_s.detach())
        source_label = torch.zeros_like(Dout_s).cuda()
        loss_adv_src = domain_loss_adjust_factor * bce_loss(Dout_s, source_label)

        Dout_t = discriminator(f_t.detach())
        target_label = torch.ones_like(Dout_t).cuda()
        loss_adv_trg = domain_loss_adjust_factor * bce_loss(Dout_t, target_label)

        ############ TODO #############

        Dloss = (loss_adv_src + loss_adv_trg) / 2
        monitor.update({'Loss/Dloss': float(Dloss)})
        Dloss = adaptation_lambda * Dloss
        Dloss.backward()
        optimizer_D.step()

        src_centroid.centroids.data = current_src_centroid.data
        trg_centroid.centroids.data = current_trg_centroid.data

        if i_iter % disp_interval == 0  and i_iter != 0:
            disp_msg = 'iter[{:8d}/{:8d}], '.format(i_iter, end_iter)
            disp_msg += str(monitor)
            disp_msg += ', lambda={:.6f}'.format(adaptation_lambda)
            disp_msg += ', lr={:.6f}'.format(current_lr)
            print(disp_msg)

            monitor.reset()

        if i_iter % save_interval == 0 and i_iter != 0:
            print("Elapsed Time: {}".format(datetime.datetime.now() - start_time))
            print("Start Evaluation at {:d}".format(i_iter))

            model.eval()

            pred_vals = []
            y_vals = []
            x_val = None
            y_val = None
            pred_val = None

            with torch.no_grad():
                for i, (x_val, y_val) in enumerate(target_val_loader):
                    y_vals.append(y_val.cpu())
                    x_val = x_val.cuda()
                    y_val = y_val.cuda()

                    pred_val = model(x_val, trg_domain_id, with_ft=False)
                    pred_vals.append(pred_val.cpu())

            pred_vals = torch.cat(pred_vals, 0)
            y_vals = torch.cat(y_vals, 0)
            total_val_accuracy = float(accuracy(pred_vals, y_vals, topk=(1,))[0])

            val_accuracy_each_c = [(c_name, float(accuracy_of_c(pred_vals, y_vals,
                                                                class_idx=c, topk=(1,))[0]))
                                   for c, c_name in enumerate(range(num_classes))]
            print('\nMNIST Accuracy of Each class')
            print(''.join(["{:<25}: {:.2f}%\n".format(c_name, 100 * c_val_acc)
                            for c_name, c_val_acc in val_accuracy_each_c]))

            mean_val_accuracy = float(
                torch.mean(torch.FloatTensor([c_val_acc for _, c_val_acc in val_accuracy_each_c])))

            print('MNIST mean Accuracy: {:.2f}%'.format(100 * mean_val_accuracy))
            print("MNIST Accuracy: {:.2f}%".format(total_val_accuracy * 100))

            model.train()

            val_accuracy = total_val_accuracy

            del x_val, y_val, pred_val, pred_vals, y_vals

            if val_accuracy > best_accuracy:
                #save best model
                best_accuracy = val_accuracy
                best_accuracy_each_c = val_accuracy_each_c
                best_mean_val_accuracy = mean_val_accuracy
                best_total_val_accuracy = total_val_accuracy

                model = model.cuda()
                discriminator = discriminator.cuda()
                src_centroid = src_centroid.cuda()
                trg_centroid = trg_centroid.cuda()

            print('\nBest {MNIST} Accuracy of Each class')
            print(''.join(["{:<25}: {:.2f}%\n".format(c_name, 100 * c_val_acc)
                           for c_name, c_val_acc in best_accuracy_each_c]))
            print('Best Accs: ' + ''.join(["{:.2f}% ".format(100 * c_val_acc)
                                           for _, c_val_acc in best_accuracy_each_c]))
            print('Best mean Accuracy: {:.2f}%'.format(100 * best_mean_val_accuracy))
            print('Best Accuracy: {:.2f}%'.format(100 * best_total_val_accuracy))



In [None]:
def main():
    start_time = datetime.datetime.now()

    # make save_dir
    if not os.path.isdir(save_dir):
        os.makedirs(save_dir)

    dataloaders = get_dataloaders()

    ###################################################################################################################
    #                                               Model Loading                                                     #
    ###################################################################################################################
    model = DSBNLeNet(num_classes = num_classes,
                      num_domains = num_domains)

    model.train(True)
    model = model.cuda()
    params = get_optimizer_params(model,
                                  lr = learning_rate,
                                  weight_decay=weight_decay,
                                  base_weight_factor=base_weight_factor)

    discriminator = Discriminator(in_features=num_classes).cuda()
    D_params = get_optimizer_params(model,
                                    lr = learning_rate,
                                    weight_decay=weight_decay,
                                    base_weight_factor=None)
    ### For sm_loss
    src_centroid = Centroids(num_classes, num_classes).cuda()
    trg_centroid = Centroids(num_classes, num_classes).cuda()
    centroids = [src_centroid, trg_centroid]

    ###################################################################################################################
    #                                               Train Configurations                                              #
    ###################################################################################################################
    ce_loss = nn.CrossEntropyLoss()
    bce_loss = nn.BCEWithLogitsLoss()

    lr_scheduler = LRScheduler(learning_rate, end_iter, base_weight_factor=base_weight_factor)

    optimizer = optim.Adam(params, betas=(0.9, 0.999))
    optimizer_D = optim.Adam(D_params, betas=(0.9, 0.999))
    optimizers = [optimizer, optimizer_D]

    train(model, discriminator, centroids, dataloaders, optimizers, lr_scheduler, ce_loss, bce_loss, start_time)
    print('Total Time:  {}'.format((datetime.datetime.now() - start_time)))


In [None]:
main()

Using downloaded and verified file: ./data/svhn/train_32x32.mat
iter[      10/   30000], Loss/Closs_src=2.3026, Loss/Dloss=0.0693, Loss/Gloss=-0.0693, Loss/SMloss=0.1029, lambda=0.001667, lr=0.000998
iter[      20/   30000], Loss/Closs_src=2.2428, Loss/Dloss=0.0693, Loss/Gloss=-0.0693, Loss/SMloss=0.0674, lambda=0.003333, lr=0.000995
iter[      30/   30000], Loss/Closs_src=2.2603, Loss/Dloss=0.0693, Loss/Gloss=-0.0693, Loss/SMloss=0.1183, lambda=0.005000, lr=0.000993
iter[      40/   30000], Loss/Closs_src=2.1815, Loss/Dloss=0.0693, Loss/Gloss=-0.0693, Loss/SMloss=0.0876, lambda=0.006667, lr=0.000990
iter[      50/   30000], Loss/Closs_src=2.1922, Loss/Dloss=0.0693, Loss/Gloss=-0.0693, Loss/SMloss=0.1229, lambda=0.008333, lr=0.000988
iter[      60/   30000], Loss/Closs_src=2.1760, Loss/Dloss=0.0693, Loss/Gloss=-0.0693, Loss/SMloss=0.0701, lambda=0.010000, lr=0.000985
iter[      70/   30000], Loss/Closs_src=2.0487, Loss/Dloss=0.0693, Loss/Gloss=-0.0693, Loss/SMloss=0.0957, lambda=0.0116

# Summary
- 도메인 별 배치 정규화 (DSBN)은 소스 도메인과 타겟 도메인 각각에 배치
정규화를 독립적으로 수행함으로써 도메인 전이 문제를 효과적으로
해결하였습니다.
- 그러나 이 방법은 타겟 도메인을 모르는 경우에는 사용할 수 없기 때문에,
도메인 일반화 문제에는 적용할 수 없다는 한계가 있습니다.