# Domain adaptation



![](./figs/da.png)

### Resources


- The original paper "Unsupervised Domain Adaptation by Backpropagation" (2014)

  http://sites.skoltech.ru/compvision/projects/grl/files/paper.pdf
  
  
- Video of the Conference ICCV (2017)

  In which the research is framed in a broader context

  https://www.youtube.com/watch?v=uUUvieVxCMs&t=1210s

In [None]:
import os
if not os.path.exists('da_results'):
    os.makedirs('da_results')

In [None]:
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision import transforms
from torch.autograd import Function
from tqdm import tqdm
import numpy as np
from matplotlib import pyplot as plt
import mnistm
from PIL import Image

In [None]:
class GrayscaleToRgb:
    """Convert a grayscale image to rgb"""
    def __call__(self, image):
        image = np.array(image)
        image = np.dstack([image, image, image])
        return Image.fromarray(image)

In [None]:
class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.feature_extractor = nn.Sequential(
            nn.Conv2d(3, 10, kernel_size=5),
            nn.MaxPool2d(2),
            nn.ReLU(),
            nn.Conv2d(10, 20, kernel_size=5),
            nn.MaxPool2d(2),
            nn.Dropout2d(),
        )
        
        self.classifier = nn.Sequential(
            nn.Linear(320, 50),
            nn.ReLU(),
            nn.Dropout(),
            nn.Linear(50, 10),
        )

    def forward(self, x):
        features = self.feature_extractor(x)
        features = features.view(x.shape[0], -1)
        logits = self.classifier(features)
        return logits

In [None]:
# Taken from https://github.com/jvanvugt/pytorch-domain-adaptation/
class GradientReversalFunction(Function):
    """
    Gradient Reversal Layer from:
    Unsupervised Domain Adaptation by Backpropagation (Ganin & Lempitsky, 2015)
    Forward pass is the identity function. In the backward pass,
    the upstream gradients are multiplied by -lambda (i.e. gradient is reversed)
    """
    
    @staticmethod
    def forward(ctx, x, lambda_):
        ctx.lambda_ = lambda_
        return x.clone()

    @staticmethod
    def backward(ctx, grads):
        lambda_ = ctx.lambda_
        lambda_ = grads.new_tensor(lambda_)
        dx = -lambda_ * grads
        return dx, None

class GradientReversal(torch.nn.Module):
    def __init__(self, lambda_=1):
        super(GradientReversal, self).__init__()
        self.lambda_ = lambda_

    def forward(self, x):
        return GradientReversalFunction.apply(x, self.lambda_)

In [None]:
epochs=10 #50
seed = 1101
batch_size = 64

use_cuda = torch.cuda.is_available()
torch.manual_seed(seed)
device = torch.device("cuda" if use_cuda else "cpu")
kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}

In [None]:
model = Net().to(device)
model.load_state_dict(torch.load('conv_for_domain_adaptation.pt'))

In [None]:
feature_extractor = model.feature_extractor
clf = model.classifier

discriminator = nn.Sequential(
    GradientReversal(),
    nn.Linear(320, 50),
    nn.ReLU(),
    nn.Linear(50, 20),
    nn.ReLU(),
    nn.Linear(20, 1)
).to(device)

In [None]:
half_batch = batch_size // 2

source_loader = torch.utils.data.DataLoader(
        datasets.MNIST('../data', train=True, download=True,
                       transform=transforms.Compose([
                           GrayscaleToRgb(),
                           transforms.ToTensor(),                           
                       ])),
        batch_size=half_batch, shuffle=True, **kwargs)

In [None]:
target_loader = torch.utils.data.DataLoader( 
    mnistm.MNISTM('../data', train=True, download=True,
                       transform=transforms.Compose([
                           transforms.ToTensor(),                           
                       ])), 
    batch_size=half_batch, shuffle=True, **kwargs)

In [None]:
imgs,labels = next(iter(target_loader))
print(imgs.shape)

In [None]:
plt.imshow(np.transpose(imgs[19], (1,2,0)))
plt.show()

In [None]:
optim = torch.optim.Adam(list(discriminator.parameters()) + list(model.parameters()))

In [None]:
for epoch in range(1, epochs+1):
    batches = zip(source_loader, target_loader)
    n_batches = min(len(source_loader), len(target_loader))

    total_domain_loss = total_label_accuracy = 0
    for (source_x, source_labels), (target_x, _) in tqdm(batches, leave=False, total=n_batches):
            x = torch.cat([source_x, target_x])
            x = x.to(device)
            domain_y = torch.cat([torch.ones(source_x.shape[0]),
                                  torch.zeros(target_x.shape[0])])
            domain_y = domain_y.to(device)
            label_y = source_labels.to(device)

            features = feature_extractor(x).view(x.shape[0], -1)
            domain_preds = discriminator(features).squeeze()
            label_preds = clf(features[:source_x.shape[0]])

            domain_loss = F.binary_cross_entropy_with_logits(domain_preds, domain_y) #Ld
            label_loss = F.cross_entropy(label_preds, label_y) #Ly
            loss = domain_loss + label_loss

            optim.zero_grad()
            loss.backward()
            optim.step()

            total_domain_loss += domain_loss.item()
            total_label_accuracy += (label_preds.max(1)[1] == label_y).float().mean().item()

    mean_loss = total_domain_loss / n_batches
    mean_accuracy = total_label_accuracy / n_batches
    tqdm.write(f'EPOCH {epoch:03d}: domain_loss={mean_loss:.4f}, '
               f'source_accuracy={mean_accuracy:.4f}')

    torch.save(model.state_dict(), 'da_results/revgrad_' + str(epoch) + '.pt')


# Optional work

- Assess the performance of the model through all epochs on source and target dataset. You should define new loaders with a suitable batch size. Pay attention to the case of the source dataset. Plot the accuracy on source and target domains as a function of the epoch, including the 0 epoch, before domain adaptation.

- Visualize features before and after domain adaptation with T-SNE. Is it visible the feature adaptation?

- Compute the ID of source and target features beore and  during the domain adaptation. What do you expect to see?