<a href="https://colab.research.google.com/github/tony-pitchblack/finrl-dt/blob/custom-backtesting/finrl_dt_replicate_sweep_rllib.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#Installs

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
# %%bash
# git clone https://github.com/tony-pitchblack/FinRL.git
# cd ./FinRL
# git checkout benchmarking
# pip install -r FinRL/requirements.txt

In [None]:
!pip install git+https://github.com/tony-pitchblack/FinRL.git@benchmarking \
    # --force-reinstall --no-deps

In [None]:
# Install the dependencies listed in the requirements.txt file
!wget https://raw.githubusercontent.com/tony-pitchblack/FinRL/benchmarking/requirements.txt -O requirements.txt
!pip install -r requirements.txt

[autoreload of packaging.utils failed: Traceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/IPython/extensions/autoreload.py", line 245, in check
    superreload(m, reload, self.old_objects)
  File "/usr/local/lib/python3.11/dist-packages/IPython/extensions/autoreload.py", line 410, in superreload
    update_generic(old_obj, new_obj)
  File "/usr/local/lib/python3.11/dist-packages/IPython/extensions/autoreload.py", line 347, in update_generic
    update(a, b)
  File "/usr/local/lib/python3.11/dist-packages/IPython/extensions/autoreload.py", line 266, in update_function
    setattr(old, name, getattr(new, name))
ValueError: canonicalize_version() requires a code object with 2 free vars, not 0
]
[autoreload of urllib3.exceptions failed: Traceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/IPython/extensions/autoreload.py", line 245, in check
    superreload(m, reload, self.old_objects)
  File "/usr/local/lib/python3.11/dist-pack

--2025-02-02 10:44:46--  https://raw.githubusercontent.com/tony-pitchblack/FinRL/benchmarking/requirements.txt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.110.133, 185.199.108.133, 185.199.111.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.110.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 766 [text/plain]
Saving to: ‘requirements.txt’


2025-02-02 10:44:46 (36.7 MB/s) - ‘requirements.txt’ saved [766/766]

Collecting alpaca-py (from -r requirements.txt (line 2))
  Downloading alpaca_py-0.37.0-py3-none-any.whl.metadata (13 kB)
Collecting selenium (from -r requirements.txt (line 3))
  Downloading selenium-4.28.1-py3-none-any.whl.metadata (7.1 kB)
Collecting webdriver-manager (from -r requirements.txt (line 4))
  Downloading webdriver_manager-4.0.2-py2.py3-none-any.whl.metadata (12 kB)
Collecting exchange_calendars==3.6.3 (from -r requirements.txt (line 7))
  Downloading exchange_calendars-

# Imports

In [None]:
import pandas as pd

from stable_baselines3.common.logger import configure
from finrl import config_tickers
from finrl.main import check_and_make_directories
from finrl.config import INDICATORS, TRAINED_MODEL_DIR, RESULTS_DIR

import warnings
warnings.filterwarnings("ignore", category=DeprecationWarning)

In [None]:
import os
from pathlib import Path
import pandas as pd
import warnings
warnings.filterwarnings("ignore", category=DeprecationWarning)

In [None]:
os.environ["WANDB_API_KEY"] = "aee284a72205e2d6787bd3ce266c5b9aefefa42c"

PROJECT = 'finrl-dt-replicate'
ENTITY = "overfit1010"

# General funcs

In [None]:
#@title YahooDownloader

"""Contains methods and classes to collect data from
Yahoo Finance API
"""

from __future__ import annotations

import pandas as pd
import yfinance as yf


class YahooDownloader:
    """Provides methods for retrieving daily stock data from
    Yahoo Finance API

    Attributes
    ----------
        start_date : str
            start date of the data (modified from neofinrl_config.py)
        end_date : str
            end date of the data (modified from neofinrl_config.py)
        ticker_list : list
            a list of stock tickers (modified from neofinrl_config.py)

    Methods
    -------
    fetch_data()
        Fetches data from yahoo API

    """

    def __init__(self, start_date: str, end_date: str, ticker_list: list):
        self.start_date = start_date
        self.end_date = end_date
        self.ticker_list = ticker_list

    def fetch_data(self, proxy=None) -> pd.DataFrame:
        """Fetches data from Yahoo API
        Parameters
        ----------

        Returns
        -------
        `pd.DataFrame`
            7 columns: A date, open, high, low, close, volume and tick symbol
            for the specified stock ticker
        """
        # Download and save the data in a pandas DataFrame:
        data_df = pd.DataFrame()
        num_failures = 0
        for tic in self.ticker_list:
            temp_df = yf.download(
                tic, start=self.start_date, end=self.end_date, proxy=proxy
            )
            temp_df["tic"] = tic
            if len(temp_df) > 0:
                # data_df = data_df.append(temp_df)
                data_df = pd.concat([data_df, temp_df], axis=0)
            else:
                num_failures += 1
        if num_failures == len(self.ticker_list):
            raise ValueError("no data is fetched.")
        # reset the index, we want to use numbers as index instead of dates
        data_df = data_df.reset_index()

        try:
            # Convert wide to long format
            # print(f"DATA COLS: {data_df.columns}")
            data_df = data_df.sort_index(axis=1).set_index(['Date']).drop(columns=['tic']).stack(level='Ticker', future_stack=True)
            data_df.reset_index(inplace=True)
            data_df.columns.name = ''

            # convert the column names to standardized names
            data_df.rename(columns={'Ticker': 'Tic', 'Adj Close': 'Adjcp'}, inplace=True)
            data_df.rename(columns={col: col.lower() for col in data_df.columns}, inplace=True)

            columns = [
                "date",
                "tic",
                "open",
                "high",
                "low",
                "close",
                "adjcp",
                "volume",
            ]

            data_df = data_df[columns]
            # use adjusted close price instead of close price
            data_df["close"] = data_df["adjcp"]
            # drop the adjusted close price column
            data_df = data_df.drop(labels="adjcp", axis=1)

        except NotImplementedError:
            print("the features are not supported currently")

        # create day of the week column (monday = 0)
        data_df["day"] = data_df["date"].dt.dayofweek
        # convert date to standard string format, easy to filter
        data_df["date"] = data_df.date.apply(lambda x: x.strftime("%Y-%m-%d"))
        # drop missing data
        data_df = data_df.dropna()
        data_df = data_df.reset_index(drop=True)
        print("Shape of DataFrame: ", data_df.shape)
        # print("Display DataFrame: ", data_df.head())

        data_df = data_df.sort_values(by=["date", "tic"]).reset_index(drop=True)

        return data_df

    def select_equal_rows_stock(self, df):
        df_check = df.tic.value_counts()
        df_check = pd.DataFrame(df_check).reset_index()
        df_check.columns = ["tic", "counts"]
        mean_df = df_check.counts.mean()
        equal_list = list(df.tic.value_counts() >= mean_df)
        names = df.tic.value_counts().index
        select_stocks_list = list(names[equal_list])
        df = df[df.tic.isin(select_stocks_list)]
        return df


In [None]:
#@title fix_daily_index

def make_daily_index(data_df, date_column='date', new_index_name='date_index'):
    # Get unique dates and create a mapping to daily indices
    total_dates = data_df[date_column].unique()
    date_to_index = {date: idx for idx, date in enumerate(sorted(total_dates))}
    return data_df[date_column].map(date_to_index)

def set_daily_index(data_df, date_column='date', new_index_name='date_index'):
    """
    Constructs a daily index from unique dates in the specified column.

    Parameters:
        data_df (pd.DataFrame): The input DataFrame.
        date_column (str): The name of the column containing dates.
        new_index_name (str): The name for the new index.

    Returns:
        pd.DataFrame: DataFrame with a daily index.
    """

    # Map dates to daily indices and set as index
    data_df[new_index_name] = make_daily_index(data_df, date_column='date', new_index_name='date_index')

    data_df.set_index(new_index_name, inplace=True)
    data_df.index.name = ''  # Remove the index name for simplicity

    return data_df

def fix_daily_index(df):
    if df.index.name == 'date':
        df.reset_index(inplace=True)

    daily_index = make_daily_index(df, date_column='date', new_index_name='date_index')
    if (df.index.values != daily_index.values).any():

        df.index = daily_index
        df.index.name = ''

    return df

# trade = fix_daily_index(trade)
# trade.index

In [None]:
#@title get dataset name

def get_quarterly_dataset_name(prefix, train_start_date, val_start_date, test_start_date):
    get_quarter = lambda date: f'Q{(date.month - 1) // 3 + 1}'

    val_quarter = get_quarter(val_start_date)
    test_quarter = get_quarter(test_start_date)

    # Extract year and month
    train_start = f"{train_start_date.year}-{train_start_date.month:02}"
    val_start = f"{val_start_date.year}"
    test_start = f"{test_start_date.year}"

    # Construct the dataset name
    dataset_name = f"{prefix} | {train_start} | {val_start} {val_quarter} | {test_start} {test_quarter}"

    return dataset_name

def get_yearly_dataset_name(prefix, train_start, test_start, test_end):
    # Extract year and month
    train_start_str = f"{train_start.year}-{train_start.month:02}"
    test_start_str = f"{test_start.year}-{test_start.month:02}"
    test_end_str = f"{test_end.year}-{test_end.month:02}"

    # Construct the dataset name
    dataset_name = f"{prefix} | {train_start_str} | {test_start_str} | {test_end_str}"
    return dataset_name


In [None]:
#@title add_dataset

def add_dataset(stock_index_name, train_df, test_df):
    if 'datasets' not in globals():
        global datasets
        datasets = {}

    # Ensure datetime format
    if 'date' in train_df.columns:
        train_df.set_index('date', inplace=True)
    train_df.index = pd.to_datetime(train_df.index)

    if 'date' in test_df.columns:
        test_df.set_index('date', inplace=True)
    test_df.index = pd.to_datetime(test_df.index)

    train_start_date = train_df.index[0]
    test_start_date = test_df.index[0]
    test_end_date = test_df.index[-1]

    dataset_name = get_yearly_dataset_name(
        stock_index_name,
        train_start_date, test_start_date, test_end_date
    )

    train_df.reset_index(inplace=True)
    test_df.reset_index(inplace=True)

    train_df = set_daily_index(train_df)
    test_df = set_daily_index(test_df)

    ticker_list = train_df.tic.unique().tolist()

    datasets[dataset_name] = {
        'train': train_df,
        'test': test_df,
        'metadata': dict(
            stock_index_name = stock_index_name,
            train_start_date = train_start_date,
            test_start_date = test_start_date,
            test_end_date = test_end_date,
            num_tickers = len(ticker_list),
            ticker_list = ticker_list,
        )
    }

# Load data

## DATA: DOW-30 (quarterly train/val/test)

In [None]:
train_start_date = '2015-01-01'
min_test_start_date = '2016-01-01'
max_test_end_date = '2016-10-01'

In [None]:
#@title generate_quarterly_date_ranges

def generate_quarterly_date_ranges(
    train_start_date,
    min_test_start_date,
    max_test_end_date,
    return_strings=False,
    finetune_previous_val=False
):
    is_quarter_start = lambda date: date.month in [1, 4, 7, 10] and date.day == 1

    min_test_start_date = pd.Timestamp(min_test_start_date)
    train_start_date = pd.Timestamp(train_start_date)
    max_test_end_date = pd.Timestamp(max_test_end_date)

    assert is_quarter_start(train_start_date), f"train_start_date {train_start_date} is not a quarter start date."
    assert is_quarter_start(min_test_start_date), f"min_test_start_date {min_test_start_date} is not a quarter start date."

    test_start_date = min_test_start_date
    date_ranges = []
    full_train_start_date = train_start_date

    while True:
        val_start_date = test_start_date - pd.DateOffset(months=3)
        test_end_date = test_start_date + pd.DateOffset(months=3)

        if test_end_date > max_test_end_date:
            break

        if len(date_ranges) == 0:
            # The first date_range contains the full training period
            train_start_date = full_train_start_date
        elif finetune_previous_val:
            # Use the previous validation range as the training range
            train_start_date = date_ranges[-1]['val_start_date']

        date_range = dict(
            train_start_date=train_start_date,
            val_start_date=val_start_date,
            test_start_date=test_start_date,
            test_end_date=test_end_date,
        )

        if return_strings:
            date_range = {k: str(v) for k, v in date_range.items()}

        date_ranges.append(date_range)

        test_start_date = test_end_date

    return date_ranges

date_ranges = generate_quarterly_date_ranges(
    train_start_date,
    min_test_start_date,
    max_test_end_date,
    finetune_previous_val=True
)

# print(*date_ranges[:2], sep='\n')
print(*date_ranges, sep='\n')

{'train_start_date': Timestamp('2015-01-01 00:00:00'), 'val_start_date': Timestamp('2015-10-01 00:00:00'), 'test_start_date': Timestamp('2016-01-01 00:00:00'), 'test_end_date': Timestamp('2016-04-01 00:00:00')}
{'train_start_date': Timestamp('2015-10-01 00:00:00'), 'val_start_date': Timestamp('2016-01-01 00:00:00'), 'test_start_date': Timestamp('2016-04-01 00:00:00'), 'test_end_date': Timestamp('2016-07-01 00:00:00')}
{'train_start_date': Timestamp('2016-01-01 00:00:00'), 'val_start_date': Timestamp('2016-04-01 00:00:00'), 'test_start_date': Timestamp('2016-07-01 00:00:00'), 'test_end_date': Timestamp('2016-10-01 00:00:00')}


In [None]:
#@title split_data

def split_data(data_df, date_range):
    def subset_date_range(df, start_date, end_date):
        df = df[(df['date'] >= start_date) & (df['date'] < end_date)]
        df = fix_daily_index(df)
        return df

    return {
        'train': subset_date_range(data_df, date_range['train_start_date'], date_range['val_start_date']),
        'val': subset_date_range(data_df, date_range['val_start_date'], date_range['test_start_date']),
        'test': subset_date_range(data_df, date_range['test_start_date'], date_range['test_end_date']),
    }

# data_splits = split_data(preproc_df, date_ranges[0])
# data_splits['train'].head()

In [None]:
#@title prepare_data (for np env)
import hashlib
from finrl.config_tickers import DOW_30_TICKER
from finrl.meta.data_processor import DataProcessor
import os

CACHE_DIR = './cache'
os.makedirs(CACHE_DIR, exist_ok=True)

def stable_hash(data):
    return hashlib.sha256(str(data).encode()).hexdigest()

def get_env_config(
    start_date,
    end_date,
    if_train,
    ticker_list=DOW_30_TICKER,
    technical_indicator_list=INDICATORS,
    time_interval='1d',
    if_vix=True,
    **kwargs
):
    data_hash = stable_hash(tuple(sorted(ticker_list) + sorted(technical_indicator_list)))
    file_path = Path(CACHE_DIR) / f"{start_date}_{end_date}_{time_interval}_{data_hash}.csv"
    dp = DataProcessor(data_source='yahoofinance', tech_indicator=technical_indicator_list, vix=if_vix, **kwargs)
    if os.path.isfile(file_path):
        print(f"Using cached data: {file_path}")
        data = pd.read_csv(file_path, index_col=0)
    else:
        print("Creating new data.")
        data = dp.download_data(ticker_list, start_date, end_date, time_interval)
        data = dp.clean_data(data)
        data = dp.add_technical_indicator(data, technical_indicator_list)
        if if_vix:
            data = dp.add_vix(data)
        data.to_csv(file_path)

    price_array, tech_array, turbulence_array = dp.df_to_array(data, if_vix)

    env_config = {
        "price_array": price_array,
        "tech_array": tech_array,
        "turbulence_array": turbulence_array,
        "if_train": if_train
    }

    return env_config

# date_range=date_ranges[0]
# train_env_config = get_env_config(
#     start_date=date_range['val_start_date'],
#     end_date=date_range['test_start_date'],
#     if_train=True
# )
# val_env_config = get_env_config(
#     start_date=date_range['test_start_date'],
#     end_date=date_range['test_end_date'],
#     if_train=False
# )

# Main

## Wandb artifacts

In [None]:
#@title update_artifact

def update_artifact(folder_path, name_prefix, type):
    """
    Create or update a W&B artifact consisting of a folder.

    Args:
        run: The current W&B run.
        folder_path (str): Path to the folder to upload.
        artifact_name (str): Name of the artifact.
        artifact_type (str): Type of the artifact.
    """
    run = wandb.run
    artifact_name = f'{name_prefix}-{wandb.run.id}'

    # Create a new artifact
    artifact = wandb.Artifact(name=artifact_name, type=type)

    # Add the folder to the artifact
    artifact.add_dir(folder_path)

    # Log the artifact to W&B
    run.log_artifact(artifact)
    print(f"Artifact '{artifact_name}' has been updated and uploaded.")

In [None]:
#@title update_model_artifacts

def update_model_artifacts(log_results_folder=True):
    if log_results_folder:
        update_artifact(
            folder_path = RESULTS_DIR,
            name_prefix = 'results',
            type = 'results'
        )

    update_artifact(
        folder_path = TRAINED_MODEL_DIR,
        name_prefix = 'trained_models',
        type = 'trained_models'
    )

In [None]:
#@title update_dataset_artifact

from pathlib import Path

def update_dataset_artifact(config, train_df, val_df=None, test_df=None):
    DATASET_DIR = Path('./dataset')
    os.makedirs(DATASET_DIR, exist_ok=True)

    train_df.to_csv(DATASET_DIR / 'train_data.csv', index=False)

    if test_df is not None:
        test_df.to_csv(DATASET_DIR / 'test_data.csv', index=False)

    if val_df is not None:
        val_df.to_csv(DATASET_DIR / 'val_data.csv', index=False)

    update_artifact(
        folder_path = DATASET_DIR,
        name_prefix = 'dataset',
        type = 'dataset'
    )

In [None]:
#@title update_env_state_artifact

def update_env_state_artifact(val_env_end_state):
    file_path = 'val_env_end_state.csv'

    df_last_state = pd.DataFrame({"last_state": val_env_end_state})
    df_last_state.to_csv(
        file_path, index=False
    )

    artifact = wandb.Artifact(name=f'val_env_end_state-{wandb.run.id}', type='env_state')
    artifact.add_file(file_path)
    wandb.run.log_artifact(artifact)

## Build & helper funcs

In [None]:
#@title build_yearly_train_test
def build_yearly_train_test(config):
    train_start_date, test_start_date, test_end_date = generate_yearly_train_test_dates(
        config['train_years_count'],
        config['test_years_count'],
        config['test_start_year']
    )

    train_df = preproc_df[(preproc_df['date'] >= train_start_date) & (preproc_df['date'] < test_start_date)]
    test_df = preproc_df[(preproc_df['date'] >= test_start_date) & (preproc_df['date'] < test_end_date)]

    train_df = set_daily_index(train_df)
    test_df = set_daily_index(test_df)

    dataset_name = get_yearly_dataset_name(
        config['stock_index_name'], train_start_date, test_start_date, test_end_date
    )

    config.update(dict(
        train_start_date=train_start_date,
        test_start_date=test_start_date,
        test_end_date=test_end_date,
        dataset_name=dataset_name
    ))

    update_dataset_artifact(
        config,

        train_df=train_df,
        val_df=val_df,
        test_df=test_df,
    )
    return train_df, test_df

In [None]:
#@title build_quarterly_train_val_test
from finrl.meta.data_processors.processor_yahoofinance import YahooFinanceProcessor

def build_quarterly_train_val_test(config):
    date_range = {key: pd.Timestamp(date) for key, date in config['date_range'].items()}

    train_start_date = date_range['train_start_date']
    val_start_date = date_range['val_start_date']
    test_start_date = date_range['test_start_date']
    test_end_date = date_range['test_end_date']

    train_env_config = get_env_config(
        start_date=date_range['train_start_date'],
        end_date=date_range['val_start_date'],
        if_train=False
    )

    val_env_config = get_env_config(
        start_date=date_range['val_start_date'],
        end_date=date_range['test_start_date'],
        if_train=True
    )

    test_env_config = get_env_config(
        start_date=date_range['test_start_date'],
        end_date=date_range['test_end_date'],
        if_train=False
    )

    dataset_name = get_quarterly_dataset_name(
        config['stock_index_name'], train_start_date, val_start_date, test_start_date
    )

    config.update(dict(
        dataset_name = dataset_name
    ))

    return train_env_config, val_env_config, test_env_config

# train_env_config, val_env_config, test_env_config = build_quarterly_train_val_test(config)

In [None]:
#@title Init StockTradingEnv (numpy)

from finrl.meta.env_stock_trading.env_stocktrading_np import StockTradingEnv
from finrl.meta.data_processor import DataProcessor
from finrl.config_tickers import DOW_30_TICKER
from finrl.config import INDICATORS
# from finrl.config import CACHE_DIR

def init_env(
    np_env_config,
    run_config,
    mode,
    turbulence_threshold=99,
):
    assert mode in ['train', 'val', 'test']

    env = StockTradingEnv(
        config=np_env_config,
        initial_capital=run_config['initial_amount'],
        buy_cost_pct=run_config['cost_pct'],
        sell_cost_pct=run_config['cost_pct'],
        turbulence_thresh=turbulence_threshold
    )

    return env

# env = init_env(
#     train_np_env_config,
#     run_config,
#     'train'
# )

In [None]:
#@title Define metric functions

def calculate_mdd(asset_values):
    """
    Calculate the Maximum Drawdown (MDD) of a portfolio.
    """
    running_max = asset_values.cummax()
    drawdown = (asset_values - running_max) / running_max
    mdd = drawdown.min() * 100  # Convert to percentage
    return mdd

def calculate_sharpe_ratio(asset_values, risk_free_rate=0.0):
    """
    Calculate the Sharpe Ratio of a portfolio.
    """
    # Calculate daily returns
    returns = asset_values.pct_change().dropna()
    excess_returns = returns - risk_free_rate / 252  # Assuming 252 trading days

    if excess_returns.std() == 0:
        return 0.0
    sharpe_ratio = excess_returns.mean() / excess_returns.std() * np.sqrt(252)  # Annualized
    return sharpe_ratio

def calculate_annualized_return(asset_values):
    """
    Calculate the annualized return of a portfolio.
    """
    # Assume `asset_values` is indexed by date or trading day
    total_return = (asset_values.iloc[-1] / asset_values.iloc[0] - 1) * 100
    num_days = (asset_values.index[-1] - asset_values.index[0]).days
    annualized_return = (1 + total_return) ** (365 / num_days) - 1
    return annualized_return

In [None]:
#@title compute metrics
import wandb
from typing import List
import numpy as np

def compute_metrics(account_values: List[pd.DataFrame, pd.Series, np.array]):
    """
    If DataFrame then should contain two columns - 'date' and name of algo, e.g. 'a2c'.
    """

    if isinstance(account_values, pd.DataFrame):
        assert isinstance(account_values, pd.DataFrame)
        if 'date' not in account_values.columns:
            if account_values.index.name == 'date':
                account_values.reset_index(inplace=True)
            else:
                raise ValueError("should contain 'date' column or index")
        account_values = account_values.dropna().set_index('date').iloc[:, 0]
    elif isinstance(account_values, np.ndarray):
        account_values = pd.Series(account_values)

    sharpe = calculate_sharpe_ratio(account_values)
    mdd = calculate_mdd(account_values)
    cum_ret = (account_values.iloc[-1] - account_values.iloc[0]) / account_values.iloc[0] * 100
    # num_days = (account_values.index.max() - account_values.index.min()).days
    num_days = len(account_values)
    ann_ret = ((1 + cum_ret / 100) ** (365 / num_days) - 1) * 100

    return {
            f'sharpe_ratio': sharpe,
            f'mdd': mdd,
            f'ann_return': ann_ret,
            f'cum_return': cum_ret,
        }

def get_env_metrics(env):
    end_total_asset = env.state[0] + sum(
        np.array(env.state[1 : (env.stock_dim + 1)])
        * np.array(env.state[(env.stock_dim + 1) : (env.stock_dim * 2 + 1)])
    )

    return {
        'begin_total_asset': env.asset_memory[0],
        'end_total_asset': end_total_asset,
        'total_cost': env.cost,
        'total_trades': env.trades,
    }

In [None]:
#@title log_metrics

def log_metrics(metrics, model_name, split_label, step=None):
    print(f'log_metrics for {model_name}')

    rename_metrics = lambda model_name: {
        f"{key}/{model_name}": value for key, value in metrics.items()
    }

    renamed_metrics = rename_metrics(model_name)
    wandb.log({split_label: renamed_metrics}, step=step)
    # wandb.run.save()

In [None]:
#@title update_best_model_metrics

def update_best_model_metrics(metrics, model_name, split_label):
    if 'sharpe_ratios' not in wandb.run.config:
        wandb.run.config['sharpe_ratios'] = {}

    if split_label not in wandb.run.config['sharpe_ratios']:
        wandb.run.config['sharpe_ratios'][split_label] = {}

    sharpe_ratios = wandb.run.config['sharpe_ratios'][split_label]

    print(f"DEBUG ({split_label}): run.id = {wandb.run.id}")
    print(f"DEBUG ({split_label}): sharpe_ratios = {sharpe_ratios}")
    print(f"DEBUG ({split_label}): updating best model based on sharpe_ratios: {sharpe_ratios}")
    if len(sharpe_ratios) > 0:
        best_model_name = max(sharpe_ratios, key=sharpe_ratios.get)
        if metrics['sharpe_ratio'] > sharpe_ratios[best_model_name]:
            print(
                f"DEBUG ({split_label}): {round(metrics['sharpe_ratio'], 2)} ({model_name})"
                f" > {round(sharpe_ratios[best_model_name], 2)} ({best_model_name})"
                f". New best model: {model_name}."
            )
            log_metrics(metrics, 'best_model', split_label)
            wandb.log({split_label: {'best_model_name': model_name}})
        else:
            print(
                f"DEBUG ({split_label}): {round(metrics['sharpe_ratio'], 2)} ({model_name})"
                f" <= {round(sharpe_ratios[best_model_name], 2)} ({best_model_name})"
                ". Not updating best model."
            )
    else:
        print(f"DEBUG ({split_label}): no models logged yet, new best model is current one: {model_name}")
        print(f"DEBUG ({split_label}): wandb.run.config['sharpe_ratios'] = {wandb.run.config['sharpe_ratios']}")
        log_metrics(metrics, 'best_model', split_label)

    wandb.run.config['sharpe_ratios'][split_label][model_name] = metrics['sharpe_ratio']

# Config

In [None]:
#@title init
parameters_dict = {}
sweep_config = {
    'method': 'grid',
    'metric': {
        'name': 'test.sharpe_ratio/best_model'
    },
    'parameters': parameters_dict
}

In [None]:
#@title CONFIG: create dataset - yearly_train_test

min_test_start_year = 2020
max_test_start_year = 2025

########################################################

yearly_dataset_params = dict(
    # dataset_period = {'value': 'year'},
    # dataset_splits = {'parameters': {
    #     'train': {'value': True },
    #     'val': {'value': False },
    #     'test': {'value': True },
    # }},

    dataset_type = {'value':'yearly_train_test'},

    stock_index_name = {'value': 'DOW-30'},

    train_years_count = {'value': 10},
    test_years_count = {'value': 1},
    test_start_year = {
        'values': list(range(min_test_start_year, max_test_start_year))
    }
)

In [None]:
#@title CONFIG: create dataset - quarterly_train_test

# Full date range (18 backtests)
train_start_date = '2009-01-01'
min_test_start_date = '2016-01-01'
max_test_end_date = '2020-08-05'

#################################################################

date_ranges = generate_quarterly_date_ranges(
    train_start_date, min_test_start_date, max_test_end_date, return_strings=True
)

quarterly_dataset_params = dict(
    dataset_type = {'value': 'quarterly_train_val_test'},
    stock_index_name = {'value': 'DOW-30'},
    train_start_date = {'value': train_start_date},
    min_test_start_date = {'value': min_test_start_date},
    max_test_end_date = {'value': max_test_end_date},
    date_range = {
        'values': date_ranges
    }
)

In [None]:
#@title CONFIG: choose dataset
parameters_dict.update(
    # yearly_dataset_params,
    quarterly_dataset_params
)

In [None]:
#@title CONFIG: env params
parameters_dict.update(dict(
    cost_pct = {'value': 1e-3},
    initial_amount = {'value': 50_000},
    turbulence_threshold = {
        'value': 99,
        # 'values': [30, 40, 50, 60, 70]
    },
    if_vix = {'value': True}
))

In [None]:
#@title CONFIG: models_used
parameters_dict.update({
    'if_using_a2c': {'value': True},
    'if_using_ddpg': {'value': True},
    'if_using_ppo': {'value': True},
    'if_using_td3': {'value': True},
    'if_using_sac': {'value': True}
})

In [None]:
#@title CONFIG: model and training params
model_params = {
    'parameters': {
        'ppo': {'value': dict(
            train_batch_size=2048,
            num_epochs=10,
            minibatch_size=128,
            lr=5e-5,
            gamma=0.99
        )}
    }
}

train_params = {
    'parameters': {
        'ppo': {
            'parameters': {
                'steps': {'value': 2_048}
            }
        },
    }
}

env_runners_params = {
    'parameters': dict(
        num_envs_per_env_runner = {'value': 1},
        num_env_runners = {'value': 0},
    )
}

parameters_dict.update({'env_runners_params': env_runners_params})
parameters_dict.update({'train_params': train_params})
parameters_dict.update({'model_params': model_params})

# Train & eval funcs

In [None]:
#@title MetricsLoggerCallback (class)
from ray.rllib.algorithms.callbacks import DefaultCallbacks
from typing import Optional, Sequence
import gymnasium as gym
from gymnasium.vector import AsyncVectorEnv

class MetricsLoggerCallback(DefaultCallbacks):
    def __init__(self, model_name, ema_coeff=0.2, ma_window=20, log_to_wandb=False):
        super().__init__()

        self.model_name = model_name
        self.ema_coeff = ema_coeff
        self.ma_window = ma_window
        self.metric_names = set()
        self.log_to_wandb = log_to_wandb

    def unwrap_env(self, env):
        env = env.unwrapped
        # print(type(env))
        # (SingleAgentEnvRunner pid=127688) <class 'gymnasium.vector.sync_vector_env.SyncVectorEnv'>

        if isinstance(env, AsyncVectorEnv):
            return env
        else:
            env = env.envs[0]
            # print(type(env))
            # (SingleAgentEnvRunner pid=127688) <class 'gymnasium.wrappers.common.OrderEnforcing'>

            env = env.env
            # print(type(env))
            # (SingleAgentEnvRunner pid=127688) <class 'gymnasium.wrappers.common.PassiveEnvChecker'>

            env = env.env
            # print(type(env))
            # (SingleAgentEnvRunner pid=127688) <class 'finrl.meta.env_stock_trading.env_stocktrading.StockTradingEnv'>
        return env

    def on_episode_step(
            self,
            *,
            episode,
            env_runner,
            metrics_logger,
            env,
            env_index,
            rl_module,
            **kwargs,
        ) -> None:

        env = self.unwrap_env(env)

        if isinstance(env, AsyncVectorEnv):
            asset_values = env.get_attr('asset_memory')
            asset_values = pd.concat([pd.Series(av) for av in asset_values], axis=1).mean(axis=1)
        else:
            asset_values = env.asset_memory

        metrics = get_account_value_metrics(asset_values)

        # mode = env.mode
        for metric_name, metric_value in metrics.items():
            # metric_name = f"{mode}/{metric_name}" if mode != "" else metric_name
            episode.add_temporary_timestep_data(metric_name, metric_value)
            self.metric_names.update([metric_name])

    def on_episode_end(
            self,
            *,
            episode,
            env_runner,
            metrics_logger,
            env,
            env_index,
            rl_module,
            **kwargs,
        ) -> None:

        for metric_name in self.metric_names:
            metric_values = episode.get_temporary_timestep_data(metric_name)
            metric_value = np.nanmean(np.array(metric_values))

            # metrics_logger.log_value(
            #     metric_name,
            #     metric_value,
            #     reduce='mean',
            # )

            metrics_logger.log_value(
                f"{metric_name}_EMA_{self.ema_coeff}",
                metric_value,
                reduce='mean',
                ema_coeff=self.ema_coeff
            )

            metrics_logger.log_value(
                f"{metric_name}_ma_{self.ma_window}",
                metric_value,
                reduce='mean',
                window=self.ma_window
            )

            if self.log_to_wandb:
                mode = 'val' if env_runner.config.in_evaluation else 'train'
                wandb.log({
                    f"{mode}.{metric_name}/{self.model_name}": metric_value,
                }) # TODO: log on every episode step

In [None]:
#@title print_result

RESULT_KEYS_TO_INCLUDE = [
    'sharpe_ratio_MA',
    'ann_return_MA',
    'mdd_MA',

    'sharpe_ratio_EMA',
    'ann_return_EMA',
    'mdd_EMA',
]

def print_result(result):
    print()
    for key in result['env_runners'].keys():
        for include_key in RESULT_KEYS_TO_INCLUDE:
            if key.startswith(include_key):
                print(f"train/{key}: {round(result['env_runners'][key], 2)}")
                break

    for key in result['evaluation']['env_runners'].keys():
        for include_key in RESULT_KEYS_TO_INCLUDE:
            if key.startswith(include_key):
                print(f"val/{key}: {round(result['evaluation']['env_runners'][key], 2)}")
                break
    print()

In [None]:
#@title benchmark_exec_time
import pandas as pd
from time import perf_counter
from functools import wraps

def benchmark_exec_time(func):
    @wraps(func)
    def wrapper(*args, **kwargs):

        start = perf_counter()
        output = func(*args, **kwargs)
        end = perf_counter()

        exec_time_sec = end - start

        data = {
            "func_name": func.__name__,
            "exec_time_sec": exec_time_sec,
        }
        print(f'\nBenchmark results: {data}')
        return output, exec_time_sec

    return wrapper

In [None]:
#@title Train_eval_models

from ray.rllib.algorithms.ppo import PPOConfig

AVAILABLE_MODELS_CONFIGS = {
    'ppo': PPOConfig
}

def create_stock_trading_env(env_config):
    return init_env(**env_config)

def train_eval_rllib_models(
        run_config,
        train_np_env_config,
        val_np_env_config,
        test_np_env_config,
        model_list = ['ppo'], # TODO: discard in favor of 'if_using_{model_name}'
        pretrained_models = {}
    ):

    assert set(model_list).issubset(AVAILABLE_MODELS_CONFIGS)
    check_and_make_directories([TRAINED_MODEL_DIR])

    register_env("stock_trading_env", create_stock_trading_env)

    for model_name in model_list:
        if run_config[f"if_using_{model_name}"]:
            print(f"Training {model_name.upper()} agent")
            model_config = AVAILABLE_MODELS_CONFIGS[model_name]
            algo = train_rllib_model(
                run_config,
                model_name,
                model_config,
                train_np_env_config,
                val_np_env_config,
                pretrained_algo = pretrained_models.get(model_name, None)
            )

            print(f"Evaluating {model_name.upper()} agent")
            val_result = evaluate_model(
                algo,
                model_name,
                run_config=run_config,
                np_env_config = val_np_env_config,
                mode='val',
            )
            fig = plot_results(**val_result)
            log_plot_as_artifact(fig, "asset_value_comparison_validation", artifact_type="plot")

            test_result = evaluate_model(
                algo,
                model_name,
                run_config=run_config,
                np_env_config = test_np_env_config,
                mode='test',
            )
            fig = plot_results(**test_result)
            log_plot_as_artifact(fig, "test_asset_value_comparison_test", artifact_type="plot")
        else:
            print(f"Skipping {model_name.upper()} agent")

In [None]:
#@title Train RLlib model

from functools import partial
from pathlib import Path
import torch
from ray.tune.registry import register_env
from time import perf_counter
from math import ceil
import ray

def train_rllib_model(
    run_config,
    model_name,
    model_config,
    train_np_env_config,
    val_np_env_config,
    pretrained_algo=None
):
    num_envs_per_env_runner = run_config['env_runners_params']['num_envs_per_env_runner']
    num_env_runners = run_config['env_runners_params']['num_env_runners']

    if pretrained_algo:
        algo = pretrained_algo
    else:
        run_config = (
            model_config()
            .environment(
                env="stock_trading_env",
                env_config={
                    # "df": train,
                    "np_env_config": train_np_env_config,
                    "run_config": run_config,
                    "mode": 'train'
                },
            )
            .env_runners(
                num_envs_per_env_runner=num_envs_per_env_runner,
                num_env_runners=num_env_runners,
                num_cpus_per_env_runner= (2/num_env_runners) if num_env_runners > 2 else None,

                # gym_env_vectorize_mode=gym.envs.registration.VectorizeMode.ASYNC,
            )
            .training(
                train_batch_size=2048,
                num_epochs=10,
                minibatch_size=128,
            )
            .evaluation(
                # Set up the validation environment
                evaluation_interval=1,  # Specify evaluation frequency (1=after each training step)
                evaluation_config={
                    "env": "stock_trading_env",
                    "env_config": {
                        # "df": val,
                        "np_env_config": val_np_env_config,
                        "run_config": run_config,
                        "mode": 'val'
                    },
                },
            )
            # .callbacks(MetricsLoggerCallback)
            .callbacks(partial(MetricsLoggerCallback, model_name='ppo', log_to_wandb=True))
            .resources(
                num_gpus=1 if torch.cuda.is_available() else None
            )
        )

        # 2. build the algorithm ..
        if run_config.num_env_runners > 0:
            ray.shutdown()
            ray.init()

            register_env("stock_trading_env", create_stock_trading_env)

        algo = run_config.build()

    print(f"Envs: {algo.config.num_envs_per_env_runner}, Runners: {algo.config.num_env_runners}")

    # 3. .. train it ..
    @benchmark_exec_time
    def _train_rllib(total_timesteps):
        results = []
        total_batches = ceil(total_timesteps / algo.config.train_batch_size)
        print(f"total_batches: {total_batches}")
        print(f"total_timesteps: {total_timesteps}")
        for _ in range(total_batches):
            result = algo.train()
            results.append(result)

        print_result(result)
        return results

    results, exec_time_sec = _train_rllib(algo.config['train_params'])

    # Log train duration
    duration_minutes = round(exec_time_sec / 60, 1)
    wandb.run.summary[f"train.duration_minutes/{model_name}"] = duration_minutes
    update_model_artifacts(log_results_folder=False)

    # Save model
    ckpt_path = (Path(TRAINED_MODEL_DIR) / model_name).resolve()
    algo.save(ckpt_path)

    return algo

In [None]:
#@title Evaluate RLlib model
from ray.rllib.algorithms.algorithm import Algorithm
from ray.rllib.core.columns import Columns
from ray.rllib.utils.numpy import convert_to_numpy, softmax
from ray.rllib.models.torch.torch_distributions import TorchDiagGaussian

import torch

# Create the testing environment
def evaluate_model(algo, model_name, mode, np_env_config, run_config, turbulence_thresh=None):
    if turbulence_thresh is None:
        turbulence_thresh = run_config.get('turbulence_thresh', 99)

    turbulence_name = 'turbulence' if not run_config.get('if_vix', None) else '^VIX'

    print(f"Evaluating using `{turbulence_name}`: {turbulence_thresh}")

    env_config = {
        # "df": val,
        "np_env_config": np_env_config,
        "run_config": run_config,
        "mode": 'val',
        'turbulence_threshold': turbulence_thresh
    }

    eval_env = create_stock_trading_env(env_config)
    state, info = eval_env.reset()
    done = False

    # Perform inference using the trained RLlib agent
    rl_module = algo.env_runner.module
    while not done:
        # Compute action using the RLlib trained agent
        input_dict = {Columns.OBS: torch.Tensor(state).unsqueeze(0)}
        rl_module_out = rl_module.forward_inference(input_dict)
        logits = rl_module_out[Columns.ACTION_DIST_INPUTS]

        # Take mean of multivariate Gaussian distribution
        mean, log_std = logits.chunk(2, dim=-1)

        # action_distribution = TorchDiagGaussian.from_logits(logits)
        # action_distribution = action_distribution.to_deterministic()
        # assert np.allclose(mean, action_distribution.loc)
        # assert np.allclose(log_std.exp(), action_distribution._dist.scale)
        # action = action_distribution.sample()

        action = mean.detach().numpy().squeeze()

        # Clip the action to ensure it's within the action space bounds
        action = np.clip(action, eval_env.action_space.low, eval_env.action_space.high)

        # Perform action
        state, reward, terminated, truncated, _ = eval_env.step(action)
        done = terminated or truncated

    df_account_value = eval_env.save_asset_memory()
    metrics = compute_metrics(df_account_value)

    turb_name = 'ti' if not run_config.get('if_vix', None) else 'vix'
    for metric_name, metric_value in metrics.items():
        metric_name = f"{mode}.{turb_name}_{turbulence_thresh}.{metric_name}/{model_name}"
        wandb.run.summary[metric_name] = metric_value

    turbulence_series = pd.Series(
        np_env_config['turbulence_array'][:len(df_account_value)],
        index=df_account_value['date'],
        name=turbulence_name
    )

    eval_result = {
        'account_value': df_account_value.rename(columns={'account_value': model_name}),
        'turbulence_series': turbulence_series,
        'turbulence_thresh': turbulence_thresh
    }

    return eval_result

# val_result = evaluate_model(
#     algo,
#     model_name,
#     run_config=run_config,
#     np_env_config = val_np_env_config,
#     mode='val',
# )

# test_result = evaluate_model(
#     algo,
#     model_name,
#     run_config=run_config,
#     np_env_config = test_np_env_config,
#     mode='test',
# )

In [None]:
#@title plot_results (w/separate turbulence plot)

%matplotlib inline

def plot_results(
        account_value,
        turbulence_series,
        turbulence_thresh,
    ):

    assert turbulence_series.name in ['turbulence', '^VIX']

    # Create figure and subplots
    fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(14, 10), sharex=True, gridspec_kw={'height_ratios': [3, 1]})

    # Main plot
    method_styles = {
        'A2C': {'color': '#8c564b', 'linestyle': '--'},
        'DDPG': {'color': '#e377c2', 'linestyle': '-'},
        'PPO': {'color': '#7f7f7f', 'linestyle': '-'},
        'TD3': {'color': '#bcbd22', 'linestyle': '--'},
        'SAC': {'color': '#17becf', 'linestyle': '-'},
        'DJIA': {'color': '#000000', 'linestyle': '-'},
    }

    if 'DJIA' in account_value:
        ax1.plot(account_value.index, account_value['DJIA'], label="Dow Jones Index",
                linestyle=method_styles['DJIA']['linestyle'], color=method_styles['DJIA']['color'])

    if 'date' in account_value.columns:
        account_value.set_index('date', inplace=True)

    account_value.rename(columns={col: col.upper() for col in account_value.columns}, inplace=True)

    for model_name in account_value.columns:
        style = method_styles[model_name]
        ax1.plot(account_value.index, account_value[model_name], label=model_name, **style)

    # Customize main plot
    ax1.set_title("Performance Comparison of DRL Agents", fontsize=20, fontweight='bold')
    ax1.set_ylabel("Total Asset Value ($)", fontsize=16, fontweight='bold')
    ax1.legend(loc='lower right')
    ax1.grid(True, linestyle='--', alpha=0.3)

    # Turbulence plot
    if turbulence_series.name == 'turbulence':
        turbulence_label = "Turbulence Index"
    else:
        turbulence_label = "VIX Coefficient"
    ax2.plot(turbulence_series.index, turbulence_series, label="Turbulence Index", color='red', linestyle='--', linewidth=2)
    ax2.axhline(y=turbulence_thresh, color='red', linestyle=':', label='Turbulence Threshold')
    ax2.set_ylabel(turbulence_label, fontsize=16, fontweight='bold')
    ax2.legend(loc='lower left')
    ax2.grid(True, linestyle='--', alpha=0.3)
    max_turbulence = max(turbulence_series.max(), turbulence_thresh)
    ax2.set_ylim(0, max_turbulence + 10)
    ax2.set_yticks(list(ax2.get_yticks()) + [turbulence_thresh])

    # Shared x-axis label
    ax2.set_xlabel("Date", fontsize=16, fontweight='bold')

    # Tight layout
    plt.tight_layout()
    return fig

# fig = plot_results(**val_result)
# plt.show()

In [None]:
#@title log_plot_as_artifact

import os
import matplotlib.pyplot as plt
import wandb

def log_plot_as_artifact(fig, artifact_name_prefix, artifact_type="plot"):
    print('log_plot_as_artifact')
    """
    Save a Matplotlib figure without clipping and log it as a W&B artifact.

    Parameters:
        fig (matplotlib.figure.Figure): The Matplotlib figure to save and log.
        artifact_name (str): The name of the W&B artifact.
        artifact_type (str): The type of the artifact (default is "plot").
        filename (str): The filename to save the plot as (default is "plot.png").
    """
    try:
        # Get full artifact name
        artifact_name = f'{artifact_name_prefix}-{wandb.run.id}'
        filename = artifact_name + '.png'

        # Save the figure with tight layout and proper padding
        fig.savefig(filename, bbox_inches='tight', pad_inches=0.1, dpi=300)
        plt.close(fig)  # Close the figure to free up memory

        # Create and log the W&B artifact
        artifact = wandb.Artifact(artifact_name, type=artifact_type)
        artifact.add_file(filename, skip_cache=True, overwrite=True)
        wandb.log_artifact(artifact)
    finally:
        # Ensure the file is deleted after use
        if os.path.exists(filename):
            os.remove(filename)

# log_plot_as_artifact(fig, "performance_comparison_DRL_agents", artifact_type="plot")

# Run sweep

In [None]:
#@title Sweep Runner

import random
import string

def set_run_name(prefix, n=5):
    run_name = f"{prefix} | {wandb.run.id}"
    wandb.run.name = run_name
    wandb.run.save()

class SweepRunner:
    def __init__(self, sweep_id):
        self.sweep_id = sweep_id
        self.pretrained_models = {}

    def main(self, run_config=None):
        run_timer_start = perf_counter()
        with wandb.init(config=run_config):
            run_config = wandb.config

            if run_config['dataset_type'] == 'yearly_train_test':
                raise NotImplementedError
            elif run_config['dataset_type'] == 'quarterly_train_val_test':
                train_np_env_config, val_np_env_config, test_np_env_config = build_quarterly_train_val_test(run_config)
                set_run_name(run_config['dataset_name'])

                if self.pretrained_models:
                    run_config.update({'finetune': True})
                else:
                    run_config.update({'finetune': False})

                pretrained_models = train_eval_rllib_models(
                    run_config,
                    train_np_env_config,
                    val_np_env_config,
                    test_np_env_config,
                    pretrained_models=self.pretrained_models
                )

                self.pretrained_models = pretrained_models

            run_timer_end = perf_counter()
            run_duration_minutes = round( (run_timer_end - run_timer_start) / 60, 1)
            wandb.run.summary[f"run.duration_minutes"] = run_duration_minutes

In [None]:
#@title RUN SWEEP
def run_sweep(n_runs, sweep_config=None, sweep_id=None):
    if sweep_id is None:
        assert sweep_config is not None
        sweep_id = wandb.sweep(sweep_config, project=PROJECT)
    else:
        assert sweep_config is None

    !rm -rf {TRAINED_MODEL_DIR}/*

    sweep_runner = SweepRunner(sweep_id)
    wandb.agent(sweep_id, sweep_runner.main, project=PROJECT, count=n_runs)

run_sweep(
    # sweep_id='50a5ela4',
    sweep_config=sweep_config,

    # n_runs=None
    n_runs=1,
)