In [8]:
import sys, os, inspect
current_dir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))
parent_dir = os.path.dirname(current_dir)
sys.path.insert(0, parent_dir)

In [15]:
import numpy as np
import random
import copy
import math
from tqdm import tqdm
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import ReduceLROnPlateau
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
from martins.complex_transformer import ComplexTransformer
from GAN import Generator, Discriminator
from data_utils import *
import argparse
import logging
import logging.handlers

# DataLoader

In [44]:
def JoinDataset(Dataset):
    def __init__(self, source_x, source_y, target_x, target_y):
        super(JoinDataset, self).__init__()
        self.source_x = source_x
        self.source_y = source_y
        self.target_x = target_x
        self.target_y = target_y
    
    def __len__(self):
        return self.target_y.shape[0]
    
    def __getitem__(self, index):
        return (source_x[index], source_y[index]), (target_x[index], target_y[index])

# Parser

In [10]:
# Parameters
parser = argparse.ArgumentParser(description='JDA Time series adaptation')
parser.add_argument("--data_path", type=str, default="/projects/rsalakhugroup/complex/domain_adaptation", help="dataset path")
parser.add_argument("--task", type=str, help='3A or 3E')
parser.add_argument('--batch_size', type=int, default=256, help='batch size')
parser.add_argument('--epochs', type=int, default=50, help='number of epochs')
parser.add_argument('--lr_gan', type=float, default=1e-4, help='learning rate for adversarial')
parser.add_argument('--lr_clf', type=float, default=1e-4, help='learning rate for classification')
parser.add_argument('--gap', type=int, default=4, help='gap: Generator train GAP times, discriminator train once')
parser.add_argument('--lbl_percentage', type=float, default=0.2, help='percentage of which target data has label')
parser.add_argument('--num_per_class', type=int, default=-1, help='number of sample per class when training local discriminator')
parser.add_argument('--seed', type=int, help='manual seed')
parser.add_argument('--classifier', type=str, help='cnet model file')
parser.add_argument('--save_path', type=str, default='../train_related/JDA_GAN', help='where to store data')
parser.add_argument('--model_save_period', type=int, default=2, help='period in which the model is saved')

args = parser.parse_args()
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

# seed
if args.seed is None:
    args.seed = random.randint(1, 10000)
torch.manual_seed(args.seed)
torch.cuda.manual_seed(args.seed)
np.random.seed(args.seed)
cudnn.deterministic = True
torch.backends.cudnn.deterministic = True


args.task = '3Av2' if args.task == '3A' else '3E'
d_out = 50 if args.task == "3Av2" else 65
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

if args.num_per_class == -1:
    args.num_per_class = math.ceil(args.batch_size / d_out)
    
model_sub_folder = '/task_%s_gap_%s_lblPer_%i_numPerClass_%i'%(args.task, args.gap, args.lbl_percentage, args.num_per_class)
    


usage: ipykernel_launcher.py [-h] [--data_path DATA_PATH] [--task TASK]
                             [--batch_size BATCH_SIZE] [--epochs EPOCHS]
                             [--lr_gan LR_GAN] [--lr_clf LR_CLF] [--gap GAP]
                             [--lbl_percentage LBL_PERCENTAGE]
                             [--num_per_class NUM_PER_CLASS] [--seed SEED]
                             [--classifier CLASSIFIER] [--save_path SAVE_PATH]
                             [--model_save_period MODEL_SAVE_PERIOD]
ipykernel_launcher.py: error: unrecognized arguments: -f /Users/stevenliu/Library/Jupyter/runtime/kernel-a776d127-4b7f-40e6-b44a-7f61be33204a.json


SystemExit: 2

# Logger

In [41]:
logger = logging.getLogger()
logger.setLevel(logging.INFO)

file_log_handler = logging.FileHandler('/Users/stevenliu/Downloads/aws/logfile.log')
logger.addHandler(file_log_handler)

stdout_log_handler = logging.StreamHandler(sys.stdout)
logger.addHandler(stdout_log_handler)


In [None]:
target_dict, (target_unlabel_x, target_unlabel_y),(a,b), target_len  = get_target_dict(args.data_path+'processed_file_%s.pkl'%args.task, d_out, args.lbl_percentage)
source_dict, source_len = get_source_dict(args.data_path+'/processed_file_%s.pkl'%args.task, d_out, data_len=target_len)


# GAN Initialize

In [13]:
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        m.weight.data.normal_(0.0, 0.02)
    elif classname.find('LayerNorm') != -1:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)


In [None]:
real_label = 0.99 # target domain
fake_label = 0.01 # source domain

encoder = ComplexTransformer


feature_dim_joint = 2 * feature_dim
DNet_global = Discriminator(feature_dim=feature_dim_joint, d_out=d_out).to(device)
DNet_local = Discriminator(feature_dim=feature_dim_joint, d_out=d_out).to(device)
GNet = Generator(feature_dim=feature_dim_joint).to(device)
DNet_global.apply(weights_init)
DNet_local.apply(weights_init)
GNet.apply(weights_init)
optimizerD_global = torch.optim.Adam(DNet_global.parameters(), lr=args.lr_gan)
optimizerD_local = torch.optim.Adam(DNet_local.parameters(), lr=args.lr_gan)
optimizerG = torch.optim.Adam(GNet.parameters(), lr=args.lr_gan)

# Train

In [45]:
join_dataset = JoinDataset(source_x, source_y, target_x, target_y)
join_dataloader = DataLoader(join_dataset, batch_size=args.batch_size, shuffle=True, pin_memory=True, num_worker=4)

NameError: name 'source_x' is not defined

In [None]:
for epoch in range(args.epochs):
    # update classifier
    # TODO: complete this
    for batch_id, (x_batch, y_batch) in tqdm(enumerate())
    
    # Assign Pesudo Label
    correct_target = 0.0
    target_pesudo_y = []
    for batch in range(math.ceil(target_unlabel_x.shape[0]/args.batch_size)):
        target_unlabel_x_batch = torch.Tensor(target_unlabel_x[batch*args.batch_size:(batch+1)*args.batch_size], device=device).to(device).float()
        target_unlabel_y_batch = torch.Tensor(target_unlabel_y[batch*args.batch_size:(batch+1)*args.batch_size], device=device)
        pred = classifier_inference(encoder, CNet, target_unlabel_x_batch, target_mean, target_std)
        correct_target += (pred.argmax(-1) == target_unlabel_y_batch.argmax(-1)).sum().item()
        target_pesudo_y.extend(pred.argmax(-1).numpy())

    target_pesudo_y = np.array(target_pesudo_y)
    pesudo_dict = get_class_data_dict(target_unlabel_x, target_pesudo_y, num_class)
    logger.info('Epoch: %i, assigned pesudo label with accuracy %f'%(epoch+1, correct_target/target_unlabel_x.shape[0]))
    
    # Update GAN
    # Update global Discriminator
    total_error_D_global = 0
    total_error_G = 0
    # TODO: write/check a dataloader to return source and target seperately
    for batch_id, ((source_x, source_y), (target_x, target_y)) in tqdm(enumerate(joint_dataloader)):
        batch_size = target_x.shape[0]
        # TODO: why???
#         target_x = target_x.reshape(batch_size, seq_len, feature_dim_joint)
#         source_x = source_x.reshape(batch_size, seq_len, feature_dim_joint)

#         # Data Normalization
#         target_x = (target_x - target_mean) / target_std
#         source_x = (source_x - source_mean) / source_std

        """Update D Net"""
        optimizerD_global.zero_grad()
        source_data = source_x.to(device).float()
        source_embedding = encoder(source_data)
        target_data = source_y.to(device).float()
        target_embedding = encoder(target_data)
        fake_source_embedding = GNet(target_embedding).detach()
        
        # adversarial loss
        loss_D_global = DNet_global(fake_source_embedding).mean() - DNet_global(source_embedding).mean()
        
        total_error_D_global += loss_D_global.item()
        
        loss_D_global.backward()
        optimizerD_global.step()
        
        # Clip weights of discriminator
        for p in DNet_global.parameters():
            p.data.clamp_(-args.clip_value, args.clip_value)
        
        if batch_id % args.n_critic == 0:
            """Update G Network"""
            optimizerG.zero_grad()
            fake_source_embedding = GNet(target_embedding)
            
            # adversarial loss
            loss_G = -DNet_global(fake_source_embedding).mean()
            
            total_error_G += loss_G.item()
            
            loss_G.step()
            optimizerG.step()
            
    logger.info('Epoch: %i, Global Discrimator Updates: Loss D_global: %f, Loss G: %f'%(epoch+1, total_error_D_global, total_error_G))
    
    # Update local Discriminator
    for batch_id in tqdm(range(math.ceil(target_len/args.batch_size))):
        target_x, target_y, target_weight = get_batch_target_data_on_class(target_dict, pesudo_dict, target_unlabel_x, args.num_per_class)
        source_x, source_y = get_batch_source_data_on_class(source_dict, args.num_per_class)
        
        source_x = torch.Tensor(source_x, device=device)
        target_x = torch.Tensor(target_x, device=device)
        source_y = torch.LongTensor(target_y, device=device)
        target_y = torch.LongTensor(target_y, device=device)
        source_mask = torch.zeros(source_x.size(0), d_out).scatter_(1, source_y, 1)
        target_mask = torch.zeros(target_x.size(0), d_out).scatter_(1, target_y, 1)
        target_weight = torch.Tensor(target_weight, device=device)
        batch_size = target_x.shape[0]
        
        #TODO: why???
        source_x = source_x.reshape(source_x.size(0), seq_len, feature_dim_joint)
        target_x = target_x.reshape(target_x.size(0), seq_len, feature_dim_joint)
        
#         # Data Normalization
#         target_x = (target_x - target_mean) / target_std
#         source_x = (source_x - source_mean) / source_std
        
    
    
        """Update D Net"""
        optimizerD_local.zero_grad()
        source_data = source_x.to(device).float()
        source_embedding = encoder(source_data)
        target_data = source_y.to(device).float()
        target_embedding = encoder(target_data)
        fake_source_embedding = GNet(target_embedding).detach()
        
        # adversarial loss
        source_DNet_local = DNet_local(source_embedding, source_mask)
        target_DNet_local = DNet_local(fake_source_embedding, target_mask)
        
        source_mask_count = source_mask.sum(dim=0)
        target_mask_count = target_mask.sum(dim=0)
        
        source_DNet_local_mean = source_DNet_local.sum(dim=0) / source_mask_count
        target_DNet_local_mean = target_DNet_local.sum(dim=0) / target_mask_count        
        
        loss_D_local = target_DNet_local_mean - source_DNet_local_mean
        
        total_error_D_local += loss_D_local.item()
        
        loss_D_local.backward()
        optimizerD_local.step()
        
        # Clip weights of discriminator
        for p in DNet_local.parameters():
            p.data.clamp_(-args.clip_value, args.clip_value)
        
        if batch_id % args.n_critic == 0:
            """Update G Network"""
            optimizerG.zero_grad()
            fake_source_embedding = GNet(target_embedding)
            
            # adversarial loss
            loss_G = -DNet_local(fake_source_embedding).mean()
            
            total_error_G += loss_G.item()
            
            loss_G.step()
            optimizerG.step()
            
    logger.info('Epoch: %i, Local Discrimator Updates: Loss D_global: %f, Loss G: %f'%(epoch+1, total_error_D_global, total_error_G))
    
    
    
    
    
    
    
    
        """Update D Net"""
        # train with source domain
        DNet_local.zero_grad()
        source_data = source_x.to(device).float()
        label = torch.full((batch_size,), real_label, device=device)
        output = DNet_local(source_data).view(-1)
        #print(output.mean().item())
        errD_local_source = criterion_gan_local(output, label).mean()
        errD_local_source.backward()

        # train with target domain
        target_data = target_x.to(device).float()
        fake = GNet(target_data)
        #print(fake)
        label.fill_(fake_label)
        output = DNet_local(fake.detach()).view(-1)
        errD_local_target = criterion_gan_local(output, label)
        errD_local_target = (errD_local_target * target_weight).mean()
        errD_local_target.backward()
        total_error_D_local += (errD_local_source + errD_local_target).item()
        
        if batch_id % args.gap == 0:
            optimizerD_local.step()

        """Update G Network"""
        GNet.zero_grad()
        label.fill_(real_label) # fake labels are real for generator cost
        output = DNet_local(fake).view(-1)

        errG = criterion_gan_local(output, label)
        errG = (errG * target_weight).mean()
        errG.backward()
        optimizerG.step()
        total_error_G += errG.item()



In [81]:
a

tensor([[0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.]])