In [None]:
from enum import Enum
import random
from typing import Any
from pathlib import Path
import datetime
import math

from pydantic import BaseModel
import matplotlib.pyplot as plt
import numpy as np
import polars as pl
import gymnasium as gym
from ray.rllib.algorithms.ppo import PPOConfig
from ray.rllib.algorithms.algorithm import Algorithm
import torch

import stock

In [None]:
class DataLoader:
    """
    Args:
        data (pl.DataFrame): 前処理済みのデータ (特徴量計算済み、欠損値処理済み)
        window_size (int): 何日分のデータを取得するか
    """
    def __init__(self, code: str, start_date: datetime.datetime, end_date: datetime.datetime, episode_length: int, window_size: int, 
                 target_columns: list[str] = ["open", "high", "low", "close", "volume"]):
        self.code = code
        self.start_date = start_date
        self.end_date = end_date
        self.episode_length = episode_length
        self.window_size = window_size
        self.target_columns = target_columns
        self.reset()

    def __len__(self) -> int:
        return len(self.data) - self.window_size + 1
    
    def __getitem__(self, idx: int) -> np.ndarray:
        if 0 <= idx < len(self.data) - self.window_size + 1:
            return self.data.slice(idx, self.window_size).select(*[pl.col(colname) for colname in self.target_columns]).to_numpy().flatten()
        else:
            raise IndexError
        
    def get_price(self, idx: int) -> float:
        if 0 <= idx < len(self.data) - self.window_size + 1:
            return self.data["close"][idx + self.window_size - 1]
        else:
            raise IndexError
    
    def get_ohlcv(self, idx: int) -> np.ndarray:
        if 0 <= idx < len(self.data) - self.window_size + 1:
            return self.data.select(pl.col("open"), pl.col("high"), pl.col("low"), pl.col("close"), pl.col("volume"))[idx + self.window_size - 1].to_numpy().flatten()
        else:
            raise IndexError
    
    def reset(self):
        max_offset = int((self.end_date - self.start_date) / datetime.timedelta(minutes=1)) - self.window_size - self.episode_length
        offset = random.randint(0, max_offset)
        start_date = self.start_date + datetime.timedelta(minutes=offset)
        end_date = start_date + datetime.timedelta(minutes=self.episode_length + self.window_size - 1)
        self.data = stock.crypto.read_data(self.code, start_date, end_date)
    
    @property
    def observation_space(self) -> gym.spaces.Space:
        return gym.spaces.Box(
            low=-np.inf,
            high=np.inf,
            shape=(int(self.window_size * len(self.target_columns)),),
            dtype=np.float32,
        )

In [None]:
class Action:
    class ActionEnum(Enum):
        BUY = 0
        SELL = 1

    def __init__(self, action):
        self.action = Action.ActionEnum(action["action"])
        self.portion = action["portion"]
    

class Portfolio:
    """cryptoのポートフォリオを表すクラス
    """
    def __init__(self, initial_cash: float):
        self.cash = initial_cash  # 現金
        self.num_unit = 0  # 保有数量（取引単位）
        self.acquision_price = 0  # 現在の保有株の取得価格（総額）
        self.total = initial_cash  # 現在の総資産
        self.prev_total = self.total

        self.market_impact = 0.0001  # 取引時の市場インパクト
        self.maker_fee = -0.0001  # メーカー手数料
        self.taker_fee = 0.0005  # テイカー手数料

        self.min_transaction_unit = 0.0001
        self.max_transaction_unit = 5.0
        self.max_loss_rate = 0.1

        self.action_space = gym.spaces.Dict({
            "action": gym.spaces.Discrete(2),
            "portion": gym.spaces.Box(low=0.0, high=1.0, shape=(1,), dtype=np.float32),
        })

    def reset(self):
        self.cash = self.total
        self.num_unit = 0
        self.acquision_price = 0
        self.prev_total = self.total

    def action(self, ohlcv: np.ndarray, action: Action):
        if self.num_unit > 0:  # 保有株がある場合, loss cutの条件を満たしているか確認
            loss_cut_price = (1 - self.max_loss_rate) * self.acquision_price / (self.num_unit * self.min_transaction_unit) 
            if loss_cut_price < ohlcv[2]:  # loss cutに引っかかったら売る
                price = max(ohlcv[0], loss_cut_price)
                self.sell(price, 1.0)
                return
            
        if action.action == Action.ActionEnum.BUY:
            self.buy(ohlcv[3], action.portion)
        elif action.action == Action.ActionEnum.SELL:
            self.sell(ohlcv[3], action.portion)

        self.prev_total = self.total
        self.total = self.cash + self.num_unit * self.min_transaction_unit * ohlcv[3]

        return self.reward()

    def reward(self):
        return self.total - self.prev_total

    def buy(self, price: float, portion: int):
        price = price * (1 + self.market_impact + self.taker_fee)
        units = math.floor(min(self.total * portion, self.cash) / (price * self.min_transaction_unit))
        amount_price = price * self.min_transaction_unit * units
        
        self.cash -= amount_price
        self.num_unit += units
        self.acquision_price += amount_price

    def sell(self, price: float, portion: int):
        if self.num_unit == 0:
            return
        price = price * (1 - self.market_impact)
        units = round(min(self.total * portion, self.num_unit * self.min_transaction_unit * price) / (price * self.min_transaction_unit))
        amount_price = price * self.min_transaction_unit * units

        self.cash += amount_price * (1.0 - self.maker_fee)
        self.acquision_price *= (1.0 - units / self.num_unit)
        self.num_unit -= units

    def __str__(self):
        return f"cash: {self.cash}, num_unit: {self.num_unit}, total: {self.total}"

In [None]:
class TradingEnv(gym.Env):
    """前処理済み（特徴量計算済み）のpl.DataFrameを入力として、
    step毎に状態を変化させ、報酬を返す環境クラス
    """

    def __init__(self, env_config):
        super().__init__()
        self.dataloader: DataLoader = env_config["dataloader"]
        self.portfolio: Portfolio = env_config["portfolio"]
        self.current_index = 1
        self.total_reward = 0
        self.history = []

        self.action_space = self.portfolio.action_space
        self.observation_space = self.dataloader.observation_space

    #@stock.debug
    def step(self, action)-> tuple[np.ndarray, float, bool, bool, dict[str, Any]]:
        if self.current_index >= len(self.dataloader):
            raise RuntimeError("Episode is done. Please reset the environment.")
        
        reward = self.portfolio.action(self.dataloader.get_ohlcv(self.current_index), Action(action))
        obs = self.dataloader[self.current_index]
        self.current_index += 1
        is_terminated = self.current_index >= len(self.dataloader)
        is_truncated = False
        info = {
            "index": self.current_index - 1,
            "action": action,
            "reward": reward,
            "total_reward": self.total_reward,
            "is_terminated": is_terminated,
            "is_truncated": is_truncated,
        }
        #breakpoint()
        self.total_reward += reward
        self.history.append(info)

        return obs, reward, is_terminated, is_truncated, info

    def reset(self, *, seed: int | None = None, options: dict[str, Any] | None = None) -> tuple[np.ndarray, dict[str, Any]]:
        self.current_index = 1
        self.total_reward = 0
        self.history = []
        self.dataloader.reset()
        self.portfolio.reset()
        return self.dataloader[0], {}

    def render(self):
        if len(self.history) == 0:
            return
        
        df = pl.from_dicts(self.history)
        df = df.with_columns(
            pl.col("index").map_elements(lambda s: self.dataloader.get_price(s), return_dtype=pl.Float64).alias("price"),
        )
        
        sellbuy_df = df.filter((pl.col("action") != pl.col("action").shift(1)).fill_null(True))
        sell_df = sellbuy_df.filter(pl.col("action") == 0)
        buy_df = sellbuy_df.filter(pl.col("action") == 1)
        
        fig, axes = plt.subplots(1, 1, figsize=(10, 5))
        axes.plot(df["index"], df["price"], label="price")
        axes.scatter(sell_df["index"], sell_df["price"], marker="^", color="red", label="sell")
        axes.scatter(buy_df["index"], buy_df["price"], marker="v", color="green", label="buy")

    def close(self):
        pass

In [None]:
dataloader = DataLoader("BTC", datetime.datetime(2022, 1, 1), datetime.datetime(2023, 1, 1), 1440, 25)
portfolio = Portfolio(100000)

In [None]:
config = (
    PPOConfig()
    .api_stack(
        enable_rl_module_and_learner=True,
        enable_env_runner_and_connector_v2=True,    
    ).learners(
        num_learners=0,
        num_gpus_per_learner=0,
    ).training(
        gamma=0,
        lr_schedule=[
            [0, 1e-1],
            [int(1e2), 1e-2],
            [int(1e3), 1e-3],
            [int(1e4), 1e-4],
            [int(1e5), 1e-5],
            [int(1e6), 1e-6],
            [int(1e7), 1e-7]
        ],
        lr=8e-6,
        model={"uses_new_env_runners": True},
        lambda_=0.72,
        vf_loss_coeff=0.5,
        entropy_coeff=0.01,
    ).environment(
        clip_rewards=True,
        env=TradingEnv,
        env_config={
            "dataloader": dataloader,
            "portfolio": portfolio,
        },
    )
)
algo = config.build()

In [None]:
# import ray
# from ray.rllib.algorithms import ppo
#algo = ppo.PPO(env=TradingEnv, config={"env_config": {"dataloader": dataloader}})

for i in range(30):
    #result = algo.train()
    result = algo.train()
    # if result["env_runners"]["episode_return_mean"] > 0:
    #     break
    print("Episode reward mean: ", result["env_runners"]["episode_return_mean"])

checkpoint_dir = algo.save_to_path()    

In [None]:
algo = Algorithm.from_checkpoint(checkpoint_dir)
rl_module = algo.get_module()
print(rl_module.input_specs_inference())

env = TradingEnv({"dataloader": dataloader})
obs, info = env.reset()
while True:
    input = torch.from_numpy(np.array([obs]))
    action_logits = rl_module.forward_inference({"obs": input})["action_dist_inputs"]
    action = torch.argmax(action_logits[0]).numpy()
    obs, reward, is_terminated, is_truncated, _ = env.step(action)
    if is_terminated:
        break

In [None]:
env.total_reward

In [None]:
env.render()