In [1]:
# load the autoreload extension
%load_ext autoreload
# Set extension to reload modules every time before executing code
%autoreload 2
import os
import sys
sys.path.append('..')
import time
import argparse
from tqdm import tqdm
import numpy as np
from sklearn import metrics
import warnings
warnings.filterwarnings("ignore")
import torch
import torchvision
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision.models as models
from torch.utils.data import DataLoader, Dataset
from datetime import datetime
from tensorboardX import SummaryWriter

from scripts.dataset import ImgDataset, readLabel, readImg
from scripts.models import DG_model, Discriminator, Feature_Generator_ResNet, Classifier
from scripts.utils import auc_acc
from scripts.hard_triplet_loss import HardTripletLoss

In [3]:
workspace_dir = '../../final_project/data/oulu_npu_cropped'

train_transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.RandomHorizontalFlip(),
    transforms.RandomAffine(degrees=30),
    transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1),
    transforms.ToTensor(),
    transforms.RandomErasing(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                        std=[0.229, 0.224, 0.225]),
])

test_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                        std=[0.229, 0.224, 0.225]),
])

In [3]:
binary = True
bs = 60
train_folder = os.path.join(workspace_dir,'train_rembg')
phones, sess, ids, train_y = readLabel(train_folder, order='single', binary=binary)
train_x_rembg = np.load(os.path.join(workspace_dir,'train_x_rembg.npy'))
train_x_orig = np.load(os.path.join(workspace_dir,'train_x.npy'))
train_x = np.concatenate([train_x_orig, train_x_rembg])
train_dataset = ImgDataset(train_x, np.tile(train_y,(2)), np.tile(sess,(2)), train_transform, 'single', binary)
train_dataloader = DataLoader(train_dataset, batch_size=bs, shuffle=True)

val_folder = os.path.join(workspace_dir,'val')
val_x = np.load(os.path.join(workspace_dir,'val_x.npy'))
phones, sess, ids, val_y = readLabel(os.path.join(workspace_dir,'val'), order='single', binary=binary)
val_set = ImgDataset(val_x, val_y, sess, test_transform, 'single', binary)
val_loader = DataLoader(val_set, batch_size=bs, shuffle=False)

In [4]:
model = DG_model(model = 'resnet18', num_cls = (1 if binary else 3)).cuda()
domain_classifier = Discriminator(num_cls=2).cuda()

In [5]:
criterion = {'class': (nn.BCELoss().cuda() if binary else nn.CrossEntropyLoss().cuda()),
             'triplet': HardTripletLoss(margin=0.1, hardest=False).cuda(),
             'domain': nn.CrossEntropyLoss().cuda()}
optimizer = torch.optim.SGD(model.parameters(),lr=1e-4, weight_decay=1e-5)  # optimizer 使用 Adam
optimizer_D = torch.optim.Adam(domain_classifier.parameters(),lr=1e-5, weight_decay=1e-4)  # optimizer 使用 Adam
num_epoch = 1

In [6]:
triplet_coef = 1.0
sess_coef = 0.0

In [9]:
for epoch in range(num_epoch):

    epoch_start_time = time.time()
    train_loss = 0.0
    triplet_loss = 0.0
    domain_loss_total = 0.0
    cls_loss_total = 0.0
    val_loss = 0.0
    running_D_loss = 0.0
    model.train()  # 確保 model 是在 train model (開啟 Dropout 等...)
    train_logits = []
    train_labels = []
    for data, label, sess in tqdm(train_dataloader):
        data = data.cuda()
        label = label.cuda()
        sess = sess.cuda()
        domain_label = sess
        # train domain classifier
        pred, feature = model(data.cuda())
        domain_logits = domain_classifier(feature.detach())

        loss = criterion['domain'](domain_logits, domain_label)
        running_D_loss+= loss.item()

        loss.backward()
        optimizer_D.step()
        # train main model
        pred, feature = model(data.cuda())
        class_logits = pred[:data.shape[0]]
        domain_logits = domain_classifier(feature)

        cls_loss = criterion['class'](class_logits.squeeze(), label)
        triplet = criterion["triplet"](feature, label)
        domain_loss = criterion['domain'](domain_logits, domain_label)

        loss =  cls_loss + triplet_coef*triplet + sess_coef*domain_loss
        cls_loss_total += cls_loss.item()
        triplet_loss += triplet_coef*triplet 
        domain_loss_total += sess_coef*domain_loss.item()
        loss.backward()
        optimizer.step()  # 以 optimizer 用 gradient 更新參數值

        optimizer.zero_grad()  # 用 optimizer 將 model 參數的 gradient 歸零
        optimizer_D.zero_grad()
        train_loss += loss.item()
        train_logits.append(class_logits)
        train_labels.append(label)

    train_auc, train_acc = auc_acc(torch.cat(train_logits), torch.cat(train_labels), binary)
    model.eval()
    val_logits = []
    val_labels = []
    with torch.no_grad():
        for data, label, sess in tqdm(val_loader):
            class_logits, embedded = model(data.cuda())
            batch_loss = criterion['class'](class_logits, label.cuda())
            val_logits.append(class_logits)
            val_labels.append(label)
            val_loss += batch_loss.item()
        val_auc, val_acc = auc_acc(torch.cat(val_logits), torch.cat(val_labels), binary)

        running_D_loss /= train_dataloader.__len__()
        train_loss /= train_dataloader.__len__()
        cls_loss_total /= train_dataloader.__len__()
        triplet_loss /= train_dataloader.__len__()
        domain_loss_total /= train_dataloader.__len__()
        val_loss /= val_loader.__len__()
        print('[{:03}/{:03}] {:3.2f} sec(s) Train Acc: {:.6f} Train AUC: {:.6f} Train  loss: {:3.6f} D_loss: {:3.6f}'\
            .format(epoch + 1, num_epoch, time.time()-epoch_start_time, train_acc, train_auc, train_loss, running_D_loss))
        print(' '*24+'cls  loss: {:3.6f} trip loss: {:3.6f} domain loss: {:3.6f}'\
            .format(cls_loss_total, triplet_loss, domain_loss_total ))
        print(' '*24+'Valid Acc: {:.6f} Valid AUC: {:.6f} Valid  loss: {:3.6f} '\
            .format(val_acc, val_auc, val_loss))

100%|██████████| 165/165 [00:12<00:00, 13.36it/s]

[001/001] 12.35 sec(s) Train Acc: 0.818447 Train AUC: 0.785311 Train  loss: 0.000000 D_loss: 0.000000
                        cls  loss: 0.000000 trip loss: 0.000000 domain loss: 0.000000
                        Valid Acc: 0.890202 Valid AUC: 0.935827 Valid  loss: 0.269006 



