<a href="https://colab.research.google.com/github/tony-pitchblack/finrl-dt/blob/custom-backtesting/finrl_dt_replicate_sweep_rllib.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#Installs

In [2]:
!pip install stable-baselines3=='2.6.0a0'



In [3]:
# %load_ext autoreload
# %autoreload 2

In [4]:
# %%bash
# git clone https://github.com/tony-pitchblack/FinRL.git
# cd ./FinRL
# git checkout benchmarking
# pip install -r FinRL/requirements.txt

In [5]:
%%time
!pip install git+https://github.com/tony-pitchblack/FinRL.git@benchmarking --no-deps \
    # --force-reinstall --no-deps

Collecting git+https://github.com/tony-pitchblack/FinRL.git@benchmarking
  Cloning https://github.com/tony-pitchblack/FinRL.git (to revision benchmarking) to /tmp/pip-req-build-2z4p4kj9
  Running command git clone --filter=blob:none --quiet https://github.com/tony-pitchblack/FinRL.git /tmp/pip-req-build-2z4p4kj9
  Running command git checkout -b benchmarking --track origin/benchmarking
  Switched to a new branch 'benchmarking'
  Branch 'benchmarking' set up to track remote branch 'benchmarking' from 'origin'.
  Resolved https://github.com/tony-pitchblack/FinRL.git to commit f7dab0a4cda4ae68bfe3b45c3be038bf01327878
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
CPU times: user 88.5 ms, sys: 14 ms, total: 103 ms
Wall time: 12.9 s


In [6]:
%%time
!wget https://raw.githubusercontent.com/tony-pitchblack/FinRL/benchmarking/requirements.txt -O requirements.txt
!pip install -r requirements.txt

--2025-02-13 20:01:00--  https://raw.githubusercontent.com/tony-pitchblack/FinRL/benchmarking/requirements.txt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 764 [text/plain]
Saving to: ‘requirements.txt’


2025-02-13 20:01:00 (39.9 MB/s) - ‘requirements.txt’ saved [764/764]

CPU times: user 101 ms, sys: 14.5 ms, total: 115 ms
Wall time: 9.81 s


# Imports

In [7]:
import pandas as pd

from stable_baselines3.common.logger import configure
from finrl import config_tickers
from finrl.main import check_and_make_directories
from finrl.config import INDICATORS, TRAINED_MODEL_DIR, RESULTS_DIR

import warnings
warnings.filterwarnings("ignore", category=DeprecationWarning)

In [8]:
import os
from pathlib import Path
import pandas as pd
import warnings
warnings.filterwarnings("ignore", category=DeprecationWarning)

In [9]:
os.environ["WANDB_API_KEY"] = "aee284a72205e2d6787bd3ce266c5b9aefefa42c"

PROJECT = 'finrl-dt-replicate'
ENTITY = "overfit1010"

# General funcs

In [10]:
#@title YahooDownloader

"""Contains methods and classes to collect data from
Yahoo Finance API
"""

from __future__ import annotations

import pandas as pd
import yfinance as yf


class YahooDownloader:
    """Provides methods for retrieving daily stock data from
    Yahoo Finance API

    Attributes
    ----------
        start_date : str
            start date of the data (modified from neofinrl_config.py)
        end_date : str
            end date of the data (modified from neofinrl_config.py)
        ticker_list : list
            a list of stock tickers (modified from neofinrl_config.py)

    Methods
    -------
    fetch_data()
        Fetches data from yahoo API

    """

    def __init__(self, start_date: str, end_date: str, ticker_list: list):
        self.start_date = start_date
        self.end_date = end_date
        self.ticker_list = ticker_list

    def fetch_data(self, proxy=None) -> pd.DataFrame:
        """Fetches data from Yahoo API
        Parameters
        ----------

        Returns
        -------
        `pd.DataFrame`
            7 columns: A date, open, high, low, close, volume and tick symbol
            for the specified stock ticker
        """
        # Download and save the data in a pandas DataFrame:
        data_df = pd.DataFrame()
        num_failures = 0
        for tic in self.ticker_list:
            temp_df = yf.download(
                tic, start=self.start_date, end=self.end_date, proxy=proxy
            )
            temp_df["tic"] = tic
            if len(temp_df) > 0:
                # data_df = data_df.append(temp_df)
                data_df = pd.concat([data_df, temp_df], axis=0)
            else:
                num_failures += 1
        if num_failures == len(self.ticker_list):
            raise ValueError("no data is fetched.")
        # reset the index, we want to use numbers as index instead of dates
        data_df = data_df.reset_index()

        try:
            # Convert wide to long format
            # print(f"DATA COLS: {data_df.columns}")
            data_df = data_df.sort_index(axis=1).set_index(['Date']).drop(columns=['tic']).stack(level='Ticker', future_stack=True)
            data_df.reset_index(inplace=True)
            data_df.columns.name = ''

            # convert the column names to standardized names
            data_df.rename(columns={'Ticker': 'Tic', 'Adj Close': 'Adjcp'}, inplace=True)
            data_df.rename(columns={col: col.lower() for col in data_df.columns}, inplace=True)

            columns = [
                "date",
                "tic",
                "open",
                "high",
                "low",
                "close",
                "adjcp",
                "volume",
            ]

            data_df = data_df[columns]
            # use adjusted close price instead of close price
            data_df["close"] = data_df["adjcp"]
            # drop the adjusted close price column
            data_df = data_df.drop(labels="adjcp", axis=1)

        except NotImplementedError:
            print("the features are not supported currently")

        # create day of the week column (monday = 0)
        data_df["day"] = data_df["date"].dt.dayofweek
        # convert date to standard string format, easy to filter
        data_df["date"] = data_df.date.apply(lambda x: x.strftime("%Y-%m-%d"))
        # drop missing data
        data_df = data_df.dropna()
        data_df = data_df.reset_index(drop=True)
        print("Shape of DataFrame: ", data_df.shape)
        # print("Display DataFrame: ", data_df.head())

        data_df = data_df.sort_values(by=["date", "tic"]).reset_index(drop=True)

        return data_df

    def select_equal_rows_stock(self, df):
        df_check = df.tic.value_counts()
        df_check = pd.DataFrame(df_check).reset_index()
        df_check.columns = ["tic", "counts"]
        mean_df = df_check.counts.mean()
        equal_list = list(df.tic.value_counts() >= mean_df)
        names = df.tic.value_counts().index
        select_stocks_list = list(names[equal_list])
        df = df[df.tic.isin(select_stocks_list)]
        return df


In [11]:
#@title fix_daily_index

def make_daily_index(data_df, date_column='date', new_index_name='date_index'):
    # Get unique dates and create a mapping to daily indices
    total_dates = data_df[date_column].unique()
    date_to_index = {date: idx for idx, date in enumerate(sorted(total_dates))}
    return data_df[date_column].map(date_to_index)

def set_daily_index(data_df, date_column='date', new_index_name='date_index'):
    """
    Constructs a daily index from unique dates in the specified column.

    Parameters:
        data_df (pd.DataFrame): The input DataFrame.
        date_column (str): The name of the column containing dates.
        new_index_name (str): The name for the new index.

    Returns:
        pd.DataFrame: DataFrame with a daily index.
    """

    # Map dates to daily indices and set as index
    data_df[new_index_name] = make_daily_index(data_df, date_column='date', new_index_name='date_index')

    data_df.set_index(new_index_name, inplace=True)
    data_df.index.name = ''  # Remove the index name for simplicity

    return data_df

def fix_daily_index(df):
    if df.index.name == 'date':
        df.reset_index(inplace=True)

    daily_index = make_daily_index(df, date_column='date', new_index_name='date_index')
    if (df.index.values != daily_index.values).any():

        df.index = daily_index
        df.index.name = ''

    return df

# trade = fix_daily_index(trade)
# trade.index

In [12]:
#@title get dataset name

def get_quarterly_dataset_name(prefix, train_start_date, val_start_date, test_start_date):
    get_quarter = lambda date: f'Q{(date.month - 1) // 3 + 1}'

    val_quarter = get_quarter(val_start_date)
    test_quarter = get_quarter(test_start_date)

    # Extract year and month
    train_start = f"{train_start_date.year}-{train_start_date.month:02}"
    val_start = f"{val_start_date.year}"
    test_start = f"{test_start_date.year}"

    # Construct the dataset name
    dataset_name = f"{prefix} | {train_start} | {val_start} {val_quarter} | {test_start} {test_quarter}"

    return dataset_name

def get_yearly_dataset_name(prefix, train_start, test_start, test_end):
    # Extract year and month
    train_start_str = f"{train_start.year}-{train_start.month:02}"
    test_start_str = f"{test_start.year}-{test_start.month:02}"
    test_end_str = f"{test_end.year}-{test_end.month:02}"

    # Construct the dataset name
    dataset_name = f"{prefix} | {train_start_str} | {test_start_str} | {test_end_str}"
    return dataset_name


In [13]:
#@title add_dataset

def add_dataset(stock_index_name, train_df, test_df):
    if 'datasets' not in globals():
        global datasets
        datasets = {}

    # Ensure datetime format
    if 'date' in train_df.columns:
        train_df.set_index('date', inplace=True)
    train_df.index = pd.to_datetime(train_df.index)

    if 'date' in test_df.columns:
        test_df.set_index('date', inplace=True)
    test_df.index = pd.to_datetime(test_df.index)

    train_start_date = train_df.index[0]
    test_start_date = test_df.index[0]
    test_end_date = test_df.index[-1]

    dataset_name = get_yearly_dataset_name(
        stock_index_name,
        train_start_date, test_start_date, test_end_date
    )

    train_df.reset_index(inplace=True)
    test_df.reset_index(inplace=True)

    train_df = set_daily_index(train_df)
    test_df = set_daily_index(test_df)

    ticker_list = train_df.tic.unique().tolist()

    datasets[dataset_name] = {
        'train': train_df,
        'test': test_df,
        'metadata': dict(
            stock_index_name = stock_index_name,
            train_start_date = train_start_date,
            test_start_date = test_start_date,
            test_end_date = test_end_date,
            num_tickers = len(ticker_list),
            ticker_list = ticker_list,
        )
    }

# Load data

## DATA: DOW-30 (quarterly train/val/test)

In [14]:
train_start_date = '2015-01-01'
min_test_start_date = '2016-01-01'
max_test_end_date = '2016-10-01'

In [15]:
#@title generate_quarterly_date_ranges

def generate_quarterly_date_ranges(
    train_start_date,
    min_test_start_date,
    max_test_end_date,
    return_strings=False,
    finetune_previous_val=False
):
    is_quarter_start = lambda date: date.month in [1, 4, 7, 10] and date.day == 1

    min_test_start_date = pd.Timestamp(min_test_start_date)
    train_start_date = pd.Timestamp(train_start_date)
    max_test_end_date = pd.Timestamp(max_test_end_date)

    assert is_quarter_start(train_start_date), f"train_start_date {train_start_date} is not a quarter start date."
    assert is_quarter_start(min_test_start_date), f"min_test_start_date {min_test_start_date} is not a quarter start date."

    test_start_date = min_test_start_date
    date_ranges = []
    full_train_start_date = train_start_date

    while True:
        val_start_date = test_start_date - pd.DateOffset(months=3)
        test_end_date = test_start_date + pd.DateOffset(months=3)

        if test_end_date > max_test_end_date:
            break

        if len(date_ranges) == 0:
            # The first date_range contains the full training period
            train_start_date = full_train_start_date
        elif finetune_previous_val:
            # Use the previous validation range as the training range
            train_start_date = date_ranges[-1]['val_start_date']

        date_range = dict(
            train_start_date=train_start_date,
            val_start_date=val_start_date,
            test_start_date=test_start_date,
            test_end_date=test_end_date,
        )

        if return_strings:
            date_range = {k: str(v) for k, v in date_range.items()}

        date_ranges.append(date_range)

        test_start_date = test_end_date

    return date_ranges

date_ranges = generate_quarterly_date_ranges(
    train_start_date,
    min_test_start_date,
    max_test_end_date,
    finetune_previous_val=True
)

# print(*date_ranges[:2], sep='\n')
print(*date_ranges, sep='\n')

{'train_start_date': Timestamp('2015-01-01 00:00:00'), 'val_start_date': Timestamp('2015-10-01 00:00:00'), 'test_start_date': Timestamp('2016-01-01 00:00:00'), 'test_end_date': Timestamp('2016-04-01 00:00:00')}
{'train_start_date': Timestamp('2015-10-01 00:00:00'), 'val_start_date': Timestamp('2016-01-01 00:00:00'), 'test_start_date': Timestamp('2016-04-01 00:00:00'), 'test_end_date': Timestamp('2016-07-01 00:00:00')}
{'train_start_date': Timestamp('2016-01-01 00:00:00'), 'val_start_date': Timestamp('2016-04-01 00:00:00'), 'test_start_date': Timestamp('2016-07-01 00:00:00'), 'test_end_date': Timestamp('2016-10-01 00:00:00')}


In [16]:
#@title split_data

def split_data(data_df, date_range):
    def subset_date_range(df, start_date, end_date):
        df = df[(df['date'] >= start_date) & (df['date'] < end_date)]
        df = fix_daily_index(df)
        return df

    return {
        'train': subset_date_range(data_df, date_range['train_start_date'], date_range['val_start_date']),
        'val': subset_date_range(data_df, date_range['val_start_date'], date_range['test_start_date']),
        'test': subset_date_range(data_df, date_range['test_start_date'], date_range['test_end_date']),
    }

# data_splits = split_data(preproc_df, date_ranges[0])
# data_splits['train'].head()

In [17]:
#@title prepare_data (for np env)
import hashlib
from finrl.config_tickers import DOW_30_TICKER
from finrl.meta.data_processor import DataProcessor
import os

CACHE_DIR = './cache'
os.makedirs(CACHE_DIR, exist_ok=True)

def stable_hash(data):
    return hashlib.sha256(str(data).encode()).hexdigest()

def get_env_config(
    start_date,
    end_date,
    if_train,
    ticker_list=DOW_30_TICKER,
    technical_indicator_list=INDICATORS,
    time_interval='1d',
    if_vix=True,
    **kwargs
):
    print(f"Loading {'train' if if_train else 'eval'} data from {start_date} to {end_date}.")

    data_hash = stable_hash(tuple(sorted(ticker_list) + sorted(technical_indicator_list)))
    file_path = Path(CACHE_DIR) / f"{start_date}_{end_date}_{time_interval}_{data_hash}.csv"
    dp = DataProcessor(data_source='yahoofinance', tech_indicator=technical_indicator_list, vix=if_vix, **kwargs)
    if os.path.isfile(file_path):
        print(f"Using cached data: {file_path}")
        data = pd.read_csv(file_path, index_col=0)
    else:
        print("Creating new data.")
        data = dp.download_data(ticker_list, start_date, end_date, time_interval)
        data = dp.clean_data(data)
        data = dp.add_technical_indicator(data, technical_indicator_list)
        if if_vix:
            data = dp.add_vix(data)
        data.to_csv(file_path)

    (
        price_array,
        tech_array,
        turbulence_array,
        timestamp_array,
    ) = dp.df_to_array(
        data,
        if_vix,
        return_timestamps=True,
    )


    env_config = {
        "price_array": price_array,
        "tech_array": tech_array,
        "turbulence_array": turbulence_array,
        "timestamp_array": timestamp_array,
        "if_train": if_train
    }

    return env_config

# date_range=date_ranges[0]
# train_env_config = get_env_config(
#     start_date=date_range['val_start_date'],
#     end_date=date_range['test_start_date'],
#     if_train=True
# )
# val_env_config = get_env_config(
#     start_date=date_range['test_start_date'],
#     end_date=date_range['test_end_date'],
#     if_train=False
# )

# Main

## Wandb artifacts

In [18]:
#@title update_artifact

def update_artifact(folder_path, name_prefix, type):
    """
    Create or update a W&B artifact consisting of a folder.

    Args:
        run: The current W&B run.
        folder_path (str): Path to the folder to upload.
        artifact_name (str): Name of the artifact.
        artifact_type (str): Type of the artifact.
    """
    run = wandb.run
    artifact_name = f'{name_prefix}-{wandb.run.id}'

    # Create a new artifact
    artifact = wandb.Artifact(name=artifact_name, type=type)

    # Add the folder to the artifact
    artifact.add_dir(folder_path)

    # Log the artifact to W&B
    run.log_artifact(artifact)
    print(f"Artifact '{artifact_name}' has been updated and uploaded.")

In [19]:
#@title update_model_artifacts

def update_model_artifacts(log_results_folder=True):
    if log_results_folder:
        update_artifact(
            folder_path = RESULTS_DIR,
            name_prefix = 'results',
            type = 'results'
        )

    update_artifact(
        folder_path = TRAINED_MODEL_DIR,
        name_prefix = 'trained_models',
        type = 'trained_models'
    )

In [20]:
#@title update_dataset_artifact

from pathlib import Path

def update_dataset_artifact(config, train_df, val_df=None, test_df=None):
    DATASET_DIR = Path('./dataset')
    os.makedirs(DATASET_DIR, exist_ok=True)

    train_df.to_csv(DATASET_DIR / 'train_data.csv', index=False)

    if test_df is not None:
        test_df.to_csv(DATASET_DIR / 'test_data.csv', index=False)

    if val_df is not None:
        val_df.to_csv(DATASET_DIR / 'val_data.csv', index=False)

    update_artifact(
        folder_path = DATASET_DIR,
        name_prefix = 'dataset',
        type = 'dataset'
    )

In [21]:
#@title update_env_state_artifact

def update_env_state_artifact(val_env_end_state):
    file_path = 'val_env_end_state.csv'

    df_last_state = pd.DataFrame({"last_state": val_env_end_state})
    df_last_state.to_csv(
        file_path, index=False
    )

    artifact = wandb.Artifact(name=f'val_env_end_state-{wandb.run.id}', type='env_state')
    artifact.add_file(file_path)
    wandb.run.log_artifact(artifact)

## Build & helper funcs

In [22]:
#@title build_yearly_train_test
def build_yearly_train_test(config):
    train_start_date, test_start_date, test_end_date = generate_yearly_train_test_dates(
        config['train_years_count'],
        config['test_years_count'],
        config['test_start_year']
    )

    train_df = preproc_df[(preproc_df['date'] >= train_start_date) & (preproc_df['date'] < test_start_date)]
    test_df = preproc_df[(preproc_df['date'] >= test_start_date) & (preproc_df['date'] < test_end_date)]

    train_df = set_daily_index(train_df)
    test_df = set_daily_index(test_df)

    dataset_name = get_yearly_dataset_name(
        config['stock_index_name'], train_start_date, test_start_date, test_end_date
    )

    config.update(dict(
        train_start_date=train_start_date,
        test_start_date=test_start_date,
        test_end_date=test_end_date,
        dataset_name=dataset_name
    ))

    update_dataset_artifact(
        config,

        train_df=train_df,
        val_df=val_df,
        test_df=test_df,
    )
    return train_df, test_df

In [23]:
#@title build_quarterly_train_val_test
from finrl.meta.data_processors.processor_yahoofinance import YahooFinanceProcessor

def build_quarterly_train_val_test(config):
    date_range = {key: pd.Timestamp(date) for key, date in config['date_range'].items()}

    train_start_date = date_range['train_start_date']
    val_start_date = date_range['val_start_date']
    test_start_date = date_range['test_start_date']
    test_end_date = date_range['test_end_date']

    train_env_config = get_env_config(
        start_date=date_range['train_start_date'],
        end_date=date_range['val_start_date'],
        if_train=False
    )

    val_env_config = get_env_config(
        start_date=date_range['val_start_date'],
        end_date=date_range['test_start_date'],
        if_train=True
    )

    test_env_config = get_env_config(
        start_date=date_range['test_start_date'],
        end_date=date_range['test_end_date'],
        if_train=False
    )

    dataset_name = get_quarterly_dataset_name(
        config['stock_index_name'], train_start_date, val_start_date, test_start_date
    )

    config.update({
        "dataset_name": dataset_name,
        "train.num_datapoints": len(train_env_config['price_array']),
        "val.num_datapoints": len(val_env_config['price_array']),
        "test.num_datapoints": len(test_env_config['price_array']),
    })

    return train_env_config, val_env_config, test_env_config

# train_env_config, val_env_config, test_env_config = build_quarterly_train_val_test(config)

In [24]:
#@title Init StockTradingEnv (numpy)

from finrl.meta.env_stock_trading.env_stocktrading_np import StockTradingEnv
from finrl.meta.data_processor import DataProcessor
from finrl.config_tickers import DOW_30_TICKER
from finrl.config import INDICATORS
# from finrl.config import CACHE_DIR

def init_env(
    np_env_config,

    # run_config,
    initial_amount,
    cost_pct,

    mode,
    turbulence_threshold=99,
):
    assert mode in ['train', 'val', 'test']

    print('Initializing env...', end=' ')
    env = StockTradingEnv(
        config=np_env_config,
        initial_capital=initial_amount,
        buy_cost_pct=cost_pct,
        sell_cost_pct=cost_pct,
        turbulence_thresh=turbulence_threshold
    )
    print('Done.')

    return env

# env = init_env(
#     train_np_env_config,
#     run_config,
#     'train'
# )

In [25]:
#@title Define metric functions

def calculate_mdd(asset_values):
    """
    Calculate the Maximum Drawdown (MDD) of a portfolio.
    """
    running_max = asset_values.cummax()
    drawdown = (asset_values - running_max) / running_max
    mdd = drawdown.min() * 100  # Convert to percentage
    return mdd

def calculate_sharpe_ratio(asset_values, risk_free_rate=0.0):
    """
    Calculate the Sharpe Ratio of a portfolio.
    """
    # Calculate daily returns
    returns = asset_values.pct_change().dropna()
    excess_returns = returns - risk_free_rate / 252  # Assuming 252 trading days

    if excess_returns.std() == 0:
        return 0.0
    sharpe_ratio = excess_returns.mean() / excess_returns.std() * np.sqrt(252)  # Annualized
    return sharpe_ratio

def calculate_annualized_return(asset_values):
    """
    Calculate the annualized return of a portfolio.
    """
    # Assume `asset_values` is indexed by date or trading day
    total_return = (asset_values.iloc[-1] / asset_values.iloc[0] - 1) * 100
    num_days = (asset_values.index[-1] - asset_values.index[0]).days
    annualized_return = (1 + total_return) ** (365 / num_days) - 1
    return annualized_return

In [26]:
#@title compute metrics
import wandb
from typing import List
import numpy as np

def compute_metrics(account_values: List[pd.DataFrame, pd.Series, np.array], use_round=True):
    """
    If DataFrame then should contain two columns - 'date' and name of algo, e.g. 'a2c'.
    """

    if isinstance(account_values, pd.DataFrame):
        assert isinstance(account_values, pd.DataFrame)
        if 'date' not in account_values.columns:
            if account_values.index.name == 'date':
                account_values.reset_index(inplace=True)
            else:
                raise ValueError("should contain 'date' column or index")
        account_values = account_values.dropna().set_index('date').iloc[:, 0]
    elif isinstance(account_values, np.ndarray):
        account_values = pd.Series(account_values)

    sharpe = calculate_sharpe_ratio(account_values)
    mdd = calculate_mdd(account_values)
    cum_ret = (account_values.iloc[-1] - account_values.iloc[0]) / account_values.iloc[0] * 100
    # num_days = (account_values.index.max() - account_values.index.min()).days
    num_days = len(account_values)
    ann_ret = ((1 + cum_ret / 100) ** (365 / num_days) - 1) * 100

    metrics = {
            f'sharpe_ratio': sharpe,
            f'mdd': mdd,
            f'ann_return': ann_ret,
            f'cum_return': cum_ret,
        }

    if use_round:
        metrics = {k: round(v, 2) for k, v in metrics.items()}

    return metrics

def get_env_metrics(env):
    end_total_asset = env.state[0] + sum(
        np.array(env.state[1 : (env.stock_dim + 1)])
        * np.array(env.state[(env.stock_dim + 1) : (env.stock_dim * 2 + 1)])
    )

    return {
        'begin_total_asset': env.asset_memory[0],
        'end_total_asset': end_total_asset,
        'total_cost': env.cost,
        'total_trades': env.trades,
    }

In [27]:
#@title log_metrics

def log_metrics(metrics, model_name, split_label, step=None):
    print(f'log_metrics for {model_name}')

    rename_metrics = lambda model_name: {
        f"{key}/{model_name}": value for key, value in metrics.items()
    }

    renamed_metrics = rename_metrics(model_name)
    wandb.log({split_label: renamed_metrics}, step=step)
    # wandb.run.save()

In [28]:
#@title update_best_model_metrics

def update_best_model_metrics(metrics, model_name, split_label):
    if 'sharpe_ratios' not in wandb.run.config:
        wandb.run.config['sharpe_ratios'] = {}

    if split_label not in wandb.run.config['sharpe_ratios']:
        wandb.run.config['sharpe_ratios'][split_label] = {}

    sharpe_ratios = wandb.run.config['sharpe_ratios'][split_label]

    print(f"DEBUG ({split_label}): run.id = {wandb.run.id}")
    print(f"DEBUG ({split_label}): sharpe_ratios = {sharpe_ratios}")
    print(f"DEBUG ({split_label}): updating best model based on sharpe_ratios: {sharpe_ratios}")
    if len(sharpe_ratios) > 0:
        best_model_name = max(sharpe_ratios, key=sharpe_ratios.get)
        if metrics['sharpe_ratio'] > sharpe_ratios[best_model_name]:
            print(
                f"DEBUG ({split_label}): {round(metrics['sharpe_ratio'], 2)} ({model_name})"
                f" > {round(sharpe_ratios[best_model_name], 2)} ({best_model_name})"
                f". New best model: {model_name}."
            )
            log_metrics(metrics, 'best_model', split_label)
            wandb.log({split_label: {'best_model_name': model_name}})
        else:
            print(
                f"DEBUG ({split_label}): {round(metrics['sharpe_ratio'], 2)} ({model_name})"
                f" <= {round(sharpe_ratios[best_model_name], 2)} ({best_model_name})"
                ". Not updating best model."
            )
    else:
        print(f"DEBUG ({split_label}): no models logged yet, new best model is current one: {model_name}")
        print(f"DEBUG ({split_label}): wandb.run.config['sharpe_ratios'] = {wandb.run.config['sharpe_ratios']}")
        log_metrics(metrics, 'best_model', split_label)

    wandb.run.config['sharpe_ratios'][split_label][model_name] = metrics['sharpe_ratio']

# Config

In [29]:
#@title init config
parameters_dict = {}
sweep_config = {
    'method': 'grid',
    'metric': {
        'name': 'test.sharpe_ratio/best_model'
    },
    'parameters': parameters_dict
}

In [30]:
#@title CONFIG: create dataset - yearly_train_test

min_test_start_year = 2020
max_test_start_year = 2025

########################################################

yearly_dataset_params = dict(
    # dataset_period = {'value': 'year'},
    # dataset_splits = {'parameters': {
    #     'train': {'value': True },
    #     'val': {'value': False },
    #     'test': {'value': True },
    # }},

    dataset_type = {'value':'yearly_train_test'},

    stock_index_name = {'value': 'DOW-30'},

    train_years_count = {'value': 10},
    test_years_count = {'value': 1},
    test_start_year = {
        'values': list(range(min_test_start_year, max_test_start_year))
    }
)

In [31]:
#@title CONFIG: create dataset - quarterly_train_test

# Full date range (18 backtests)
train_start_date = '2009-01-01'
min_test_start_date = '2016-01-01'
max_test_end_date = '2020-08-05'

# NUM_DATE_RANGES = None
NUM_DATE_RANGES = 2

# NUM_RUNS_PER_DATERANGE = 1
NUM_RUNS_PER_DATERANGE = 1

#################################################################

date_ranges = generate_quarterly_date_ranges(
    train_start_date,
    min_test_start_date,
    max_test_end_date,
    return_strings=True,
    finetune_previous_val=True
)

truncated_date_ranges = date_ranges[:NUM_DATE_RANGES]
copied_or_truncated_date_ranges = [
    date_range
    for date_range in truncated_date_ranges
    for _ in range(NUM_RUNS_PER_DATERANGE)
]

quarterly_dataset_params = dict(
    dataset_type = {'value': 'quarterly_train_val_test'},
    stock_index_name = {'value': 'DOW-30'},
    train_start_date = {'value': train_start_date},
    min_test_start_date = {'value': min_test_start_date},
    max_test_end_date = {'value': max_test_end_date},
    date_range = {
        # 'values': copied_or_truncated_date_ranges
        'values': truncated_date_ranges
    }
)

In [32]:
#@title CONFIG: choose dataset
parameters_dict.update(
    # yearly_dataset_params,
    quarterly_dataset_params
)

In [33]:
#@title CONFIG: number of seeds

NUM_SEEDS = 2

parameters_dict.update({
    'seed': {'values': (np.random.randn(NUM_SEEDS) * 1e8).astype(int).tolist()}
})

In [34]:
#@title CONFIG: env params
parameters_dict.update(dict(
    cost_pct = {'value': 1e-3},
    initial_amount = {'value': 50_000},
    turbulence_threshold = {
        'value': 99,
        # 'values': [30, 40, 50, 60, 70]
    },
    eval_turbulence_thresh = {'value': 25},
    if_vix = {'value': True}
))

In [35]:
#@title CONFIG: models_used

MODELS_USED = [
    'ppo'
]

parameters_dict.update({
    'models_used': {'value': MODELS_USED}
})

# TODO: remove legacy config used for backward compat
parameters_dict.update({
    'if_using_a2c': {'value': 'ppo' in MODELS_USED},
    'if_using_ddpg': {'value': 'ddpg' in MODELS_USED},
    'if_using_ppo': {'value': 'ppo' in MODELS_USED},
    'if_using_td3': {'value': 'td3' in MODELS_USED},
    'if_using_sac': {'value': 'sac' in MODELS_USED}
})

In [36]:
#@title CONFIG: model and training params

training_params = {
    "parameters": {
        "ppo": {
            "parameters": dict(
                steps={"value": 2_048},
                # steps={"values": [2_048, 4_096]}, # gridsearch steps demo
                # steps={"values": [2**i for i in range(11, 18)]}, # [2048, 4096, 8192, 16384, 32768, 65536, 131072]

                train_batch_size={"value": 2048},
                num_epochs={"value": 10},
                minibatch_size={"value": 128},
                lr={"value": 5e-5},
                gamma={"value": 0.99},
            )
        }
    }
}

env_runners_params = {
    'parameters': dict(
        num_envs_per_env_runner = {'value': 1},
        num_env_runners = {'value': 0},
    )
}

parameters_dict.update({'env_runners_params': env_runners_params})
parameters_dict.update({'training_params': training_params})

# Train & eval funcs

In [None]:
#@title MetricsLoggerCallback (class)
from ray.rllib.algorithms.callbacks import DefaultCallbacks
from typing import Optional, Sequence
import gymnasium as gym
from gymnasium.vector import AsyncVectorEnv

class MetricsLoggerCallback(DefaultCallbacks):
    def __init__(self, model_name, ema_coeff=0.2, ma_window=20, log_to_wandb=False):
        super().__init__()

        self.model_name = model_name
        self.ema_coeff = ema_coeff
        self.ma_window = ma_window
        self.metric_names = set()
        self.log_to_wandb = log_to_wandb

    def unwrap_env(self, env):
        env = env.unwrapped
        # print(type(env))
        # (SingleAgentEnvRunner pid=127688) <class 'gymnasium.vector.sync_vector_env.SyncVectorEnv'>

        if isinstance(env, AsyncVectorEnv):
            return env
        else:
            env = env.envs[0]
            # print(type(env))
            # (SingleAgentEnvRunner pid=127688) <class 'gymnasium.wrappers.common.OrderEnforcing'>

            env = env.env
            # print(type(env))
            # (SingleAgentEnvRunner pid=127688) <class 'gymnasium.wrappers.common.PassiveEnvChecker'>

            env = env.env
            # print(type(env))
            # (SingleAgentEnvRunner pid=127688) <class 'finrl.meta.env_stock_trading.env_stocktrading.StockTradingEnv'>
        return env

    def on_episode_step(
            self,
            *,
            episode,
            env_runner,
            metrics_logger,
            env,
            env_index,
            rl_module,
            **kwargs,
        ) -> None:

        env = self.unwrap_env(env)

        if isinstance(env, AsyncVectorEnv):
            asset_values = env.get_attr('asset_memory')
            asset_values = pd.concat([pd.Series(av) for av in asset_values], axis=1).mean(axis=1)

            # TODO: save_asset_memory
            raise NotImplementedError
        else:
            # asset_values = env.asset_memory
            asset_values = env.save_asset_memory()

        metrics = compute_metrics(asset_values)

        # mode = env.mode
        for metric_name, metric_value in metrics.items():
            # metric_name = f"{mode}/{metric_name}" if mode != "" else metric_name
            episode.add_temporary_timestep_data(metric_name, metric_value)
            self.metric_names.update([metric_name])

    def on_episode_end(
            self,
            *,
            episode,
            env_runner,
            metrics_logger,
            env,
            env_index,
            rl_module,
            **kwargs,
        ) -> None:

        for metric_name in self.metric_names:
            metric_values = episode.get_temporary_timestep_data(metric_name)
            metric_value = np.nanmean(np.array(metric_values))

            # metrics_logger.log_value(
            #     metric_name,
            #     metric_value,
            #     reduce='mean',
            # )

            # Log EMA metrics locally
            metrics_logger.log_value(
                f"{metric_name}_EMA_{self.ema_coeff}",
                metric_value,
                reduce='mean',
                ema_coeff=self.ema_coeff
            )

            # Log MA metrics locally
            metrics_logger.log_value(
                f"{metric_name}_ma_{self.ma_window}",
                metric_value,
                reduce='mean',
                window=self.ma_window
            )

            # Log unsmoothed metrics to wandb
            if self.log_to_wandb:
                mode = 'val' if env_runner.config.in_evaluation else 'train'
                wandb.log({
                    f"{mode}.{metric_name}/{self.model_name}": metric_value,
                }) # TODO: log on every episode step

In [None]:
#@title print_result

RESULT_KEYS_TO_INCLUDE = [
    'sharpe_ratio_MA',
    'ann_return_MA',
    'mdd_MA',

    'sharpe_ratio_EMA',
    'ann_return_EMA',
    'mdd_EMA',
]

def print_result(result):
    print()
    for key in result['env_runners'].keys():
        for include_key in RESULT_KEYS_TO_INCLUDE:
            if key.startswith(include_key):
                print(f"train/{key}: {round(result['env_runners'][key], 2)}")
                break

    for key in result['evaluation']['env_runners'].keys():
        for include_key in RESULT_KEYS_TO_INCLUDE:
            if key.startswith(include_key):
                print(f"val/{key}: {round(result['evaluation']['env_runners'][key], 2)}")
                break
    print()

In [None]:
#@title benchmark_exec_time
import pandas as pd
from time import perf_counter
from functools import wraps

def benchmark_exec_time(func):
    @wraps(func)
    def wrapper(*args, **kwargs):

        start = perf_counter()
        output = func(*args, **kwargs)
        end = perf_counter()

        exec_time_sec = end - start

        data = {
            "func_name": func.__name__,
            "exec_time_sec": exec_time_sec,
        }
        # print(f'\nBenchmark results: {data}')
        return output, exec_time_sec

    return wrapper

In [None]:
#@title Train_eval_models

from ray.rllib.algorithms.ppo import PPOConfig

AVAILABLE_MODELS_CONFIGS = {
    'ppo': PPOConfig
}

def create_stock_trading_env(env_config):
    return init_env(**env_config)

def train_eval_rllib_models(
        run_config,
        train_np_env_config,
        val_np_env_config,
        test_np_env_config,
        model_list = ['ppo'], # TODO: discard in favor of 'if_using_{model_name}'
        pretrained_models = {} # pretrained on previous train set (not validation set)
    ):

    assert set(model_list).issubset(AVAILABLE_MODELS_CONFIGS)
    check_and_make_directories([TRAINED_MODEL_DIR])

    register_env("stock_trading_env", create_stock_trading_env)

    for model_name in model_list:
        if run_config[f"if_using_{model_name}"]:
            print(f"Training {model_name.upper()} agent")
            model_config = AVAILABLE_MODELS_CONFIGS[model_name]
            algo = train_rllib_model(
                run_config,
                model_name,
                model_config,
                train_np_env_config,
                val_np_env_config,
                pretrained_algo = pretrained_models.get(model_name, None)
            )

            print(f"Evaluating {model_name.upper()} agent")
            val_result = evaluate_model(
                algo,
                model_name,
                run_config=run_config,
                np_env_config = val_np_env_config,
                mode='val',
            )
            fig = plot_results(**val_result)
            log_plot_as_artifact(fig, "val_cumulative_return", artifact_type="plot")

            test_result = evaluate_model(
                algo,
                model_name,
                run_config=run_config,
                np_env_config = test_np_env_config,
                mode='test',
            )
            fig = plot_results(**test_result)
            log_plot_as_artifact(fig, "test_cumulative_return", artifact_type="plot")
            pretrained_models[model_name] = algo
        else:
            print(f"Skipping {model_name.upper()} agent")

    return pretrained_models

In [None]:
#@title Train_eval_models (w/ threshold gridsearch)

from ray.rllib.algorithms.ppo import PPOConfig

AVAILABLE_MODELS_CONFIGS = {
    'ppo': PPOConfig
}

def create_stock_trading_env(env_config):
    return init_env(**env_config)

def train_eval_rllib_models(
        run_config,
        train_np_env_config,
        val_np_env_config,
        test_np_env_config,
        model_list = ['ppo'], # TODO: discard in favor of 'if_using_{model_name}'
        pretrained_val_models = {},
    ):

    assert set(model_list).issubset(AVAILABLE_MODELS_CONFIGS)
    check_and_make_directories([TRAINED_MODEL_DIR])

    register_env("stock_trading_env", create_stock_trading_env)

    for model_name in model_list:
        if run_config[f"if_using_{model_name}"]:
            model_config = AVAILABLE_MODELS_CONFIGS[model_name]

            # Skip training if a model pretrained on previous validation set is passed.
            # The training should happen only for first run,
            # further runs should reuse pretrained val models.
            if model_name in pretrained_val_models:
                algo_train = pretrained_val_models[model_name]
            else:
                algo_train = train_rllib_model(
                    run_config,
                    model_name,
                    model_config,
                    train_np_env_config,
                    val_np_env_config,
                    pretrained_algo = None # Train from scratch
                )

            wandb.run.summary['val.num_pretrain_iters'] = algo_train.training_iteration

            # Evaluate on validation set
            # using model trained on current train set (previous val set)
            val_best_th = evaluate_threshold_grid(
                algo_train,
                model_name,
                run_config,
                val_np_env_config,
                split_label='val',
            )

            # Finetune on current validation set (previous test set)
            algo_val = train_rllib_model(
                run_config,
                model_name,
                model_config,
                train_np_env_config,
                val_np_env_config,
                pretrained_algo = algo_train
            )

            # Evaluate on test set
            # using model finetuned on current validation set (previous test set)
            _ = evaluate_threshold_grid(
                algo_val,
                model_name,
                run_config,
                test_np_env_config,
                split_label='test',
                chosen_th=val_best_th # chosen before fine-tuning
            )

            pretrained_val_models[model_name] = algo_val
        else:
            print(f"Skipping {model_name.upper()} agent")

    return pretrained_val_models

In [None]:
#@title Train RLlib model

from functools import partial
from pathlib import Path
import torch
from ray.tune.registry import register_env
from time import perf_counter
from math import ceil
import ray

def train_rllib_model(
    run_config,
    algo_name,
    algo_cls_config,
    train_np_env_config,
    val_np_env_config,
    pretrained_algo=None
):

    num_envs_per_env_runner = run_config['env_runners_params']['num_envs_per_env_runner']
    num_env_runners = run_config['env_runners_params']['num_env_runners']

    if pretrained_algo:
        print(f"Finetuning {algo_name.upper()} agent.")
        algo = pretrained_algo
        print(f'Using agent with {algo.training_iteration} pretrain iterations.')
    else:
        print(f"Training {algo_name.upper()} agent from scratch.")
        algo_config = (
            algo_cls_config()
            .environment(
                env="stock_trading_env",
                env_config={
                    # "df": train,
                    "np_env_config": train_np_env_config,

                    # "run_config": run_config,
                    "initial_amount": run_config['initial_amount'],
                    "cost_pct": run_config['cost_pct'],

                    "mode": 'train'
                },
            )
            .env_runners(
                num_envs_per_env_runner=num_envs_per_env_runner,
                num_env_runners=num_env_runners,
                num_cpus_per_env_runner= (2/num_env_runners) if num_env_runners > 2 else None,

                # gym_env_vectorize_mode=gym.envs.registration.VectorizeMode.ASYNC,
            )
            .training(
                train_batch_size=2048,
                num_epochs=10,
                minibatch_size=128,
            )
            .evaluation(
                # Set up the validation environment
                evaluation_interval=1,  # Specify evaluation frequency (1=after each training step)
                evaluation_config={
                    "env": "stock_trading_env",
                    "env_config": {
                        # "df": val,
                        "np_env_config": val_np_env_config,

                        # "run_config": run_config,
                        "initial_amount": run_config['initial_amount'],
                        "cost_pct": run_config['cost_pct'],

                        "mode": 'val'
                    },
                },
            )
            # .callbacks(MetricsLoggerCallback)
            .callbacks(partial(MetricsLoggerCallback, model_name='ppo', log_to_wandb=True))
            .resources(
                num_gpus=1 if torch.cuda.is_available() else None
            )
        )

        # 2. build the algorithm ..
        if algo_config.num_env_runners > 0:
            ray.shutdown()
            ray.init()

            register_env("stock_trading_env", create_stock_trading_env)

        algo = algo_config.build()
        print(f'Created new {algo_name} agent.')

    print(f"Envs: {algo.config.num_envs_per_env_runner}, Runners: {algo.config.num_env_runners}")

    # 3. .. train it ..
    @benchmark_exec_time
    def _train_rllib(total_timesteps):
        print('Started training.')
        results = []
        total_batches = ceil(total_timesteps / algo.config.train_batch_size)
        print(f"total_batches: {total_batches}")
        print(f"total_timesteps: {total_timesteps}")
        for _ in range(total_batches):
            result = algo.train()
            results.append(result)

        print_result(result)
        print('Training complete.')
        return results

    results, exec_time_sec = _train_rllib(run_config['training_params'][algo_name]['steps'])
    print(f"TRAINING DURATION: {exec_time_sec}")

    # Log train duration
    duration_minutes = round(exec_time_sec / 60, 1)
    wandb.run.summary[f"train.duration_minutes/{algo_name}"] = duration_minutes
    wandb.run.summary[f"train.duration_minutes/{algo_name}"] = duration_minutes

    # Save model
    ckpt_path = (Path(TRAINED_MODEL_DIR) / algo_name).resolve()
    algo.save(ckpt_path)
    update_model_artifacts(log_results_folder=False)

    return algo

In [None]:
#@title Evaluate RLlib model
from ray.rllib.algorithms.algorithm import Algorithm
from ray.rllib.core.columns import Columns
from ray.rllib.utils.numpy import convert_to_numpy, softmax
from ray.rllib.models.torch.torch_distributions import TorchDiagGaussian
from ray.rllib.core.rl_module.rl_module import RLModule
import torch

# Create the testing environment
def evaluate_model(algo_or_rl_module, model_name, split_label, np_env_config, run_config, turbulence_thresh=None,
                   log_to_wandb=False, return_metrics=False):

    if turbulence_thresh is None:
        turbulence_thresh = run_config.get('turbulence_thresh', 99)
    turbulence_name = 'turbulence' if not run_config.get('if_vix', None) else '^VIX'
    print(f"Evaluating for `{split_label}` using `{turbulence_name}`: {turbulence_thresh}")

    env_config = {
        # "df": val,
        "np_env_config": np_env_config,

        # "run_config": run_config,
        "initial_amount": run_config['initial_amount'],
        "cost_pct": run_config['cost_pct'],

        "mode": split_label,
        'turbulence_threshold': turbulence_thresh
    }

    eval_env = create_stock_trading_env(env_config)
    state, info = eval_env.reset()
    done = False

    # Perform inference using the trained RLlib agent
    if isinstance(algo_or_rl_module, Algorithm):
        rl_module = algo_or_rl_module.env_runner.module
    elif isinstance(algo_or_rl_module, RLModule):
        rl_module = algo_or_rl_module
    else:
        raise NotImplementedError

    while not done:
        # Compute action using the RLlib trained agent
        input_dict = {Columns.OBS: torch.Tensor(state).unsqueeze(0)}
        rl_module_out = rl_module.forward_inference(input_dict)
        logits = rl_module_out[Columns.ACTION_DIST_INPUTS]

        # Take mean of multivariate Gaussian distribution
        mean, log_std = logits.chunk(2, dim=-1)

        # action_distribution = TorchDiagGaussian.from_logits(logits)
        # action_distribution = action_distribution.to_deterministic()
        # assert np.allclose(mean, action_distribution.loc)
        # assert np.allclose(log_std.exp(), action_distribution._dist.scale)
        # action = action_distribution.sample()

        action = mean.detach().numpy().squeeze()

        # Clip the action to ensure it's within the action space bounds
        action = np.clip(action, eval_env.action_space.low, eval_env.action_space.high)

        # Perform action
        state, reward, terminated, truncated, _ = eval_env.step(action)
        done = terminated or truncated

    df_account_value = eval_env.save_asset_memory()
    metrics = compute_metrics(df_account_value)
    print(metrics)

    if log_to_wandb:
        turbulence_log_name = 'ti' if not run_config.get('if_vix', None) else 'vix'
        for metric_name, metric_value in metrics.items():
            metric_name = f"{split_label}.{turbulence_log_name}_{turbulence_thresh}.{metric_name}/{model_name}"
            wandb.run.summary[metric_name] = round(metric_value, 2)

    turbulence_series = pd.Series(
        np_env_config['turbulence_array'][:len(df_account_value)],
        index=df_account_value['date'],
        name=turbulence_name
    )

    eval_result = {
        'account_value': df_account_value.rename(columns={'account_value': model_name}),
        'turbulence_series': turbulence_series,
        'turbulence_thresh': turbulence_thresh
    }

    if return_metrics:
        return eval_result, metrics
    else:
        return eval_result

In [None]:
#@title plot_results (enhanced)

import matplotlib.pyplot as plt
import matplotlib.ticker as mtick

def plot_results(
        account_value,
        turbulence_series,
        turbulence_thresh,
        figsize='small',
        split_label=None,
        metrics=None,
        ylim_bottom = None,
        ylim_top = None
    ):
    assert split_label in ['val', 'test']
    assert turbulence_series.name in ['turbulence', '^VIX']
    assert figsize in ['small', 'medium']

    figsizes = {
        'medium': (14, 10),
        'small': (8.3, 8)
    }

    # Create figure and subplots
    fig, (ax1, ax2) = plt.subplots(2, 1, figsize=figsizes[figsize], sharex=True, gridspec_kw={'height_ratios': [3, 1]})

    # Method styles
    method_styles = {
        'A2C': {'color': '#8c564b', 'linestyle': '--'},
        'DDPG': {'color': '#e377c2', 'linestyle': '-'},
        'PPO': {'color': '#7f7f7f', 'linestyle': '-'},
        'TD3': {'color': '#bcbd22', 'linestyle': '--'},
        'SAC': {'color': '#17becf', 'linestyle': '-'},
        'DJIA': {'color': '#000000', 'linestyle': '-'},
    }

    # Plot DJIA if present
    if 'DJIA' in account_value:
        ax1.plot(account_value.index, account_value['DJIA'], label="Dow Jones Index",
                linestyle=method_styles['DJIA']['linestyle'], color=method_styles['DJIA']['color'])

    # Ensure date is index
    if 'date' in account_value.columns:
        account_value.set_index('date', inplace=True)

    account_value.rename(columns={col: col.upper() for col in account_value.columns}, inplace=True)

    # Plot account values
    for model_name in account_value.columns:
        style = method_styles.get(model_name, {'color': 'blue', 'linestyle': '-'})  # Default style fallback
        ax1.plot(account_value.index, account_value[model_name], label=model_name, **style)

    # Construct subtitle text
    turbulence_label = "Turbulence Index" if turbulence_series.name == 'turbulence' else "VIX Coefficient"
    split_label_name = ('validation' if split_label == 'val' else split_label).capitalize()
    title = f"{split_label_name} split | {turbulence_label} threshold: {turbulence_thresh}"
    fig.suptitle(title, fontsize=20, fontweight='bold')

   # Define the mapping for metric names
    full_names = {
        'mdd': 'MDD',
        'ann_return': 'Annualized Return',
        'cum_return': 'Cumulative Return',
        'sharpe_ratio': 'Sharpe Ratio'
    }

    # Add subtitle, properly positioned and centered
    if metrics:
        metric_text = ", ".join(
            f"{full_names.get(name, name.replace('_', ' ').capitalize())}: {value:.2f}"
            for name, value in metrics.items()
        )
        ax1.set_title(metric_text, fontsize=12, color='gray', ha='center')


    # **Prettify y-axis numbers**
    ax1.yaxis.set_major_formatter(mtick.FuncFormatter(lambda x, _: f"{x:,.0f}"))

    # **Horizontal Line at Initial Asset Value**
    initial_asset_value = account_value.iloc[0].mean()  # Assuming initial value from mean of first row
    ax1.axhline(y=initial_asset_value, color='gray', linestyle='-.', linewidth=1.5, label="Initial Asset Value")

    # **Main Plot Customization**
    ax1.set_ylabel("Total Asset Value ($)", fontsize=16, fontweight='bold')
    if ylim_bottom is not None:
        ax1.set_ylim(bottom=ylim_bottom)
    if ylim_top is not None:
        ax1.set_ylim(top=ylim_top)

    ax1.legend(loc='lower right')
    ax1.grid(True, linestyle='--', alpha=0.3)

    # **Turbulence Plot**
    ax2.plot(turbulence_series.index, turbulence_series, label=turbulence_label, color='red', linestyle='--', linewidth=2)
    ax2.axhline(y=turbulence_thresh, color='red', linestyle=':', label=f'Threshold = {turbulence_thresh}')

    ax2.set_ylabel(turbulence_label, fontsize=16, fontweight='bold')
    ax2.legend(loc='upper left')
    ax2.grid(True, linestyle='--', alpha=0.3)

    max_turbulence = max(turbulence_series.max(), turbulence_thresh)
    ax2.set_ylim(0, max_turbulence + 10)

    # **Shared x-axis label**
    ax2.set_xlabel("Date", fontsize=16, fontweight='bold')

    return fig

In [None]:
#@title log_plot_as_artifact

import os
import matplotlib.pyplot as plt
import wandb

def log_plot_as_artifact(fig, artifact_name_prefix, artifact_type="plot"):
    """
    Save a Matplotlib figure without clipping and log it as a W&B artifact.

    Parameters:
        fig (matplotlib.figure.Figure): The Matplotlib figure to save and log.
        artifact_name (str): The name of the W&B artifact.
        artifact_type (str): The type of the artifact (default is "plot").
        filename (str): The filename to save the plot as (default is "plot.png").
    """
    try:
        # Get full artifact name
        artifact_name = f'{artifact_name_prefix}-{wandb.run.id}'
        filename = artifact_name + '.png'

        # Save the figure with tight layout and proper padding
        fig.savefig(filename, bbox_inches='tight', pad_inches=0.1, dpi=300)
        plt.close(fig)  # Close the figure to free up memory

        # Create and log the W&B artifact
        artifact = wandb.Artifact(artifact_name, type=artifact_type)
        artifact.add_file(filename, skip_cache=True, overwrite=True)
        wandb.log_artifact(artifact)
    finally:
        # Ensure the file is deleted after use
        if os.path.exists(filename):
            os.remove(filename)

# log_plot_as_artifact(fig, "performance_comparison_DRL_agents", artifact_type="plot")

In [None]:
#@title batch_log_plots_as_artifact

import os
import shutil
import wandb
import matplotlib.pyplot as plt

def batch_log_plots_as_artifact(
        figs, fig_names, artifact_name_prefix, artifact_type="plot"
    ):

    """
    Save a list of Matplotlib figures to a folder, log the folder as a W&B artifact,
    and delete the folder after logging.

    Parameters:
        figs (list): List of Matplotlib figure objects.
        folder_name (str): Name of the folder to store plots.
        artifact_name_prefix (str): Prefix for the artifact name.
        artifact_type (str): The type of the artifact (default is "plot").
    """
    assert wandb.run.id

    # Ensure the folder exists
    os.makedirs(artifact_name_prefix, exist_ok=True)

    try:
        # Save all figures in the folder
        for i, (fig, fig_name) in enumerate(zip(figs, fig_names)):
            filename = os.path.join(artifact_name_prefix, f"{fig_name}.png")
            fig.savefig(filename, bbox_inches='tight', pad_inches=0.1, dpi=300)
            plt.close(fig)  # Close the figure to free up memory

        # Create and log the W&B artifact
        artifact_name = f"{artifact_name_prefix}-{wandb.run.id}"
        artifact = wandb.Artifact(artifact_name, type=artifact_type)
        artifact.add_dir(artifact_name_prefix, skip_cache=True)
        wandb.log_artifact(artifact)
    finally:
        # Ensure the folder is deleted after use
        if os.path.exists(artifact_name_prefix):
            shutil.rmtree(artifact_name_prefix)

In [None]:
#@title log_eval_results

def log_eval_results(
        model_name,
        metrics,
        split_label,
        turbulence_log_name,
        turbulence_thresh=None,
        postfix=None
    ):

    if turbulence_thresh is None:
        assert postfix in ['best', 'chosen']

    if postfix is None:
        assert turbulence_thresh is not None
        postfix = turbulence_thresh

    for metric_name, metric_value in metrics.items():
        formatted_name = f"{split_label}.{turbulence_log_name}_{postfix}.{metric_name}/{model_name}"
        wandb.run.summary[formatted_name] = round(metric_value, 2)  # Use formatted_name instead

In [None]:
#@title evaluate_threshold_grid

import numpy as np
import pandas as pd
import wandb

def evaluate_threshold_grid(
    algo_or_rl_module,
    model_name,
    run_config,
    np_env_config,
    num_grid_points=10,
    split_label='val',
    chosen_th=None,
    plot_padding=500
):

    if run_config.get('if_vix', True):
        turbulence_log_name = 'vix'
    else:
        turbulence_log_name = 'ti'
        raise NotImplementedError

    # Calculate threshold grid
    turb_ary = np_env_config['turbulence_array']
    threshold_grid = np.linspace(turb_ary.min(), turb_ary.max(), num_grid_points)
    threshold_grid = np.ceil(threshold_grid).astype(int).tolist()
    threshold_grid.append(max(threshold_grid) + 1)
    if chosen_th is not None:
        threshold_grid.append(chosen_th)

    th_metrics = []  # List to store metrics for all thresholds
    th_results = {}  # Dictionary to store results per threshold

    for th in threshold_grid:
        result, metrics = evaluate_model(
            turbulence_thresh=th,
            model_name=model_name,
            algo_or_rl_module=algo_or_rl_module,
            run_config=run_config,
            np_env_config=np_env_config,
            split_label=split_label,
            log_to_wandb=True,
            return_metrics=True
        )

        metrics = {turbulence_log_name: th, **metrics}
        th_metrics.append(metrics)
        th_results[th] = result

    # Convert metrics to a DataFrame and log to WandB
    df = pd.DataFrame(th_metrics).astype({turbulence_log_name: int})
    metric_cols = df.drop(columns=[turbulence_log_name]).columns # take only true metric columns

    # Identify the protected row where turbulence_log_name == chosen_th
    protected_row = df[df[turbulence_log_name] == chosen_th].tail(1)  # Keep only the last occurrence

    # Drop duplicate rows (except the protected one)
    df = df.drop_duplicates(metric_cols, keep="last")

    # Drop rows where all metric columns are 0, excluding the protected row
    df = df[~((df[metric_cols] == 0).all(axis=1) & ~df.index.isin(protected_row.index))]


    wandb_table = wandb.Table(dataframe=df)
    wandb.log({f"threshold_grid_metrics-{split_label}": wandb_table})

    # Log metrics for best threshold
    best_idx = df['sharpe_ratio'].idxmax()  # Use idxmax() instead of argmax()
    best_metrics = df.loc[best_idx].to_dict()
    best_metrics[turbulence_log_name] = int(best_metrics[turbulence_log_name])
    log_eval_results(model_name, best_metrics, split_label, turbulence_log_name, 'best')

    # Log metrics for chosen threshold (if any)
    if chosen_th is not None:
        chosen_metrics = df[df[turbulence_log_name] == chosen_th].iloc[0].to_dict()
        log_eval_results(model_name, chosen_metrics, split_label, turbulence_log_name, 'chosen')

    # Compute min_account_value and max_account_value
    min_account_value = float('+inf')
    max_account_value = float('-inf')
    for th_value, result in th_results.items():
        min_account_value = min(min_account_value, result['account_value'][model_name].min())
        max_account_value = max(max_account_value, result['account_value'][model_name].max())

    min_account_value -= plot_padding
    max_account_value += plot_padding

    # Plot returns
    figs = []
    fig_names = []
    fig_collection_name = f"cum_return-{split_label}-{turbulence_log_name}"
    best_th_value = best_metrics[turbulence_log_name]

    for th_value in df[turbulence_log_name].values:
        result = th_results[th_value]
        metrics = df.drop(columns=[turbulence_log_name])[
            df[turbulence_log_name] == th_value].iloc[0].to_dict()

        fig = plot_results(
            **result,
            figsize='small',
            split_label=split_label,
            metrics=metrics,
            ylim_bottom=min_account_value,
            ylim_top=max_account_value
        )
        fig_name = f"{fig_collection_name}_{th_value}"
        if th_value == best_th_value:
            fig_name += "_best"
        if chosen_th is not None and th_value == chosen_th:
            fig_name += "_chosen"

        figs.append(fig)
        fig_names.append(fig_name)

    batch_log_plots_as_artifact(
        figs,
        fig_names,
        artifact_name_prefix=fig_collection_name
    )

    best_th = best_metrics[turbulence_log_name]
    return best_th

# Sweep Runner

In [37]:
#@title load_model

import ray
from ray.rllib.algorithms.ppo import PPO

AVAILABLE_MODELS_CLASSES = {
    'ppo': PPO
}

def load_model(model_name, trained_model_dir = TRAINED_MODEL_DIR):
    model_class = AVAILABLE_MODELS_CLASSES[model_name]
    checkpoint_path = os.path.abspath(f"{trained_model_dir}/{model_name}")
    algo = model_class.from_checkpoint(checkpoint_path)
    return algo

In [38]:
#@title Download artifacts

def download_artifacts(run_id, artifact_types):
    # Initialize the W&B API
    api = wandb.Api()

    # Retrieve the run
    run = api.run(f"{ENTITY}/{PROJECT}/{run_id}")

    # Iterate over the artifacts used or logged by the run
    for artifact in run.logged_artifacts():
        if artifact.type in artifact_types:
            artifact_folder = f'./{artifact.type}'
            !rm -rf artifact_folder
            artifact.download(artifact_folder)

In [39]:
#@title find_previous_run (sweep)

def find_prev_run_id(
        sweep_id,
        curr_run_id,
        curr_config_hash # HACK: pass along with curr_run_id since uploading new config_hash to wandb might be slow
    ):

    api = wandb.Api()
    sweep_runs = api.sweep(f"{ENTITY}/{PROJECT}/{sweep_id}").runs
    curr_run = api.run(f"{ENTITY}/{PROJECT}/{curr_run_id}")

    prev_run_id = None
    for run in sweep_runs:
        # find run with same `hash_config`
        # and current previous `test_start_date == current `val_start_date`
        if run.id != curr_run.id \
        and run.summary['config_hash'] == curr_config_hash \
        and run.config['date_range']['test_start_date'] == curr_run.config['date_range']['val_start_date']:
            prev_run_id = run.id
            break

    return prev_run_id

# SWEEP_ID = '9otblgq2'
# CURR_RUN_ID = 'kxjtk9qg'
# prev_run_id = find_prev_run_id(SWEEP_ID, CURR_RUN_ID)
# prev_run_id

In [40]:
#@title Sweep Runner
from copy import deepcopy
from datetime import datetime

import random
import string
import json

def config_to_canonical_string(cfg_dict):
    return json.dumps(cfg_dict, sort_keys=True)

def set_run_name(prefix, n=5):
    run_name = f"{prefix} | {wandb.run.id}"
    wandb.run.name = run_name
    wandb.run.save()

class SweepRunner:
    def __init__(self, sweep_id):
        self.sweep_id = sweep_id
        self.pretrained_val_models = {}

    def main(self, run_config=None):
        run_timer_start = perf_counter()

        current_time = datetime.now()
        print("START TIME:", current_time)

        with wandb.init(config=run_config):
            run_config = wandb.config
            # print('Debug copy run_config...', end=' ')
            # _ = deepcopy(run_config)
            # print("Done.")

            if run_config['dataset_type'] == 'yearly_train_test':
                raise NotImplementedError
            elif run_config['dataset_type'] == 'quarterly_train_val_test':
                (
                    train_np_env_config,
                    val_np_env_config,
                    test_np_env_config
                ) = build_quarterly_train_val_test(run_config)
                set_run_name(run_config['dataset_name'])

                # Extract pretrained model from previous run
                # for a given unique hyperparameter combination (e.g. (seed, steps) pair).
                # Ignore `date_range` parameter since pretrained models should persist
                # across the whole date range
                config_seed = run_config['seed']
                config_training_params = run_config['training_params']
                # config_hash = hash(str(config_seed) + str(config_training_params))
                config_hash = hash(config_to_canonical_string({
                    'seed': config_seed,
                    'training_params': config_training_params
                }))

                # run_config_copy = run_config._as_dict().copy()
                # run_config_copy.pop('date_range')
                # config_hash = hash(str(run_config_copy))


                # extract pretrained model from previous run for a given seed
                pretrained_val_models = self.pretrained_val_models.get(config_hash, {})
                wandb.run.summary['config_hash'] = config_hash

                if pretrained_val_models:
                    run_config.update({'finetune': True})
                else:
                    run_config.update({'finetune': False})

                pretrained_val_models = train_eval_rllib_models(
                    run_config,
                    train_np_env_config,
                    val_np_env_config,
                    test_np_env_config,
                    pretrained_val_models=pretrained_val_models
                )

                self.pretrained_val_models[config_hash] = pretrained_val_models

            ray.shutdown()

            run_timer_end = perf_counter()
            run_duration_minutes = round( (run_timer_end - run_timer_start) / 60, 1)
            wandb.run.summary[f"run.duration_minutes"] = run_duration_minutes
            print(f"RUN DURATION: {run_duration_minutes}")

        current_time = datetime.now()
        print("END TIME:", current_time)

In [41]:
#@title Sweep Runner (load model weights + continue = async agent)
from copy import deepcopy
from datetime import datetime
from hashlib import sha256
import random
import string

def config_to_canonical_string(cfg_dict):
    return json.dumps(cfg_dict, sort_keys=True)

def set_run_name(prefix, n=5):
    run_name = f"{prefix} | {wandb.run.id}"
    wandb.run.name = run_name
    wandb.run.save()

class SweepRunner:
    def __init__(self, sweep_id):
        self.sweep_id = sweep_id
        self.pretrained_val_models = {}

    def main(self, run_config=None):
        run_timer_start = perf_counter()

        current_time = datetime.now()
        print("START TIME:", current_time)

        with wandb.init(config=run_config):
            run_config = wandb.config
            # print('Debug copy run_config...', end=' ')
            # _ = deepcopy(run_config)
            # print("Done.")

            if run_config['dataset_type'] == 'yearly_train_test':
                raise NotImplementedError
            elif run_config['dataset_type'] == 'quarterly_train_val_test':
                (
                    train_np_env_config,
                    val_np_env_config,
                    test_np_env_config
                ) = build_quarterly_train_val_test(run_config)
                set_run_name(run_config['dataset_name'])

                # Extract pretrained model from previous run
                # for a given unique hyperparameter combination (e.g. (seed, steps) pair).
                # Ignore `date_range` parameter since pretrained models should persist
                # across the whole date range
                config_seed = run_config['seed']
                config_training_params = run_config['training_params']
                config_canonical_str = config_to_canonical_string({
                    'seed': config_seed,
                    'training_params': config_training_params
                })
                config_hash = sha256(config_canonical_str.encode()).hexdigest()

                # extract pretrained model from previous run for a given seed
                pretrained_val_models = self.pretrained_val_models.get(config_hash, {})
                wandb.run.summary['config_hash'] = config_hash

                if not pretrained_val_models:
                    print("Looking for models from previous run for this date range...")
                    prev_run_id = find_prev_run_id(self.sweep_id, wandb.run.id, config_hash)
                    if prev_run_id is not None:
                        download_artifacts(prev_run_id, artifact_types=['trained_models'])

                        register_env("stock_trading_env", create_stock_trading_env)
                        pretrained_val_models = {
                            model_name: load_model(model_name)
                            for model_name in run_config['models_used']
                        }
                        print("Models downloaded.")
                    else:
                        print("No models found.")

                if pretrained_val_models:
                    run_config.update({'finetune': True})
                else:
                    run_config.update({'finetune': False})

                pretrained_val_models = train_eval_rllib_models(
                    run_config,
                    train_np_env_config,
                    val_np_env_config,
                    test_np_env_config,
                    pretrained_val_models=pretrained_val_models
                )

                self.pretrained_val_models[config_hash] = pretrained_val_models

            ray.shutdown()

            run_timer_end = perf_counter()
            run_duration_minutes = round( (run_timer_end - run_timer_start) / 60, 1)
            wandb.run.summary[f"run.duration_minutes"] = run_duration_minutes
            print(f"RUN DURATION: {run_duration_minutes}")

        current_time = datetime.now()
        print("END TIME:", current_time)

# Run sweep

In [None]:
#@title RUN SWEEP
def run_sweep(n_runs, sweep_config=None, sweep_id=None):
    wandb.finish()
    if sweep_id is None:
        assert sweep_config is not None
        sweep_id = wandb.sweep(sweep_config, project=PROJECT)
    else:
        assert sweep_config is None

    !rm -rf {TRAINED_MODEL_DIR}/*

    sweep_runner = SweepRunner(sweep_id)
    wandb.agent(sweep_id, sweep_runner.main, project=PROJECT, count=n_runs)

run_sweep(
    sweep_id='jrebsx44',
    # sweep_config=sweep_config,

    # n_runs=None,
    n_runs=1,
)