In [69]:
import numpy as np
import torch
from torch.utils import data
from torch import nn
import torch.optim as optim
from tqdm import tqdm
import argparse


In [2]:
class TimeSeriesChunkDataset(data.Dataset):
    def __init__(self, x, y, context):
        super(TimeSeriesChunkDataset, self).__init__()
        self.x = x
        self.y = y
        self.context = context
        self.points_per_series = self.x.shape[1] - self.context + 1
        
    def __len__(self):
        return self.x.shape[0] * self.points_per_series

    def __getitem__(self, index):
        index_series = index // self.points_per_series
        index_point = index % self.points_per_series
        return_x = self.x[index_series, index_point:index_point+self.context, :]
        return_y = np.argmax(self.y[index_series, :])
        return return_x, return_y

In [3]:
class Classifier(nn.Module):
    def __init__(self, layers_size, dim_out, dropout=0.3):
        super(Classifier, self).__init__()
        self.layers = []
        self.activs = []
        self.dropouts = []
        self.layers_size = layers_size
        
        for i in range(len(layers_size)-1):
            self.layers.append(nn.Linear(layers_size[i], layers_size[i+1]))
            self.activs.append(nn.ReLU())
            if i < len(layers_size)-2:
                self.dropouts.append(nn.Dropout(p=dropout))

        self.layers.append(nn.Linear(layers_size[-1], dim_out))
        self.nlayer = len(self.layers)
        self.nactivs = len(self.activs)
        self.ndropouts = len(self.dropouts)

        self.layers = nn.ModuleList(self.layers)
        
    def forward(self, x):
        out = x
        for i in range(self.nlayer):
            out = self.layers[i](out)
            if i < self.nactivs:
                out = self.activs[i](out)
            if i < self.ndropouts:
                out = self.dropouts[i](out)

        return out


In [4]:
class Discriminator(nn.Module):
    '''
    credit: from https://github.com/martinmamql/complex_da
    '''
    def __init__(self, feature_dim, d_out):
        super(Discriminator, self).__init__()
        self.net = nn.Sequential(

            nn.Linear(feature_dim, feature_dim),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Linear(feature_dim, feature_dim),
            nn.LayerNorm(feature_dim),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Linear(feature_dim, feature_dim),
            nn.LayerNorm(feature_dim),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Linear(feature_dim, feature_dim),
            nn.LayerNorm(feature_dim),
            nn.LeakyReLU(0.2, inplace=True),
        ) 
        self.fc = nn.Linear(3200, 1)
        self.sigmoid = nn.Sigmoid()
        
    def forward(self, x):
        # x: [bs, seq, feature_dim]
        x = self.net(x)
        bs = x.shape[0]
        x = x.reshape(bs, -1)
        out = self.sigmoid(self.fc(x))
        return out

In [102]:
class Generator(nn.Module):
    def __init__(self, **kwargs):
        super(Generator, self).__init__()
        self.transformer_layer_args = {'d_model':2, 'nhead':1, 'dim_feedforward':1024, 'dropout':0.1, 'activation':'gelu'}
        self.transformer_args = {'num_layers':3, 'norm':None}
        self.transformer_layer_args.update(kwargs['transformer_layer'])
        self.transformer_args.update(kwargs['transformer'])
        
        self.transformer_layer = nn.TransformerEncoderLayer(**self.transformer_layer_args)
        self.transformer = nn.TransformerEncoder(self.transformer_layer, **self.transformer_args)
        
    def forward(self, x):
        out = self.transformer(x)
        return out
        

In [139]:
class SourceDomainClassifier(nn.Module):
    def __init__(self, **kwargs):
        super(SourceDomainClassifier, self).__init__()
        self.net = nn.Sequential(
            Generator(**(kwargs["generator"])),
            nn.Flatten(),
            Classifier(**(kwargs['classifier']))
        )
        
    def forward(self, x):
        x = self.net(x)
        return x
        

In [148]:
def get_accuracy(preds, label):
    class_preds = torch.argmax(preds, dim=1)
    correct = torch.eq(class_preds, label).float()
    acc = correct.sum()
    return acc

def inference(model, dataloader, device, testing=False):
    model.eval()

    if testing: 
        result = []
    else:
        inference_acc = 0.0

    num_data = 0.0
    with torch.no_grad():
        for key, (x_batch, y_batch) in enumerate(dataloader):
            num_data += y_batch.shape[0]
            x_batch = x_batch.to(device)
            y_batch = y_batch.long().to(device)
            preds = model(x_batch)
            if testing:
                result.extend(torch.argmax(nn.functional.softmax(preds, dim=1), dim=1).cpu().numpy())
            else:
                inference_acc += get_accuracy(preds, y_batch.squeeze_()).item()
    

    if testing:
        return result
    else:
        return inference_acc / num_data
    
def init_weights(m):
    if type(m) == nn.Linear:
        torch.nn.init.xavier_uniform_(m.weight)
        m.bias.data.fill_(0.00)
    

In [143]:
def train(model, train_dataloader, vali_dataloader, lr, n_epochs, device):
    optimizer = optim.Adam(model.parameters(), lr=lr)
#     scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.7, patience=3)
    loss_fn = nn.CrossEntropyLoss()
    train_loss_ = []
    train_acc_ = []
    vali_acc_ = []
    
    for epoch in range(1, n_epochs+1):
        train_loss = 0.0
        train_acc = 0.0

        num_data = 0.0
        num_batches = len(train_dataloader)
        for batches, (x_batch, y_batch) in tqdm(enumerate(train_dataloader), total=num_batches):
            model.train()
            num_data += y_batch.shape[0]
            x_batch = x_batch.to(device)
            y_batch = y_batch.long().to(device)
            optimizer.zero_grad()
            preds = model(x_batch)
            loss = loss_fn(preds, y_batch.squeeze_())
            acc = get_accuracy(preds, y_batch)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()
            train_acc += acc.item()
            
            if batches > 1 and batches % int(num_batches/3) == 0:
                vali_acc = inference(model, vali_dataloader, device)
                print("validation_acc: ", vali_acc)

        vali_acc = inference(model, vali_dataloader, device)
        train_loss_.append(train_loss/num_data)
        train_acc_.append(train_acc/num_data)
        vali_acc_.append(vali_acc)
        # scheduler.step(vali_acc)
        name = "model_" + str(epoch) + ".t7"   
        np.save("train_loss_.npy",train_loss_)
        np.save("train_acc_.npy",train_acc_)
        np.save("vali_acc_.npy",vali_acc_)
        torch.save(model.state_dict(), name)

        print("epoch {}: train_loss: {}, train_acc: {}, vali_acc: {}".format(epoch, train_loss/num_data, train_acc/num_data, vali_acc))


In [146]:
#local only

class fake_args():
    def __init__(self, **kwargs):
        self.__dict__.update(kwargs)

        
args = fake_args(data_path='./data_unzip/', 
                 task='3A', 
                 batch_size=30,
                 epochs=2,
                 lr=1e-3,
                 context=10)

args.task = "processed_file_3Av2.pkl" if args.task == "3A" else "processed_file_3E.pkl"
args.data_path = args.data_path + args.task

In [149]:
if __name__ == "__main__":
    cuda = torch.cuda.is_available()
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    
    data_dict = np.load(args.data_path, allow_pickle=True)
    
    # TODO: to be commented
    data_dict['tr_data'] = data_dict['tr_data'][:10]
    data_dict['tr_data'] = data_dict['tr_data'][:10]
    
    # split train data and validation data
    np.random.seed(seed=0)
    indices = np.random.permutation(data_dict['tr_data'].shape[0])
    train_x = data_dict['tr_data'][indices[:int(indices.shape[0]*0.9)],:,:].astype("float32")
    train_y = data_dict['tr_lbl'][indices[:int(indices.shape[0]*0.9)],:].astype("float32")
    vali_x = data_dict['tr_data'][indices[int(indices.shape[0]*0.9):],:,:].astype("float32")
    vali_y = data_dict['tr_lbl'][indices[int(indices.shape[0]*0.9):],:].astype("float32")

    
    train_dataset = TimeSeriesChunkDataset(train_x, train_y, args.context)
    train_dataloader = data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, pin_memory=True, num_workers=4)
    vali_dataset = TimeSeriesChunkDataset(vali_x, vali_y, args.context)
    vali_dataloader = data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, pin_memory=True, num_workers=4)
    
    model_args = {
        'classifier': {
            'layers_size': [args.context*2, 100, 100], 
            'dim_out': 50 if args.task == '3A' else 65
        },
        'generator':{
            'transformer_layer': {},
            'transformer': {}
        }
    }
    
    
    
    source_classifier = SourceDomainClassifier(**model_args)
    source_classifier.apply(init_weights)
    
    train(source_classifier, train_dataloader, vali_dataloader, args.lr, args.epochs, device)
    



  0%|          | 0/478 [00:00<?, ?it/s][A
  1%|          | 4/478 [00:00<00:12, 37.79it/s][A
  2%|▏         | 10/478 [00:00<00:11, 41.13it/s][A
  3%|▎         | 15/478 [00:00<00:10, 42.67it/s][A
  4%|▍         | 20/478 [00:00<00:10, 43.94it/s][A
  5%|▌         | 25/478 [00:00<00:10, 44.99it/s][A
  6%|▋         | 30/478 [00:00<00:09, 46.18it/s][A
  8%|▊         | 36/478 [00:00<00:09, 47.59it/s][A
  9%|▉         | 42/478 [00:00<00:09, 48.27it/s][A
 10%|▉         | 47/478 [00:00<00:09, 47.05it/s][A
 11%|█         | 53/478 [00:01<00:08, 48.15it/s][A
 12%|█▏        | 59/478 [00:01<00:08, 49.19it/s][A
 13%|█▎        | 64/478 [00:01<00:08, 49.34it/s][A
 15%|█▍        | 70/478 [00:01<00:08, 49.84it/s][A
 16%|█▌        | 75/478 [00:01<00:08, 49.39it/s][A
 17%|█▋        | 80/478 [00:01<00:08, 48.72it/s][A
 18%|█▊        | 85/478 [00:01<00:08, 48.52it/s][A
 19%|█▉        | 91/478 [00:01<00:07, 49.53it/s][A
 20%|██        | 96/478 [00:01<00:07, 49.63it/s][A
 21%|██▏       | 102/

validation_acc:  0.3329841469376353



 36%|███▌      | 172/478 [00:05<00:20, 15.07it/s][A
 37%|███▋      | 177/478 [00:05<00:15, 19.05it/s][A
 38%|███▊      | 182/478 [00:05<00:12, 23.00it/s][A
 39%|███▉      | 187/478 [00:05<00:10, 26.79it/s][A
 40%|████      | 192/478 [00:05<00:09, 30.79it/s][A
 41%|████      | 197/478 [00:05<00:08, 34.55it/s][A
 42%|████▏     | 202/478 [00:06<00:07, 37.61it/s][A
 43%|████▎     | 207/478 [00:06<00:06, 39.77it/s][A
 44%|████▍     | 212/478 [00:06<00:06, 41.12it/s][A
 45%|████▌     | 217/478 [00:06<00:06, 42.83it/s][A
 46%|████▋     | 222/478 [00:06<00:05, 44.39it/s][A
 47%|████▋     | 227/478 [00:06<00:05, 44.95it/s][A
 49%|████▊     | 232/478 [00:06<00:05, 44.72it/s][A
 50%|████▉     | 237/478 [00:06<00:05, 45.40it/s][A
 51%|█████     | 242/478 [00:06<00:05, 46.28it/s][A
 52%|█████▏    | 247/478 [00:06<00:04, 47.07it/s][A
 53%|█████▎    | 252/478 [00:07<00:04, 47.65it/s][A
 54%|█████▍    | 257/478 [00:07<00:04, 47.25it/s][A
 55%|█████▍    | 262/478 [00:07<00:04, 46.21i

validation_acc:  0.3674139255534604



 69%|██████▉   | 332/478 [00:10<00:11, 12.90it/s][A
 71%|███████   | 337/478 [00:10<00:08, 16.50it/s][A
 72%|███████▏  | 342/478 [00:10<00:06, 20.56it/s][A
 73%|███████▎  | 347/478 [00:11<00:05, 24.78it/s][A
 74%|███████▎  | 352/478 [00:11<00:04, 28.93it/s][A
 75%|███████▍  | 357/478 [00:11<00:03, 32.43it/s][A
 76%|███████▌  | 362/478 [00:11<00:03, 35.40it/s][A
 77%|███████▋  | 367/478 [00:11<00:02, 37.64it/s][A
 78%|███████▊  | 372/478 [00:11<00:02, 39.49it/s][A
 79%|███████▉  | 377/478 [00:11<00:02, 40.26it/s][A
 80%|███████▉  | 382/478 [00:11<00:02, 40.84it/s][A
 81%|████████  | 387/478 [00:11<00:02, 41.43it/s][A
 82%|████████▏ | 392/478 [00:12<00:02, 41.74it/s][A
 83%|████████▎ | 397/478 [00:12<00:01, 43.05it/s][A
 84%|████████▍ | 402/478 [00:12<00:01, 44.36it/s][A
 85%|████████▌ | 407/478 [00:12<00:01, 45.11it/s][A
 86%|████████▌ | 412/478 [00:12<00:01, 45.60it/s][A
 87%|████████▋ | 417/478 [00:12<00:01, 44.93it/s][A
 88%|████████▊ | 422/478 [00:12<00:01, 44.47i

validation_acc:  0.3686709965779733





epoch 1: train_loss: 0.059469761641317426, train_acc: 0.3227180669041134, vali_acc: 0.37202318597667433



  0%|          | 0/478 [00:00<?, ?it/s][A
  1%|          | 4/478 [00:00<00:11, 39.57it/s][A
  2%|▏         | 9/478 [00:00<00:11, 41.30it/s][A
  3%|▎         | 14/478 [00:00<00:10, 42.58it/s][A
  4%|▍         | 19/478 [00:00<00:10, 43.44it/s][A
  5%|▌         | 24/478 [00:00<00:10, 44.07it/s][A
  6%|▌         | 29/478 [00:00<00:10, 44.75it/s][A
  7%|▋         | 34/478 [00:00<00:09, 45.26it/s][A
  8%|▊         | 39/478 [00:00<00:09, 45.40it/s][A
  9%|▉         | 44/478 [00:00<00:09, 45.88it/s][A
 10%|█         | 49/478 [00:01<00:09, 45.83it/s][A
 11%|█▏        | 54/478 [00:01<00:09, 45.58it/s][A
 12%|█▏        | 59/478 [00:01<00:09, 45.60it/s][A
 13%|█▎        | 64/478 [00:01<00:09, 45.86it/s][A
 14%|█▍        | 69/478 [00:01<00:08, 46.21it/s][A
 15%|█▌        | 74/478 [00:01<00:08, 46.24it/s][A
 17%|█▋        | 79/478 [00:01<00:08, 46.38it/s][A
 18%|█▊        | 84/478 [00:01<00:08, 46.24it/s][A
 19%|█▊        | 89/478 [00:01<00:08, 46.42it/s][A
 20%|█▉        | 94/47

validation_acc:  0.3845938962218032



 36%|███▋      | 174/478 [00:05<00:24, 12.61it/s][A
 37%|███▋      | 179/478 [00:05<00:18, 16.11it/s][A
 38%|███▊      | 184/478 [00:05<00:14, 19.75it/s][A
 40%|███▉      | 189/478 [00:06<00:12, 23.29it/s][A
 41%|████      | 194/478 [00:06<00:10, 27.27it/s][A
 42%|████▏     | 199/478 [00:06<00:09, 30.95it/s][A
 43%|████▎     | 204/478 [00:06<00:08, 34.16it/s][A
 44%|████▎     | 209/478 [00:06<00:07, 36.14it/s][A
 45%|████▍     | 214/478 [00:06<00:06, 38.25it/s][A
 46%|████▌     | 219/478 [00:06<00:06, 39.65it/s][A
 47%|████▋     | 224/478 [00:06<00:06, 41.18it/s][A
 48%|████▊     | 229/478 [00:06<00:05, 42.24it/s][A
 49%|████▉     | 234/478 [00:07<00:05, 43.15it/s][A
 50%|█████     | 239/478 [00:07<00:05, 44.25it/s][A
 51%|█████     | 244/478 [00:07<00:05, 44.70it/s][A
 52%|█████▏    | 249/478 [00:07<00:05, 45.09it/s][A
 53%|█████▎    | 254/478 [00:07<00:04, 45.60it/s][A
 54%|█████▍    | 259/478 [00:07<00:04, 45.83it/s][A
 55%|█████▌    | 264/478 [00:07<00:04, 45.89i

validation_acc:  0.3907395767860884



 69%|██████▉   | 329/478 [00:11<00:11, 12.83it/s][A
 70%|██████▉   | 334/478 [00:11<00:08, 16.34it/s][A
 71%|███████   | 339/478 [00:11<00:06, 20.25it/s][A
 72%|███████▏  | 344/478 [00:11<00:05, 24.39it/s][A
 73%|███████▎  | 349/478 [00:11<00:04, 28.44it/s][A
 74%|███████▍  | 354/478 [00:11<00:03, 32.05it/s][A
 75%|███████▌  | 359/478 [00:11<00:03, 35.31it/s][A
 76%|███████▌  | 364/478 [00:11<00:03, 37.79it/s][A
 77%|███████▋  | 369/478 [00:11<00:02, 40.03it/s][A
 78%|███████▊  | 374/478 [00:12<00:02, 41.60it/s][A
 79%|███████▉  | 379/478 [00:12<00:02, 42.62it/s][A
 80%|████████  | 384/478 [00:12<00:02, 43.26it/s][A
 81%|████████▏ | 389/478 [00:12<00:02, 44.02it/s][A
 82%|████████▏ | 394/478 [00:12<00:01, 44.42it/s][A
 83%|████████▎ | 399/478 [00:12<00:01, 45.07it/s][A
 85%|████████▍ | 404/478 [00:12<00:01, 45.47it/s][A
 86%|████████▌ | 409/478 [00:12<00:01, 45.47it/s][A
 87%|████████▋ | 414/478 [00:12<00:01, 45.67it/s][A
 88%|████████▊ | 419/478 [00:13<00:01, 45.65i

validation_acc:  0.3951393253718835





epoch 2: train_loss: 0.05015768307332129, train_acc: 0.3864795027585725, vali_acc: 0.3920664850897409


In [122]:
for name, param in source_classifier.named_parameters():
    if param.requires_grad:
        print (name, param.data, param.dtype)

transformation.transformer_layer.self_attn.in_proj_weight tensor([[-0.0365, -0.8347],
        [-0.1136, -0.7268],
        [-0.8044, -0.5697],
        [ 0.5512, -0.5657],
        [ 0.5618, -0.2861],
        [-0.3158,  0.6102]]) torch.float32
transformation.transformer_layer.self_attn.in_proj_bias tensor([0., 0., 0., 0., 0., 0.]) torch.float32
transformation.transformer_layer.self_attn.out_proj.weight tensor([[-1.0541, -0.2033],
        [-0.3923, -1.0161]]) torch.float32
transformation.transformer_layer.self_attn.out_proj.bias tensor([0., 0.]) torch.float32
transformation.transformer_layer.linear1.weight tensor([[ 0.0526,  0.0169],
        [-0.0167, -0.0424],
        [ 0.0643, -0.0719],
        ...,
        [ 0.0023, -0.0527],
        [-0.0460, -0.0076],
        [-0.0416, -0.0403]]) torch.float32
transformation.transformer_layer.linear1.bias tensor([0., 0., 0.,  ..., 0., 0., 0.]) torch.float32
transformation.transformer_layer.linear2.weight tensor([[-0.0639,  0.0704, -0.0497,  ..., -0.02