# 上昇するかどうかを予測するAIを学習する

In [None]:
from datetime import datetime

from tqdm import tqdm
import matplotlib.pyplot as plt
import numpy as np

import stock

In [None]:
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
params = stock.dl.train.TrainerParams(
    output_dir=f"./tmp/{timestamp}"
)
trainer = stock.dl.train.Trainer(params)

In [None]:
trainer.train()

In [None]:
trues, preds = trainer.test()

In [None]:
idx = 20
ts = trues[:, idx]
ps = preds[:, idx]
print(np.corrcoef(ts, ps)[0, 1])
plt.scatter(ts, ps)
plt.axhline(0)
plt.axvline(0)

In [None]:
trues_arr = trues.numpy()
preds_arr = preds.numpy()
num_codes = trues_arr.shape[1]

In [None]:
profits = []
for i in range(num_codes):
    ts = trues_arr[:, i]
    ps = preds_arr[:, i]
    ps_thr = ps > 0
    profit = ts[ps_thr].sum() - ts[~ps_thr].sum()
    profits.append(profit)

In [None]:
np.mean(profits)

In [None]:
dataset = trainer.dataset
us_data = dataset.data[:, dataset._us_data_indices]
jp_data = dataset.data[:, dataset._jp_data_indices]
n_us = us_data.shape[1]
n_jp = jp_data.shape[1]

In [None]:
len(dataset.us_symbols), len(dataset.jp_symbols)

In [None]:
changes = np.abs(us_data).mean(axis=1)
us_data1 = us_data[changes > 1.25]
jp_data1 = jp_data[changes > 1.25]

In [None]:
us_data1.shape, us_data.shape, jp_data1.shape, jp_data.shape

In [None]:
def calc_corres(data1, data2):
    n_d1 = data1.shape[1]
    n_d2 = data2.shape[1]
    corres = np.zeros((n_d1, n_d2))

    for iu in tqdm(range(n_us)):
        for ij in range(n_jp):
            u = data1[:, iu]
            j = data2[:, ij]
            c = np.corrcoef(u, j)[0, 1]
            corres[iu, ij] = c
    return corres

In [None]:
corres = calc_corres(us_data1, jp_data1)

In [None]:
us_data1.shape, jp_data1.shape

In [None]:
corres.argmax()

In [None]:
corres.max(), corres.argmax(axis=0)

In [None]:
us_idx = 132
jp_idx = 13

profit = 1.0
cnt = 0
num_day = us_data1.shape[0]
for i in range(num_day):
    if us_data1[i, us_idx] > 0:
        profit *= (1.0 + jp_data1[i, jp_idx] / 100)
    else:
        profit *= (1.0 - jp_data1[i, jp_idx] / 100)

    us = us_data1[i, us_idx] > 0 
    jp = jp_data1[i, jp_idx] > 0
    if us != jp:
        cnt += 1

print(f"{profit}, {cnt} / {num_day}")

In [None]:
plt
plt.scatter(us_data1[:, us_idx], jp_data1[:, jp_idx])