In [36]:
import numpy as np
import os
import argparse
import random
from tqdm import tqdm
import datetime
import time

import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import ReduceLROnPlateau
import torch.nn.functional as F
from martins.complex_transformer import ComplexTransformer
from dataset import TimeSeriesDataset


In [37]:

# fake args
parameters = {
    'data_path': '/home/tianqinl/time-series-domain-adaptation/data_unzip/',
    'train_file': 'train_{}.pkl',
    'vali_file': 'validation_{}.pkl',
    'task': '3Av2',
    'batch_size': 32,
    'lr_clf': 1e-3,
    'epochs': 2,
    'PATH': '/home/tianqinl/time-series-domain-adaptation/JDA/data_resultsb400.e500.lr0.0001.task3Av2'
}

class ARGS:
    def __init__(self, parameters):
        self.data_path = parameters['data_path']
        self.train_file = parameters['train_file']
        self.vali_file = parameters['vali_file']
        self.task = parameters['task']
        self.batch_size = parameters['batch_size']
        self.lr_clf = parameters['lr_clf']
        self.epochs = parameters['epochs']
        self.PATH = parameters['PATH']


args = ARGS(parameters)


In [40]:
def call_model(x, y, model_dir="FNN_trained_model/Final_FNN_3Av2/", epoch="model.ep100", task = "3Av2"):
    
    class FNN(nn.Module):
        def __init__(self, d_in, d_h, d_out, dp):
            super(FNN, self).__init__()
            self.fc1 = nn.Linear(d_in, d_h)
            self.fc2 = nn.Linear(d_h, d_out)
            self.dp = nn.Dropout(dp)

        def forward(self, x):
            x = self.fc1(x)
            x = F.relu(x)
            x = self.dp(x)
            x = self.fc2(x)

            return x


    
    seq_len = 10
    feature_dim = 160
    d_out = 50 if task == "3Av2" else 65
    device = torch.device("cuda:0")

    
    CNet_path = model_dir + "CNet_" + epoch
    encoder_path = model_dir + "Encoder_" + epoch
    
    CNet = FNN(d_in=feature_dim * 2 * seq_len, d_h=500, d_out=d_out, dp=0.5)
    
    encoder = ComplexTransformer(layers=1,
                               time_step=seq_len,
                               input_dim=feature_dim,
                               hidden_size=512,
                               output_dim=512,
                               num_heads=8,
                               out_dropout=0.5)

    if torch.cuda.is_available():
        CNet.load_state_dict(torch.load(CNet_path))
        encoder.load_state_dict(torch.load(encoder_path))
    else:
        CNet.load_state_dict(torch.load(CNet_path, map_location=torch.device('cpu')))
        encoder.load_state_dict(torch.load(encoder_path, map_location=torch.device('cpu')))
    batch_size = x.shape[0]
    CNet.eval()
    encoder.eval()

    with torch.no_grad():
        #normalize data
        x = (x - x_mean_tr) / x_std_tr
        # take the real and imaginary part out
        real = x[:,:,0].reshape(batch_size, seq_len, feature_dim).float()
        imag = x[:,:,1].reshape(batch_size, seq_len, feature_dim).float()
        if torch.cuda.is_available():
            real.to(device)
            imag.to(device)
        real, imag = encoder(real, imag)
        pred = CNet(torch.cat((real, imag), -1).reshape(x.shape[0], -1))
        loss = criterion(pred, y.argmax(-1))
        
    return pred, loss, CNet, encoder


In [41]:
# test

training_set = TimeSeriesDataset(root_dir=args.data_path, file_name=args.train_file.format(args.task), train=True)
vali_set = TimeSeriesDataset(root_dir=args.data_path, file_name=args.vali_file.format(args.task), train=True)
# test_set = TimeSeriesDataset(root_dir=args.data_path, file_name=args.file.format(args.task), train=False)

train_loader = DataLoader(training_set, batch_size=args.batch_size, shuffle=True)
vali_loader = DataLoader(vali_set, batch_size=args.batch_size, shuffle=True)
# test_loader = DataLoader(test_set, batch_size=args.batch_size, shuffle=True)

# Encoding by complex transformer
x_mean_tr = training_set.data_mean
x_std_tr = training_set.data_std

preds = []
correct_vali = 0
total_bs_vali = 0
vali_loss = 0
criterion = nn.CrossEntropyLoss()

for batch_id, (x, y) in enumerate(tqdm(train_loader)):
    if torch.cuda.is_available():
        x, y = x.to(device), y.to(device)
    ############ Call Model ############
    pred, loss, _, _ = call_model(x, y)
    ####################################
    #print(pred.argmax(-1), y.argmax(-1))
    correct_vali += (pred.argmax(-1) == y.argmax(-1)).sum().item()
    total_bs_vali += y.shape[0]
    
vali_acc = float(correct_vali) / total_bs_vali
vali_log_str = " validation_acc: "+ str(vali_acc)
preds.append(pred)

print(vali_log_str)


100%|██████████| 154/154 [00:07<00:00, 20.91it/s]

 validation_acc: 1.0



