# Stockformer Demo

In [None]:
import sys

# if not 'Informer2020' in sys.path:
#     sys.path += ['Informer2020']

## Open log_dir

In [None]:
from utils.tools import dotdict
import torch
import numpy as np
import pandas as pd
import os
import matplotlib.pyplot as plt
import seaborn as sns
from utils.ipynb_helpers import (
    args_from_setting,
    setting_from_args,
    handle_gpu,
    read_data,
)
import yaml

log_dir = "lightning_logs/stockformer_custom_ftMS_sl16_ll0_pl1_ei9_diNone_co1_iTrue_dm128_nh8_el12_dlNone_df2048_atprob_fc5_ebNone_dtFalse_mxFalse_full_1h_0/version_3"
with open(os.path.join(log_dir, "hparams.yaml"), "r") as file:
    args = dotdict(yaml.load(file, Loader=yaml.FullLoader))
args

## Visualization

In [None]:
# When we finished exp.train(setting) and exp.test(setting), we will get a trained model and the results of test experiment
# The results of test experiment will be saved in ./results/{setting}/pred.npy (prediction of test dataset) and ./results/{setting}/true.npy (groundtruth of test dataset)

tp_dict = {}
for flag in ["train", "val", "test"]:
    device = 0
    while True: # Device Loop
        preds_path = os.path.join(log_dir, f"results/pred_{flag}_{device}.npy")
        trues_path = os.path.join(log_dir, f"results/true_{flag}_{device}.npy")
        dates_path = os.path.join(log_dir, f"results/date_{flag}_{device}.npy")
        if (
            os.path.exists(preds_path)
            and os.path.exists(trues_path)
            and os.path.exists(dates_path)
        ):
            dp = [np.load(trues_path), np.load(preds_path), np.load(dates_path)]
            tp_dict[flag] = dp if flag not in tp_dict else [np.append(tpdfi, dpi,axis=0) for tpdfi, dpi in zip(tp_dict[flag], dp)]
            s = np.argsort(tp_dict[flag][2], axis=None)
            tp_dict[flag] = list(map(lambda x: x[s], tp_dict[flag]))
        else:
            # Done searching for devices
            break
        device+=1


print("Open true/pred data for:", list(tp_dict.keys()))

# [samples, pred_len, dimensions]
print(
    tp_dict["train"][0].shape, tp_dict["val"][0].shape, tp_dict["test"][0].shape, "\n\n"
)

for flag in tp_dict:
    trues, preds, dates = tp_dict[flag]
    print(
        f"{flag}\ttrues.shape: {trues.shape}, preds.shape: {preds.shape}, dates.shape: {preds.shape}"
    )

    MSE = np.square(np.subtract(trues, preds)).mean()
    RMSE = np.sqrt(MSE)
    print("against preds", MSE, RMSE)

    MSE = np.square(np.subtract(trues, np.zeros(preds.shape))).mean()
    RMSE = np.sqrt(MSE)
    print("against 0s", MSE, RMSE)

In [None]:
# draw OT prediction
for flag in tp_dict:
    trues, preds, dates = tp_dict[flag]
    true = trues[:, 0, 0]
    pred = preds[:, 0, 0]
    date = dates[:, 0]
    plt.figure(num=flag, figsize=(16, 4))
    plt.title(flag)
    plt.plot(date, true, label="GroundTruth", linestyle="", marker=".", markersize=4)
    plt.plot(date, pred, label="Prediction", linestyle="", marker=".", markersize=4)
    plt.plot(date, np.zeros(date.shape), color="red")
    # plt.scatter(range(trues.shape[0]), trues[:,0,0], marker='v', color='r', label='GroundTruth')
    # plt.scatter(range(trues.shape[0]), preds[:,0,0], marker='^', color='m', label='Prediction')

    plt.legend()
    plt.show()

    plt.figure(num=flag, figsize=(16, 4))
    plt.title("Diff histogram")
    # plt.hist(np.abs(true), bins=len(true)//6, label='Diff 0', alpha=0.5)
    # plt.hist(np.abs(true - pred), bins=len(true)//6, label='Diff Pred', alpha=0.5)
    plt.hist(
        [np.abs(true), np.abs(true - pred)], bins=60, label=["Diff 0", "Diff Pred"]
    )
    plt.xlabel("Diff Value")
    plt.ylabel("Count")
    plt.legend()
    plt.show()

    # df = pd.concat([pd.DataFrame(a, columns=[f"{i}"]) for i, a in enumerate([np.abs(true - pred), np.abs(true)])], axis=1)

    # # plot the data
    # df.plot.hist(stacked=True, bins=len(true), density=True, figsize=(10, 6), grid=True)

## Basic back-test based on buying in predicted direction if prediction is above a threshold

In [None]:
max_tracker = (0, 0)

# Tracks results
tracker = {}

df = read_data(os.path.join(args.root_path, args.data_path))

# Get the percentile to check thresh until
percentile = [50, 0.0]
for flag in ["train"]:  # tp_dict:
    _, preds, _ = tp_dict[flag]
    percentile[1] += np.percentile(
        np.abs(preds), percentile[0]
    )  # np.median(np.abs(preds))
percentile[1] /= len(tp_dict)
print(f"{percentile[0]}'th percentile: {percentile[1]}")

ticker, field = args.target.split("_")
assert field == "pctchange"

for thresh in [.0002]:#np.linspace(0, 0.00025, 501):
    # print("thresh:", thresh)
    tracker[thresh] = {}
    track = {}
    for flag in tp_dict:
        trues, preds, dates = tp_dict[flag]
        # trues, preds = np.exp(trues), np.exp(preds)
        true = trues[:, 0, 0].copy()
        pred = preds[:, 0, 0].copy()
        date = pd.DatetimeIndex(dates[:, 0], tz="UTC")

        df_flag = df.loc[date][np.abs(pred) >= thresh]

        # Filter by thresh. Note in log scale
        true_c_log = true[np.abs(pred) >= thresh]
        pred_c_log = pred[np.abs(pred) >= thresh]

        # Percent direction correct, ie up or down
        pct_dir_correct = np.sum(np.sign(true_c_log) == np.sign(pred_c_log)) / len(
            true_c_log
        )

        true_c, pred_c = np.exp(true_c_log), np.exp(pred_c_log)

        # # Turn pct_change to price change
        # true_price_change = df_flag[ticker]["open"] * (true_c-1)
        # pred_price_change = df_flag[ticker]["open"] * (pred_c-1)
        # # Profit if you always bought one share with shorting
        # p_one_share_wshort = (true_price_change * np.sign(pred_price_change)).sum()
        # # Profit if you always bought one share without shorting
        # p_one_share = (true_price_change * np.sign(pred_price_change))[pred_price_change > 0].sum()

        # Important: Percent profit with & without shorting
        # pct_profit_wshort = ((true_c-1) * np.sign(pred_c-1) + 1).prod()
        pct_profit_wshort = np.exp((true_c_log * np.sign(pred_c_log)).sum())
        # pct_profit = ((true_c-1) * np.sign(pred_c-1) + 1)[pred_c > 1].prod()
        pct_profit = np.exp((true_c_log * np.sign(pred_c_log))[pred_c_log > 0].sum())

        # Important: percent profit with & without shorting with partial purchase
        pct_profit_tanh_wshort = np.exp((true_c_log * np.tanh(1000 * pred_c_log)).sum())
        pct_profit_tanh = np.exp(
            (true_c_log * np.tanh(1000 * pred_c_log))[pred_c_log > 0].sum()
        )

        # Optimal percent profit without shorting
        # pct_profit_opt = ((true_c-1) * np.sign(true_c-1) + 1)[true_c > 1].prod()
        pct_profit_opt = np.exp(
            (true_c_log * np.sign(true_c_log))[true_c_log > 0].sum()
        )

        # Tune threshhold based off of train's metric we care about
        tune_metric = pct_profit_tanh if args.loss == "stock_tanh" else pct_profit
        if tune_metric > max_tracker[0] and flag == "train":
            max_tracker = (tune_metric, thresh)

        # Save
        tracker[thresh][flag] = {
            "pct_profit": pct_profit,
            "pct_profit_wshort": pct_profit_wshort,
            # "p_one_share": p_one_share, "p_one_share_wshort": p_one_share_wshort,
            "pct_profit_tanh": pct_profit_tanh,
            "pct_profit_tanh_wshort": pct_profit_tanh_wshort,
            "pct_excluded": (len(pred) - len(pred_c_log[pred_c_log > 0])) / len(pred),
            "pct_excluded_wshort": (len(pred) - len(pred_c_log)) / len(pred),
            "pct_dir_correct": pct_dir_correct,
            "pct_profit_opt": pct_profit_opt,
        }


best_thresh = max_tracker[1]
print("best thresh:", best_thresh)
for k in tracker[best_thresh]:
    print(f"{k}\t", tracker[best_thresh][k])


In [None]:
fig, axs = plt.subplots(2,1, sharex=True, figsize=(16, 8))
best_thresh=.0002
for flag in tp_dict:
    trues, preds, dates = tp_dict[flag]
    true = trues[:, 0, 0].copy()
    pred = preds[:, 0, 0].copy()
    date = pd.DatetimeIndex(dates[:, 0], tz="UTC")

    # Filter by best_thresh. Note in log scale
    true_c_log = true[np.abs(pred) >= best_thresh]
    pred_c_log = pred[np.abs(pred) >= best_thresh]
    date_c = date[np.abs(pred) >= best_thresh]


    pct_profit_wshort = np.exp((true_c_log * np.sign(pred_c_log)).sum())
    axs[0].plot(date_c, np.exp(np.cumsum((true_c_log * np.sign(pred_c_log)))), label=flag)
    axs[0].set_ylabel("pct_profit_wshort")
    axs[0].set_title("pct_profit_wshort")
    axs[0].grid(axis = 'y')

    pct_profit = np.exp((true_c_log * np.sign(pred_c_log))[pred_c_log > 0].sum())
    axs[1].plot(date_c[pred_c_log > 0], np.exp(np.cumsum((true_c_log * np.sign(pred_c_log))[pred_c_log > 0])))#, label=flag)
    axs[1].set_ylabel("pct_profit")
    axs[1].set_title("pct_profit")
    axs[1].grid(axis = 'y')

fig.legend()
fig.suptitle("Cumulative metrics overtime")

fig.show()

## Attention Visualization

In [None]:
# args.output_attention = True

# exp = Exp(args)

# model = exp.model

# path = os.path.join(args.checkpoints, setting, "checkpoint.pth")

# print(model.load_state_dict(torch.load(path)))

# df = pd.read_csv(os.path.join(args.root_path, args.data_path))
# df[args.cols].head()

In [None]:
# from data_provider.data_loader import Dataset_Custom
# from torch.utils.data import DataLoader

# Data = Dataset_Custom
# timeenc = 0 if args.embed != "timeF" else 1
# flag = "test"
# shuffle_flag = False
# drop_last = True
# batch_size = 1
# data_set = Data(args, flag=flag)

# data_loader = DataLoader(
#     data_set,
#     batch_size=batch_size,
#     shuffle=shuffle_flag,
#     num_workers=args.num_workers,
#     drop_last=drop_last,
# )


# idx = 0
# for i, (batch_x, batch_y, batch_x_mark, batch_y_mark, ds_index) in enumerate(
#     data_loader
# ):
#     if i != idx:
#         continue
#     batch_x = batch_x.float().to(exp.device)
#     batch_y = batch_y.float()

#     batch_x_mark = batch_x_mark.float().to(exp.device)
#     batch_y_mark = batch_y_mark.float().to(exp.device)

#     dec_inp = torch.zeros_like(batch_y[:, -args.pred_len :, :]).float()
#     dec_inp = (
#         torch.cat([batch_y[:, : args.label_len, :], dec_inp], dim=1)
#         .float()
#         .to(exp.device)
#     )

#     outputs, attn = model(batch_x, batch_x_mark, dec_inp, batch_y_mark)


# print(attn[0].shape, attn[1].shape)  # , attn[2].shape


# layers = [0, 1]
# distil = "Distil" if args.distil else "NoDistil"
# for layer in layers:
#     print("\n\n==========================")
#     print("Showing attention layer", layer)
#     print("==========================\n\n")
#     for h in range(0, args.n_heads):
#         plt.figure(figsize=[10, 8])
#         plt.title(f"Informer, {distil}, attn:{args.attn} layer:{layer} head:{h}")
#         A = attn[layer][0, h].detach().cpu().numpy()
#         ax = sns.heatmap(A, vmin=0, vmax=A.max() + 0.01)
#         plt.show()