# Stockformer Demo

In [None]:
import sys

# if not 'Informer2020' in sys.path:
#     sys.path += ['Informer2020']

## Open log_dir

In [None]:
from utils.tools import dotdict
import torch
import numpy as np
import pandas as pd
import os
from pprint import pprint
import matplotlib.pyplot as plt

from utils.ipynb_helpers import (
    setting_from_args,
    handle_gpu,
    read_data,
)
import yaml
from utils.stock_metrics import (
    apply_threshold_metric,
    PctProfitDirection,
    PctProfitTanh,
    LogPctProfitDirection,
    LogPctProfitTanhV1,
    pct_direction,
)
from utils.results_analysis import open_results, get_tuned_metrics

# log_dir = "bbtest_logs/2023_01_25_11_46_44_stockformer_sl16_ei9_dm512_nh16_el4_ebtime2vec_app/version_45"
# log_dir = "bbtest_logs/2023_01_29_12_38_22_stockformer_sl16_ei9_dm512_nh16_el4_ebtime2vec_app/version_45"
# log_dir = "bbtest_logs/2023_01_29_17_53_13_stockformer_sl16_ei12_dm512_nh16_el4_ebtime2vec_app/version_6"
# log_dir = "bbtest_logs/2023_01_31_14_29_23_stockformer_sl16_ei10_dm512_nh16_el4_ebtime2vec_add/version_5"
log_dir = "lightning_logs/2023_01_31_17_58_24_stockformer_sl16_ei9_dm512_nh16_el4_ebtime2vec_add/version_0"
with open(os.path.join(log_dir, "hparams.yaml"), "r") as file:
    args = dotdict(yaml.load(file, Loader=yaml.FullLoader))

args

In [None]:
df = read_data(os.path.join(args.root_path,args.data_path))
df.head()

## Visualization

In [None]:
tpd_dict = open_results(log_dir, args, df)

print("Open true/pred data for:", list(tpd_dict.keys()))

# [samples, pred_len, dimensions]
print(
    tpd_dict["train"]["trues"].shape, tpd_dict["val"]["trues"].shape, tpd_dict["test"]["trues"].shape, "\n\n"
)

for data_group in tpd_dict:
    trues = tpd_dict[data_group]["trues"]
    preds = tpd_dict[data_group]["preds"]
    dates = tpd_dict[data_group]["dates"]
    print(
        f"{data_group}\ttrues.shape: {trues.shape}, preds.shape: {preds.shape}, dates.shape: {dates.shape}"
    )

    MSE = np.square(np.subtract(trues, preds)).mean()
    RMSE = np.sqrt(MSE)
    print("against preds", MSE, RMSE)

    MSE = np.square(np.subtract(trues, np.zeros(preds.shape))).mean()
    RMSE = np.sqrt(MSE)
    print("against 0s", MSE, RMSE)

In [None]:
for data_group in tpd_dict:
    true = tpd_dict[data_group]["trues"]
    pred = tpd_dict[data_group]["preds"]
    date = tpd_dict[data_group]["dates"]

    if "stock" in args.loss:
        # pred = np.tanh(pred)
        true = (true / np.linalg.norm(true))*np.linalg.norm(pred)

    plt.figure(num=data_group, figsize=(16, 4))
    plt.title(data_group)
    plt.plot(date, true, label="GroundTruth", linestyle="", marker=".", markersize=4)
    plt.plot(date, pred, label="Prediction", linestyle="", marker=".", markersize=4)
    plt.plot(date, np.zeros(date.shape), color="red")


    plt.legend()
    plt.show()

    plt.figure(num=data_group, figsize=(16, 4))
    plt.title("Diff histogram")
    plt.hist(
        [np.abs(true), np.abs(true - pred)], bins=60, label=["Diff 0", "Diff Pred"]
    )
    plt.xlabel("Diff Value")
    plt.ylabel("Count")
    plt.legend()
    plt.show()

## Basic back-test based on buying in predicted direction if prediction is above a threshold

In [None]:
best_thresh, best_thresh_metrics, zero_thresh_metrics = get_tuned_metrics(args, tpd_dict)

In [None]:
fig, axs = plt.subplots(4,1, sharex=True, figsize=(16, 8))

for data_group in ["val", "test"]: #tpd_dict:
    true = tpd_dict[data_group]["trues"]
    pred = tpd_dict[data_group]["preds"]
    date = tpd_dict[data_group]["dates"]

    # if data_group == "train":
    #     true = true[:1000]
    #     pred = pred[:1000]
    #     date = date[:1000]


    # Filter by best_thresh. Note in log scale
    pred_f, true_f = apply_threshold_metric(pred, true, best_thresh)
    date_f = date[np.abs(pred) >= best_thresh]
    df_f = df.loc[date[np.abs(pred) >= best_thresh]]

    if "lpp" in args.loss:
        metric = LogPctProfitDirection
        metric_name = "pct_profit_dir"
    elif "tanh" in args.loss:
        metric = LogPctProfitTanhV1
        metric_name = "pct_profit_tanh"
    elif "mse" in args.loss:
        metric = LogPctProfitDirection
        metric_name = "pct_profit_dir"
    elif "mae" in args.loss:
        metric = LogPctProfitDirection
        metric_name = "pct_profit_dir"

    axs[0].plot(date_f, metric.accumulate(pred_f, true_f, short_filter=None), label=data_group)
    axs[0].set_ylabel(metric_name)
    axs[0].set_title(metric_name)
    axs[0].grid(axis = 'y')

    axs[1].plot(date_f[pred_f > 0], metric.accumulate(pred_f, true_f, short_filter="ns"))#, label=data_group)
    axs[1].set_ylabel(f"{metric_name}_nshort")
    axs[1].set_title(f"{metric_name}_nshort")
    axs[1].grid(axis = 'y')

    axs[2].plot(date_f[pred_f < 0], metric.accumulate(pred_f, true_f, short_filter="os"))#, label=data_group)
    axs[2].set_ylabel(f"{metric_name}_oshort")
    axs[2].set_title(f"{metric_name}_oshort")
    axs[2].grid(axis = 'y')

    axs[3].plot(date_f, np.exp(np.cumsum(true_f)), label="Market")
    # axs[3].set_ylabel("Market")
    axs[3].set_title("Market")
    axs[3].grid(axis = 'y')

fig.legend()
fig.suptitle("Cumulative metrics overtime")

fig.show()

## Attention Visualization

In [None]:
# args.output_attention = True

# exp = Exp(args)

# model = exp.model

# path = os.path.join(args.checkpoints, setting, "checkpoint.pth")

# print(model.load_state_dict(torch.load(path)))

# df = pd.read_csv(os.path.join(args.root_path, args.data_path))
# df[args.cols].head()






# from data_provider.data_loader import Dataset_Custom
# from torch.utils.data import DataLoader

# Data = Dataset_Custom
# timeenc = 0 if args.t_embed != "timeF" else 1
# data_group = "test"
# shuffle_data_group = False
# drop_last = True
# batch_size = 1
# data_set = Data(args, data_group=data_group)

# data_loader = DataLoader(
#     data_set,
#     batch_size=batch_size,
#     shuffle=shuffle_data_group,
#     num_workers=args.num_workers,
#     drop_last=drop_last,
# )


# idx = 0
# for i, (batch_x, batch_y, batch_x_mark, batch_y_mark, ds_index) in enumerate(
#     data_loader
# ):
#     if i != idx:
#         continue
#     batch_x = batch_x.float().to(exp.device)
#     batch_y = batch_y.float()

#     batch_x_mark = batch_x_mark.float().to(exp.device)
#     batch_y_mark = batch_y_mark.float().to(exp.device)

#     dec_inp = torch.zeros_like(batch_y[:, -args.pred_len :, :]).float()
#     dec_inp = (
#         torch.cat([batch_y[:, : args.label_len, :], dec_inp], dim=1)
#         .float()
#         .to(exp.device)
#     )

#     outputs, attn = model(batch_x, batch_x_mark, dec_inp, batch_y_mark)


# print(attn[0].shape, attn[1].shape)  # , attn[2].shape


# layers = [0, 1]
# distil = "Distil" if args.distil else "NoDistil"
# for layer in layers:
#     print("\n\n==========================")
#     print("Showing attention layer", layer)
#     print("==========================\n\n")
#     for h in range(0, args.n_heads):
#         plt.figure(figsize=[10, 8])
#         plt.title(f"Informer, {distil}, attn:{args.attn} layer:{layer} head:{h}")
#         A = attn[layer][0, h].detach().cpu().numpy()
#         ax = sns.heatmap(A, vmin=0, vmax=A.max() + 0.01)
#         plt.show()






# import pytorch_lightning as pl
# from exp.exp_timeseries import ExpTimeseries
# from data_provider.data_module import CustomDataModule

# trainer = pl.Trainer(accelerator="gpu",devices=1)#, log_dir=os.path.abspath(log_dir))

# exp = ExpTimeseries.load_from_checkpoint(
#     os.path.join(log_dir, "checkpoints/checkpoint.ckpt"), config=args
# )
# data_module = CustomDataModule(args, 0)

# # Test Model
# # t = trainer.test(exp, data_module)

# # # Predict and Save Results
# results = trainer.predict(exp, data_module)
# results