<a href="https://colab.research.google.com/github/srijabiswas-01/reinforcement-learning-trading-bot/blob/main/stock_ananlysis_with_company.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Elite Reinforcement Learning–Based Automated Trading Agent

This notebook demonstrates the initial data analysis and environment setup
for a reinforcement learning–based stock trading agent.


In [1]:
!pip install -q yfinance pandas numpy plotly scipy stable-baselines3 gymnasium seaborn matplotlib

In [2]:
import yfinance as yf
import pandas as pd
import numpy as np
from datetime import datetime, timezone

import plotly.express as px
import seaborn as sns
import matplotlib.pyplot as plt

import gymnasium as gym
from gymnasium import spaces
from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import DummyVecEnv

import warnings
warnings.filterwarnings("ignore")

np.random.seed(42)

Gym has been unmaintained since 2022 and does not support NumPy 2.0 amongst other critical functionality.
Please upgrade to Gymnasium, the maintained drop-in replacement of Gym, or contact the authors of your software and request that they upgrade.
See the migration guide at https://gymnasium.farama.org/introduction/migration_guide/ for additional information.


In [3]:
start_date = "2015-01-01"
end_date = datetime.today().strftime("%Y-%m-%d")

ACTION_MAP = {0: "HOLD", 1: "BUY", 2: "SELL"}

### **Market ETFs (tradable proxies)**

In [4]:
MARKETS = {
    "USA_S&P500": "SPY",
    "Australia_ASX200": "STW.AX",
    "UK_FTSE100": "ISF.L",
    "Japan_Nikkei225": "EWJ",
    "China_SSE": "FXI",
    "India_NIFTY50": "NIFTYBEES.NS"
}

In [5]:
MARKET_COMPANIES = {
    "USA_S&P500": {
        "Apple": "AAPL",
        "Microsoft": "MSFT",
        "Amazon": "AMZN",
        "NVIDIA": "NVDA",
        "Tesla": "TSLA"
    },
    "Australia_ASX200": {
        "BHP Group": "BHP.AX",
        "CSL": "CSL.AX",
        "Commonwealth Bank": "CBA.AX",
        "Westpac": "WBC.AX",
        "ANZ": "ANZ.AX"
    },
    "UK_FTSE100": {
        "Shell": "SHEL.L",
        "HSBC": "HSBA.L",
        "Unilever": "ULVR.L",
        "BP": "BP.L",
        "AstraZeneca": "AZN.L"
    },
    "Japan_Nikkei225": {
        "Toyota": "7203.T",
        "Sony": "6758.T",
        "Nintendo": "7974.T",
        "SoftBank": "9984.T",
        "Keyence": "6861.T"
    },
    "China_SSE": {
        "ICBC": "601398.SS",
        "China Life": "601628.SS",
        "PetroChina": "601857.SS",
        "SAIC Motor": "600104.SS",
        "CRRC": "601766.SS"
    },
    "India_NIFTY50": {
        "Reliance": "RELIANCE.NS",
        "TCS": "TCS.NS",
        "HDFC Bank": "HDFCBANK.NS",
        "Infosys": "INFY.NS",
        "ICICI Bank": "ICICIBANK.NS"
    }
}

In [6]:
def build_features(df):
    price_col = "Adj Close" if "Adj Close" in df.columns else "Close"

    df = df.copy()
    df["Return"] = df[price_col].pct_change()
    df["Volatility_20"] = df["Return"].rolling(20).std()
    df["Momentum_10"] = df[price_col] / df[price_col].shift(10) - 1
    df["Volume_MA20"] = df["Volume"].rolling(20).mean()

    df.dropna(inplace=True)
    df.reset_index(drop=True, inplace=True)

    return df[[
        "Date", "Close", "Volume",
        "Return", "Volatility_20",
        "Momentum_10", "Volume_MA20"
    ]]

In [7]:
market_data = {}

for market, ticker in MARKETS.items():
    print(f"Downloading market ETF: {market} ({ticker})")
    df = yf.download(ticker, start=start_date, end=end_date, progress=False)

    if isinstance(df.columns, pd.MultiIndex):
        df.columns = df.columns.get_level_values(0)

    df.reset_index(inplace=True)
    df.dropna(inplace=True)

    market_data[market] = build_features(df)

Downloading market ETF: USA_S&P500 (SPY)
Downloading market ETF: Australia_ASX200 (STW.AX)
Downloading market ETF: UK_FTSE100 (ISF.L)
Downloading market ETF: Japan_Nikkei225 (EWJ)
Downloading market ETF: China_SSE (FXI)
Downloading market ETF: India_NIFTY50 (NIFTYBEES.NS)


In [8]:
train_data, test_data = {}, {}
split_date = "2022-01-01"

for market, df in market_data.items():
    train_data[market] = df[df["Date"] < split_date].reset_index(drop=True)
    test_data[market]  = df[df["Date"] >= split_date].reset_index(drop=True)

In [9]:
class TradingEnv(gym.Env):
    def __init__(self, df, risk_penalty=0.001):
        super().__init__()
        self.df = df.reset_index(drop=True)
        self.risk_penalty = risk_penalty

        self.action_space = spaces.Discrete(3)
        self.observation_space = spaces.Box(
            low=-np.inf, high=np.inf, shape=(5,), dtype=np.float32
        )

        self.reset()

    def reset(self, seed=None, options=None):
        self.step_idx = 0
        self.position = 0
        self.prev_price = self.df.loc[self.step_idx, "Close"]
        return self._obs(), {}

    def _obs(self):
        r = self.df.loc[self.step_idx]
        volume_norm = r["Volume"] / (r["Volume_MA20"] + 1e-8)
        return np.array([
            r["Return"],
            r["Volatility_20"],
            r["Momentum_10"],
            volume_norm,
            self.position
        ], dtype=np.float32)

    def step(self, action):
        price = self.df.loc[self.step_idx, "Close"]

        if action == 1:
            self.position = 1
        elif action == 2:
            self.position = -1

        ret = (price - self.prev_price) / self.prev_price
        pnl = self.position * ret

        reward = 100 * pnl - 10 * self.risk_penalty * self.df.loc[self.step_idx, "Volatility_20"]

        self.prev_price = price
        self.step_idx += 1

        done = self.step_idx >= len(self.df) - 1
        return self._obs() if not done else np.zeros(5), reward, done, False, {}

In [10]:
env = DummyVecEnv([
    lambda m=m: TradingEnv(train_data[m])
    for m in train_data
])

model = PPO(
    "MlpPolicy",
    env,
    learning_rate=3e-4,
    n_steps=2048,
    batch_size=64,
    gamma=0.99,
    verbose=1
)

model.learn(total_timesteps=400_000)

Using cuda device
------------------------------
| time/              |       |
|    fps             | 2360  |
|    iterations      | 1     |
|    time_elapsed    | 5     |
|    total_timesteps | 12288 |
------------------------------
------------------------------------------
| time/                   |              |
|    fps                  | 1308         |
|    iterations           | 2            |
|    time_elapsed         | 18           |
|    total_timesteps      | 24576        |
| train/                  |              |
|    approx_kl            | 0.0061533856 |
|    clip_fraction        | 0.0182       |
|    clip_range           | 0.2          |
|    entropy_loss         | -1.09        |
|    explained_variance   | -1.47e-05    |
|    learning_rate        | 0.0003       |
|    loss                 | 332          |
|    n_updates            | 10           |
|    policy_gradient_loss | -0.00178     |
|    value_loss           | 483          |
----------------------------------

<stable_baselines3.ppo.ppo.PPO at 0x786ac3f274a0>

In [11]:
company_data = {}

for market, companies in MARKET_COMPANIES.items():
    company_data[market] = {}

    for name, ticker in companies.items():
        print(f"Downloading company: {name} ({ticker})")
        df = yf.download(ticker, start=start_date, end=end_date, progress=False)
        if df.empty:
            continue

        if isinstance(df.columns, pd.MultiIndex):
            df.columns = df.columns.get_level_values(0)

        df.reset_index(inplace=True)
        df.dropna(inplace=True)

        company_data[market][name] = build_features(df)

Downloading company: Apple (AAPL)
Downloading company: Microsoft (MSFT)
Downloading company: Amazon (AMZN)
Downloading company: NVIDIA (NVDA)
Downloading company: Tesla (TSLA)
Downloading company: BHP Group (BHP.AX)
Downloading company: CSL (CSL.AX)
Downloading company: Commonwealth Bank (CBA.AX)
Downloading company: Westpac (WBC.AX)
Downloading company: ANZ (ANZ.AX)
Downloading company: Shell (SHEL.L)
Downloading company: HSBC (HSBA.L)
Downloading company: Unilever (ULVR.L)
Downloading company: BP (BP.L)
Downloading company: AstraZeneca (AZN.L)
Downloading company: Toyota (7203.T)
Downloading company: Sony (6758.T)
Downloading company: Nintendo (7974.T)
Downloading company: SoftBank (9984.T)
Downloading company: Keyence (6861.T)
Downloading company: ICBC (601398.SS)
Downloading company: China Life (601628.SS)
Downloading company: PetroChina (601857.SS)
Downloading company: SAIC Motor (600104.SS)
Downloading company: CRRC (601766.SS)
Downloading company: Reliance (RELIANCE.NS)
Download

In [12]:
def build_state(df, position):
    r = df.iloc[-1]
    volume_norm = r["Volume"] / (r["Volume_MA20"] + 1e-8)
    return np.array([
        r["Return"],
        r["Volatility_20"],
        r["Momentum_10"],
        volume_norm,
        position
    ], dtype=np.float32)

In [13]:
def agent_cli(model):
    print("\n=== Elite RL Trading Agent (CLI) ===")

    while True:
        print("\nAvailable Markets:")
        for m in MARKET_COMPANIES:
            print("•", m)

        market = input("\nSelect market (or 'exit'): ").strip()
        if market.lower() == "exit":
            print("Exiting agent.")
            break

        if market not in MARKET_COMPANIES:
            print("Invalid market.")
            continue

        print("\nAvailable Companies:")
        for c in MARKET_COMPANIES[market]:
            print("•", c)

        company = input("\nSelect company: ").strip()
        if company not in company_data[market]:
            print("Invalid company.")
            continue

        position = int(input("Current position (-1 short, 0 flat, 1 long): "))

        state = build_state(company_data[market][company], position)
        action, _ = model.predict(state.reshape(1, -1), deterministic=True)

        print("\n=== AGENT RESPONSE ===")
        print("Market:", market)
        print("Company:", company)
        print("Suggested Action:", ACTION_MAP[int(action)])
        print("Current Position:", position)
        print("Timestamp:", datetime.now(timezone.utc).isoformat())

In [14]:
agent_cli(model)


=== Elite RL Trading Agent (CLI) ===

Available Markets:
• USA_S&P500
• Australia_ASX200
• UK_FTSE100
• Japan_Nikkei225
• China_SSE
• India_NIFTY50

Select market (or 'exit'): USA_S&P500

Available Companies:
• Apple
• Microsoft
• Amazon
• NVIDIA
• Tesla

Select company: NVIDIA
Current position (-1 short, 0 flat, 1 long): 1

=== AGENT RESPONSE ===
Market: USA_S&P500
Company: NVIDIA
Suggested Action: SELL
Current Position: 1
Timestamp: 2026-02-06T11:30:55.321880+00:00

Available Markets:
• USA_S&P500
• Australia_ASX200
• UK_FTSE100
• Japan_Nikkei225
• China_SSE
• India_NIFTY50

Select market (or 'exit'): exit
Exiting agent.
