In [1]:
import gymnasium as gym
from tqdm import tqdm
import numpy as np
from pathlib import Path
import polars as pl
import random

print(f"{gym.__version__}")

0.28.1


In [2]:
import torch as th
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR

In [3]:
from stable_baselines3 import PPO, A2C, SAC, TD3
from stable_baselines3.common.evaluation import evaluate_policy

In [4]:
from stock_prediction_rl.envs.numpy.stock_trading_env import StockTradingEnv
from stock_prediction_rl.sb.utils import (
    create_numpy_array,
    create_envs,
)

In [5]:
SEED = 1337
random.seed(SEED)
np.random.seed(SEED)
th.manual_seed(SEED)
th.backends.cudnn.deterministic = True

In [6]:
df = pl.read_excel(Path.home() / "Documents/LabelTradeSBI.NS.xlsx")
df

Datetime,Close,RSI,EMA9,EMA21,MACD,MACD_SIGNAL,BBANDS_UPPER,BBANDS_MIDDLE,BBANDS_LOWER,ADX,STOCH_K,STOCH_D,ATR,CCI,MOM,ROC,WILLR,PPO,Actions
str,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,str
"""2023-08-04 03:…",592.599976,34.215289,593.04732,599.098184,-6.307788,-5.752122,617.59649,599.995004,582.393519,49.057268,66.097559,44.168844,4.162341,-56.254793,-3.900024,-0.653818,-68.801719,-1.776887,"""HOLD"""
"""2023-08-04 04:…",595.200012,39.898269,593.477859,598.743805,-5.79683,-5.761063,615.672932,599.085004,582.497075,49.078604,80.213751,61.716574,4.193601,-7.76646,-4.049988,-0.675843,-49.502458,-1.701393,"""HOLD"""
"""2023-08-04 05:…",596.799988,43.152731,594.142285,598.567094,-5.202812,-5.649413,613.624074,598.280002,582.93593,47.995942,89.26872,78.526677,4.158344,55.752468,-2.100037,-0.350649,-37.994763,-1.567384,"""HOLD"""
"""2023-08-04 06:…",593.0,37.903187,593.913828,598.060995,-4.981256,-5.515782,611.614952,597.360001,583.10505,47.412213,75.581632,81.688034,4.229174,18.573737,0.799988,0.135087,-49.683428,-1.479955,"""HOLD"""
"""2023-08-04 07:…",576.200012,24.00159,590.371065,596.073633,-6.091075,-5.63084,611.045368,595.585001,580.124633,48.706005,48.939613,71.263322,5.516376,-222.763293,-16.849976,-2.84124,-86.188757,-1.517362,"""HOLD"""
"""2023-08-04 08:…",573.900024,22.770303,587.076857,594.05785,-7.074652,-5.919603,609.851195,593.662503,577.473811,49.921356,22.926329,49.149192,5.67235,-233.637316,-18.0,-3.041054,-93.414121,-1.482601,"""HOLD"""
"""2023-08-04 09:…",573.200012,22.393764,584.301488,592.161683,-7.82048,-6.299778,608.931511,591.902502,574.873494,51.07946,9.433371,27.099771,5.45647,-190.258119,-17.149963,-2.90505,-94.20779,-1.48486,"""HOLD"""
"""2023-08-07 03:…",570.900024,21.155885,581.621195,590.228805,-8.49917,-6.739656,608.61177,590.242505,571.87324,52.291359,6.454252,12.937984,5.552435,-148.452113,-15.099976,-2.576788,-94.237247,-1.508338,"""HOLD"""
"""2023-08-07 04:…",568.200012,19.774005,578.936958,588.226187,-9.149436,-7.221612,608.403215,588.542505,568.681795,53.533983,5.361881,7.083168,5.50226,-133.378831,-22.599976,-3.825318,-96.671987,-1.534055,"""HOLD"""
"""2023-08-07 05:…",569.650024,22.694395,577.079572,586.537445,-9.438967,-7.665083,607.681385,587.000006,566.318627,54.687848,9.918647,7.244927,5.344954,-105.989029,-21.75,-3.677714,-92.076067,-1.497092,"""HOLD"""


In [8]:
# How much profit buy buying and selling 5 shares
(
    df
    .with_columns(
        pl
        .when(pl.col("Actions") == "BUY")
        .then(-((pl.col("Close") * 5) + 20))
        .when(pl.col("Actions") == "SELL")
        .then(+((pl.col("Close") * 5) + 20))
        .otherwise(pl.lit(0))
        .alias("PROFIT LOSS")
    )
    .select(pl.col("PROFIT LOSS").sum())
)   

PROFIT LOSS
f64
729.9997


In [28]:
ticker = "SBIN.NS"
datasets = Path.cwd().parent / ("datasets")
model_name = "A2C"
num_envs = 16
seed = 1337


train_file = datasets / f"{ticker}_train"
trade_file = datasets / f"{ticker}_trade"


train_df = pl.read_parquet(train_file)
trade_df = pl.read_parquet(trade_file)
train_array = create_numpy_array(train_df)
trade_arrays = create_numpy_array(trade_df)


trade_envs = create_envs(
    StockTradingEnv, trade_arrays, num_envs=num_envs, mode="trade", seed=seed
)
trade_df

Datetime,Close,High,Low,Ticker,Past1Hour,Past2Hour,Past3Hour,Past4Hour,Past5Hour,Past6Hour,Past7Hour,Past8Hour,Past9Hour,Past10Hour,Past11Hour,Past12Hour,Past13Hour,Past14Hour,Past15Hour,Past16Hour,Past17Hour,Past18Hour,Past19Hour,Past20Hour,Past21Hour,Past22Hour,Past23Hour,Past24Hour,Past25Hour,Past26Hour,Past27Hour,Past28Hour,Past29Hour,RSI,EMA9,EMA21,…,ROC,WILLR,PPO,Previous1Action,Previous2Action,Previous3Action,Previous4Action,Previous5Action,Previous6Action,Previous7Action,Previous8Action,Previous9Action,Previous10Action,Previous11Action,Previous12Action,Previous13Action,Previous14Action,Previous15Action,Previous16Action,Previous17Action,Previous18Action,Previous19Action,Previous20Action,Previous21Action,Previous22Action,Previous23Action,Previous24Action,Previous25Action,Previous26Action,Previous27Action,Previous28Action,Previous29Action,PortfolioValue,AvailableAmount,SharesHolding,CummulativeProfitLoss,Actions
"datetime[ns, UTC]",f64,f64,f64,str,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,…,f64,f64,f64,i32,i32,i32,i32,i32,i32,i32,i32,i32,i32,i32,i32,i32,i32,i32,i32,i32,i32,i32,i32,i32,i32,i32,i32,i32,i32,i32,i32,i32,i32,i32,i32,i32,str


In [51]:
correct_actions = (df.select("Actions").to_series().to_list())

In [10]:
trained_model_dir = Path.cwd().parent.parent / ("trained_models")
model_filename = trained_model_dir / f"sb_{model_name}_{ticker}_single_digit_reward_default_parameters"
a2c_expert = A2C.load(model_filename, env=trade_envs, force_reset=False)

In [20]:
obs = trade_envs.reset()
counter = 0
while counter < num_envs:
    action, _ = a2c_expert.predict(obs, deterministic=False)
    obs, rewards, dones, infos = trade_envs.step(action)

    for i in range(num_envs):
        if dones[i]:
            profit_loss = [info["cummulative_profit_loss"] for info in (infos)]
            counter += 1
print(profit_loss)

[-18.5, 0, -4.89996337890625, 0, -32.14996337890625, -33.64996337890625, -23.64996337890625, 0, 0, -32.8499755859375, 0, -0.0999755859375, 0, -6.1500244140625, -20.39996337890625, -35.75]
