In [None]:
from typing import Any

import matplotlib.pyplot as plt
import numpy as np
import polars as pl
import gymnasium as gym
from ray.rllib.algorithms.ppo import PPOConfig
from ray.rllib.algorithms.algorithm import Algorithm
import torch

In [None]:
class DataLoader:
    """前処理済みのデータを読み込むクラス
    Args:
        data (pl.DataFrame): 前処理済みのデータ (特徴量計算済み、欠損値処理済み)
        window_size (int): 何日分のデータを取得するか
    """
    def __init__(self, data: pl.DataFrame, window_size: int):
        self.data = data
        self.window_size = window_size

    def __len__(self) -> int:
        return len(self.data) - self.window_size + 1
    
    def __getitem__(self, idx: int) -> np.ndarray:
        if 0 <= idx < len(self.data):
            return self.data.slice(idx, self.window_size).to_numpy().flatten()
        else:
            raise IndexError
        
    def get_price(self, idx: int) -> float:
        return self.data["close"][idx + self.window_size - 1]
    
    @property
    def observation_space(self) -> gym.spaces.Space:
        return gym.spaces.Box(
            low=-np.inf,
            high=np.inf,
            shape=(int(self.window_size * len(self.data.columns)),),
            dtype=np.float32,
        )

In [None]:
class TradingEnv(gym.Env):
    """前処理済み（特徴量計算済み）のpl.DataFrameを入力として、
    step毎に状態を変化させ、報酬を返す環境クラス
    """

    def __init__(self, env_config):
        super().__init__()
        self.dataloader = env_config["dataloader"]
        self.current_index = 1
        self.total_reward = 0
        self.history = []

        self.action_space = gym.spaces.Discrete(2)
        self.observation_space = self.dataloader.observation_space
    
    def _calc_reward(self, action) -> float:
        cur = self.dataloader.get_price(self.current_index)
        pre = self.dataloader.get_price(self.current_index - 1)
        action = 1 if action == 1 else -1
        reward = (cur - pre) * action
        return reward

    def step(self, action)-> tuple[np.ndarray, float, bool, bool, dict[str, Any]]:
        if self.current_index >= len(self.dataloader):
            raise RuntimeError("Episode is done. Please reset the environment.")
        reward = self._calc_reward(action)
        obs = self.dataloader[self.current_index]
        self.current_index += 1
        is_terminated = self.current_index >= len(self.dataloader)
        is_truncated = False
        info = {
            "index": self.current_index - 1,
            "action": action,
            "reward": reward,
            "total_reward": self.total_reward,
            "is_terminated": is_terminated,
            "is_truncated": is_truncated,
        }
        
        self.total_reward += reward
        self.history.append(info)

        return obs, reward, is_terminated, is_truncated, info

    def reset(self, *, seed: int | None = None, options: dict[str, Any] | None = None) -> tuple[np.ndarray, dict[str, Any]]:
        self.current_index = 1
        self.total_reward = 0
        self.history = []
        return self.dataloader[0], {}

    def render(self):
        if len(self.history) == 0:
            return
        
        df = pl.from_dicts(self.history)
        df = df.with_columns(
            pl.col("index").map_elements(lambda s: self.dataloader.get_price(s), return_dtype=pl.Float64).alias("price"),
        )
        
        sellbuy_df = df.filter((pl.col("action") != pl.col("action").shift(1)).fill_null(True))
        sell_df = sellbuy_df.filter(pl.col("action") == 0)
        buy_df = sellbuy_df.filter(pl.col("action") == 1)
        
        fig, axes = plt.subplots(1, 1, figsize=(10, 5))
        axes.plot(df["index"], df["price"], label="price")
        axes.scatter(sell_df["index"], sell_df["price"], marker="^", color="red", label="sell")
        axes.scatter(buy_df["index"], buy_df["price"], marker="v", color="green", label="buy")

    def close(self):
        pass

In [None]:
def create_artifitial_data(window_size, num_data_points) -> pl.DataFrame:
    """sine waveデータを生成する"""
    t = np.linspace(0, np.pi * 6, int(num_data_points))
    sine_wave = np.sin(t) * 50 + 100
    sine_wave_with_pad = np.concatenate([np.zeros(int(window_size) - 1), sine_wave])
    data = pl.DataFrame({
        "close": sine_wave_with_pad,
    })

    return data

In [None]:
window_size = 25
num_data_points = 1000
data = create_artifitial_data(window_size, num_data_points)
dataloader = DataLoader(data, 25)
assert len(dataloader) == num_data_points

In [None]:
config = (
    PPOConfig()
    .api_stack(
        enable_rl_module_and_learner=True,
        enable_env_runner_and_connector_v2=True,    
    ).learners(
        num_learners=0,
        num_gpus_per_learner=0,
    ).training(
        gamma=0,
        lr_schedule=[
            [0, 1e-1],
            [int(1e2), 1e-2],
            [int(1e3), 1e-3],
            [int(1e4), 1e-4],
            [int(1e5), 1e-5],
            [int(1e6), 1e-6],
            [int(1e7), 1e-7]
        ],
        lr=8e-6,
        model={"uses_new_env_runners": True},
        lambda_=0.72,
        vf_loss_coeff=0.5,
        entropy_coeff=0.01,
    ).environment(
        clip_rewards=True,
        env=TradingEnv,
        env_config={"dataloader": dataloader},
    )
)
algo = config.build()

In [None]:
# import ray
# from ray.rllib.algorithms import ppo
#algo = ppo.PPO(env=TradingEnv, config={"env_config": {"dataloader": dataloader}})

while True:
    result = algo.train()
    if result["env_runners"]["episode_return_mean"] > 500:
        break
    print("Episode reward mean: ", result["env_runners"]["episode_return_mean"])

checkpoint_dir = algo.save_to_path()    

In [None]:
algo = Algorithm.from_checkpoint(checkpoint_dir)
rl_module = algo.get_module()
print(rl_module.input_specs_inference())

env = TradingEnv({"dataloader": dataloader})
obs, info = env.reset()
while True:
    input = torch.from_numpy(np.array([obs]))
    action_logits = rl_module.forward_inference({"obs": input})["action_dist_inputs"]
    action = torch.argmax(action_logits[0]).numpy()
    obs, reward, is_terminated, is_truncated, _ = env.step(action)
    if is_terminated:
        break

In [None]:
env.total_reward

In [None]:
env.render()