In [None]:
from pathlib import Path
import re
import datetime
import shutil
import xml.etree.ElementTree as ET
import random

import cv2
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
import polars as pl
import plotly.graph_objects as go
from plotly.subplots import make_subplots

import stock

In [None]:
# negative dataのうち、値下がりが大きいものをピックアップ
negative_data_dir = stock.DATA_DIR / "train/neg"
negative_data_list = sorted(negative_data_dir.glob("*.png"))

In [None]:
regex = re.compile("code(\d+)_date(\d+)_rate\d+\.png")
target = []

for neg_path in tqdm(negative_data_list):
    res = regex.search(neg_path.name)
    code, date = res.group(1), datetime.datetime.strptime(res.group(2), "%Y%m%d").date()
    df = stock.kabutan.read_data_csv(code, start_date=date + datetime.timedelta(days=1), end_date=date + datetime.timedelta(days=28))
    start = df["open"][0]
    minimum = df["low"].min()
    if minimum < start * 0.8:
        target.append(neg_path)

In [None]:
dst_dir = stock.DATA_DIR / "train/super_neg"
dst_dir.mkdir(exist_ok=True)
# for src in target:
#     shutil.copy(src, dst_dir)

In [None]:
def write_image(code, date, before_days, output_dir=None, width=256, height=256):
    if output_dir is not None:
        output_path = Path(output_dir) / "code{}_date{}.jpg".format(code, date.strftime("%Y%m%d"))
        if output_path.exists():
            return output_path
    df = stock.kabutan.read_data_csv(code, end_date=date)[-before_days:]
    base = df["close"][-1]
    #fig = make_subplots(rows=1, cols=1)
    fig = make_subplots(specs=[[{"secondary_y": True,"r":-0.06}]])

    x = [i for i in range(len(df))]
    # 売買高
    fig.add_trace(
        go.Scatter(
            x=x, y=df["volume"] / df["volume"][-1], 
            name="volume", 
            line_color="rgba(0, 0, 255, 0.5)"
        ), 
        secondary_y=True
    )
    fig.add_trace(
        go.Candlestick(
            x=x,
            open=df["open"] / base,
            high=df["high"] / base,
            low=df["low"] / base,
            close=df["close"] / base,
            name="candle",
        ),
        secondary_y=False
    )
    # グラフの設定
    fig.update_layout(
        xaxis_rangeslider_visible=False,
        showlegend=False,
        xaxis=dict(visible=False),
        yaxis=dict(visible=False), 
        yaxis_range=[0.7, 1.3],
        width=width, height=height,
        margin={'l': 0, 'r': 0, 't': 0, 'b': 0},
    )
    fig.layout.yaxis2.update(showticklabels=False, range=[0, 3])
    if output_dir is not None:
        #print(output_path)
        fig.write_image(output_path, width=width, height=height, scale=1.0)
        return output_path

In [None]:
# train (negative)データの準備
regex = re.compile("code(\d+)_date(\d+)_rate\d+\.png")
train_neg_list = []
output_dir = stock.DATA_DIR / "train/20240727/neg"
for p in tqdm(dst_dir.glob("*.png")):
    res = regex.search(p.name)
    code, date = res.group(1), datetime.datetime.strptime(res.group(2), "%Y%m%d").date()
    train_neg_list.append([code, date])
    #train_neg_list.append(write_image(code, date, 30, output_dir, width=196, height=196))

In [None]:
# train(positive)データの準備
xml_path = stock.TRAIN_DATA_DIR / "annotations_pos.xml"
tree = ET.parse(xml_path)
root = tree.getroot()

target_positive = []
images = [child for child in root if child.tag == "image"]
for image in images:
    for child in image:
        if child.tag == "tag":
            if child.attrib["label"] == "proper base":
                target_positive.append(image.attrib["name"])

In [None]:
regex = re.compile("code(\d+)_date(\d+)_rate\d+\.png")
train_pos_list = []
output_dir = stock.DATA_DIR / "train/20240727/pos"
for fname in tqdm(target_positive):
    res = regex.search(fname)
    code, date = res.group(1), datetime.datetime.strptime(res.group(2), "%Y%m%d").date()
    train_pos_list.append([code, date])
    #train_pos_list.append(write_image(code, date, 30, output_dir, width=196, height=196))

In [None]:
def create_schema(code, date, image_dir, schema_dir, label):
    stem = "code{}_date{}".format(code, date.strftime("%Y%m%d"))
    image_path = image_dir / (stem + ".jpg")
    assert image_path.exists()
    schema = stock.dl.dataloader.ImageDataloader.DataSchema(image_path=image_path, label=label)
    schema_path = schema_dir / (stem + ".json")
    with open(schema_path, "w" , encoding="utf-8") as f:
        f.write(schema.model_dump_json(indent=4))
    return schema_path

In [None]:
image_dir = stock.PROJECT_ROOT / "data/train/20240727/image"
schema_dir = stock.PROJECT_ROOT / "data/train/20240727/schema"
for code, date in train_neg_list:
    create_schema(code, date, image_dir, schema_dir, 0)
for code, date in train_pos_list:
    create_schema(code, date, image_dir, schema_dir, 1)

In [None]:
# valid データの準備
csv_path = stock.TRAIN_DATA_DIR / "valid.csv"
valid_df = pl.read_csv(csv_path)
max_hold_days = 10
valid_schema_paths = []

for i in range(len(valid_df)):
    code, date = valid_df["code"][i], datetime.datetime.strptime(valid_df["date"][i], "%Y-%m-%d").date()
    write_image(code, date, 30, image_dir, width=196, height=196)
    df = stock.kabutan.read_data_csv(code, start_date=date, end_date=date + datetime.timedelta(days=28))
    df = df.with_columns(
        (pl.col("close").rolling_max(window_size=max_hold_days).shift(-max_hold_days) / pl.col("open").shift(-1)).alias("growing_rate")
    )
    if df["growing_rate"][0] is None:
        continue
    if df["growing_rate"][0] > 1.4:
        label = 1
        valid_schema_paths.append(create_schema(code ,date, image_dir, schema_dir, label))
    elif df["growing_rate"][0] < 1.1:
        label = 0
        valid_schema_paths.append(create_schema(code ,date, image_dir, schema_dir, label))
        

In [None]:
len(train_pos_list), len(train_neg_list), len(valid_schema_paths)

In [None]:
train_pos_schemas = [schema_dir / "code{}_date{}.json".format(code, date.strftime("%Y%m%d")) for code, date in train_pos_list]
train_neg_schemas = [schema_dir / "code{}_date{}.json".format(code, date.strftime("%Y%m%d")) for code, date in train_neg_list]

In [None]:
schema_dir = stock.PROJECT_ROOT / "data/train/20240727/schema"

dataset = stock.dl.dataloader.ImageDataloader.Dataset(
    train = [train_pos_schemas, train_neg_schemas],
    valid=[valid_schema_paths]
)

In [None]:
dataset_path = stock.TRAIN_DATA_DIR / "20240727/dataset.json"
with open(dataset_path, "w", encoding="utf-8") as f:
    f.write(dataset.model_dump_json(indent=4))

In [None]:
dataset_path = stock.TRAIN_DATA_DIR / "20240727/dataset.json"
params = stock.dl.dataloader.ImageDataloader.Params(
    batch_size=32, 
    dataset_json_path=dataset_path,
    ratio_per_group=[1, 1,],
    num_class=2
)

dataloader = stock.dl.dataloader.ImageDataloader(params, is_train=True)

In [None]:
def get_next(self):
    """ """
    ratio_per_group = [
        rate / sum(self.params.ratio_per_group) for rate in self.params.ratio_per_group
    ]
    base_sample_per_group = [self.params.batch_size * rate for rate in ratio_per_group]
    sample_per_group = [int(n) for n in base_sample_per_group]
    residual = self.params.batch_size - sum(sample_per_group)
    res_rate = [f - i for i, f in zip(sample_per_group, base_sample_per_group)]
    if sum(res_rate) > 0.5:
        res_rate = [r / sum(res_rate) for r in res_rate]
        indices = np.random.choice([i for i in range(len(res_rate))], size=residual, p=res_rate)
        for i in indices:
            sample_per_group[i] += 1

    sample = (
        [
            s
            for num_sample, group in zip(sample_per_group, self.data_schema)
            for s in random.sample(group, num_sample)
        ]
    )
    image = (np.stack([cv2.imread(s.image_path) for s in sample]) / 255.0).astype(np.float32)
    mat = np.identity(self.params.num_class)
    label = np.stack([mat[s.label] for s in sample]).astype(np.float32)
    return {"input": image, "y_true": label}


In [None]:
import random
random.sample(dataloader.data_schema[0], 16)
len(
        [
            s
            for num_sample, group in zip([16, 16], dataloader.data_schema)
            for s in random.sample(group, num_sample)
        ], 
    )

In [None]:
res  = get_next(dataloader)

In [None]:
res["input"].shape

In [None]:
plt.imshow(res["input"][0])