In [None]:
from pathlib import Path
import datetime 
import random
import re

from tqdm import tqdm
import numpy as np
import polars as pl
import tensorflow as tf
import matplotlib.pyplot as plt
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import pandas as pd

import stock

train_data_dir = stock.PROJECT_ROOT / "data" / "train"
output_file_path = train_data_dir / "{}.npz".format(datetime.date.today().strftime("%Y%m%d"))    

In [None]:
# eps、純利益から時価総額を計算する
def calc_estimated_capitalization(code, current_date=datetime.date.today()):
    fdf = stock.kabutan.read_financial_csv(code).filter(
        (pl.col("duration") == 3) & (pl.col("eps").abs() > 1e-5)
    ).sort(pl.col("annoounce_date"))
    df = stock.kabutan.read_data_csv(code, end_date=current_date).sort(pl.col("date"))

    if len(fdf) == 0:
        return -1
    num_stock = fdf["net_income"][-1] * 1000000 / fdf["eps"][-1]
    est_capit = num_stock * df["close"][-1]
    return est_capit

In [None]:
# まずは学習データ準備
target_data_dict = {}
stacked = []
codes = stock.kabutan.get_code_list()
max_hold_days = 10

for code in codes:
    capt = calc_estimated_capitalization(code)
    if capt > 100000000000: # 時価総額1000億円以上の場合はスキップ
        continue
    
    df = stock.trend_template.calc_for_watch_list(code)
    df = df.with_columns(
        (pl.col("close").rolling_max(window_size=max_hold_days).shift(-max_hold_days) / pl.col("open").shift(-1)).alias("growing_rate")
    )
    # df = df.with_columns(
    #     ((pl.col("growing_rate") - 1.0)* 100).log().alias("log_growing_rate")
    # )
    # target_data_dict[code] = df
    stacked.append(df.filter(pl.col("watch_list")).with_columns(pl.lit(code).alias("code")))

stacked_df = pl.concat(stacked)

In [None]:
# trainとvalidの分割日を決定する
dates = stacked_df.sort(pl.col("date"))["date"]
# この日付までをtrain、これより先をvalidationとする
split_date = dates[int(len(dates) * 0.8)]

train_df = stacked_df.filter(pl.col("date") <= split_date)
valid_df = stacked_df.filter(pl.col("date") > split_date)
print("Split date = {}, num train = {}, num_valid = {}".format(split_date, len(train_df), len(valid_df)))

In [None]:
output_csv = stock.TRAIN_DATA_DIR / "valid.csv"
valid_df.select(pl.col("date"), pl.col("code")).write_csv(output_csv)

In [None]:
def write_image(code, date, df, before_days):
    output_dirname = ""
    if df["growing_rate"][before_days - 1] > 1.2:
        output_dirname = "pos"
    elif df["growing_rate"][before_days - 1] > 1.1:
        output_dirname = "mid"
    else:
        output_dirname = "neg"
    output_path = Path("./tmp/{}/code{}_date{}_rate{:03d}.png".format(
        output_dirname, code, date.strftime("%Y%m%d"), int((df["growing_rate"][before_days - 1]) * 100)
    ))
    output_path.parent.mkdir(exist_ok=True)
    if output_path.exists():
        return

    base = df["close"][before_days - 1]
    fig = make_subplots(
        rows=2, cols=1, shared_xaxes=True, vertical_spacing=0.0, row_heights=[0.7, 0.3]
    )
    fig.add_trace(
        go.Candlestick(
            x=df["date"],
            open=df["open"] / base,
            high=df["high"] / base,
            low=df["low"] / base,
            close=df["close"] / base,
            name="candle",
        ),
        row=1,
        col=1,
    )
    # 売り買いポイント
    fig.add_trace(
        go.Scatter(
            x=df[before_days]["date"],
            y=df[before_days]["open"] / base,
            mode="markers",
            name="buy",
            marker=dict(size=10, color="blue"),
        ),
        row=1,
        col=1,
    )
    # 売買高
    fig.add_trace(go.Bar(x=df["date"], y=df["volume"], name="volume"), row=2, col=1)
    # グラフの設定
    fig.update_layout(
        xaxis_rangeslider_visible=False,
        xaxis2_rangeslider_visible=False,
        margin=go.layout.Margin(l=5, r=5, t=5, b=5, autoexpand=True),
    )
    fig.update_layout(hovermode="x unified")
    fig.update_layout(yaxis_range=[0.8, 1.6])
    fig.update_traces(xaxis="x2")
    fig.update_xaxes(rangebreaks=[dict(bounds=["sat", "mon"])])  # 土日を除外

    fig.write_image(output_path)kkk
    

In [None]:
# 学習データは直前x日分のcloseとvolumeにする
def get_data_list(df, before_days = 30, after_days = 15):
    outputs = []
    for i in tqdm(range(len(df))):
        code = df["code"][i]
        date = df["date"][i]

        target_df = stock.trend_template.calc_for_watch_list(code, start_date=date - datetime.timedelta(before_days * 2), end_date=date + datetime.timedelta(after_days * 2))
        target_df = target_df.with_columns(
            (pl.col("close").rolling_max(window_size=max_hold_days).shift(-max_hold_days) / pl.col("open").shift(-1)).alias("growing_rate")
        )
        before_df = target_df.filter(pl.col("date") <= date)
        after_df = target_df.filter(pl.col("date") > date)
        if len(before_df) < before_days or len(after_df) < after_days:
            continue
            
        out_df = before_df[-before_days:].vstack(after_df[:after_days])
        outputs.append([code, date, out_df])
        write_image(code, date, out_df, before_days)
    return outputs

In [None]:
before_days = 30
after_days = 15
outputs = get_data_list(train_df, before_days=before_days, after_days=after_days)

In [None]:
good_df = train_df.filter(pl.col("growing_rate") > 1.4)
outputs = get_data_list(good_df, before_days=before_days, after_days=after_days)

In [None]:
neg_data_dir = stock.TRAIN_DATA_DIR / "neg"
file_list = [file for file in neg_data_dir.glob("*.png")]
    

In [None]:
len(file_list)
target = []
for file in file_list:
    if int(file.stem.split("rate")[1]) < 100:
        target.append(file)

In [None]:
import random
import shutil

selected = random.sample(target, 1000)
output_dir = stock.TRAIN_DATA_DIR / "neg_selected"
output_dir.mkdir(exist_ok=True)

for d in selected:
    shutil.copy(d, output_dir)

In [None]:
import xml.etree.ElementTree as ET



In [None]:
poss = []
negs = []

xml_paths = [
    stock.TRAIN_DATA_DIR / "annotations_pos.xml",
    stock.TRAIN_DATA_DIR / "annotations_neg.xml"
]

for xml_path in xml_paths:
    tree = ET.parse(xml_path)
    root = tree.getroot()

    images = [child for child in root if child.tag == "image"]
    for image in images:
        for child in image:
            if child.tag == "tag":
                if child.attrib["label"] == "invalid base":
                    negs.append(image.attrib["name"])
                elif child.attrib["label"] == "proper base":
                    poss.append(image.attrib["name"])


In [None]:
len(poss), len(negs)

In [None]:
def get_data(data_list):
    regex = re.compile("code(\d+)_date(\d+)_rate\d+")
    datas = []
    for data in data_list:
        res = regex.search(data)
        if res is None:
            print("Regex not found : {}".format(data))
        code, date = res.group(1), datetime.datetime.strptime(res.group(2), "%Y%m%d").date()

        df = stock.kabutan.read_data_csv(code, end_date=date)
        if len(df) < 30:
            print("Insufficient data length : {}, {}", data, len(df))
            continue

        data = df[-30:].select(pl.col("open"), pl.col("high"), pl.col("low"), pl.col("close"), pl.col("volume")).to_numpy()
        data[:, :4] /= data[-1, 3]
        data[:, 4] /= data[-1, 4]
        
        datas.append(data.reshape(-1))
    return datas

In [None]:
pos_data = np.array(get_data(poss))
neg_data = np.array(get_data(negs))

In [None]:
train_true = np.array([int(re.search("rate(\d+).png", d).group(1)) for d in poss] + [int(re.search("rate(\d+).png", d).group(1)) for d in negs])
#len(pos_data), len(neg_data)
train_input = np.concatenate([pos_data, neg_data])
np.savez(output_file_path, train_input, train_true, train_input, train_true)