# Homework 11 - Transfer Learning (Domain Adversarial Training)

> Author: Arvin Liu (r09922071@ntu.edu.tw)

若有任何問題，歡迎來信至助教信箱 kafuchino0410@gmail.com


# Readme


這份作業的任務是Transfer Learning中的Domain Adversarial Training。

<img src="https://i.imgur.com/iMVIxCH.png" width="500px">

> 也就是左下角的那一塊。

## Scenario and Why Domain Adversarial Training
你現在有Source Data + label，其中Source Data和Target Data可能有點關係，所以你想要訓練一個model做在Source Data上並Predict在Target Data上。

但這樣有什麼樣的問題? 相信大家學過Anomaly Detection就會知道，如果有data是在Source Data沒有出現過的(或稱Abnormal的)，那麼model大部分都會因為不熟悉這個data而可能亂做一發。 

以下我們將model拆成Feature Extractor(上半部)和Classifier(下半部)來作例子:
<img src="https://i.imgur.com/IL0PxCY.png" width="500px">

整個Model在學習Source Data的時候，Feature Extrator因為看過很多次Source Data，所以所抽取出來的Feature可能就頗具意義，例如像圖上的藍色Distribution，已經將圖片分成各個Cluster，所以這個時候Classifier就可以依照這個Cluster去預測結果。

但是在做Target Data的時候，Feature Extractor會沒看過這樣的Data，導致輸出的Target Feature可能不屬於在Source Feature Distribution上，這樣的Feature給Classifier預測結果顯然就不會做得好。

## Domain Adversarial Training of Nerural Networks (DaNN)
基於如此，是不是只要讓Soucre Data和Target Data經過Feature Extractor都在同個Distribution上，就會做得好了呢? 這就是DaNN的主要核心。

<img src="https://i.imgur.com/vrOE5a6.png" width="500px">

我們追加一個Domain Classifier，在學習的過程中，讓Domain Classifier去判斷經過Feature Extractor後的Feature是源自於哪個domain，讓Feature Extractor學習如何產生Feature以**騙過**Domain Classifier。 持久下來，通常Feature Extractor都會打贏Domain Classifier。(因為Domain Classifier的Input來自於Feature Extractor，而且對Feature Extractor來說Domain&Classification的任務並沒有衝突。)

如此一來，我們就可以確信不管是哪一個Domain，Feature Extractor都會把它產生在同一個Feature Distribution上。

# Data Introduce

這次的任務是Source Data: 真實照片，Target Data: 手畫塗鴉。

我們必須讓model看過真實照片以及標籤，嘗試去預測手畫塗鴉的標籤為何。

資料位於[這裡](https://drive.google.com/file/d/1e4CaQ5VUF3F04XRDGXrnRQGogo89TiF8/view?usp=sharing)，以下的code分別為下載和觀看這次的資料大概長甚麼樣子。

特別注意一點: **這次的source和target data的圖片都是平衡的，你們可以使用這個資訊做其他事情。**

In [None]:
import os
print(os.environ['USERPROFILE'])

In [None]:
import matplotlib.pyplot as plt

def no_axis_show(img, title='', cmap=None):
  # imshow, and set the interpolation mode to be "nearest"。
  fig = plt.imshow(img, interpolation='nearest', cmap=cmap)
  # do not show the axes in the images.
  fig.axes.get_xaxis().set_visible(False)
  fig.axes.get_yaxis().set_visible(False)
  plt.title(title)

titles = ['horse', 'bed', 'clock', 'apple', 'cat', 'plane', 'television', 'dog', 'dolphin', 'spider']
plt.figure(figsize=(18, 18))
for i in range(10):
  plt.subplot(1, 10, i+1)
  fig = no_axis_show(plt.imread(f'C:/Users/terry/desktop/real_or_drawing/train_data/{i}/{500*i}.bmp'), title=titles[i])

In [None]:
plt.figure(figsize=(18, 18))
for i in range(10):
  plt.subplot(1, 10, i+1)
  fig = no_axis_show(plt.imread(f'C:/Users/terry/desktop//real_or_drawing/test_data/0/' + str(i).rjust(5, '0') + '.bmp'))

# Special Domain Knowledge

因為大家塗鴉的時候通常只會畫輪廓，我們可以根據這點將source data做點邊緣偵測處理，讓source data更像target data一點。

## Canny Edge Detection
算法這邊不贅述，只教大家怎麼用。若有興趣歡迎參考wiki或[這裡](https://medium.com/@pomelyu5199/canny-edge-detector-%E5%AF%A6%E4%BD%9C-opencv-f7d1a0a57d19)。

cv2.Canny使用非常方便，只需要兩個參數: low_threshold, high_threshold。

```cv2.Canny(image, low_threshold, high_threshold)```

簡單來說就是當邊緣值超過high_threshold，我們就確定它是edge。如果只有超過low_threshold，那就先判斷一下再決定是不是edge。

以下我們直接拿source data做做看。

In [None]:
import cv2
import matplotlib.pyplot as plt
titles = ['horse', 'bed', 'clock', 'apple', 'cat', 'plane', 'television', 'dog', 'dolphin', 'spider']
plt.figure(figsize=(18, 18))

original_img = plt.imread(f'C:/Users/terry/desktop/real_or_drawing/train_data/0/0.bmp')
plt.subplot(1, 5, 1)
no_axis_show(original_img, title='original')

gray_img = cv2.cvtColor(original_img, cv2.COLOR_RGB2GRAY)
plt.subplot(1, 5, 2)
no_axis_show(gray_img, title='gray scale', cmap='gray')

gray_img = cv2.cvtColor(original_img, cv2.COLOR_RGB2GRAY)
plt.subplot(1, 5, 2)
no_axis_show(gray_img, title='gray scale', cmap='gray')

canny_50100 = cv2.Canny(gray_img, 50, 100)
plt.subplot(1, 5, 3)
no_axis_show(canny_50100, title='Canny(50, 100)', cmap='gray')

canny_150200 = cv2.Canny(gray_img, 150, 200)
plt.subplot(1, 5, 4)
no_axis_show(canny_150200, title='Canny(150, 200)', cmap='gray')

canny_250300 = cv2.Canny(gray_img, 250, 300)
plt.subplot(1, 5, 5)
no_axis_show(canny_250300, title='Canny(250, 300)', cmap='gray')
  

# Data Process

在這裡我故意將data用成可以使用torchvision.ImageFolder的形式，所以只要使用該函式便可以做出一個datasets。

transform的部分請參考以下註解。
<!-- 
#### 一些細節

在一般的版本上，對灰階圖片使用RandomRotation使用```transforms.RandomRotation(15)```即可。但在colab上需要加上```fill=(0,)```才可運行。
在n98上執行需要把```fill=(0,)```拿掉才可運行。 -->


In [None]:
from tqdm import tqdm
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Function

import torch.optim as optim
import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader

source_transform = transforms.Compose([
    # Turn RGB to grayscale. (Bacause Canny do not support RGB images.)
    transforms.Grayscale(),
    # cv2 do not support skimage.Image, so we transform it to np.array, 
    # and then adopt cv2.Canny algorithm.
    transforms.Lambda(lambda x: cv2.Canny(np.array(x), 170, 300)),
    # Transform np.array back to the skimage.Image.
    transforms.ToPILImage(),
    # 50% Horizontal Flip. (For Augmentation)
    transforms.RandomHorizontalFlip(),
    # Rotate +- 15 degrees. (For Augmentation), and filled with zero 
    # if there's empty pixel after rotation.
    transforms.RandomRotation(15, fill=(0,)),
    # Transform to tensor for model inputs.
    transforms.ToTensor(),
])
target_transform = transforms.Compose([
    # Turn RGB to grayscale.
    transforms.Grayscale(),
    # Resize: size of source data is 32x32, thus we need to 
    #  enlarge the size of target data from 28x28 to 32x32。
    transforms.Resize((32, 32)),
    # 50% Horizontal Flip. (For Augmentation)
    transforms.RandomHorizontalFlip(),
    # Rotate +- 15 degrees. (For Augmentation), and filled with zero 
    # if there's empty pixel after rotation.
    transforms.RandomRotation(15, fill=(0,)),
    # Transform to tensor for model inputs.
    transforms.ToTensor(),
])

source_dataset = ImageFolder('C:/Users/terry/desktop/real_or_drawing/train_data', transform=source_transform)
target_dataset = ImageFolder('C:/Users/terry/desktop/real_or_drawing/test_data', transform=target_transform)

source_dataloader = DataLoader(source_dataset, batch_size=32, shuffle=True)
target_dataloader = DataLoader(target_dataset, batch_size=32, shuffle=True)
test_dataloader = DataLoader(target_dataset, batch_size=128, shuffle=False)

# Model

Feature Extractor: 典型的VGG-like疊法。

Label Predictor / Domain Classifier: MLP到尾。

相信作業寫到這邊大家對以下的Layer都很熟悉，因此不再贅述。

In [None]:
class FeatureExtractor(nn.Module):

    def __init__(self):
        super(FeatureExtractor, self).__init__()

        self.conv = nn.Sequential(
            nn.Conv2d(1, 64, 3, 1, 1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(2),

            nn.Conv2d(64, 128, 3, 1, 1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.MaxPool2d(2),

            nn.Conv2d(128, 256, 3, 1, 1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.MaxPool2d(2),

            nn.Conv2d(256, 256, 3, 1, 1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.MaxPool2d(2),

            nn.Conv2d(256, 512, 3, 1, 1),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.MaxPool2d(2)
        )
        
    def forward(self, x):
        x = self.conv(x).squeeze()
        return x

class LabelPredictor(nn.Module):

    def __init__(self):
        super(LabelPredictor, self).__init__()

        self.layer = nn.Sequential(
            nn.Linear(512, 512),
            nn.ReLU(),

            nn.Linear(512, 512),
            nn.ReLU(),

            nn.Linear(512, 10),
        )

    def forward(self, h):
        c = self.layer(h)
        return c

class DomainClassifier(nn.Module):

    def __init__(self):
        super(DomainClassifier, self).__init__()

        self.layer = nn.Sequential(
            nn.Linear(512, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(),

            nn.Linear(512, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(),

            nn.Linear(512, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(),

            nn.Linear(512, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(),

            nn.Linear(512, 1),
        )

    def forward(self, h):
        y = self.layer(h)
        return y

# Pre-processing

這裡我們選用Adam來當Optimizer。

In [None]:
feature_extractor = FeatureExtractor().cuda()
label_predictor = LabelPredictor().cuda()
domain_classifier = DomainClassifier().cuda()

class_criterion = nn.CrossEntropyLoss()
domain_criterion = nn.BCEWithLogitsLoss()

optimizer_F = optim.Adam(feature_extractor.parameters())
optimizer_C = optim.Adam(label_predictor.parameters())
optimizer_D = optim.Adam(domain_classifier.parameters())

# Start Training


## 如何實作DaNN?

理論上，在原始paper中是加上Gradient Reversal Layer，並將Feature Extractor / Label Predictor / Domain Classifier 一起train，但其實我們也可以交換的train Domain Classfier & Feature Extractor(就像在train GAN的Generator & Discriminator一樣)，這也是可行的。

在code實現中，我們採取後者的方式。

## 小提醒
* 原文中的lambda(控制Domain Adversarial Loss的係數)是有Adaptive的版本，如果有興趣可以參考[原文](https://arxiv.org/pdf/1505.07818.pdf)。
* 因為我們完全沒有target的label，所以結果如何，只好丟kaggle看看囉:)?

In [18]:
import math
def train_epoch(source_dataloader, target_dataloader, lamb):
    '''
      Args:
        source_dataloader: source data的dataloader
        target_dataloader: target data的dataloader
        lamb: control the balance of domain adaptatoin and classification.
    '''

    # D loss: Domain Classifier的loss
    # F loss: Feature Extrator & Label Predictor的loss
    running_D_loss, running_F_loss = 0.0, 0.0
    total_hit, total_num = 0.0, 0.0

    for i, ((source_data, source_label), (target_data, _)) in enumerate(zip(source_dataloader, target_dataloader)):

        source_data = source_data.cuda()
        source_label = source_label.cuda()
        target_data = target_data.cuda()
        
        # Mixed the source data and target data, or it'll mislead the running params
        #   of batch_norm. (runnning mean/var of soucre and target data are different.)
        mixed_data = torch.cat([source_data, target_data], dim=0)
        domain_label = torch.zeros([source_data.shape[0] + target_data.shape[0], 1]).cuda()
        # set domain label of source data to be 1.
        domain_label[:source_data.shape[0]] = 1

        # Step 1 : train domain classifier
        feature = feature_extractor(mixed_data)
        # We don't need to train feature extractor in step 1.
        # Thus we detach the feature neuron to avoid backpropgation.
        domain_logits = domain_classifier(feature.detach())
        loss = domain_criterion(domain_logits, domain_label)
        running_D_loss+= loss.item()
        loss.backward()
        optimizer_D.step()

        # Step 2 : train feature extractor and label classifier
        class_logits = label_predictor(feature[:source_data.shape[0]])
        domain_logits = domain_classifier(feature)
        # loss = cross entropy of classification - lamb * domain binary cross entropy.
        #  The reason why using subtraction is similar to generator loss in disciminator of GAN
        loss = class_criterion(class_logits, source_label) - lamb * domain_criterion(domain_logits, domain_label)
        running_F_loss+= loss.item()
        loss.backward()
        optimizer_F.step()
        optimizer_C.step()

        optimizer_D.zero_grad()
        optimizer_F.zero_grad()
        optimizer_C.zero_grad()

        total_hit += torch.sum(torch.argmax(class_logits, dim=1) == source_label).item()
        total_num += source_data.shape[0]
        print(i, end='\r')

    return running_D_loss / (i+1), running_F_loss / (i+1), total_hit / total_num

# train 200 epochs
for epoch in tqdm(range(200)):
    # You should chooose lamnda cleverly.
    train_D_loss, train_F_loss, train_acc = train_epoch(source_dataloader, target_dataloader
        , lamb=(2-math.exp(-(0.00330612557*epoch-0.693))))

    torch.save(feature_extractor.state_dict(), f'extractor_model.bin')
    torch.save(label_predictor.state_dict(), f'predictor_model.bin')

    print('epoch {:>3d}: train D loss: {:6.4f}, train F loss: {:6.4f}, acc {:6.4f}'.format(epoch, train_D_loss, train_F_loss, train_acc))


  0%|                                                                                          | 0/200 [00:00<?, ?it/s]

154

  0%|▍                                                                                 | 1/200 [00:07<24:22,  7.35s/it]

epoch   0: train D loss: 0.6382, train F loss: 0.0191, acc 0.9934
155

  1%|▊                                                                                 | 2/200 [00:14<24:49,  7.52s/it]

epoch   1: train D loss: 0.6212, train F loss: 0.0181, acc 0.9932
151

  2%|█▏                                                                                | 3/200 [00:22<24:45,  7.54s/it]

epoch   2: train D loss: 0.5898, train F loss: 0.0097, acc 0.9948
151

  2%|█▋                                                                                | 4/200 [00:30<24:43,  7.57s/it]

epoch   3: train D loss: 0.5857, train F loss: 0.0020, acc 0.9960
155

  2%|██                                                                                | 5/200 [00:37<24:45,  7.62s/it]

epoch   4: train D loss: 0.5764, train F loss: 0.0077, acc 0.9942
155

  3%|██▍                                                                               | 6/200 [00:45<24:35,  7.61s/it]

epoch   5: train D loss: 0.5825, train F loss: -0.0033, acc 0.9946
154

  4%|██▊                                                                               | 7/200 [00:53<24:28,  7.61s/it]

epoch   6: train D loss: 0.5748, train F loss: -0.0066, acc 0.9950
154

  4%|███▎                                                                              | 8/200 [01:00<24:21,  7.61s/it]

epoch   7: train D loss: 0.5775, train F loss: -0.0125, acc 0.9950
151

  4%|███▋                                                                              | 9/200 [01:08<24:15,  7.62s/it]

epoch   8: train D loss: 0.6305, train F loss: 0.0824, acc 0.9788
154

  5%|████                                                                             | 10/200 [01:16<24:17,  7.67s/it]

epoch   9: train D loss: 0.6142, train F loss: -0.0143, acc 0.9926
151

  6%|████▍                                                                            | 11/200 [01:23<24:06,  7.65s/it]

epoch  10: train D loss: 0.5896, train F loss: -0.0252, acc 0.9958
154

  6%|████▊                                                                            | 12/200 [01:31<23:56,  7.64s/it]

epoch  11: train D loss: 0.6006, train F loss: -0.0300, acc 0.9958
151

  6%|█████▎                                                                           | 13/200 [01:39<23:52,  7.66s/it]

epoch  12: train D loss: 0.6132, train F loss: -0.0345, acc 0.9954
152

  7%|█████▋                                                                           | 14/200 [01:46<23:51,  7.70s/it]

epoch  13: train D loss: 0.6120, train F loss: -0.0364, acc 0.9946
154

  8%|██████                                                                           | 15/200 [01:54<23:45,  7.71s/it]

epoch  14: train D loss: 0.6197, train F loss: -0.0347, acc 0.9934
155

  8%|██████▍                                                                          | 16/200 [02:02<23:37,  7.71s/it]

epoch  15: train D loss: 0.6183, train F loss: -0.0394, acc 0.9932
152

  8%|██████▉                                                                          | 17/200 [02:10<23:33,  7.72s/it]

epoch  16: train D loss: 0.6227, train F loss: -0.0398, acc 0.9930
154

  9%|███████▎                                                                         | 18/200 [02:17<23:23,  7.71s/it]

epoch  17: train D loss: 0.6331, train F loss: -0.0478, acc 0.9928
152

 10%|███████▋                                                                         | 19/200 [02:25<23:17,  7.72s/it]

epoch  18: train D loss: 0.6264, train F loss: -0.0563, acc 0.9950
152

 10%|████████                                                                         | 20/200 [02:33<23:19,  7.77s/it]

epoch  19: train D loss: 0.6243, train F loss: -0.0603, acc 0.9952
152

 10%|████████▌                                                                        | 21/200 [02:41<23:18,  7.81s/it]

epoch  20: train D loss: 0.6367, train F loss: -0.0705, acc 0.9968
155

 11%|████████▉                                                                        | 22/200 [02:49<23:21,  7.87s/it]

epoch  21: train D loss: 0.6332, train F loss: -0.0723, acc 0.9960
153

 12%|█████████▎                                                                       | 23/200 [02:57<23:14,  7.88s/it]

epoch  22: train D loss: 0.6346, train F loss: -0.0679, acc 0.9934
155

 12%|█████████▋                                                                       | 24/200 [03:05<23:11,  7.91s/it]

epoch  23: train D loss: 0.6334, train F loss: -0.0802, acc 0.9958
153

 12%|██████████▏                                                                      | 25/200 [03:12<22:39,  7.77s/it]

epoch  24: train D loss: 0.6388, train F loss: -0.0699, acc 0.9920
154

 13%|██████████▌                                                                      | 26/200 [03:20<22:27,  7.75s/it]

epoch  25: train D loss: 0.6357, train F loss: -0.0747, acc 0.9928
152

 14%|██████████▉                                                                      | 27/200 [03:28<22:32,  7.82s/it]

epoch  26: train D loss: 0.6362, train F loss: -0.0886, acc 0.9946
152

 14%|███████████▎                                                                     | 28/200 [03:36<22:29,  7.84s/it]

epoch  27: train D loss: 0.6374, train F loss: -0.0907, acc 0.9948
154

 14%|███████████▋                                                                     | 29/200 [03:44<22:25,  7.87s/it]

epoch  28: train D loss: 0.6445, train F loss: -0.0943, acc 0.9938
154

 15%|████████████▏                                                                    | 30/200 [03:52<22:23,  7.90s/it]

epoch  29: train D loss: 0.6427, train F loss: -0.0937, acc 0.9932
154

 16%|████████████▌                                                                    | 31/200 [04:00<22:20,  7.93s/it]

epoch  30: train D loss: 0.6472, train F loss: -0.1056, acc 0.9954
153

 16%|████████████▉                                                                    | 32/200 [04:08<22:16,  7.95s/it]

epoch  31: train D loss: 0.6486, train F loss: -0.1109, acc 0.9962
155

 16%|█████████████▎                                                                   | 33/200 [04:15<22:05,  7.94s/it]

epoch  32: train D loss: 0.6451, train F loss: -0.1134, acc 0.9950
154

 17%|█████████████▊                                                                   | 34/200 [04:23<21:56,  7.93s/it]

epoch  33: train D loss: 0.6470, train F loss: -0.1177, acc 0.9948
154

 18%|██████████████▏                                                                  | 35/200 [04:31<21:53,  7.96s/it]

epoch  34: train D loss: 0.6581, train F loss: -0.0923, acc 0.9904
156

 18%|██████████████▌                                                                  | 36/200 [04:40<21:51,  8.00s/it]

epoch  35: train D loss: 0.6563, train F loss: -0.1147, acc 0.9908
153

 18%|██████████████▉                                                                  | 37/200 [04:48<21:45,  8.01s/it]

epoch  36: train D loss: 0.6462, train F loss: -0.1263, acc 0.9936
155

 19%|███████████████▍                                                                 | 38/200 [04:56<21:40,  8.03s/it]

epoch  37: train D loss: 0.6446, train F loss: -0.1255, acc 0.9926
153

 20%|███████████████▊                                                                 | 39/200 [05:04<21:32,  8.03s/it]

epoch  38: train D loss: 0.6526, train F loss: -0.1356, acc 0.9936
155

 20%|████████████████▏                                                                | 40/200 [05:12<21:21,  8.01s/it]

epoch  39: train D loss: 0.6438, train F loss: -0.1389, acc 0.9946
155

 20%|████████████████▌                                                                | 41/200 [05:20<21:11,  8.00s/it]

epoch  40: train D loss: 0.6599, train F loss: -0.1331, acc 0.9888
152

 21%|█████████████████                                                                | 42/200 [05:27<20:57,  7.96s/it]

epoch  41: train D loss: 0.6528, train F loss: -0.1475, acc 0.9942
154

 22%|█████████████████▍                                                               | 43/200 [05:35<20:44,  7.93s/it]

epoch  42: train D loss: 0.6472, train F loss: -0.1564, acc 0.9962
154

 22%|█████████████████▊                                                               | 44/200 [05:43<20:37,  7.93s/it]

epoch  43: train D loss: 0.6484, train F loss: -0.1583, acc 0.9964
156

 22%|██████████████████▏                                                              | 45/200 [05:51<20:32,  7.95s/it]

epoch  44: train D loss: 0.6521, train F loss: -0.1651, acc 0.9964
152

 23%|██████████████████▋                                                              | 46/200 [05:59<20:24,  7.95s/it]

epoch  45: train D loss: 0.6588, train F loss: -0.1531, acc 0.9910
153

 24%|███████████████████                                                              | 47/200 [06:07<20:25,  8.01s/it]

epoch  46: train D loss: 0.6538, train F loss: -0.1566, acc 0.9934
153

 24%|███████████████████▍                                                             | 48/200 [06:16<20:24,  8.06s/it]

epoch  47: train D loss: 0.6562, train F loss: -0.0795, acc 0.9678
153

 24%|███████████████████▊                                                             | 49/200 [06:23<19:55,  7.92s/it]

epoch  48: train D loss: 0.6515, train F loss: -0.1616, acc 0.9922
154

 25%|████████████████████▎                                                            | 50/200 [06:30<19:21,  7.74s/it]

epoch  49: train D loss: 0.6489, train F loss: -0.1728, acc 0.9936
155

 26%|████████████████████▋                                                            | 51/200 [06:38<19:06,  7.69s/it]

epoch  50: train D loss: 0.6461, train F loss: -0.1801, acc 0.9950
154

 26%|█████████████████████                                                            | 52/200 [06:46<19:07,  7.76s/it]

epoch  51: train D loss: 0.6499, train F loss: -0.1841, acc 0.9938
153

 26%|█████████████████████▍                                                           | 53/200 [06:54<19:01,  7.76s/it]

epoch  52: train D loss: 0.6506, train F loss: -0.1871, acc 0.9942
152

 27%|█████████████████████▊                                                           | 54/200 [07:02<18:56,  7.79s/it]

epoch  53: train D loss: 0.6524, train F loss: -0.1857, acc 0.9944
154

 28%|██████████████████████▎                                                          | 55/200 [07:09<18:46,  7.77s/it]

epoch  54: train D loss: 0.6503, train F loss: -0.1913, acc 0.9936
154

 28%|██████████████████████▋                                                          | 56/200 [07:17<18:42,  7.80s/it]

epoch  55: train D loss: 0.6524, train F loss: -0.1893, acc 0.9932
154

 28%|███████████████████████                                                          | 57/200 [07:25<18:34,  7.79s/it]

epoch  56: train D loss: 0.6538, train F loss: -0.2032, acc 0.9950
154

 29%|███████████████████████▍                                                         | 58/200 [07:33<18:31,  7.83s/it]

epoch  57: train D loss: 0.6575, train F loss: -0.2038, acc 0.9932
151

 30%|███████████████████████▉                                                         | 59/200 [07:41<18:24,  7.83s/it]

epoch  58: train D loss: 0.6633, train F loss: -0.2103, acc 0.9942
152

 30%|████████████████████████▎                                                        | 60/200 [07:48<18:15,  7.82s/it]

epoch  59: train D loss: 0.6573, train F loss: -0.2109, acc 0.9936
151

 30%|████████████████████████▋                                                        | 61/200 [07:56<18:08,  7.83s/it]

epoch  60: train D loss: 0.6520, train F loss: -0.2165, acc 0.9946
154

 31%|█████████████████████████                                                        | 62/200 [08:04<18:04,  7.86s/it]

epoch  61: train D loss: 0.6597, train F loss: -0.2222, acc 0.9944
155

 32%|█████████████████████████▌                                                       | 63/200 [08:12<17:58,  7.87s/it]

epoch  62: train D loss: 0.6553, train F loss: -0.2240, acc 0.9938
152

 32%|█████████████████████████▉                                                       | 64/200 [08:20<17:50,  7.87s/it]

epoch  63: train D loss: 0.6617, train F loss: -0.2187, acc 0.9910
151

 32%|██████████████████████████▎                                                      | 65/200 [08:28<17:36,  7.83s/it]

epoch  64: train D loss: 0.6577, train F loss: -0.2240, acc 0.9924
153

 33%|██████████████████████████▋                                                      | 66/200 [08:36<17:29,  7.83s/it]

epoch  65: train D loss: 0.6584, train F loss: -0.2386, acc 0.9950
151

 34%|███████████████████████████▏                                                     | 67/200 [08:44<17:26,  7.87s/it]

epoch  66: train D loss: 0.6523, train F loss: -0.2392, acc 0.9958
152

 34%|███████████████████████████▌                                                     | 68/200 [08:51<17:12,  7.82s/it]

epoch  67: train D loss: 0.6628, train F loss: -0.2327, acc 0.9916
153

 34%|███████████████████████████▉                                                     | 69/200 [08:59<17:04,  7.82s/it]

epoch  68: train D loss: 0.6621, train F loss: -0.2462, acc 0.9938
155

 35%|████████████████████████████▎                                                    | 70/200 [09:07<16:54,  7.80s/it]

epoch  69: train D loss: 0.6550, train F loss: -0.2502, acc 0.9946
153

 36%|████████████████████████████▊                                                    | 71/200 [09:15<16:44,  7.79s/it]

epoch  70: train D loss: 0.6610, train F loss: -0.2514, acc 0.9932
153

 36%|█████████████████████████████▏                                                   | 72/200 [09:22<16:38,  7.80s/it]

epoch  71: train D loss: 0.6525, train F loss: -0.2498, acc 0.9932
154

 36%|█████████████████████████████▌                                                   | 73/200 [09:30<16:24,  7.76s/it]

epoch  72: train D loss: 0.6589, train F loss: -0.2571, acc 0.9940
155

 37%|█████████████████████████████▉                                                   | 74/200 [09:38<16:13,  7.73s/it]

epoch  73: train D loss: 0.6616, train F loss: -0.2679, acc 0.9950
153

 38%|██████████████████████████████▍                                                  | 75/200 [09:46<16:08,  7.75s/it]

epoch  74: train D loss: 0.6664, train F loss: -0.2597, acc 0.9914
155

 38%|██████████████████████████████▊                                                  | 76/200 [09:53<16:03,  7.77s/it]

epoch  75: train D loss: 0.6644, train F loss: -0.2486, acc 0.9888
154

 38%|███████████████████████████████▏                                                 | 77/200 [10:01<15:56,  7.78s/it]

epoch  76: train D loss: 0.6570, train F loss: -0.2752, acc 0.9952
156

 39%|███████████████████████████████▌                                                 | 78/200 [10:09<15:48,  7.77s/it]

epoch  77: train D loss: 0.6640, train F loss: -0.2784, acc 0.9938
155

 40%|███████████████████████████████▉                                                 | 79/200 [10:17<15:38,  7.75s/it]

epoch  78: train D loss: 0.6622, train F loss: -0.2770, acc 0.9916
155

 40%|████████████████████████████████▍                                                | 80/200 [10:24<15:26,  7.72s/it]

epoch  79: train D loss: 0.6620, train F loss: -0.2890, acc 0.9964
154

 40%|████████████████████████████████▊                                                | 81/200 [10:32<15:20,  7.73s/it]

epoch  80: train D loss: 0.6630, train F loss: -0.2932, acc 0.9958
154

 41%|█████████████████████████████████▏                                               | 82/200 [10:40<15:19,  7.79s/it]

epoch  81: train D loss: 0.6619, train F loss: -0.2948, acc 0.9952
152

 42%|█████████████████████████████████▌                                               | 83/200 [10:48<15:17,  7.84s/it]

epoch  82: train D loss: 0.6660, train F loss: -0.2789, acc 0.9900
151

 42%|██████████████████████████████████                                               | 84/200 [10:56<15:18,  7.91s/it]

epoch  83: train D loss: 0.6623, train F loss: -0.2931, acc 0.9924
152

 42%|██████████████████████████████████▍                                              | 85/200 [11:04<15:13,  7.94s/it]

epoch  84: train D loss: 0.6650, train F loss: -0.2498, acc 0.9864
154

 43%|██████████████████████████████████▊                                              | 86/200 [11:12<15:09,  7.97s/it]

epoch  85: train D loss: 0.6543, train F loss: -0.2934, acc 0.9930
155

 44%|███████████████████████████████████▏                                             | 87/200 [11:20<15:01,  7.98s/it]

epoch  86: train D loss: 0.6582, train F loss: -0.3077, acc 0.9952
153

 44%|███████████████████████████████████▋                                             | 88/200 [11:28<14:50,  7.95s/it]

epoch  87: train D loss: 0.6635, train F loss: -0.3126, acc 0.9940
152

 44%|████████████████████████████████████                                             | 89/200 [11:36<14:37,  7.91s/it]

epoch  88: train D loss: 0.6595, train F loss: -0.3063, acc 0.9926
156

 45%|████████████████████████████████████▍                                            | 90/200 [11:44<14:32,  7.93s/it]

epoch  89: train D loss: 0.6592, train F loss: -0.3178, acc 0.9950
152

 46%|████████████████████████████████████▊                                            | 91/200 [11:52<14:27,  7.96s/it]

epoch  90: train D loss: 0.6629, train F loss: -0.3210, acc 0.9942
153

 46%|█████████████████████████████████████▎                                           | 92/200 [12:00<14:21,  7.97s/it]

epoch  91: train D loss: 0.6667, train F loss: -0.3294, acc 0.9956
epoch  92: train D loss: 0.6781, train F loss: -0.2576, acc 0.9796


 46%|█████████████████████████████████████▋                                           | 93/200 [12:08<14:13,  7.97s/it]

154

 47%|██████████████████████████████████████                                           | 94/200 [12:16<14:07,  7.99s/it]

epoch  93: train D loss: 0.6711, train F loss: -0.3119, acc 0.9874
154

 48%|██████████████████████████████████████▍                                          | 95/200 [12:24<14:01,  8.02s/it]

epoch  94: train D loss: 0.6601, train F loss: -0.3351, acc 0.9952
153

 48%|██████████████████████████████████████▉                                          | 96/200 [12:32<13:54,  8.03s/it]

epoch  95: train D loss: 0.6619, train F loss: -0.3388, acc 0.9944
153

 48%|███████████████████████████████████████▎                                         | 97/200 [12:40<13:47,  8.03s/it]

epoch  96: train D loss: 0.6626, train F loss: -0.3450, acc 0.9952
154

 49%|███████████████████████████████████████▋                                         | 98/200 [12:48<13:38,  8.02s/it]

epoch  97: train D loss: 0.6667, train F loss: -0.3446, acc 0.9942
154

 50%|████████████████████████████████████████                                         | 99/200 [12:56<13:27,  8.00s/it]

epoch  98: train D loss: 0.6669, train F loss: -0.3419, acc 0.9928
152

 50%|████████████████████████████████████████                                        | 100/200 [13:04<13:15,  7.96s/it]

epoch  99: train D loss: 0.6687, train F loss: -0.3492, acc 0.9924
155

 50%|████████████████████████████████████████▍                                       | 101/200 [13:12<13:03,  7.92s/it]

epoch 100: train D loss: 0.6606, train F loss: -0.3547, acc 0.9952
152

 51%|████████████████████████████████████████▊                                       | 102/200 [13:19<12:53,  7.89s/it]

epoch 101: train D loss: 0.6653, train F loss: -0.3538, acc 0.9930
151

 52%|█████████████████████████████████████████▏                                      | 103/200 [13:27<12:45,  7.89s/it]

epoch 102: train D loss: 0.6698, train F loss: -0.3633, acc 0.9952
156

 52%|█████████████████████████████████████████▌                                      | 104/200 [13:35<12:35,  7.87s/it]

epoch 103: train D loss: 0.6717, train F loss: -0.3683, acc 0.9934
153

 52%|██████████████████████████████████████████                                      | 105/200 [13:43<12:27,  7.87s/it]

epoch 104: train D loss: 0.6669, train F loss: -0.3564, acc 0.9924
155

 53%|██████████████████████████████████████████▍                                     | 106/200 [13:51<12:19,  7.87s/it]

epoch 105: train D loss: 0.6678, train F loss: -0.3666, acc 0.9938
153

 54%|██████████████████████████████████████████▊                                     | 107/200 [13:59<12:10,  7.86s/it]

epoch 106: train D loss: 0.6709, train F loss: -0.3731, acc 0.9946
155

 54%|███████████████████████████████████████████▏                                    | 108/200 [14:07<12:04,  7.87s/it]

epoch 107: train D loss: 0.6683, train F loss: -0.3779, acc 0.9938
153

 55%|███████████████████████████████████████████▌                                    | 109/200 [14:14<11:53,  7.84s/it]

epoch 108: train D loss: 0.6720, train F loss: -0.3801, acc 0.9928
154

 55%|████████████████████████████████████████████                                    | 110/200 [14:22<11:40,  7.79s/it]

epoch 109: train D loss: 0.6654, train F loss: -0.3768, acc 0.9938
156

 56%|████████████████████████████████████████████▍                                   | 111/200 [14:30<11:32,  7.78s/it]

epoch 110: train D loss: 0.6726, train F loss: -0.3839, acc 0.9934
154

 56%|████████████████████████████████████████████▊                                   | 112/200 [14:37<11:23,  7.77s/it]

epoch 111: train D loss: 0.6664, train F loss: -0.3924, acc 0.9948
156

 56%|█████████████████████████████████████████████▏                                  | 113/200 [14:45<11:17,  7.79s/it]

epoch 112: train D loss: 0.6745, train F loss: -0.3529, acc 0.9820
154

 57%|█████████████████████████████████████████████▌                                  | 114/200 [14:53<11:10,  7.80s/it]

epoch 113: train D loss: 0.6615, train F loss: -0.3808, acc 0.9912
70

 57%|█████████████████████████████████████████████▌                                  | 114/200 [14:57<11:16,  7.87s/it]

7172




KeyboardInterrupt: 

# Inference

就跟前幾次作業一樣。這裡我使用pd來生產csv，因為看起來比較潮(?)

此外，200 epochs的Accuracy可能會不太穩定，可以多丟幾次或train久一點。

In [19]:
result = []
label_predictor.eval()
feature_extractor.eval()
for i, (test_data, _) in enumerate(test_dataloader):
    test_data = test_data.cuda()

    class_logits = label_predictor(feature_extractor(test_data))

    x = torch.argmax(class_logits, dim=1).cpu().detach().numpy()
    result.append(x)

import pandas as pd
result = np.concatenate(result)

# Generate your submission
df = pd.DataFrame({'id': np.arange(0,len(result)), 'label': result})
df.to_csv('DaNN_submission.csv',index=False)

# Training Statistics

- Number of parameters:
  - Feature Extractor: 2, 142, 336
  - Label Predictor: 530, 442
  - Domain Classifier: 1, 055, 233

- Simple
 - Training time on colab: ~ 1 hr
- Medium
 - Training time on colab: 2 ~ 4 hr
- Strong
 - Training time on colab: 5 ~ 6 hrs
- Boss
 - **Unmeasurable**

# Learning Curve (Strong Baseline)
* This method is slightly different from colab.

![Loss Curve](https://i.imgur.com/vIujQyo.png)

# Accuracy Curve (Strong Baseline)
* Note that you cannot access testing accuracy. But this plot tells you that even though the model overfits the training data, the testing accuracy is still improving, and that's why you need to train more epochs.

![Acc Curve](https://i.imgur.com/4W1otXG.png)



# Special Thanks
下面是原本台大助教提供的參考作業。

[NTU_r08942071_太神啦 / 組長: 劉正仁同學](https://drive.google.com/open?id=11uNDcz7_eMS8dMQxvnWsbrdguu9k4c-c)

[NTU_r08921a08_CAT / 組長: 廖子毅同學](https://drive.google.com/open?id=1xIkSs8HAShdcfV1E0NEnf4JDbL7POZTf)
