# <center>Visualizing Results of Trained Transformers</center>

In [None]:
from collections import Counter
import os
os.chdir('../')

import numpy as np
import matplotlib.pyplot as plt
import torch
from tqdm import tqdm

from main import build_parser, AutoformerEstimator

In [None]:
DEVICE = torch.device(0)
SAVE = True
SAVEDIR = 'notebooks/img'

In [None]:
def find_ckpt_synthetic(dataset, attn, dm, dff, lwin, nth=1): 
    exp_dir = '/usr2/home/yongyiw/ckpt/lstf/Synthetic/{}/EXP_{}.npy_tfM_fh_le96_ll48_lp192_mautoformer_attn{}_ein1_din1_dout1_dm{}_dff{}_nh8_ne2_nd1_lw{}_c3_t0_dr0.05_E6_B32_p2_lr0.0001_sch1'.format(
        dataset, dataset, attn, dm, dff, lwin
    )
    return os.path.join(exp_dir, os.listdir(exp_dir)[nth], 'ckpt.pt')

def find_ckpt_real(dataset, dpath, lenc, llabel, lpred, attn, dim, lwin, nth=1): 
    exp_dir = '/usr2/home/yongyiw/ckpt/lstf/{}/autoformer/EXP_{}_tfM_fh_le{}_ll{}_lp{}_mautoformer_attn{}_ein{}_din{}_dout{}_dm512_dff2048_nh8_ne2_nd1_lw{}_c3_t1_dr0.05_E6_B32_p2_lr0.0001_sch1/'.format(
        dataset, os.path.basename(dpath), lenc, llabel, lpred, attn, dim, dim, dim, lwin
    )
    return os.path.join(exp_dir, os.listdir(exp_dir)[nth], 'ckpt.pt')

## Prediction Plots

### Synthetic Dataset

In [None]:
DATASETS = ['sinx', 'sinx_sin2x_sin4x', 'xsinx', 'sinx_c', 'x', 'sinx_x', 'sinx_x2_sym', 'sinx_x2_asym', 'sinx_sqrtx']
LEN_ENC = 96
LEN_LABEL = 48
LEN_PRED = 192
D_MODEL = 512
D_FF = 2048
START = 700

In [None]:
for dataset in DATASETS: 
    res = {}
    dpath = '/usr2/home/yongyiw/data/synth/{}.npy'.format(dataset)
    for attn in ['autocorrelation', 'dot']: 
        for lwin in [25, 0]: 
            parser = build_parser()
            cfg = parser.parse_args([
                '--data', 'Synthetic', 
                '--data_path', dpath, 
                '--ckpt', './temp', 
                '--len_enc', str(LEN_ENC), 
                '--len_label', str(LEN_LABEL), 
                '--len_pred', str(LEN_PRED), 
                '--model', 'autoformer', 
                '--attn', attn, 
                '--d_model', str(D_MODEL), 
                '--d_ff', str(D_FF), 
                '--n_enc_layers', '2', 
                '--n_dec_layers', '1', 
                '--len_window', str(lwin), 
                '--no_temporal', 
                '--lr_schedule', 
                '--devices', '0', '1', '2', '3', 
                '--no_verbose'
            ])
            cfg.config = ''

            estimator = AutoformerEstimator(cfg)
            _, _, testloader = estimator.get_data()
            _, yhats, ys = estimator.test(
                testloader, 
                find_ckpt_synthetic(dataset, attn, D_MODEL, D_FF, lwin, nth=0)
            )
            res[(attn, lwin)] = yhats
            os.system('rm -rf temp')
    res['y'] = ys
    
    plt.figure(figsize=(16, 4))
    plt.plot(res['y'][START, :], label='Label', color='black')
    plt.plot(res[('dot', 0)][START, :], label='Transformer')
    plt.plot(res[('autocorrelation', 0)][START, :], label='+ AutoCorrelation')
    plt.plot(res[('dot', 25)][START, :], label='+ Decomposition')
    plt.plot(res[('autocorrelation', 25)][START, :], label='+ AutoCorrelation + Decomposition')

    plt.title('{}'.format(dataset))
    plt.legend(loc='lower right')
    if SAVE: 
        plt.savefig(os.path.join(SAVEDIR, 'pred_{}.png'.format(dataset)))
    plt.show()

### Real Dataset

In [None]:
DATASETS = ['ETTm2', 'Electricity', 'Exchange', 'Traffic', 'Weather', 'ILI']
CONFIG = {
    'ETTm2': {
        'dpath': '/usr2/home/yongyiw/data/ETT-small/ETTm2.csv', 
        'dim': 7
    }, 
    'Electricity': {
        'dpath': '/usr2/home/yongyiw/data/electricity/electricity.csv', 
        'dim': 321
    }, 
    'Exchange': {
        'dpath': '/usr2/home/yongyiw/data/exchange_rate/exchange_rate.csv', 
        'dim': 8
    }, 
    'Traffic': {
        'dpath': '/usr2/home/yongyiw/data/traffic/traffic.csv', 
        'dim': 862
    }, 
    'Weather': {
        'dpath': '/usr2/home/yongyiw/data/weather/weather.csv', 
        'dim': 21
    }, 
    'ILI': {
        'dpath': '/usr2/home/yongyiw/data/illness/national_illness.csv', 
        'dim': 7
    }
}
START = 42
DIM = -1

In [None]:
for dataset in DATASETS: 
    LEN_ENC = 96 if dataset != 'ILI' else 36
    LEN_LABEL = 48 if dataset != 'ILI' else 18
    LEN_PRED = 192 if dataset != 'ILI' else 48
    
    res = {}
    for attn in ['autocorrelation', 'dot']: 
        for lwin in [25, 0]: 
            parser = build_parser()
            cfg = parser.parse_args([
                '--data', dataset, 
                '--data_path', CONFIG[dataset]['dpath'], 
                '--ckpt', './temp', 
                '--len_enc', str(LEN_ENC), 
                '--len_label', str(LEN_LABEL), 
                '--len_pred', str(LEN_PRED), 
                '--model', 'autoformer', 
                '--attn', attn, 
                '--d_enc_in', str(CONFIG[dataset]['dim']), 
                '--d_dec_in', str(CONFIG[dataset]['dim']), 
                '--d_dec_out', str(CONFIG[dataset]['dim']), 
                '--n_enc_layers', '2', 
                '--n_dec_layers', '1', 
                '--len_window', str(lwin), 
                '--lr_schedule', 
                '--devices', '0', '1', '2', '3', 
                '--no_verbose'
            ])
            cfg.config = ''

            estimator = AutoformerEstimator(cfg)
            _, _, testloader = estimator.get_data()
            _, yhats, ys = estimator.test(
                testloader, 
                find_ckpt_real(dataset, CONFIG[dataset]['dpath'], LEN_ENC, LEN_LABEL, LEN_PRED, attn, CONFIG[dataset]['dim'], lwin)
            )
            res[(attn, lwin)] = yhats
            os.system('rm -rf temp')
    res['y'] = ys
    
    plt.figure(figsize=(12, 4))
    plt.plot(res['y'][START, :, DIM], label='Label', color='black')
    plt.plot(res[('dot', 0)][START, :, DIM], label='Transformer')
    plt.plot(res[('autocorrelation', 0)][START, :, DIM], label='+ AutoCorrelation')
    plt.plot(res[('dot', 25)][START, :, DIM], label='+ Decomposition')
    plt.plot(res[('autocorrelation', 25)][START, :, DIM], label='+ AutoCorrelation + Decomposition')

    plt.title('{}'.format(dataset))
    plt.legend(loc='upper left')
    if SAVE: 
        plt.savefig(os.path.join(SAVEDIR, 'pred_{}.png'.format(dataset)))
    plt.show()

## Explainability of $\tau$

In [None]:
DATASETS = ['sinx', 'sinx_sin2x_sin4x', 'xsinx', 'sinx_c', 'x', 'sinx_x', 'sinx_x2_sym', 'sinx_x2_asym', 'sinx_sqrtx']
LEN_ENC = 96
LEN_LABEL = 48
LEN_PRED = 192
ENC_WEIGHTS = True
I_LAYER = 0
TOPK = 1

def step(model, data, counter): 
    enc_x = data['x'].to(DEVICE, dtype=torch.float)
    enc_x_time = data['x_time'].to(DEVICE, dtype=torch.float)
    y = data['y'].to(DEVICE, dtype=torch.float)
    dec_y_time = data['y_time'].to(DEVICE, dtype=torch.float)
    dec_y_s, dec_y_t = estimator.get_dec_input(enc_x, y)

    _, ((enc_self_weights, _), (dec_self_weights, dec_cross_weights)) = model(
        enc_x, enc_x_time, dec_y_s, dec_y_t, dec_y_time, 
        dec_self_mask=None
    )
    
    corr = (enc_self_weights if ENC_WEIGHTS else dec_self_weights)[I_LAYER]
    _, taus = torch.topk(torch.mean(corr, dim=-2), TOPK, dim=-1)
    counter.update(Counter(taus.cpu().reshape(-1).numpy()))

In [None]:
attn = 'autocorrelation'
lwin = 25

for dataset in DATASETS: 
    for D_MODEL, D_FF in [(512, 2048)]: # [(128, 512), (512, 2048)]: 
        counter = Counter()
        dpath = '/usr2/home/yongyiw/data/synth/{}.npy'.format(dataset)
        parser = build_parser()
        os.system('mkdir -p temp')
        cfg = parser.parse_args([
            '--data', 'Synthetic', 
            '--data_path', dpath, 
            '--ckpt', './temp', 
            '--len_enc', str(LEN_ENC), 
            '--len_label', str(LEN_LABEL), 
            '--len_pred', str(LEN_PRED), 
            '--model', 'autoformer', 
            '--attn', 'autocorrelation', 
            '--d_model', str(D_MODEL), 
            '--d_ff', str(D_FF), 
            '--n_enc_layers', '2', 
            '--n_dec_layers', '1', 
            '--len_window', '25', 
            '--no_temporal', 
            '--output_attn', 
            '--lr_schedule', 
            '--devices', '0', '1', '2', '3', 
            '--no_verbose'
        ])
        cfg.config = ''

        estimator = AutoformerEstimator(cfg)
        _, _, testloader = estimator.get_data()
        estimator.load(find_ckpt_synthetic(dataset, attn, D_MODEL, D_FF, lwin, nth=0))
        
        tbar = tqdm(testloader, dynamic_ncols=True)
        for data in tbar: 
            step(estimator.model, data, counter)
        os.system('rm -rf temp')

        labels, values = zip(*sorted(counter.items()))
        indexes = np.arange(len(labels))
        width = 0.8

        plt.figure(figsize=(24, 4))
        plt.bar(indexes, values, width)
        plt.xticks(indexes, labels)
        plt.title(dataset)
        plt.xlabel('tau (0 ~ 96)')
        plt.ylabel('Count')
        if SAVE: 
            plt.savefig(os.path.join(SAVEDIR, 'pred_{}.png'.format(dataset)))
        plt.show()