In [1]:
# encoding=utf-8

import data_preprocess
import matplotlib.pyplot as plt
import transferlearningnetwork as net
import numpy as np
import torch
import torch.nn.functional as F
import torch.nn as nn
import torch.optim as optim
import math
import tqdm
import maximum_mean_discrepancy

BATCH_SIZE = 64
N_EPOCH = 10
lr = 0.01
DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
result = []
LAMBDA = 0.25
l2_decay = 5e-4
momentum = 0.9
GAMMA = 10 ^ 3

In [2]:
def mmd_loss(x_src, x_tar):
    return maximum_mean_discrepancy.mmd_rbf_noaccelerate(x_src, x_tar)

def train(model, source_loader, target_loader):
    
    n_batch = len(source_loader.dataset) // BATCH_SIZE
    criterion = nn.CrossEntropyLoss()
    iter_target = iter(target_loader)
    len_target_loader = len(target_loader)
    for e in range(N_EPOCH):
        LEARNING_RATE = lr / math.pow((1 + 10 * e / N_EPOCH), 0.75)
        optimizer = optim.SGD(params=model.parameters(), lr=LEARNING_RATE, momentum=0.9, weight_decay=l2_decay)
        
        model.train()
        correct, total_loss = 0, 0
        total = 0
        
        for index, (sample, target) in enumerate(source_loader):
            
            data_target, label_target = iter_target.next()
            if index % (len_target_loader-1) == 0:
                iter_target = iter(target_loader)
            
            sample, target = sample.to(DEVICE).float(), target.to(DEVICE).long()
            sample = sample.view(-1, 3, 1, 200)
            data_target, label_target = data_target.to(DEVICE).float(), label_target.to(DEVICE).long()
            data_target = data_target.view(-1, 3, 1, 200)
            
            optimizer.zero_grad()
            y_src, x_src_mmd, x_tar_mmd = model(sample, data_target)
            loss_c = criterion(y_src, target)
            loss_mmd = mmd_loss(x_src_mmd, x_tar_mmd)
            _, predicted = torch.max(y_src.data, 1)
            correct += (predicted == target).sum()
            
            loss_cls = F.nll_loss(F.log_softmax(y_src, dim=1), target)
            gamma = 2 / (1 + math.exp(-10 * (e) / N_EPOCH)) - 1
            loss = loss_cls + gamma * loss_mmd
            #loss = loss_c + LAMBDA * loss_mmd
            
            loss.backward()
            optimizer.step()
            total_loss += loss.data
            total += target.size(0)
            
            if index % 20 == 0:
                tqdm.tqdm.write('Epoch: [{}/{}], Batch: [{}/{}], loss:{:.4f}'.format(e + 1, N_EPOCH, index + 1, n_batch, loss.data))

    acc_train = float(correct) * 100.0 / (BATCH_SIZE * n_batch)
    tqdm.tqdm.write('Epoch: [{}/{}], Total loss: {:.4f}, train acc: {:.2f}%'.format(e + 1, N_EPOCH, total_loss * 1.0 / n_batch, acc_train))
    return model
    

In [3]:
def test(model, target_loader):
    total_loss_test = 0
    criterion = nn.CrossEntropyLoss()
    model.eval()
    
    with torch.no_grad():
        correct, total = 0, 0
        for index, (sample, target) in enumerate(target_loader):
            sample, target = sample.to(DEVICE).float(), target.to(DEVICE).long()
            sample = sample.view(-1, 3, 1, 200)
            output, _, _  = model(sample, sample)
            loss = criterion(output, target)
            #print('loss : ', loss)
            _, predicted = torch.max(output.data, 1)
            total_loss_test += loss.data
            total += target.size(0)
            correct += (predicted == target).sum()
            
    acc_test = float(correct) * 100 / total
    tqdm.tqdm.write('Test: loss: {:.6f}, correct: [{}/{}], test acc: {:.4f}%'.format(loss, correct, total, acc_test))

In [4]:
if __name__ == '__main__':
    torch.manual_seed(10)
    source_loader, target_loader = data_preprocess.load(batch_size=BATCH_SIZE)
    model = net.Network().to(DEVICE)
    model = train(model, source_loader, target_loader)
    dummysource_loader, test_loader = data_preprocess.load(batch_size=BATCH_SIZE)
    test(model, test_loader)

['Jogging', 'Sit_Down', 'Skip', 'Stand_Up', 'Stay', 'Walk']
Hasc_X_train.shape :  (9832, 3, 1, 200)
Hasc_Y_train.shape :  (9832,)
Hasc_Y_train values :  [0 0 0 ... 1 1 1]
Wisdm_X_train.shape :  (38319, 3, 1, 200)
Wisdm_Y_train shape :  (38319,)
Wisdm_Y_train values :  [0 0 0 ... 1 1 1]
Epoch: [1/10], Batch: [1/153], loss:0.6943
Epoch: [1/10], Batch: [21/153], loss:0.6228
Epoch: [1/10], Batch: [41/153], loss:0.5109
Epoch: [1/10], Batch: [61/153], loss:0.3008
Epoch: [1/10], Batch: [81/153], loss:0.2106
Epoch: [1/10], Batch: [101/153], loss:0.2374
Epoch: [1/10], Batch: [121/153], loss:0.0740
Epoch: [1/10], Batch: [141/153], loss:0.0484
Epoch: [2/10], Batch: [1/153], loss:1.2106
Epoch: [2/10], Batch: [21/153], loss:0.7577
Epoch: [2/10], Batch: [41/153], loss:0.7529
Epoch: [2/10], Batch: [61/153], loss:0.7154
Epoch: [2/10], Batch: [81/153], loss:0.7305
Epoch: [2/10], Batch: [101/153], loss:0.7176
Epoch: [2/10], Batch: [121/153], loss:0.6555
Epoch: [2/10], Batch: [141/153], loss:0.6828
Epoch