# Fada implemention in pytorch
* papers:Few-Shot Domain Adaptation (https://arxiv.org/abs/1711.02536)

## 准备工作

### 导入相关的包

In [None]:
import torch
import torchvision
import numpy as np
import torch.nn.functional as F
from torch import nn
from torch.autograd import Variable
import random
import matplotlib.pyplot as plt
% matplotlib inline 

### 定义相关的参数（训练）

### use_functions

#### accuracy

In [None]:
def accuracy(y_pred, y):
    return (torch.max(y_pred, 1)[1] == y).float().mean().data[0]

#### eval_on_test
* Returns the mean accuracy on the test set, given a model 

In [None]:
def eval_on_test(test_dataloader, model_fn):
    acc = 0
    for x, y in test_dataloader:
        x, y = Variable(x), Variable(y)
        if torch.cuda.is_available():
            x, y = x.cuda(), y.cuda()
        acc += accuracy(model_fn(x), y)
    return round(acc / float(len(test_dataloader)), 3)

#### into_tensor
* Converts a list of (x, x) pairs into two Tensors 

In [None]:
def into_tensor(data, into_vars=True):
    X1 = [x[0] for x in data]
    X2 = [x[1] for x in data]
    if torch.cuda.is_available():
        return Variable(torch.stack(X1)).cuda(), Variable(torch.stack(X2)).cuda()
    return Variable(torch.stack(X1)), Variable(torch.stack(X2))

## 构建数据集

### mnist_dataloader
* Returns the MNIST dataloader

In [None]:
def mnist_dataloader(batch_size=256, train=True, cuda=False):
    dataset = torchvision.datasets.MNIST('./data', download=True, train=train, transform=torchvision.transforms.ToTensor())
    return torch.utils.data.DataLoader(dataset, shuffle=True, batch_size=batch_size, num_workers=2, pin_memory=cuda)

### svhn_dataloaderabs
* Returns the SVHN dataloader

In [None]:
def svhn_dataloader(batch_size=256, train=True, cuda=False):
    transform = torchvision.transforms.Compose([
        torchvision.transforms.Resize((28, 28)),
        torchvision.transforms.Grayscale(),
        torchvision.transforms.ToTensor()
    ])
    dataset = torchvision.datasets.SVHN('./data', download=True, split=('train' if train else 'test'), transform=transform)
    return torch.utils.data.DataLoader(dataset, shuffle=True, batch_size=batch_size, num_workers=2, pin_memory=cuda)

### sample_data
* Samples a subset from source into memory 

In [None]:
def sample_data(n=2000):
    dataset = torchvision.datasets.MNIST('./data', download=True, train=True, transform=torchvision.transforms.ToTensor())
    X = torch.FloatTensor(n, 1, 28, 28)
    Y = torch.LongTensor(n)
    inds = torch.randperm(len(dataset))[:n]
    for i, index in enumerate(inds):
        x, y = dataset[index]
        X[i] = x
        Y[i] = y
    return X, Y

### create_target_samples
* Returns a subset of the target domain such that it has n_target_samples per class 

In [None]:
def create_target_samples(n=1):
    transform = torchvision.transforms.Compose([
        torchvision.transforms.Resize((28, 28)),
        torchvision.transforms.Grayscale(),
        torchvision.transforms.ToTensor()
    ])
    dataset = torchvision.datasets.SVHN('./data', download=True, split='train', transform=transform)
    X, Y = [], []
    classes = 10 * [n]
    i = 0
    while True:
        if len(X) == n*10:
            break
        x, y = dataset[i]
        if classes[y] > 0:
            X.append(x)
            Y.append(y)
            classes[y] -= 1
        i += 1
    assert(len(X) == n*10)
    return torch.stack(X), torch.from_numpy(np.array(Y))

### create_groups
*  Samples uniformly groups G1 and G3 from D_s x D_s and groups G2 and G4 from D_s x D_t

In [None]:
def create_groups(X_s, y_s, X_t, y_t):
    n = X_t.shape[0]
    G1, G3 = [], []
    # TODO optimize
    # Groups G1 and G3 come from the source domain
    for i, (x1, y1) in enumerate(zip(X_s, y_s)):
        for j, (x2, y2) in enumerate(zip(X_s, y_s)):
            if y1 == y2 and i != j and len(G1) < n:
                G1.append((x1, x2))
            if y1 != y2 and i != j and len(G3) < n:
                G3.append((x1, x2))
    G2, G4 = [], []
    # Groups G2 and G4 are mixed from the source and target domains
    for i, (x1, y1) in enumerate(zip(X_s, y_s)):
        for j, (x2, y2) in enumerate(zip(X_t, y_t)):
            if y1 == y2 and i != j and len(G2) < n:
                G2.append((x1, x2))
            if y1 != y2 and i != j and len(G4) < n:
                G4.append((x1, x2))
    groups = [G1, G2, G3, G4]
    # Make sure we sampled enough samples
    for g in groups:
        assert(len(g) == n)
    return groups

### sample_groups
*  Sample groups G1, G2, G3, G4

In [None]:
def sample_groups(n_target_samples=2):
    X_s, y_s = sample_data()
    X_t, y_t = create_target_samples(n_target_samples)
    print("Sampling groups")
    return create_groups(X_s, y_s, X_t, y_t), (X_s, y_s, X_t, y_t)

## 定义网络结构

### Domain-Class Discriminator
* Domain-Class Discriminator (see (3) in the paper) Takes in the concatenated latent representation of two samples from G1, G2, G3 or G4, and outputs a class label, one of [0, 1, 2, 3]

In [None]:
class DCD(nn.Module):
    def __init__(self, H=64, D_in=784):
        super(DCD, self).__init__()
        self.fc1 = nn.Linear(D_in, H)
        self.fc2 = nn.Linear(H, H)
        self.out = nn.Linear(H, 4)
    def forward(self, x):
        out = F.relu(self.fc1(x))
        out = self.fc2(out)
        return F.softmax(self.out(out), dim=1)

### Classifier
* Called h in the paper. Gives class predictions based on the latent representation 

In [None]:
class Classifier(nn.Module):
    def __init__(self, D_in=64):
        super(Classifier, self).__init__()
        self.out = nn.Linear(D_in, 10)
    def forward(self, x):
        return F.softmax(self.out(x), dim=1)

### Encoder
* Creates latent representation based on data. Called g in the paper.Like in the paper, we use g_s = g_t = g, that is, we share weights between target and source representations. Model is as specified in section 4.1. See https://github.com/kuangliu/pytorch-cifar/blob/master/models/lenet.py

In [None]:
class Encoder(nn.Module):
    def __init__(self):
        super(Encoder, self).__init__()
    
        self.conv1 = nn.Conv2d(1, 6, 5)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1   = nn.Linear(256, 120)
        self.fc2   = nn.Linear(120, 84)
        self.fc3   = nn.Linear(84, 64)

    def forward(self, x):
        out = F.relu(self.conv1(x))
        out = F.max_pool2d(out, 2)
        out = F.relu(self.conv2(out))
        out = F.max_pool2d(out, 2)
        out = out.view(out.size(0), -1)
        out = F.relu(self.fc1(out))
        out = F.relu(self.fc2(out))
        out = self.fc3(out)
        return out

## 训练网络

In [None]:
def model_fn(encoder, classifier):
    return lambda x: classifier(encoder(x))

''' Pretrain the encoder and classifier as in (a) in figure 2. '''

def pretrain(data, epochs=5, batch_size=128, cuda=False):
    X_s, y_s, _, _ = data
    test_dataloader = mnist_dataloader(train=False, cuda=cuda)
    classifier = Classifier()
    encoder = Encoder()
    if cuda:
        classifier.cuda()
        encoder.cuda()
    ''' Jointly optimize both encoder and classifier ''' 
    optimizer = optim.Adam(list(encoder.parameters()) + list(classifier.parameters()))
    loss_fn = nn.CrossEntropyLoss()
    for e in range(epochs):  
        for _ in range(len(X_s) // batch_size):
            inds = torch.randperm(len(X_s))[:batch_size]
            x, y = Variable(X_s[inds]), Variable(y_s[inds])
            optimizer.zero_grad()
            if cuda:
                x, y = x.cuda(), y.cuda()
            y_pred = model_fn(encoder, classifier)(x)
            loss = loss_fn(y_pred, y)
            loss.backward()
            optimizer.step()
        print("Epoch", e, "Loss", loss.data[0], "Accuracy", eval_on_test(test_dataloader, model_fn(encoder, classifier)))
    return encoder, classifier

''' Train the discriminator while the encoder is frozen '''
def train_discriminator(encoder, groups, n_target_samples=2, cuda=False, epochs=20):
    source_loader = mnist_dataloader(train=True, cuda=cuda)
    target_loader = svhn_dataloader(train=True, cuda=cuda)
    discriminator = DCD(D_in=128) # Takes in concatenated hidden representations
    loss_fn = nn.CrossEntropyLoss()
    # Only train DCD
    optimizer = optim.Adam(discriminator.parameters())
    # Size of group G2, the smallest one, times the amount of groups
    n_iters = 4 * n_target_samples
    if cuda:
        discriminator.cuda()
    print("Training DCD")
    for e in range(epochs):
        for _ in range(n_iters):    
            # Sample a pair of samples from a group
            group = random.choice([0, 1, 2, 3])
            x1, x2 = groups[group][random.randint(0, len(groups[group]) - 1)]
            x1, x2 = Variable(x1), Variable(x2)
            if cuda:
                x1, x2 = x1.cuda(), x2.cuda()
            # Optimize the DCD using sample drawn
            optimizer.zero_grad()
            # Concatenate encoded representations
            x_cat = torch.cat([encoder(x1.unsqueeze(0)), encoder(x2.unsqueeze(0))], 1)
            y_pred = discriminator(x_cat)
            # Label is the group
            y = Variable(torch.LongTensor([group]))
            if cuda:
                y = y.cuda()
            loss = -loss_fn(y_pred, y)
            loss.backward()
            optimizer.step()
        print("Epoch", e, "Loss", loss.data[0])    
    return discriminator

''' FADA Loss, as given by (4) in the paper. The minus sign is shifted because it seems to be wrong '''
def fada_loss(y_pred_g2, g1_true, y_pred_g4, g3_true, gamma=0.2):
    return -gamma * torch.mean(g1_true * torch.log(y_pred_g2) + g3_true * torch.log(y_pred_g4))

''' Step three of the algorithm, train everything except the DCD '''
def train(encoder, discriminator, classifier, data, groups, n_target_samples=2, cuda=False, epochs=20, batch_size=256, plot_accuracy=False):   
    # For evaluation only
    test_dataloader = svhn_dataloader(train=False, cuda=cuda)
    X_s, Y_s, X_t, Y_t = data
    G1, G2, G3, G4 = groups
    ''' Two optimizers, one for DCD (which is frozen) and one for class training ''' 
    class_optimizer = optim.Adam(list(encoder.parameters()) + list(classifier.parameters()))
    dcd_optimizer = optim.Adam(encoder.parameters())
    loss_fn = nn.CrossEntropyLoss()
    n_iters = 4 * n_target_samples   
    if plot_accuracy:
        accuracies = []
    for e in range(epochs):      
        # Shuffle data at each epoch
        inds = torch.randperm(X_s.shape[0])
        X_s, Y_s = X_s[inds], Y_s[inds]
        inds = torch.randperm(X_t.shape[0])
        X_t, Y_t = X_t[inds], Y_t[inds]
        g2_one, g2_two = into_tensor(G2, into_vars=True)
        g4_one, g4_two = into_tensor(G4, into_vars=True)
        inds = torch.randperm(g2_one.shape[0])
        if cuda:
            inds = inds.cuda()
        g2_one, g2_two, g4_one, g4_two = g2_one[inds], g2_two[inds], g4_one[inds], g4_two[inds]
        for _ in range(n_iters):           
            class_optimizer.zero_grad()
            dcd_optimizer.zero_grad()
            # Evaluate source predictions
            inds = torch.randperm(X_s.shape[0])[:batch_size]
            x_s, y_s = Variable(X_s[inds]), Variable(Y_s[inds])
            if cuda:
                x_s, y_s = x_s.cuda(), y_s.cuda()
            y_pred_s = model_fn(encoder, classifier)(x_s)            
            # Evaluate target predictions
            ind = random.randint(0, X_t.shape[0] - 1)
            x_t, y_t = Variable(X_t[ind].unsqueeze(0)), Variable(torch.LongTensor([Y_t[ind]]))
            if cuda:
                x_t, y_t = x_t.cuda(), y_t.cuda()
            y_pred_t = model_fn(encoder, classifier)(x_t)
            # Evaluate groups       
            x1, x2 = encoder(g2_one), encoder(g2_two)
            y_pred_g2 = discriminator(torch.cat([x1, x2], 1))
            g1_true = 1
            x1, x2 = encoder(g4_one), encoder(g4_two)
            y_pred_g4 = discriminator(torch.cat([x1, x2], 1))
            g3_true = 3
            # Evaluate loss
            # This is the full loss given by (5) in the paper
            loss = fada_loss(y_pred_g2, g1_true, y_pred_g4, g3_true) + loss_fn(y_pred_s, y_s) + loss_fn(y_pred_t, y_t)
            loss.backward()
            class_optimizer.step()
        acc = eval_on_test(test_dataloader, model_fn(encoder, classifier))
        print("Epoch", e, "Loss", loss.data[0], "Accuracy", acc)
        if plot_accuracy:
            accuracies.append(acc)
    if plot_accuracy:
        plt.plot(range(len(accuracies)), accuracies)
        plt.title("SVHN test accuracy")
        plt.xlabel("Epoch")
        plt.ylabel("Accuracy")
        plt.show()

In [None]:
n_target_samples = 7
plot_accuracy = True
cuda = torch.cuda.is_available()
groups, data = sample_groups(n_target_samples=n_target_samples)  
encoder, classifier = pretrain(data, cuda=cuda, epochs=20)
discriminator = train_discriminator(encoder, groups, n_target_samples=n_target_samples, epochs=50, cuda=cuda)
train(encoder, discriminator, classifier, data, groups, n_target_samples=n_target_samples, cuda=cuda, epochs=150, plot_accuracy=plot_accuracy)

In [None]:
# 测试网络

In [None]:
# 使用训练好的网络