In [None]:
from typing import Any
from pathlib import Path
import datetime

import matplotlib.pyplot as plt
import numpy as np
import polars as pl
import gymnasium as gym
from ray.rllib.algorithms.ppo import PPOConfig
from ray.rllib.algorithms.algorithm import Algorithm
import torch

import stock

In [None]:
def create_feat_df(code: str, date: datetime.date, window_size: int = 30) -> pl.DataFrame | None:
    # 日足データの読み込み
    daily_df = stock.kabutan.read_data_csv(code)
    # 分足データの読み込み
    date_str = date.strftime("%Y%m%d")
    minutes_csv_path = stock.PROJECT_ROOT / f"data/minutes/{date_str}" /  f"{code}_{date_str}.arrow"
    if not minutes_csv_path.exists():
        return None
    df = pl.read_ipc(minutes_csv_path)
    df = df.filter(pl.col("datetime").is_between(
        datetime.datetime(date.year, date.month, date.day, 0, 0), datetime.datetime(date.year, date.month, date.day, 2, 30), closed="both"
    ) | pl.col("datetime").is_between(
        datetime.datetime(date.year, date.month, date.day, 3, 30), datetime.datetime(date.year, date.month, date.day, 6, 0), closed="both"
    )).sort(pl.col("datetime"))

    # 前日のデータを取得、正規化
    prev_df = daily_df.filter(pl.col("date") < df["datetime"][0].date()).sort(pl.col("date"))
    if len(prev_df) == 0:
        stock.logger.debug("No previous data")
        return None

    prev_val = prev_df["close"][-1]
    mean_volume = prev_df.filter(pl.col("date") > df["datetime"][0].date() - datetime.timedelta(days=7))["volume"].mean() / 302
    feat_df = df.select(
        pl.col("datetime"),
        pl.col("open") / prev_val, 
        pl.col("high") / prev_val,
        pl.col("low") / prev_val,
        pl.col("close") / prev_val,
        pl.col("volume") / mean_volume,
    )

    # 無いデータを埋める
    dates = df["datetime"].to_list()
    date = dates[0].date()
    cur_date = datetime.datetime(date.year, date.month, date.day, 0, 0)
    add_dates = []
    while cur_date <= datetime.datetime(date.year, date.month, date.day, 6, 0):
        if cur_date not in dates :
            if cur_date <= datetime.datetime(date.year, date.month, date.day, 2, 30) or datetime.datetime(date.year, date.month, date.day, 3, 30) <= cur_date:
                #print(cur_date)
                add_dates.append(cur_date)
        cur_date += datetime.timedelta(minutes=1)
        #cur_date += datetime.timedelta(minutes=1)
    pad_df = pl.DataFrame(
        {
            "datetime": add_dates,
            "open": [None] * len(add_dates),
            "high": [None] * len(add_dates),
            "low": [None] * len(add_dates),
            "close": [None] * len(add_dates),
            "volume": np.zeros(len(add_dates)),
        }
    )
    feat_df = pl.concat([feat_df, pad_df]).sort(pl.col("datetime"))
    # 前にmargin_size (window_size - 1)分のdummyデータを追加
    margin_df = pl.DataFrame(
        {
            "datetime": [datetime.datetime(date.year, date.month, date.day)] * (window_size - 1),
            "open": np.ones(window_size - 1),
            "high": np.ones(window_size - 1),
            "low": np.ones(window_size - 1),
            "close": np.ones(window_size - 1),
            "volume": np.zeros(window_size - 1)
        }
    )
    feat_df = pl.concat([margin_df, feat_df])
    # nullを埋める
    feat_df = feat_df.with_columns(
        pl.col("close").fill_null(strategy="forward"),
        pl.when(pl.col("volume").is_null()).then(pl.lit(0)).otherwise(pl.col("volume")).alias("volume"),
    ).with_columns(
        pl.when(pl.col("open").is_null()).then(pl.col("close")).otherwise(pl.col("open")).alias("open"),
        pl.when(pl.col("high").is_null()).then(pl.col("close")).otherwise(pl.col("high")).alias("high"),
        pl.when(pl.col("low").is_null()).then(pl.col("close")).otherwise(pl.col("low")).alias("low"),
    )    
    feat_df = feat_df.select(pl.col("open"), pl.col("high"), pl.col("low"), pl.col("close"), pl.col("volume"))

    return feat_df

In [None]:
class DataLoader:
    """
    Args:
        data (pl.DataFrame): 前処理済みのデータ (特徴量計算済み、欠損値処理済み)
        window_size (int): 何日分のデータを取得するか
    """
    def __init__(self, data_csvs: list[Path], window_size: int):
        self.data_csvs = data_csvs
        self.window_size = window_size
        self.reset()

    def __len__(self) -> int:
        return len(self.data) - self.window_size + 1
    
    def __getitem__(self, idx: int) -> np.ndarray:
        if 0 <= idx < len(self.data):
            return self.data.slice(idx, self.window_size).to_numpy().flatten()
        else:
            raise IndexError
        
    def get_price(self, idx: int) -> float:
        return self.data["close"][idx + self.window_size - 1]
    
    def reset(self):
        target_index = np.random.randint(len(self.data_csvs))
        target_csv = self.data_csvs[target_index]
        code = target_csv.stem.split("_")[0]
        date = datetime.datetime.strptime(target_csv.stem.split("_")[1], "%Y%m%d").date()
        self.data = create_feat_df(code, date, self.window_size)
    
    @property
    def observation_space(self) -> gym.spaces.Space:
        return gym.spaces.Box(
            low=-np.inf,
            high=np.inf,
            shape=(int(self.window_size * len(self.data.columns)),),
            dtype=np.float32,
        )

In [None]:
class TradingEnv(gym.Env):
    """前処理済み（特徴量計算済み）のpl.DataFrameを入力として、
    step毎に状態を変化させ、報酬を返す環境クラス
    """

    def __init__(self, env_config):
        super().__init__()
        self.dataloader = env_config["dataloader"]
        self.current_index = 1
        self.total_reward = 0
        self.history = []

        self.action_space = gym.spaces.Discrete(2)
        self.observation_space = self.dataloader.observation_space
    
    def _calc_reward(self, action) -> float:
        cur = self.dataloader.get_price(self.current_index)
        pre = self.dataloader.get_price(self.current_index - 1)
        action = 1 if action == 1 else -1
        reward = (cur - pre) * action
        return reward

    def step(self, action)-> tuple[np.ndarray, float, bool, bool, dict[str, Any]]:
        if self.current_index >= len(self.dataloader):
            raise RuntimeError("Episode is done. Please reset the environment.")
        reward = self._calc_reward(action)
        obs = self.dataloader[self.current_index]
        self.current_index += 1
        is_terminated = self.current_index >= len(self.dataloader)
        is_truncated = False
        info = {
            "index": self.current_index - 1,
            "action": action,
            "reward": reward,
            "total_reward": self.total_reward,
            "is_terminated": is_terminated,
            "is_truncated": is_truncated,
        }
        
        self.total_reward += reward
        self.history.append(info)

        return obs, reward, is_terminated, is_truncated, info

    def reset(self, *, seed: int | None = None, options: dict[str, Any] | None = None) -> tuple[np.ndarray, dict[str, Any]]:
        self.current_index = 1
        self.total_reward = 0
        self.history = []
        self.dataloader.reset()
        return self.dataloader[0], {}

    def render(self):
        if len(self.history) == 0:
            return
        
        df = pl.from_dicts(self.history)
        df = df.with_columns(
            pl.col("index").map_elements(lambda s: self.dataloader.get_price(s), return_dtype=pl.Float64).alias("price"),
        )
        
        sellbuy_df = df.filter((pl.col("action") != pl.col("action").shift(1)).fill_null(True))
        sell_df = sellbuy_df.filter(pl.col("action") == 0)
        buy_df = sellbuy_df.filter(pl.col("action") == 1)
        
        fig, axes = plt.subplots(1, 1, figsize=(10, 5))
        axes.plot(df["index"], df["price"], label="price")
        axes.scatter(sell_df["index"], sell_df["price"], marker="^", color="red", label="sell")
        axes.scatter(buy_df["index"], buy_df["price"], marker="v", color="green", label="buy")

    def close(self):
        pass

In [None]:
window_size = 25
csv_list = list((stock.PROJECT_ROOT / "data/minutes_yf/").rglob("1570*.arrow"))
# dataloader = DataLoader(csv_list, window_size=window_size)
# len(dataloader)

In [None]:
len(csv_list)

In [None]:
config = (
    PPOConfig()
    .api_stack(
        enable_rl_module_and_learner=True,
        enable_env_runner_and_connector_v2=True,    
    ).learners(
        num_learners=0,
        num_gpus_per_learner=0,
    ).training(
        gamma=0,
        lr_schedule=[
            [0, 1e-1],
            [int(1e2), 1e-2],
            [int(1e3), 1e-3],
            [int(1e4), 1e-4],
            [int(1e5), 1e-5],
            [int(1e6), 1e-6],
            [int(1e7), 1e-7]
        ],
        lr=8e-6,
        model={"uses_new_env_runners": True},
        lambda_=0.72,
        vf_loss_coeff=0.5,
        entropy_coeff=0.01,
    ).environment(
        clip_rewards=True,
        env=TradingEnv,
        env_config={"dataloader": dataloader},
    )
)
algo = config.build()

In [None]:
# import ray
# from ray.rllib.algorithms import ppo
#algo = ppo.PPO(env=TradingEnv, config={"env_config": {"dataloader": dataloader}})

for i in range(30):
    result = algo.train()
    # if result["env_runners"]["episode_return_mean"] > 0:
    #     break
    print("Episode reward mean: ", result["env_runners"]["episode_return_mean"])

checkpoint_dir = algo.save_to_path()    

In [None]:
algo = Algorithm.from_checkpoint(checkpoint_dir)
rl_module = algo.get_module()
print(rl_module.input_specs_inference())

env = TradingEnv({"dataloader": dataloader})
obs, info = env.reset()
while True:
    input = torch.from_numpy(np.array([obs]))
    action_logits = rl_module.forward_inference({"obs": input})["action_dist_inputs"]
    action = torch.argmax(action_logits[0]).numpy()
    obs, reward, is_terminated, is_truncated, _ = env.step(action)
    if is_terminated:
        break

In [None]:
env.total_reward

In [None]:
env.render()