In [None]:
import datetime
from pathlib import Path
from typing import Any, Type

import gymnasium as gym
import numpy as np
import torch
from pydantic import BaseModel, ConfigDict
from ray.rllib.algorithms.algorithm import Algorithm
from ray.rllib.algorithms.ppo import PPOConfig
import matplotlib.pyplot as plt

import stock

In [None]:
params = stock.crypto.TrainParams()
params.dataloader.start_date = datetime.datetime(2023, 4, 1, 0, 0)
params.dataloader.end_date = datetime.datetime(2024, 4, 1, 0, 0)
params.dataloader.episode_length = 1440
params.dataloader.window_size = 100
params.portfolio.maker_fee = 0.0
params.portfolio.market_impact = 0.0
params.portfolio.taker_fee = 0.0
params.epoch = 10

In [None]:
trainer = stock.crypto.Trainer(params)

In [None]:
trainer.train()

In [None]:
trainer.portfolio.history

In [None]:
trainer.portfolio.acquision_price

In [None]:
# ckpt_dir = Path(trainer.checkpoint_dir)
# algo = Algorithm.from_checkpoint(ckpt_dir.as_posix())
# rl_module = algo.get_module()
rl_module = trainer.algo.get_module()
trainer.dataloader.reset()
trainer.portfolio.reset()
env = stock.crypto.environment.TradingEnv({"dataloader": trainer.dataloader, "portfolio":  trainer.portfolio})
obs, info = env.reset()
acts = []
while True:
    input = torch.from_numpy(np.array([obs]))
    action_logits = rl_module.forward_inference({"obs": input})["action_dist_inputs"]
    dist_cls = rl_module.get_inference_action_dist_cls()
    dist = dist_cls.from_logits(action_logits[0])
    dist = dist.to_deterministic()
    action = dist.sample()
    #action["portion"] = np.clip(action["portion"].detach().numpy(), 0, 1)    
    action["action"] = action["action"].numpy()
    obs, reward, is_terminated, is_truncated, info = env.step(action)
    acts.append(action)
    if is_terminated:
        break

In [None]:
env.history[-1]

In [None]:
total = [h["total_portfolio"] for h in env.history]
price = [h["price"] for h in env.history]
units = [h["current_holding_unit"] for h in env.history]
action = [h["action"]["action"] for h in env.history]
equity = [h["current_equity_value"] for h in env.history]

In [None]:
plt.plot(action)

In [None]:
plt.plot(equity)

In [None]:
plt.plot(units)

In [None]:
plt.plot(price)

In [None]:
plt.plot(total)