In [None]:
from pathlib import Path
import re
import datetime
import shutil
import xml.etree.ElementTree as ET

from tqdm import tqdm
import polars as pl
import plotly.graph_objects as go
from plotly.subplots import make_subplots

import stock

In [None]:
# negative dataのうち、値下がりが大きいものをピックアップ
negative_data_dir = stock.DATA_DIR / "train/neg"
negative_data_list = sorted(negative_data_dir.glob("*.png"))

In [None]:
regex = re.compile("code(\d+)_date(\d+)_rate\d+\.png")
target = []

for neg_path in tqdm(negative_data_list):
    res = regex.search(neg_path.name)
    code, date = res.group(1), datetime.datetime.strptime(res.group(2), "%Y%m%d").date()
    df = stock.kabutan.read_data_csv(code, start_date=date + datetime.timedelta(days=1), end_date=date + datetime.timedelta(days=28))
    start = df["open"][0]
    minimum = df["low"].min()
    if minimum < start * 0.8:
        target.append(neg_path)

In [None]:
dst_dir = stock.DATA_DIR / "train/super_neg"
dst_dir.mkdir(exist_ok=True)
# for src in target:
#     shutil.copy(src, dst_dir)

In [None]:
def write_image(code, date, before_days, output_dir=None, width=256, height=256):
    if output_dir is not None:
        output_path = Path(output_dir) / "code{}_date{}.jpg".format(code, date.strftime("%Y%m%d"))
        if output_path.exists():
            return output_path
    df = stock.kabutan.read_data_csv(code, end_date=date)[-before_days:]
    base = df["close"][-1]
    #fig = make_subplots(rows=1, cols=1)
    fig = make_subplots(specs=[[{"secondary_y": True,"r":-0.06}]])

    x = [i for i in range(len(df))]
    # 売買高
    fig.add_trace(
        go.Scatter(
            x=x, y=df["volume"] / df["volume"][-1], 
            name="volume", 
            line_color="rgba(0, 0, 255, 0.5)"
        ), 
        secondary_y=True
    )
    fig.add_trace(
        go.Candlestick(
            x=x,
            open=df["open"] / base,
            high=df["high"] / base,
            low=df["low"] / base,
            close=df["close"] / base,
            name="candle",
        ),
        secondary_y=False
    )
    # グラフの設定
    fig.update_layout(
        xaxis_rangeslider_visible=False,
        showlegend=False,
        xaxis=dict(visible=False),
        yaxis=dict(visible=False), 
        yaxis_range=[0.7, 1.3],
        width=width, height=height,
        margin={'l': 0, 'r': 0, 't': 0, 'b': 0},
    )
    fig.layout.yaxis2.update(showticklabels=False, range=[0, 3])
    if output_dir is not None:
        #print(output_path)
        fig.write_image(output_path, width=width, height=height, scale=1.0)
        return output_path

In [None]:
# train (negative)データの準備
regex = re.compile("code(\d+)_date(\d+)_rate\d+\.png")
train_neg_list = []
output_dir = stock.DATA_DIR / "train/20240727/neg"
for p in tqdm(dst_dir.glob("*.png")):
    res = regex.search(p.name)
    code, date = res.group(1), datetime.datetime.strptime(res.group(2), "%Y%m%d").date()
    train_neg_list.append(write_image(code, date, 30, output_dir, width=196, height=196))

In [None]:
# train(positive)データの準備
xml_path = stock.TRAIN_DATA_DIR / "annotations_pos.xml"
tree = ET.parse(xml_path)
root = tree.getroot()

target_positive = []
images = [child for child in root if child.tag == "image"]
for image in images:
    for child in image:
        if child.tag == "tag":
            if child.attrib["label"] == "proper base":
                target_positive.append(image.attrib["name"])

In [None]:
regex = re.compile("code(\d+)_date(\d+)_rate\d+\.png")
train_pos_list = []
output_dir = stock.DATA_DIR / "train/20240727/pos"
for fname in tqdm(target_positive):
    res = regex.search(fname)
    code, date = res.group(1), datetime.datetime.strptime(res.group(2), "%Y%m%d").date()
    train_pos_list.append(write_image(code, date, 30, output_dir, width=196, height=196))

In [None]:
# valid データの準備
csv_path = stock.TRAIN_DATA_DIR / "valid.csv"
valid_df = pl.read_csv(csv_path)

for i in range(len(df)):
    code, date = df["code"][i], df["date"][i]
    write_image(code, date, 30, )

In [None]:
valid_df

In [None]:
len(train_pos_list)

In [None]:
for data in train_neg_list:
    

In [None]:
schema_dir = stock.PROJECT_ROOT / "data/train/20240727/schema"

stock.dl.dataloader.ImageDataloader.Dataset(
    train = [train_pos_list, train_neg_list]
    valid=
)