In [1]:
import sys

sys.path.append("../")
from pathlib import Path

import polars as pl
from stable_baselines3 import PPO
from stable_baselines3.common.env_checker import check_env

from common.make_vec_env import make_vec_env
from envs.single_stock_trading_env import StockTradingEnv

In [2]:
TICKER = "SBIN.NS"
TRAIN_FILE = Path("../datasets") / f"{TICKER}_train"
EVAL_FILE = Path("../datasets") / f"{TICKER}_trade"

CLOSE_PRICES = pl.read_parquet(TRAIN_FILE)["Close"].to_numpy()
EVAL_CLOSE_PRICES = pl.read_parquet(EVAL_FILE)["Close"].to_numpy()

In [3]:
check_env(StockTradingEnv(CLOSE_PRICES, seed=0))

In [4]:
env = StockTradingEnv(CLOSE_PRICES, seed=0)

In [5]:
obs, info = env.reset()
obs, info

(array([  457.65, 10000.  ,     0.  ,     0.  ], dtype=float32), {})

In [9]:
s, r, d, t, i = env.step(2)
i

{'seed': 0,
 'counter': 3,
 'close_price': 447.4,
 'predicted_action': 'SELL',
 'description': '22.0 shares sold at 447.40 with profit of -34.1005859375',
 'available_amount': 9965.899,
 'shares_holdings': 0,
 'buy_price': 0,
 'buy_price_index': -1,
 'reward': -34.1005859375,
 'done': False,
 'truncated': False,
 'correct_trade': 4,
 'wrong_trade': 0,
 'correct_trade %': 100.0,
 'buy_counter': 1,
 'sell_counter': 1,
 'hold_counter': 0,
 'good_hold_counter': 0,
 'good_sell_counter': 0,
 'good_buy_counter': 1,
 'bad_hold_counter': 0,
 'bad_sell_counter': 1,
 'bad_buy_counter': 0,
 'hold_with_no_shares_counter': 2,
 'good_hold_streak': 0,
 'bad_hold_streak': 0,
 'buy_counter %': 25.0,
 'good_sell_counter %': 0,
 'good_hold_counter %': 0,
 'bad_sell_counter %': 25.0,
 'bad_hold_counter %': 0,
 'holds_with_no_shares_counter %': 50.0,
 'good_hold_profit': 0,
 'good_sell_profit': 0,
 'good_buy_profit': 0,
 'bad_hold_loss': 0,
 'bad_sell_loss': -34.1005859375,
 'bad_buy_loss': 0,
 'good_moves'

In [18]:
df = pl.read_excel("../datasets/LabelTradeSBI.NS.xlsx")
df

Datetime,Close,RSI,EMA9,EMA21,MACD,MACD_SIGNAL,BBANDS_UPPER,BBANDS_MIDDLE,BBANDS_LOWER,ADX,STOCH_K,STOCH_D,ATR,CCI,MOM,ROC,WILLR,PPO,Actions
str,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,str
"""2023-08-04 03:…",592.599976,34.215289,593.04732,599.098184,-6.307788,-5.752122,617.59649,599.995004,582.393519,49.057268,66.097559,44.168844,4.162341,-56.254793,-3.900024,-0.653818,-68.801719,-1.776887,"""HOLD"""
"""2023-08-04 04:…",595.200012,39.898269,593.477859,598.743805,-5.79683,-5.761063,615.672932,599.085004,582.497075,49.078604,80.213751,61.716574,4.193601,-7.76646,-4.049988,-0.675843,-49.502458,-1.701393,"""HOLD"""
"""2023-08-04 05:…",596.799988,43.152731,594.142285,598.567094,-5.202812,-5.649413,613.624074,598.280002,582.93593,47.995942,89.26872,78.526677,4.158344,55.752468,-2.100037,-0.350649,-37.994763,-1.567384,"""HOLD"""
"""2023-08-04 06:…",593.0,37.903187,593.913828,598.060995,-4.981256,-5.515782,611.614952,597.360001,583.10505,47.412213,75.581632,81.688034,4.229174,18.573737,0.799988,0.135087,-49.683428,-1.479955,"""HOLD"""
"""2023-08-04 07:…",576.200012,24.00159,590.371065,596.073633,-6.091075,-5.63084,611.045368,595.585001,580.124633,48.706005,48.939613,71.263322,5.516376,-222.763293,-16.849976,-2.84124,-86.188757,-1.517362,"""HOLD"""
"""2023-08-04 08:…",573.900024,22.770303,587.076857,594.05785,-7.074652,-5.919603,609.851195,593.662503,577.473811,49.921356,22.926329,49.149192,5.67235,-233.637316,-18.0,-3.041054,-93.414121,-1.482601,"""HOLD"""
"""2023-08-04 09:…",573.200012,22.393764,584.301488,592.161683,-7.82048,-6.299778,608.931511,591.902502,574.873494,51.07946,9.433371,27.099771,5.45647,-190.258119,-17.149963,-2.90505,-94.20779,-1.48486,"""HOLD"""
"""2023-08-07 03:…",570.900024,21.155885,581.621195,590.228805,-8.49917,-6.739656,608.61177,590.242505,571.87324,52.291359,6.454252,12.937984,5.552435,-148.452113,-15.099976,-2.576788,-94.237247,-1.508338,"""HOLD"""
"""2023-08-07 04:…",568.200012,19.774005,578.936958,588.226187,-9.149436,-7.221612,608.403215,588.542505,568.681795,53.533983,5.361881,7.083168,5.50226,-133.378831,-22.599976,-3.825318,-96.671987,-1.534055,"""HOLD"""
"""2023-08-07 05:…",569.650024,22.694395,577.079572,586.537445,-9.438967,-7.665083,607.681385,587.000006,566.318627,54.687848,9.918647,7.244927,5.344954,-105.989029,-21.75,-3.677714,-92.076067,-1.497092,"""HOLD"""


In [22]:
# backtest
start_price = 10_000
(
    df.with_columns([pl.lit(10_000).alias("Available_Amount")]).with_columns(
        [
            pl.when(pl.col("Actions") == "BUY")
            .then(pl.col("Available_Amount") // pl.col("Close"))
            .then(
                pl.col("Available_Amount")
                - pl.col("Available_Amount") // pl.col("Close")
            )
            # .otherwise(0)
            .alias("Shares_Holdings")
        ]
    )
)

AttributeError: 'Then' object has no attribute 'then'