In [1]:
import sys, os, inspect
current_dir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))
parent_dir = os.path.dirname(current_dir)
sys.path.insert(0, parent_dir)

In [21]:
import numpy as np
import random
import copy
import math
from tqdm import tqdm
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import ReduceLROnPlateau
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
from martins.complex_transformer import ComplexTransformer
from GAN import Generator, Discriminator
from data_utils import *
import argparse
import logging
import logging.handlers
import pickle

# DataLoader

In [25]:
class JoinDataset(Dataset):
    def __init__(self, root_dir, file_name):
        f = open(os.path.join(root_dir, file_name), "rb")
        dataset = pickle.load(f)
        self.source_x = dataset['tr_data']
        self.source_y = dataset['tr_lbl']
        self.target_x = dataset['te_data']
        self.target_y = dataset['te_lbl']

    def __len__(self):
        return self.target_y.shape[0]

    def __getitem__(self, idx):
        return (source_x[index], source_y[index]), (target_x[index], target_y[index])

# Parser

In [4]:
# Parameters
parser = argparse.ArgumentParser(description='JDA Time series adaptation')
parser.add_argument("--data_path", type=str, default="/projects/rsalakhugroup/complex/domain_adaptation", help="dataset path")
parser.add_argument("--task", type=str, help='3A or 3E')
parser.add_argument('--batch_size', type=int, default=256, help='batch size')
parser.add_argument('--epochs', type=int, default=50, help='number of epochs')
parser.add_argument('--lr_gan', type=float, default=1e-4, help='learning rate for adversarial')
parser.add_argument('--lr_clf', type=float, default=1e-4, help='learning rate for classification')
parser.add_argument('--gap', type=int, default=4, help='gap: Generator train GAP times, discriminator train once')
parser.add_argument('--lbl_percentage', type=float, default=0.2, help='percentage of which target data has label')
parser.add_argument('--num_per_class', type=int, default=-1, help='number of sample per class when training local discriminator')
parser.add_argument('--seed', type=int, help='manual seed')
parser.add_argument('--classifier', type=str, help='cnet model file')
parser.add_argument('--save_path', type=str, default='../train_related/JDA_GAN', help='where to store data')
parser.add_argument('--model_save_period', type=int, default=2, help='period in which the model is saved')

args = parser.parse_args()
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

# seed
if args.seed is None:
    args.seed = random.randint(1, 10000)
torch.manual_seed(args.seed)
torch.cuda.manual_seed(args.seed)
np.random.seed(args.seed)
cudnn.deterministic = True
torch.backends.cudnn.deterministic = True


args.task = '3Av2' if args.task == '3A' else '3E'
d_out = 50 if args.task == "3Av2" else 65
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

if args.num_per_class == -1:
    args.num_per_class = math.ceil(args.batch_size / d_out)
    
model_sub_folder = '/task_%s_gap_%s_lblPer_%i_numPerClass_%i'%(args.task, args.gap, args.lbl_percentage, args.num_per_class)
    


usage: ipykernel_launcher.py [-h] [--data_path DATA_PATH] [--task TASK]
                             [--batch_size BATCH_SIZE] [--epochs EPOCHS]
                             [--lr_gan LR_GAN] [--lr_clf LR_CLF] [--gap GAP]
                             [--lbl_percentage LBL_PERCENTAGE]
                             [--num_per_class NUM_PER_CLASS] [--seed SEED]
                             [--classifier CLASSIFIER] [--save_path SAVE_PATH]
                             [--model_save_period MODEL_SAVE_PERIOD]
ipykernel_launcher.py: error: unrecognized arguments: -f /Users/stevenliu/Library/Jupyter/runtime/kernel-59462722-2028-457e-8e13-a96282293391.json


SystemExit: 2

  warn("To exit: use 'exit', 'quit', or Ctrl-D.", stacklevel=1)


In [6]:
# local only
class local_args:
    def __init__(self, **entries):
        self.__dict__.update(entries)
        
args = local_args(**{
    'data_path': '/Users/stevenliu/time-series-adaption/time-series-domain-adaptation/data_unzip',
    'task': '3Av2',
    'num_class': 50,
    'batch_size': 10,
    'num_per_class': -1,
    'gap': 5,
    'lbl_percentage':0.2,
    'lr_gan': 1e-4,
    'lr_FNN': 1e-4,
    'lr_encocer': 1e-4
})

In [6]:
args.task = '3Av2' if args.task == '3A' else '3E'
d_out = 50 if args.task == "3Av2" else 65
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

if args.num_per_class == -1:
    args.num_per_class = math.ceil(args.batch_size / d_out)
    
model_sub_folder = '/task_%s_gap_%s_lblPer_%i_numPerClass_%i'%(args.task, args.gap, args.lbl_percentage, args.num_per_class)
    

# Logger

In [7]:
logger = logging.getLogger()
logger.setLevel(logging.INFO)

file_log_handler = logging.FileHandler('/Users/stevenliu/Downloads/aws/logfile.log')
logger.addHandler(file_log_handler)

stdout_log_handler = logging.StreamHandler(sys.stdout)
logger.addHandler(stdout_log_handler)


# Data loading

In [8]:
target_dict, (target_unlabel_x, target_unlabel_y),(a,b), target_len  = get_target_dict(args.data_path+'/processed_file_%s.pkl'%args.task, d_out, args.lbl_percentage)
source_dict, source_len = get_source_dict(args.data_path+'/processed_file_%s.pkl'%args.task, d_out, data_len=target_len)


In [26]:
join_dataset = JoinDataset(args.data_path, 'processed_file_%s.pkl'%args.task)
join_dataloader = DataLoader(join_dataset, batch_size=args.batch_size, shuffle=True, pin_memory=True, num_workers=4)

# GAN Initialize

In [9]:
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        m.weight.data.normal_(0.0, 0.02)
    elif classname.find('LayerNorm') != -1:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)


In [10]:
class FNN(nn.Module):
    def __init__(self, d_in, d_h, d_out, dp):
        super(FNN, self).__init__()
        self.fc1 = nn.Linear(d_in, d_h)
        self.fc2 = nn.Linear(d_h, d_out)
        self.dp = nn.Dropout(dp)

    def forward(self, x):
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dp(x)
        x = self.fc2(x)

        return x

# model creation

In [12]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

real_label = 0.99 # target domain
fake_label = 0.01 # source domain


seq_len = 10
feature_dim = 160
encoder = ComplexTransformer(layers=1,
                             time_step=seq_len,
                             input_dim=feature_dim,
                             hidden_size=64,
                             output_dim=64,
                             num_heads=8,
                             out_dropout=0.5,
                             leaky_slope=0.2)
encoder.to(device)

CNet = FNN(d_in=feature_dim * 2 * seq_len, d_h=500, d_out=d_out, dp=0.7)
CNet.to(device)

DNet_global = Discriminator(feature_dim=64*20, d_out=d_out).to(device)
DNet_local = Discriminator(feature_dim=64*20, d_out=d_out).to(device)
GNet = Generator(feature_dim=64*20).to(device)
DNet_global.apply(weights_init)
DNet_local.apply(weights_init)
GNet.apply(weights_init)
encoder.apply(weights_init)
FNN.apply(weights_init)
optimizerD_global = torch.optim.Adam(DNet_global.parameters(), lr=args.lr_gan)
optimizerD_local = torch.optim.Adam(DNet_local.parameters(), lr=args.lr_gan)
optimizerG = torch.optim.Adam(GNet.parameters(), lr=args.lr_gan)
optimizerFNN = torch.optim.Adam(FNN.parameters(), lr=args.lr_FNN)
optimizerEncoder = torch.optim.Adam(encoder.parameters(), lr=args.lr_encoder)

AssertionError: embed_dim must be divisible by num_heads

# Train

In [None]:
target_acc_ = []
error_D_global = []
error_G_global = []
error_D_local = []
error_G_local = []

for epoch in range(args.epochs):
    # update classifier
    # TODO: complete this
    for batch_id, (x_batch, y_batch) in tqdm(enumerate()):
    
    # Assign Pesudo Label
    correct_target = 0.0
    target_pesudo_y = []
    for batch in range(math.ceil(target_unlabel_x.shape[0]/args.batch_size)):
        target_unlabel_x_batch = torch.Tensor(target_unlabel_x[batch*args.batch_size:(batch+1)*args.batch_size], device=device).to(device).float()
        target_unlabel_y_batch = torch.Tensor(target_unlabel_y[batch*args.batch_size:(batch+1)*args.batch_size], device=device)
        pred = classifier_inference(encoder, CNet, target_unlabel_x_batch, target_mean, target_std)
        correct_target += (pred.argmax(-1) == target_unlabel_y_batch.argmax(-1)).sum().item()
        target_pesudo_y.extend(pred.argmax(-1).numpy())
        
    target_pesudo_y = np.array(target_pesudo_y)
    pesudo_dict = get_class_data_dict(target_unlabel_x, target_pesudo_y, num_class)
        
    for batch in range(math.ceil(target_label_x.shape[0]/args.batch_size)):
        target_label_x_batch = torch.Tensor(target_label_x[batch*args.batch_size:(batch+1)*args.batch_size], device=device).to(device).float()
        target_label_y_batch = torch.Tensor(target_label_y[batch*args.batch_size:(batch+1)*args.batch_size], device=device)
        pred = classifier_inference(encoder, CNet, target_label_x_batch, target_mean, target_std)
        correct_target += (pred.argmax(-1) == target_label_y_batch.argmax(-1)).sum().item()

    logger.info('Epoch: %i, assigned pesudo label with accuracy %f'%(epoch+1, correct_target/(target_unlabel_x.size(0)+target_label_x.size(0)))
    target_acc_.append(correct_target/(target_unlabel_x.size(0)+target_label_x.size(0)))
                
    # Update GAN
    # Update global Discriminator
    total_error_D_global = 0
    total_error_G = 0
    for batch_id, ((source_x, source_y), (target_x, target_y)) in tqdm(enumerate(joint_dataloader)):
        batch_size = target_x.shape[0]
        target_x = target_x.reshape(batch_size, -1)
        source_x = source_x.reshape(batch_size, -1)

        """Update D Net"""
        optimizerD_global.zero_grad()
        source_data = source_x.to(device).float()
        source_embedding = encoder(source_data)
        target_data = source_y.to(device).float()
        target_embedding = encoder(target_data)
        fake_source_embedding = GNet(target_embedding).detach()
        
        # adversarial loss
        loss_D_global = DNet_global(fake_source_embedding).mean() - DNet_global(source_embedding).mean()
        
        total_error_D_global += loss_D_global.item()
        
        loss_D_global.backward()
        optimizerD_global.step()
        
        # Clip weights of discriminator
        for p in DNet_global.parameters():
            p.data.clamp_(-args.clip_value, args.clip_value)
        
        if batch_id % args.n_critic == 0:
            """Update G Network"""
            optimizerG.zero_grad()
            fake_source_embedding = GNet(target_embedding)
            
            # adversarial loss
            loss_G = -DNet_global(fake_source_embedding).mean()
            
            total_error_G += loss_G.item()
            
            loss_G.step()
            optimizerG.step()
            
    logger.info('Epoch: %i, Global Discrimator Updates: Loss D_global: %f, Loss G: %f'%(epoch+1, total_error_D_global, total_error_G))
    error_D_global.append(total_error_D_global)
    error_G_global.append(total_error_G)
                
    # Update local Discriminator
    total_error_D_local = 0
    total_error_G = 0
    for batch_id in tqdm(range(math.ceil(target_len/args.batch_size))):
        target_x, target_y, target_weight = get_batch_target_data_on_class(target_dict, pesudo_dict, target_unlabel_x, args.num_per_class)
        source_x, source_y = get_batch_source_data_on_class(source_dict, args.num_per_class)
        
        source_x = torch.Tensor(source_x, device=device)
        target_x = torch.Tensor(target_x, device=device)
        source_y = torch.LongTensor(target_y, device=device)
        target_y = torch.LongTensor(target_y, device=device)
        source_mask = torch.zeros(source_x.size(0), d_out).scatter_(1, source_y.unsqueeze(-1), 1)
        target_mask = torch.zeros(target_x.size(0), d_out).scatter_(1, target_y.unsqueeze(-1), 1)
        target_weight = torch.zeros(target_x.size(0), d_out).scatter_(1, target_y.unsqueeze(-1), target_weight.unsqueeze(-1))
        
        torch.Tensor(target_weight, device=device)
        batch_size = target_x.shape[0]
        
        source_x = source_x.reshape(source_x.size(0), -1)
        target_x = target_x.reshape(target_x.size(0), -1
    
        """Update D Net"""
        optimizerD_local.zero_grad()
        source_data = source_x.to(device).float()
        source_embedding = encoder(source_data)
        target_data = source_y.to(device).float()
        target_embedding = encoder(target_data)
        fake_source_embedding = GNet(target_embedding).detach()
        
        # adversarial loss
        source_DNet_local = DNet_local(source_embedding, source_mask)
        target_DNet_local = DNet_local(fake_source_embedding, target_mask)
        
        source_weight_count = source_mask.sum(dim=0)
        target_weight_count = target_weight.sum(dim=0)
        
        source_DNet_local_mean = source_DNet_local.sum(dim=0) / source_weight_count
        target_DNet_local_mean = (target_DNet_local * target_weight).sum(dim=0) / target_weight_count        
        
        loss_D_local = (target_DNet_local_mean - source_DNet_local_mean).sum()
        
        total_error_D_local += loss_D_local.item()
        
        loss_D_local.backward()
        optimizerD_local.step()
        
        # Clip weights of discriminator
        for p in DNet_local.parameters():
            p.data.clamp_(-args.clip_value, args.clip_value)
        
        if batch_id % args.n_critic == 0:
            """Update G Network"""
            optimizerG.zero_grad()
            fake_source_embedding = GNet(target_embedding)
            
            # adversarial loss
            loss_G = -DNet_local(fake_source_embedding).mean()
            
            total_error_G += loss_G.item()
            
            loss_G.step()
            optimizerG.step()
            
    logger.info('Epoch: %i, Local Discrimator Updates: Loss D_global: %f, Loss G: %f'%(epoch+1, total_error_D_local, total_error_G))
    error_D_local.append(total_error_D_local)
    error_G_global.append(total_error_G)
