<a href="https://colab.research.google.com/github/tony-pitchblack/finrl-dt/blob/custom-backtesting/finrl_dt_replicate_sweep_rllib.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#Installs

In [1]:
# %load_ext autoreload
# %autoreload 2

In [2]:
%%capture
!pip install git+https://github.com/tony-pitchblack/FinRL.git@benchmarking --no-deps \
    # --force-reinstall --no-deps

In [3]:
%%capture
!wget https://raw.githubusercontent.com/tony-pitchblack/FinRL/benchmarking/requirements.txt -O requirements.txt
!pip install -r requirements.txt

# Imports

In [2]:
import pandas as pd

from stable_baselines3.common.logger import configure
from finrl import config_tickers
from finrl.main import check_and_make_directories
from finrl.config import INDICATORS, TRAINED_MODEL_DIR, RESULTS_DIR

import warnings
warnings.filterwarnings("ignore", category=DeprecationWarning)

In [3]:
import os
from pathlib import Path
import pandas as pd
import warnings
warnings.filterwarnings("ignore", category=DeprecationWarning)

In [4]:
os.environ["WANDB_API_KEY"] = "aee284a72205e2d6787bd3ce266c5b9aefefa42c"

PROJECT = 'finrl-dt-replicate'
ENTITY = "overfit1010"

# General funcs

In [5]:
#@title YahooDownloader

"""Contains methods and classes to collect data from
Yahoo Finance API
"""

from __future__ import annotations

import pandas as pd
import yfinance as yf


class YahooDownloader:
    """Provides methods for retrieving daily stock data from
    Yahoo Finance API

    Attributes
    ----------
        start_date : str
            start date of the data (modified from neofinrl_config.py)
        end_date : str
            end date of the data (modified from neofinrl_config.py)
        ticker_list : list
            a list of stock tickers (modified from neofinrl_config.py)

    Methods
    -------
    fetch_data()
        Fetches data from yahoo API

    """

    def __init__(self, start_date: str, end_date: str, ticker_list: list):
        self.start_date = start_date
        self.end_date = end_date
        self.ticker_list = ticker_list

    def fetch_data(self, proxy=None) -> pd.DataFrame:
        """Fetches data from Yahoo API
        Parameters
        ----------

        Returns
        -------
        `pd.DataFrame`
            7 columns: A date, open, high, low, close, volume and tick symbol
            for the specified stock ticker
        """
        # Download and save the data in a pandas DataFrame:
        data_df = pd.DataFrame()
        num_failures = 0
        for tic in self.ticker_list:
            temp_df = yf.download(
                tic, start=self.start_date, end=self.end_date, proxy=proxy
            )
            temp_df["tic"] = tic
            if len(temp_df) > 0:
                # data_df = data_df.append(temp_df)
                data_df = pd.concat([data_df, temp_df], axis=0)
            else:
                num_failures += 1
        if num_failures == len(self.ticker_list):
            raise ValueError("no data is fetched.")
        # reset the index, we want to use numbers as index instead of dates
        data_df = data_df.reset_index()

        try:
            # Convert wide to long format
            # print(f"DATA COLS: {data_df.columns}")
            data_df = data_df.sort_index(axis=1).set_index(['Date']).drop(columns=['tic']).stack(level='Ticker', future_stack=True)
            data_df.reset_index(inplace=True)
            data_df.columns.name = ''

            # convert the column names to standardized names
            data_df.rename(columns={'Ticker': 'Tic', 'Adj Close': 'Adjcp'}, inplace=True)
            data_df.rename(columns={col: col.lower() for col in data_df.columns}, inplace=True)

            columns = [
                "date",
                "tic",
                "open",
                "high",
                "low",
                "close",
                "adjcp",
                "volume",
            ]

            data_df = data_df[columns]
            # use adjusted close price instead of close price
            data_df["close"] = data_df["adjcp"]
            # drop the adjusted close price column
            data_df = data_df.drop(labels="adjcp", axis=1)

        except NotImplementedError:
            print("the features are not supported currently")

        # create day of the week column (monday = 0)
        data_df["day"] = data_df["date"].dt.dayofweek
        # convert date to standard string format, easy to filter
        data_df["date"] = data_df.date.apply(lambda x: x.strftime("%Y-%m-%d"))
        # drop missing data
        data_df = data_df.dropna()
        data_df = data_df.reset_index(drop=True)
        print("Shape of DataFrame: ", data_df.shape)
        # print("Display DataFrame: ", data_df.head())

        data_df = data_df.sort_values(by=["date", "tic"]).reset_index(drop=True)

        return data_df

    def select_equal_rows_stock(self, df):
        df_check = df.tic.value_counts()
        df_check = pd.DataFrame(df_check).reset_index()
        df_check.columns = ["tic", "counts"]
        mean_df = df_check.counts.mean()
        equal_list = list(df.tic.value_counts() >= mean_df)
        names = df.tic.value_counts().index
        select_stocks_list = list(names[equal_list])
        df = df[df.tic.isin(select_stocks_list)]
        return df


In [6]:
#@title fix_daily_index

def make_daily_index(data_df, date_col_name='date', new_index_name='date_index'):
    # Get unique dates and create a mapping to daily indices
    total_dates = data_df[date_col_name].unique()
    date_to_index = {date: idx for idx, date in enumerate(sorted(total_dates))}
    return data_df[date_col_name].map(date_to_index)

def set_daily_index(data_df, date_col_name='date', new_index_name='date_index'):
    """
    Constructs a daily index from unique dates in the specified column.

    Parameters:
        data_df (pd.DataFrame): The input DataFrame.
        date_column (str): The name of the column containing dates.
        new_index_name (str): The name for the new index.

    Returns:
        pd.DataFrame: DataFrame with a daily index.
    """

    # Map dates to daily indices and set as index
    data_df[new_index_name] = make_daily_index(data_df, date_col_name=date_col_name, new_index_name='date_index')

    data_df.set_index(new_index_name, inplace=True)
    data_df.index.name = ''  # Remove the index name for simplicity

    return data_df

def fix_daily_index(df, date_col_name='date'):
    if df.index.name == date_col_name:
        df.reset_index(inplace=True)

    daily_index = make_daily_index(df, date_col_name=date_col_name, new_index_name='date_index')
    if (df.index.values != daily_index.values).any():

        df.index = daily_index
        df.index.name = ''

    return df

# trade = fix_daily_index(trade)
# trade.index

In [7]:
#@title get dataset name

def get_quarterly_dataset_name(prefix, train_start_date, val_start_date, test_start_date):
    get_quarter = lambda date: f'Q{(date.month - 1) // 3 + 1}'

    val_quarter = get_quarter(val_start_date)
    test_quarter = get_quarter(test_start_date)

    # Extract year and month
    train_start = f"{train_start_date.year}-{train_start_date.month:02}"
    val_start = f"{val_start_date.year}"
    test_start = f"{test_start_date.year}"

    # Construct the dataset name
    dataset_name = f"{prefix} | {train_start} | {val_start} {val_quarter} | {test_start} {test_quarter}"

    return dataset_name

def get_yearly_dataset_name(prefix, train_start, test_start, test_end):
    # Extract year and month
    train_start_str = f"{train_start.year}-{train_start.month:02}"
    test_start_str = f"{test_start.year}-{test_start.month:02}"
    test_end_str = f"{test_end.year}-{test_end.month:02}"

    # Construct the dataset name
    dataset_name = f"{prefix} | {train_start_str} | {test_start_str} | {test_end_str}"
    return dataset_name


In [8]:
#@title add_dataset

def add_dataset(stock_index_name, train_df, test_df):
    if 'datasets' not in globals():
        global datasets
        datasets = {}

    # Ensure datetime format
    if 'date' in train_df.columns:
        train_df.set_index('date', inplace=True)
    train_df.index = pd.to_datetime(train_df.index)

    if 'date' in test_df.columns:
        test_df.set_index('date', inplace=True)
    test_df.index = pd.to_datetime(test_df.index)

    train_start_date = train_df.index[0]
    test_start_date = test_df.index[0]
    test_end_date = test_df.index[-1]

    dataset_name = get_yearly_dataset_name(
        stock_index_name,
        train_start_date, test_start_date, test_end_date
    )

    train_df.reset_index(inplace=True)
    test_df.reset_index(inplace=True)

    train_df = set_daily_index(train_df)
    test_df = set_daily_index(test_df)

    ticker_list = train_df.tic.unique().tolist()

    datasets[dataset_name] = {
        'train': train_df,
        'test': test_df,
        'metadata': dict(
            stock_index_name = stock_index_name,
            train_start_date = train_start_date,
            test_start_date = test_start_date,
            test_end_date = test_end_date,
            num_tickers = len(ticker_list),
            ticker_list = ticker_list,
        )
    }

# Load data

## DATA: DOW-30 (quarterly train/val/test)

In [9]:
train_start_date = '2015-01-01'
min_test_start_date = '2016-01-01'
max_test_end_date = '2016-10-01'

In [10]:
#@title generate_quarterly_date_ranges

def generate_quarterly_date_ranges(
    train_start_date,
    min_test_start_date,
    max_test_end_date,
    return_strings=False,
    finetune_previous_val=False
):
    is_quarter_start = lambda date: date.month in [1, 4, 7, 10] and date.day == 1

    min_test_start_date = pd.Timestamp(min_test_start_date)
    train_start_date = pd.Timestamp(train_start_date)
    max_test_end_date = pd.Timestamp(max_test_end_date)

    assert is_quarter_start(train_start_date), f"train_start_date {train_start_date} is not a quarter start date."
    assert is_quarter_start(min_test_start_date), f"min_test_start_date {min_test_start_date} is not a quarter start date."

    test_start_date = min_test_start_date
    date_ranges = []
    full_train_start_date = train_start_date

    while True:
        val_start_date = test_start_date - pd.DateOffset(months=3)
        test_end_date = test_start_date + pd.DateOffset(months=3)

        if test_end_date > max_test_end_date:
            break

        if len(date_ranges) == 0:
            # The first date_range contains the full training period
            train_start_date = full_train_start_date
        elif finetune_previous_val:
            # Use the previous validation range as the training range
            train_start_date = date_ranges[-1]['val_start_date']

        date_range = dict(
            train_start_date=train_start_date,
            val_start_date=val_start_date,
            test_start_date=test_start_date,
            test_end_date=test_end_date,
        )

        if return_strings:
            date_range = {k: str(v) for k, v in date_range.items()}

        date_ranges.append(date_range)

        test_start_date = test_end_date

    return date_ranges

date_ranges = generate_quarterly_date_ranges(
    train_start_date,
    min_test_start_date,
    max_test_end_date,
    finetune_previous_val=True
)

# print(*date_ranges[:2], sep='\n')
print(*date_ranges, sep='\n')

{'train_start_date': Timestamp('2015-01-01 00:00:00'), 'val_start_date': Timestamp('2015-10-01 00:00:00'), 'test_start_date': Timestamp('2016-01-01 00:00:00'), 'test_end_date': Timestamp('2016-04-01 00:00:00')}
{'train_start_date': Timestamp('2015-10-01 00:00:00'), 'val_start_date': Timestamp('2016-01-01 00:00:00'), 'test_start_date': Timestamp('2016-04-01 00:00:00'), 'test_end_date': Timestamp('2016-07-01 00:00:00')}
{'train_start_date': Timestamp('2016-01-01 00:00:00'), 'val_start_date': Timestamp('2016-04-01 00:00:00'), 'test_start_date': Timestamp('2016-07-01 00:00:00'), 'test_end_date': Timestamp('2016-10-01 00:00:00')}


In [11]:
#@title split_data

def split_data(data_df, date_range, date_col_name='date'):
    def subset_date_range(df, start_date, end_date):
        df = df[(df[date_col_name] >= start_date) & (df[date_col_name] < end_date)]
        df = fix_daily_index(df, date_col_name)
        return df

    return {
        'train': subset_date_range(data_df, date_range['train_start_date'], date_range['val_start_date']),
        'val': subset_date_range(data_df, date_range['val_start_date'], date_range['test_start_date']),
        'test': subset_date_range(data_df, date_range['test_start_date'], date_range['test_end_date']),
    }

# data_splits = split_data(preproc_df, date_ranges[0])
# data_splits['train'].head()

# Main

## Wandb utils

### Artifact management

In [12]:
#@title update_artifact

def update_artifact(folder_path, name_prefix, type):
    """
    Create or update a W&B artifact consisting of a folder.

    Args:
        run: The current W&B run.
        folder_path (str): Path to the folder to upload.
        artifact_name (str): Name of the artifact.
        artifact_type (str): Type of the artifact.
    """
    run = wandb.run
    artifact_name = f'{name_prefix}-{wandb.run.id}'

    # Create a new artifact
    artifact = wandb.Artifact(name=artifact_name, type=type)

    # Add the folder to the artifact
    artifact.add_dir(folder_path)

    # Log the artifact to W&B
    run.log_artifact(artifact)
    print(f"Artifact '{artifact_name}' has been updated and uploaded.")

In [13]:
#@title update_model_artifacts

def update_model_artifacts(log_results_folder=True):
    if log_results_folder:
        update_artifact(
            folder_path = RESULTS_DIR,
            name_prefix = 'results',
            type = 'results'
        )

    update_artifact(
        folder_path = TRAINED_MODEL_DIR,
        name_prefix = 'trained_models',
        type = 'trained_models'
    )

In [14]:
#@title update_dataset_artifact

from pathlib import Path

def update_dataset_artifact(config, train_df, val_df=None, test_df=None):
    DATASET_DIR = Path('./dataset')
    os.makedirs(DATASET_DIR, exist_ok=True)

    train_df.to_csv(DATASET_DIR / 'train_data.csv', index=False)

    if test_df is not None:
        test_df.to_csv(DATASET_DIR / 'test_data.csv', index=False)

    if val_df is not None:
        val_df.to_csv(DATASET_DIR / 'val_data.csv', index=False)

    update_artifact(
        folder_path = DATASET_DIR,
        name_prefix = 'dataset',
        type = 'dataset'
    )

In [15]:
#@title update_env_state_artifact

def update_env_state_artifact(val_env_end_state):
    file_path = 'val_env_end_state.csv'

    df_last_state = pd.DataFrame({"last_state": val_env_end_state})
    df_last_state.to_csv(
        file_path, index=False
    )

    artifact = wandb.Artifact(name=f'val_env_end_state-{wandb.run.id}', type='env_state')
    artifact.add_file(file_path)
    wandb.run.log_artifact(artifact)

In [16]:
#@title Download artifacts

def download_artifacts(run_id, artifact_types):
    # Retrieve the run
    global WANDB_API

    run = WANDB_API.run(f"{ENTITY}/{PROJECT}/{run_id}")

    # Iterate over the artifacts used or logged by the run
    for artifact in run.logged_artifacts():
        if artifact.type in artifact_types:
            artifact_folder = f'./{artifact.type}'
            !rm -rf artifact_folder
            artifact.download(artifact_folder)

### Run management

In [17]:
import wandb
WANDB_API = wandb.Api() # instantiate API once

In [18]:
#@title get_config_hash

import json
from hashlib import sha256

def dict_to_canonical_string(cfg_dict):
    return json.dumps(cfg_dict, sort_keys=True)

def get_config_hash(run_config):
    config_seed = run_config['seed']
    config_training_params = run_config['training_params']
    config_canonical_str = dict_to_canonical_string({
        'seed': config_seed,
        'training_params': config_training_params
    })
    config_hash = sha256(config_canonical_str.encode()).hexdigest()
    return config_hash

In [19]:
#@title group_by_hash_and_sort_by_date (for all runs in a sweep)

from collections import OrderedDict
from datetime import datetime

def group_by_hash_and_sort_by_date(
        sweep_id,
        return_inner_list=True
    ):
    global WANDB_API

    sweep_runs = WANDB_API.sweep(f"{ENTITY}/{PROJECT}/{sweep_id}").runs

    prev_run_id = None
    prev_run_id_list = []

    # Collect run IDs and their test_start_date, grouped by config_hash
    config_hash_to_runs = {}
    for run in sweep_runs:
        config_hash = get_config_hash(run.config)
        if config_hash not in config_hash_to_runs:
            config_hash_to_runs[config_hash] = {}

        config_hash_to_runs[config_hash][run.id] = run.config['date_range']['test_start_date']

    # Sort runs for current config hash by test_start_date
    for config_hash, runs_dict in config_hash_to_runs.items():
        sorted_runs_dict = OrderedDict(
            sorted(
                runs_dict.items(),
                key=lambda x: datetime.strptime(x[1], "%Y-%m-%d %H:%M:%S"),
            )
        )

        if return_inner_list:
            config_hash_to_runs[config_hash] = list(sorted_runs_dict.items())
        else:
            config_hash_to_runs[config_hash] = sorted_runs_dict

    return config_hash_to_runs

In [20]:
#@title find_previous_run (sweep)

def find_prev_run(sweep_id, curr_run_id):
    global WANDB_API

    sweep_runs = WANDB_API.sweep(f"{ENTITY}/{PROJECT}/{sweep_id}").runs
    curr_run = WANDB_API.run(f"{ENTITY}/{PROJECT}/{curr_run_id}")

    prev_run = None
    for run in sweep_runs:
        # find run with same `hash_config`
        # and current previous `test_start_date == current `val_start_date`
        run_hash = get_config_hash(run.config)
        curr_run_hash = get_config_hash(curr_run.config)
        if run_hash == curr_run_hash \
        and run.config['date_range']['test_start_date'] == curr_run.config['date_range']['val_start_date']:
            prev_run = run
            break

    return prev_run

# SWEEP_ID = 'yd7fz9as'
# CURR_RUN_ID = 'rbug5oxn'
# prev_run = find_prev_run(SWEEP_ID, CURR_RUN_ID)
# prev_run.id

## Build & helper funcs

In [21]:
#@title build_yearly_train_test
def build_yearly_train_test(config):
    train_start_date, test_start_date, test_end_date = generate_yearly_train_test_dates(
        config['train_years_count'],
        config['test_years_count'],
        config['test_start_year']
    )

    train_df = preproc_df[(preproc_df['date'] >= train_start_date) & (preproc_df['date'] < test_start_date)]
    test_df = preproc_df[(preproc_df['date'] >= test_start_date) & (preproc_df['date'] < test_end_date)]

    train_df = set_daily_index(train_df)
    test_df = set_daily_index(test_df)

    dataset_name = get_yearly_dataset_name(
        config['stock_index_name'], train_start_date, test_start_date, test_end_date
    )

    config.update(dict(
        train_start_date=train_start_date,
        test_start_date=test_start_date,
        test_end_date=test_end_date,
        dataset_name=dataset_name
    ))

    update_dataset_artifact(
        config,

        train_df=train_df,
        val_df=val_df,
        test_df=test_df,
    )
    return train_df, test_df

In [22]:
#@title build_quarterly_train_val_test
from finrl.meta.data_processors.processor_yahoofinance import YahooFinanceProcessor

def build_quarterly_train_val_test(data, data_processor, date_range, config):
    train_np_env_config = get_np_env_config(
        data,
        data_processor,
        start_date=date_range['train_start_date'],
        end_date=date_range['val_start_date'],
        if_train=True
    )

    val_np_env_config = get_np_env_config(
        data,
        data_processor,
        start_date=date_range['val_start_date'],
        end_date=date_range['test_start_date'],
        if_train=False
    )

    test_np_env_config = get_np_env_config(
        data,
        data_processor,
        start_date=date_range['test_start_date'],
        end_date=date_range['test_end_date'],
        if_train=False
    )

    config.update({
        "train.num_datapoints": len(train_np_env_config['price_array']),
        "val.num_datapoints": len(val_np_env_config['price_array']),
        "test.num_datapoints": len(test_np_env_config['price_array']),
    })

    return dict(
        train_np_env_config = train_np_env_config,
        val_np_env_config = val_np_env_config,
        test_np_env_config = test_np_env_config
    )

# train_np_env_config, val_np_env_config, test_np_env_config = build_quarterly_train_val_test(config)

In [23]:
#@title load_cached_data (w/ tech indicator padding)
from pathlib import Path
from bisect import bisect_left
import pandas as pd
import hashlib
from finrl.config_tickers import DOW_30_TICKER

NY = "America/New_York"

CACHE_DIR = './cache'
os.makedirs(CACHE_DIR, exist_ok=True)

def stable_hash(data):
    return hashlib.sha256(str(data).encode()).hexdigest()

def load_cached_data(
    start_date: str,
    end_date: str,
    ticker_list=DOW_30_TICKER,
    technical_indicator_list=INDICATORS,
    extra_indicator_list=None,
    time_interval='1d',
    if_vix=True,
    ignore_cache=False,
    tech_indicator_padding=60,
):

    """
    Load data with a buffer of N trading days before start_date, so that
    rolling tech indicators can be computed properly within the environment.

    tech_indicator_padding: how many extra timestamps to load for correct tech indicators calculation
    """

    # Create the DataProcessor first so we can use .get_trading_days(...)
    dp = DataProcessor(
        data_source='yahoofinance',
        tech_indicator=technical_indicator_list,
        extra_indicator=extra_indicator_list,
        vix=if_vix,
    )

    # 1) Define a big offset (e.g. 6 months) before start_date to gather trading days
    #    (We assume 6 months is enough to include all desired tech_indicator_padding.)
    raw_buffer_start = pd.Timestamp(start_date) - pd.DateOffset(months=6)

    # 2) Get *all* trading days from that offset up to the real end_date
    all_trading_days = dp.processor.get_trading_days(
        start=str(raw_buffer_start.date()),
        end=str(pd.Timestamp(end_date).date())
    )

    # 3) Find the index of the first trading day >= start_date
    #    We compare strings "YYYY-MM-DD", so make sure to align formats.
    target_str = str(pd.Timestamp(start_date).date())
    i = bisect_left(all_trading_days, target_str)

    # 4) Subtract tech_indicator_padding from that index (clamp to zero)
    earliest_idx = max(0, i - tech_indicator_padding)
    actual_start_date = all_trading_days[earliest_idx]

    # Build a stable hash for caching
    data_hash = stable_hash(tuple(sorted(ticker_list)
                                  + sorted(technical_indicator_list)
                                  + sorted(extra_indicator_list) if extra_indicator_list else []
                                  + [if_vix, time_interval]))

    # File path includes the actual_start_date so we pick up correct caching
    file_path = Path(CACHE_DIR) / (
        f"{actual_start_date}_{end_date}_{time_interval}_{data_hash}.csv"
    )

    if file_path.is_file() and not ignore_cache:
        print(f"Using cached data: {file_path}")
        data = pd.read_csv(file_path, index_col=0)
        data['timestamp'] = (
            pd.to_datetime(data['timestamp'], utc=True)
              .dt.tz_convert(NY)
              .dt.tz_localize(None)
        )
        return data, dp
    else:
        print("Creating new data (no suitable cache).")
        # 5) Download from the earlier (buffered) start date
        data = dp.download_data(
            ticker_list=ticker_list,
            start_date=actual_start_date,
            end_date=end_date,
            time_interval=time_interval
        )

        data = dp.clean_data(data)
        # Add VIX or turbulence
        if if_vix:
            data = dp.add_vix(data)
        else:
            data = dp.add_turbulence(data)

        data = dp.add_technical_indicator(data, technical_indicator_list, extra_indicator_list)

        # -------------------------------------------
        # (6) Trim out rows before the *original* start_date
        # so that the final returned DataFrame is the "original size".
        start_ts = pd.Timestamp(start_date).tz_localize(NY)
        data = data[data["timestamp"] >= start_ts]
        data.reset_index(drop=True, inplace=True)
        # -------------------------------------------

        # Save to cache
        data.to_csv(file_path)

        # Convert timestamps
        data['timestamp'] = (
            pd.to_datetime(data['timestamp'], utc=True)
              .dt.tz_convert(NY)
              .dt.tz_localize(None)
        )

        return data, dp

In [24]:
#@title get_np_env_config

def get_np_env_config(
    data: pd.DataFrame,
    data_processor: DataProcessor,
    start_date,
    end_date,
    if_train,
    use_extra_indicators=False,
    if_extra_indicators_tech=False
):
    """
    Uses data_processor to convert data to array.
    WARNING: data_processor should be same as the one that was used to create the data.
    """

    start_date = pd.Timestamp(start_date)
    end_date = pd.Timestamp(end_date)

    data = data[(data['timestamp'] >= start_date) & (data['timestamp'] < end_date)]

    # print(f'use_extra_indicators: {use_extra_indicators}')
    (
        price_array,
        tech_array,
        turbulence_array,
        timestamp_array,
        *extra_arrays
    ) = data_processor.df_to_array(
        df=data,
        return_timestamps=True,
        use_extra_indicators=use_extra_indicators,
        if_extra_indicators_tech=if_extra_indicators_tech,
    )

    env_config = {
        "price_array": price_array,
        "tech_array": tech_array,
        "turbulence_array": turbulence_array,
        "timestamp_array": timestamp_array,
        "if_train": if_train
    }

    extra_cols = data_processor.extra_indicator_list
    if extra_cols is not None:
        env_config.update({
            col: array for col, array in zip(extra_cols, extra_arrays)
        })

    return env_config

In [25]:
#@title Init NUMPY StockTradingEnv

from finrl.meta.env_stock_trading.env_stocktrading_np import StockTradingEnv
from finrl.meta.data_processor import DataProcessor
from finrl.config_tickers import DOW_30_TICKER
from finrl.config import INDICATORS
# from finrl.config import CACHE_DIR

def init_np_env(
    np_env_config,

    initial_amount,
    cost_pct,

    mode,
    turbulence_threshold=99,
    initial_stocks=None
):
    assert mode in ['train', 'val', 'test']

    print(f"Initializing {'train' if np_env_config['if_train'] else 'eval'} env...", end=' ')
    env = StockTradingEnv(
        config=np_env_config,
        initial_capital=initial_amount,
        buy_cost_pct=cost_pct,
        sell_cost_pct=cost_pct,
        turbulence_thresh=turbulence_threshold,
        initial_stocks=np.array(initial_stocks) if isinstance(initial_stocks, list) else initial_stocks
    )
    print('Done.')

    return env

def create_np_stock_trading_env(env_config):
    return init_np_env(**env_config)

In [26]:
#@title Init STOPLOSS env

from finrl.meta.env_stock_trading.env_stocktrading_stoploss import StockTradingEnvStopLoss

def create_stoploss_stock_trading_env(env_config):
    min_date = env_config['df']['timestamp'].min()
    max_date = env_config['df']['timestamp'].max()
    print(f"Creating Stop Loss env from {min_date} to {max_date}")

    env = StockTradingEnvStopLoss(
        df = env_config['df'],
        date_col_name = 'timestamp',

        buy_cost_pct = env_config['cost_pct'],
        sell_cost_pct = env_config['cost_pct'],
        initial_amount = env_config['initial_amount'],
        discrete_actions = env_config['discrete_actions'],
        cache_indicator_data =  env_config['cache_indicator_data'],
        patient = env_config['patient'],
        print_verbosity = env_config['print_verbosity'],
    )

    return env

In [27]:
CREATE_ENV_FN = {
    'np': create_np_stock_trading_env,
    'stoploss': create_stoploss_stock_trading_env
}

In [28]:
#@title Define metric functions

def calculate_mdd(asset_values):
    """
    Calculate the Maximum Drawdown (MDD) of a portfolio.
    """
    running_max = asset_values.cummax()
    drawdown = (asset_values - running_max) / running_max
    mdd = drawdown.min() * 100  # Convert to percentage
    return mdd

def calculate_sharpe_ratio(asset_values, risk_free_rate=0.0):
    """
    Calculate the Sharpe Ratio of a portfolio.
    """
    # Calculate daily returns
    returns = asset_values.pct_change().dropna()
    excess_returns = returns - risk_free_rate / 252  # Assuming 252 trading days

    if excess_returns.std() == 0:
        return 0.0
    sharpe_ratio = excess_returns.mean() / excess_returns.std() * np.sqrt(252)  # Annualized
    return sharpe_ratio

def calculate_annualized_return(asset_values):
    """
    Calculate the annualized return of a portfolio.
    """
    # Assume `asset_values` is indexed by date or trading day
    total_return = (asset_values.iloc[-1] / asset_values.iloc[0] - 1) * 100
    num_days = (asset_values.index[-1] - asset_values.index[0]).days
    annualized_return = (1 + total_return) ** (365 / num_days) - 1
    return annualized_return

In [29]:
#@title compute metrics
import wandb
from typing import List
import numpy as np

def compute_metrics(account_values: List[pd.DataFrame, pd.Series, np.array], use_round=True):
    """
    If DataFrame then should contain two columns - 'date' and name of algo, e.g. 'a2c'.
    """

    if isinstance(account_values, pd.DataFrame):
        assert isinstance(account_values, pd.DataFrame)
        if 'date' not in account_values.columns:
            if account_values.index.name == 'date':
                account_values.reset_index(inplace=True)
            else:
                raise ValueError("should contain 'date' column or index")
        account_values = account_values.dropna().set_index('date').iloc[:, 0]
    elif isinstance(account_values, np.ndarray):
        account_values = pd.Series(account_values)

    sharpe = calculate_sharpe_ratio(account_values)
    mdd = calculate_mdd(account_values)
    cum_ret = (account_values.iloc[-1] - account_values.iloc[0]) / account_values.iloc[0] * 100
    # num_days = (account_values.index.max() - account_values.index.min()).days
    num_days = len(account_values)
    ann_ret = ((1 + cum_ret / 100) ** (365 / num_days) - 1) * 100

    metrics = {
            f'sharpe_ratio': sharpe,
            f'mdd': mdd,
            f'ann_return': ann_ret,
            f'cum_return': cum_ret,
        }

    if use_round:
        metrics = {k: round(v, 2) for k, v in metrics.items()}

    return metrics

def get_env_metrics(env):
    end_total_asset = env.state[0] + sum(
        np.array(env.state[1 : (env.stock_dim + 1)])
        * np.array(env.state[(env.stock_dim + 1) : (env.stock_dim * 2 + 1)])
    )

    return {
        'begin_total_asset': env.asset_memory[0],
        'end_total_asset': end_total_asset,
        'total_cost': env.cost,
        'total_trades': env.trades,
    }

In [30]:
#@title log_metrics

def log_metrics(metrics, model_name, split_label, step=None):
    print(f'log_metrics for {model_name}')

    rename_metrics = lambda model_name: {
        f"{key}/{model_name}": value for key, value in metrics.items()
    }

    renamed_metrics = rename_metrics(model_name)
    wandb.log({split_label: renamed_metrics}, step=step)
    # wandb.run.save()

In [31]:
#@title update_best_model_metrics

def update_best_model_metrics(metrics, model_name, split_label):
    if 'sharpe_ratios' not in wandb.run.config:
        wandb.run.config['sharpe_ratios'] = {}

    if split_label not in wandb.run.config['sharpe_ratios']:
        wandb.run.config['sharpe_ratios'][split_label] = {}

    sharpe_ratios = wandb.run.config['sharpe_ratios'][split_label]

    print(f"DEBUG ({split_label}): run.id = {wandb.run.id}")
    print(f"DEBUG ({split_label}): sharpe_ratios = {sharpe_ratios}")
    print(f"DEBUG ({split_label}): updating best model based on sharpe_ratios: {sharpe_ratios}")
    if len(sharpe_ratios) > 0:
        best_model_name = max(sharpe_ratios, key=sharpe_ratios.get)
        if metrics['sharpe_ratio'] > sharpe_ratios[best_model_name]:
            print(
                f"DEBUG ({split_label}): {round(metrics['sharpe_ratio'], 2)} ({model_name})"
                f" > {round(sharpe_ratios[best_model_name], 2)} ({best_model_name})"
                f". New best model: {model_name}."
            )
            log_metrics(metrics, 'best_model', split_label)
            wandb.log({split_label: {'best_model_name': model_name}})
        else:
            print(
                f"DEBUG ({split_label}): {round(metrics['sharpe_ratio'], 2)} ({model_name})"
                f" <= {round(sharpe_ratios[best_model_name], 2)} ({best_model_name})"
                ". Not updating best model."
            )
    else:
        print(f"DEBUG ({split_label}): no models logged yet, new best model is current one: {model_name}")
        print(f"DEBUG ({split_label}): wandb.run.config['sharpe_ratios'] = {wandb.run.config['sharpe_ratios']}")
        log_metrics(metrics, 'best_model', split_label)

    wandb.run.config['sharpe_ratios'][split_label][model_name] = metrics['sharpe_ratio']

# Config

In [32]:
#@title init config
parameters_dict = {}
sweep_config = {
    'method': 'grid',
    'metric': {
        'name': 'test.sharpe_ratio/best_model'
    },
    'parameters': parameters_dict
}

In [33]:
#@title CONFIG: create dataset - yearly_train_test

min_test_start_year = 2020
max_test_start_year = 2025

########################################################

yearly_dataset_params = dict(
    # dataset_period = {'value': 'year'},
    # dataset_splits = {'parameters': {
    #     'train': {'value': True },
    #     'val': {'value': False },
    #     'test': {'value': True },
    # }},

    dataset_type = {'value':'yearly_train_test'},

    stock_index_name = {'value': 'DOW-30'},

    train_years_count = {'value': 10},
    test_years_count = {'value': 1},
    test_start_year = {
        'values': list(range(min_test_start_year, max_test_start_year))
    }
)

In [71]:
#@title CONFIG: create dataset - quarterly_train_test

# Full date range (18 backtests)
# train_start_date = '2009-01-01'
# min_test_start_date = '2016-01-01'
# max_test_end_date = '2020-08-05'

# Single train/val/test split (9 months = 3 quarters) for smoke test
# train_start_date = '2015-01-01'
# min_test_start_date = '2016-01-01'
# max_test_end_date = '2016-10-01'

# Up to 2025
train_start_date = '2009-01-01'
min_test_start_date = '2016-01-01'
max_test_end_date = '2025-01-01'

NUM_DATE_RANGES = None
# NUM_DATE_RANGES = 2

# NUM_RUNS_PER_DATERANGE = 1
NUM_RUNS_PER_DATERANGE = 1

#################################################################

date_ranges = generate_quarterly_date_ranges(
    train_start_date,
    min_test_start_date,
    max_test_end_date,
    return_strings=True,
    finetune_previous_val=True
)

truncated_date_ranges = date_ranges[:NUM_DATE_RANGES]
copied_or_truncated_date_ranges = [
    date_range
    for date_range in truncated_date_ranges
    for _ in range(NUM_RUNS_PER_DATERANGE)
]

quarterly_dataset_params = dict(
    dataset_type = {'value': 'quarterly_train_val_test'},
    stock_index_name = {'value': 'DOW-30'},
    train_start_date = {'value': train_start_date},
    min_test_start_date = {'value': min_test_start_date},
    max_test_end_date = {'value': max_test_end_date},
    date_range = {
        # 'values': copied_or_truncated_date_ranges
        'values': truncated_date_ranges
    }
)

In [35]:
#@title CONFIG: choose dataset
parameters_dict.update(
    # yearly_dataset_params,
    quarterly_dataset_params
)

In [36]:
#@title CONFIG: number of seeds

NUM_SEEDS = 1

parameters_dict.update({
    'seed': {'values': (np.random.randn(NUM_SEEDS) * 1e8).astype(int).tolist()}
})

In [77]:
#@title CONFIG: env params

# ENV_CLASS = 'np'
ENV_CLASS = 'stoploss'

###########################
assert ENV_CLASS in ['np', 'stoploss']

parameters_dict.update(dict(
    env_class = {'value': ENV_CLASS},
    cost_pct = {'value': 1e-3},
    initial_amount = {'value': 50_000},
    turbulence_threshold = {
        'value': 99,
        # 'values': [30, 40, 50, 60, 70]
    },
    # eval_turbulence_thresh = {'value': 25},
    if_vix = {'value': True}
))

if ENV_CLASS == 'stoploss':
    parameters_dict.update({
        'discrete_actions': {'value': True},
        'patient': {'value': True},
        'cache_indicator_data': {'value': True},
        'print_verbosity': {'value': 0}
    })

In [38]:
#@title CONFIG: models_used

MODELS_USED = [
    'ppo'
]

parameters_dict.update({
    'models_used': {'value': MODELS_USED}
})

# TODO: remove legacy config used for backward compat
parameters_dict.update({
    'if_using_a2c': {'value': 'ppo' in MODELS_USED},
    'if_using_ddpg': {'value': 'ddpg' in MODELS_USED},
    'if_using_ppo': {'value': 'ppo' in MODELS_USED},
    'if_using_td3': {'value': 'td3' in MODELS_USED},
    'if_using_sac': {'value': 'sac' in MODELS_USED}
})

In [76]:
#@title CONFIG: model and training params

training_params = {
    "parameters": {
        "ppo": {
            "parameters": dict(
                steps={"value": 1}, # smoke test
                train_batch_size={"value": 128}, # smoke test

                # steps={"value": 8_192},
                # train_batch_size={"value": 2048},

                minibatch_size={"value": 128},
                num_epochs={"value": 10},
                lr={"value": 5e-5},
                gamma={"value": 0.99},
            )
        }
    }
}

env_runners_params = {
    'parameters': dict(
        num_envs_per_env_runner = {'value': 1},
        num_env_runners = {'value': 0},
    )
}

parameters_dict.update({'env_runners_params': env_runners_params})
parameters_dict.update({'training_params': training_params})

In [40]:
#@title CONFIG: quantile gridsearch

# quantile_gridsearch_params = dict(
#     lookback_window = {'value': 63}, # one quarter
#     is_expanding_insample = {'value': True},
#     is_sliding_window = {'value': True},

#     # q_lower = {'value': 0.1},
#     # q_upper = {'value': 0.9},
#     # q_step = {'value': 0.1},

#     q_lower = {'value': 0.5},
#     q_upper = {'value': 0.5},
#     q_step = {'value': 0.1},
# )

# parameters_dict.update(quantile_gridsearch_params)

# Train & eval funcs

In [41]:
#@title MetricsLoggerCallback (class)
from ray.rllib.algorithms.callbacks import DefaultCallbacks
from typing import Optional, Sequence
import gymnasium as gym
from gymnasium.vector import AsyncVectorEnv

AVAILABLE_ENV_CLASSES = ['np', 'stoploss']

class MetricsLoggerCallback(DefaultCallbacks):
    def __init__(self, model_name, env_class, ema_coeff=0.2, ma_window=20, log_to_wandb=False):
        assert env_class in AVAILABLE_ENV_CLASSES

        super().__init__()

        self.model_name = model_name
        self.env_class = env_class
        self.ema_coeff = ema_coeff
        self.ma_window = ma_window
        self.metric_names = set()
        self.log_to_wandb = log_to_wandb

    def unwrap_env(self, env):
        env = env.unwrapped
        # print(type(env))
        # (SingleAgentEnvRunner pid=127688) <class 'gymnasium.vector.sync_vector_env.SyncVectorEnv'>

        if isinstance(env, AsyncVectorEnv):
            return env
        else:
            env = env.envs[0]
            # print(type(env))
            # (SingleAgentEnvRunner pid=127688) <class 'gymnasium.wrappers.common.OrderEnforcing'>

            env = env.env
            # print(type(env))
            # (SingleAgentEnvRunner pid=127688) <class 'gymnasium.wrappers.common.PassiveEnvChecker'>

            env = env.env
            # print(type(env))
            # (SingleAgentEnvRunner pid=127688) <class 'finrl.meta.env_stock_trading.env_stocktrading.StockTradingEnv'>
        return env

    def on_episode_step(
            self,
            *,
            episode,
            env_runner,
            metrics_logger,
            env,
            env_index,
            rl_module,
            **kwargs,
        ) -> None:

        env = self.unwrap_env(env)

        if isinstance(env, AsyncVectorEnv):
            df_account_value = env.get_attr('asset_memory')
            df_account_value = pd.concat([pd.Series(av) for av in df_account_value], axis=1).mean(axis=1)

            # TODO: save_asset_memory
            raise NotImplementedError
        else:
            # df_account_value = env.asset_memory
            df_account_value = env.save_asset_memory()

            if self.env_class == 'np':
                df_account_value = df_account_value.rename(
                    columns={'account_value': self.model_name.upper()}
                )
            elif self.env_class == 'stoploss':
                df_account_value = df_account_value[['date', 'total_assets']].rename(
                    columns={'total_assets': self.model_name.upper()}
                )

        # print(f"Env {env_index} day: {env.day}")
        metrics = compute_metrics(df_account_value)

        # mode = env.mode
        for metric_name, metric_value in metrics.items():
            # metric_name = f"{mode}/{metric_name}" if mode != "" else metric_name
            episode.add_temporary_timestep_data(metric_name, metric_value)
            self.metric_names.update([metric_name])

    def on_episode_end(
            self,
            *,
            episode,
            env_runner,
            metrics_logger,
            env,
            env_index,
            rl_module,
            **kwargs,
        ) -> None:

        for metric_name in self.metric_names:
            metric_values = episode.get_temporary_timestep_data(metric_name)
            metric_value = np.nanmean(np.array(metric_values))

            metrics_logger.log_value(
                metric_name,
                metric_value,
                reduce='mean',
            )

            metrics_logger.log_value(
                f"{metric_name}_EMA_{self.ema_coeff}",
                metric_value,
                reduce='mean',
                ema_coeff=self.ema_coeff
            )

            metrics_logger.log_value(
                f"{metric_name}_MA_{self.ma_window}",
                metric_value,
                reduce='mean',
                window=self.ma_window
            )

            # mode = 'val' if env_runner.config.in_evaluation else 'train'
            # print(f"Metrics for {mode} episode {episode}:")
            # print({
            #     f"{mode}.{metric_name}/{self.model_name}": metric_value,
            # }) # TODO: log on every episode step

            if self.log_to_wandb:
                mode = 'val' if env_runner.config.in_evaluation else 'train'
                wandb.log({
                    f"{mode}.{metric_name}/{self.model_name}": metric_value,
                }) # TODO: log on every episode step



In [42]:
#@title print_result

RESULT_KEYS_TO_INCLUDE = [
    'sharpe_ratio',
    'ann_return',
    'mdd',

    # 'sharpe_ratio_MA',
    # 'ann_return_MA',
    # 'mdd_MA',

    # 'sharpe_ratio_EMA',
    # 'ann_return_EMA',
    # 'mdd_EMA',
]

def print_result(result):
    print()
    print('-' * 40)

    keys_to_print = [
        key
        for include_key in RESULT_KEYS_TO_INCLUDE
        for key in result['env_runners'].keys()
        if key.startswith(include_key)
    ]

    for key in sorted(keys_to_print):
        print(f"train/{key}: {round(result['env_runners'][key], 2)}")

    if 'evaluation' in result.keys():
        print('*' * 40)
        keys_to_print = [
            key
            for include_key in RESULT_KEYS_TO_INCLUDE
            for key in result['evaluation']['env_runners'].keys()
            if key.startswith(include_key)
        ]

        for key in sorted(keys_to_print):
            print(f"val/{key}: {round(result['evaluation']['env_runners'][key], 2)}")

        print('-' * 40)
        print()

In [43]:
#@title benchmark_exec_time
import pandas as pd
from time import perf_counter
from functools import wraps

def benchmark_exec_time(func):
    @wraps(func)
    def wrapper(*args, **kwargs):

        start = perf_counter()
        output = func(*args, **kwargs)
        end = perf_counter()

        exec_time_sec = end - start

        data = {
            "func_name": func.__name__,
            "exec_time_sec": exec_time_sec,
        }
        # print(f'\nBenchmark results: {data}')
        return output, exec_time_sec

    return wrapper

In [44]:
#@title Train_eval_models

from ray.rllib.algorithms.ppo import PPOConfig

AVAILABLE_MODELS_CONFIGS = {
    'ppo': PPOConfig
}

def train_eval_rllib_models(
        run_config,
        train_np_env_config,
        val_np_env_config,
        test_np_env_config,
        model_list = ['ppo'], # TODO: discard in favor of 'if_using_{model_name}'
        pretrained_models = {} # pretrained on previous train set (not validation set)
    ):

    assert set(model_list).issubset(AVAILABLE_MODELS_CONFIGS)
    check_and_make_directories([TRAINED_MODEL_DIR])

    create_env_fn = CREATE_ENV_FN[run_config['env_class']]
    register_env("stock_trading_env", create_env_fn)

    for model_name in model_list:
        if run_config[f"if_using_{model_name}"]:
            print(f"Training {model_name.upper()} agent")
            model_config = AVAILABLE_MODELS_CONFIGS[model_name]
            algo = train_rllib_model(
                run_config,
                model_name,
                model_config,
                train_np_env_config,
                val_np_env_config,
                pretrained_ckpt_path = pretrained_models.get(model_name, None)
            )

            print(f"Evaluating {model_name.upper()} agent")
            val_result = evaluate_model(
                algo,
                model_name,
                run_config=run_config,
                np_env_config = val_np_env_config,
                mode='val',
            )
            fig = plot_results(**val_result)
            log_plot_as_artifact(fig, "val_cumulative_return", artifact_type="plot")

            test_result = evaluate_model(
                algo,
                model_name,
                run_config=run_config,
                np_env_config = test_np_env_config,
                mode='test',
            )
            fig = plot_results(**test_result)
            log_plot_as_artifact(fig, "test_cumulative_return", artifact_type="plot")
            pretrained_models[model_name] = algo
        else:
            print(f"Skipping {model_name.upper()} agent")

    return pretrained_models

In [45]:
#@title Train_eval_models (w/ threshold gridsearch)

from ray.rllib.algorithms.ppo import PPOConfig

AVAILABLE_MODELS_CONFIGS = {
    'ppo': PPOConfig
}

def train_eval_rllib_models(
        run_config,

        train_np_env_config,
        val_np_env_config,
        test_np_env_config,

        model_list = ['ppo'], # TODO: discard in favor of 'if_using_{model_name}'
        pretrained_val_models = {},
    ):

    assert set(model_list).issubset(AVAILABLE_MODELS_CONFIGS)
    check_and_make_directories([TRAINED_MODEL_DIR])

    create_env_fn = CREATE_ENV_FN[run_config['env_class']]
    register_env("stock_trading_env", create_env_fn)

    for model_name in model_list:
        if run_config[f"if_using_{model_name}"]:
            model_config = AVAILABLE_MODELS_CONFIGS[model_name]

            # Skip training if a model pretrained on previous validation set is passed.
            # The training should happen only for first run,
            # further runs should reuse pretrained val models.
            if model_name in pretrained_val_models:
                algo_train = pretrained_val_models[model_name]
            else:
                algo_train = train_rllib_model(
                    run_config,
                    model_name,
                    model_config,
                    train_np_env_config,
                    val_np_env_config,
                    pretrained_ckpt_path = None # Train from scratch
                )

            wandb.run.summary['val.num_pretrain_iters'] = algo_train.training_iteration

            # Evaluate on validation set
            # using model trained on current train set (previous val set)
            val_best_th = evaluate_threshold_grid(
                algo_train,
                model_name,
                run_config,
                val_np_env_config,
                split_label='val',
            )

            # Finetune on current validation set (previous test set)
            algo_val = train_rllib_model(
                run_config,
                model_name,
                model_config,
                train_np_env_config,
                val_np_env_config,
                pretrained_ckpt_path = algo_train
            )

            # Evaluate on test set
            # using model finetuned on current validation set (previous test set)
            _ = evaluate_threshold_grid(
                algo_val,
                model_name,
                run_config,
                test_np_env_config,
                split_label='test',
                chosen_th=val_best_th # chosen before fine-tuning
            )

            pretrained_val_models[model_name] = algo_val
        else:
            print(f"Skipping {model_name.upper()} agent")

    return pretrained_val_models

In [46]:
#@title Train_eval_models (w/ quantile gridsearch)

from ray.rllib.algorithms.ppo import PPOConfig

AVAILABLE_MODELS_CONFIGS = {
    'ppo': PPOConfig
}

def train_eval_rllib_models(
        run_config,
        train_np_env_config,
        val_np_env_config,
        test_np_env_config,

        data,
        data_processor,
        current_run_id,
        sweep_api,

        model_list = ['ppo'], # TODO: discard in favor of 'if_using_{model_name}'
        pretrained_val_models = {},
    ):

    assert set(model_list).issubset(AVAILABLE_MODELS_CONFIGS)
    check_and_make_directories([TRAINED_MODEL_DIR])

    create_env_fn = CREATE_ENV_FN[run_config['env_class']]
    register_env("stock_trading_env", create_env_fn)

    for model_name in model_list:
        if run_config[f"if_using_{model_name}"]:
            model_config = AVAILABLE_MODELS_CONFIGS[model_name]

            # Skip training if a model pretrained on previous validation set is passed.
            # The training should happen only for first run,
            # further runs should reuse pretrained val models.
            if model_name in pretrained_val_models:
                algo_train = pretrained_val_models[model_name]
            else:
                algo_train = train_rllib_model(
                    run_config,
                    model_name,
                    model_config,
                    train_np_env_config,
                    val_np_env_config,
                    pretrained_ckpt_path = None # Train from scratch
                )

            wandb.run.summary['val.num_pretrain_iters'] = algo_train.training_iteration

            # Evaluate on validation set

            # Finetune on current validation set (previous test set)
            algo_val = train_rllib_model(
                run_config,
                model_name,
                model_config,
                train_np_env_config,
                val_np_env_config,
                pretrained_ckpt_path = algo_train
            )

            val_rl_module = algo_val.env_runner.module

            # eval_quantile_grid(
            #     model_name,
            #     val_rl_module,
            #     data,
            #     data_processor,
            #     current_run_id,
            #     sweep_api,

            #     SILENT = True,

            #     FILTER_HASH_LIST = None,

            #     LOG_PLOTS = 'last_only',

            #     LIMIT_RUNS = None,

            #     #######################

            #     lookback_window = run_config['lookback_window'],
            #     is_expanding_insample = run_config['is_expanding_insample'],
            #     is_sliding_window = run_config['is_sliding_window'],

            #     q_lower = run_config['q_lower'],
            #     q_upper = run_config['q_upper'],
            #     q_step = run_config['q_step'],
            # )

            pretrained_val_models[model_name] = algo_val
        else:
            print(f"Skipping {model_name.upper()} agent")

    return pretrained_val_models

In [47]:
#@title Train_eval_models (StopLossEnv)

from ray.rllib.algorithms.ppo import PPOConfig

AVAILABLE_MODELS_CONFIGS = {
    'ppo': PPOConfig
}

def train_eval_rllib_models(
        run,

        data,
        data_processor,
        current_run_id,
        sweep_api,

        train_np_env_config=None,
        val_np_env_config=None,
        test_np_env_config=None,

        train_stoploss_data_df = None,
        val_stoploss_data_df = None,
        test_stoploss_data_df = None,

        model_list = ['ppo'], # TODO: discard in favor of 'if_using_{model_name}'
        pretrained_ckpt_paths = {},
    ):

    run_config = run.config

    assert set(model_list).issubset(AVAILABLE_MODELS_CONFIGS)
    check_and_make_directories([TRAINED_MODEL_DIR])

    create_env_fn = CREATE_ENV_FN[run_config['env_class']]
    register_env("stock_trading_env", create_env_fn)

    for model_name in model_list:
        if run_config[f"if_using_{model_name}"]:
            model_config = AVAILABLE_MODELS_CONFIGS[model_name]

            # Skip training if a model pretrained on previous validation set is passed.
            # The training should happen only for first run,
            # further runs should reuse pretrained val models.
            if model_name in pretrained_ckpt_paths:
                train_ckpt_path = pretrained_ckpt_paths[model_name]
            else:
                algo_train, train_ckpt_path = train_rllib_model(
                    run_config,
                    model_name,
                    model_config,

                    train_np_env_config = train_np_env_config,
                    val_np_env_config = val_np_env_config,

                    train_stoploss_data_df = train_stoploss_data_df,
                    val_stoploss_data_df = val_stoploss_data_df,

                    split_name = 'train',

                    pretrained_ckpt_path = None # Train from scratch
                )

            eval_run_config = run_config._as_dict().copy() # HACK: ideally contain settings in run_config['env_config']
            eval_run_config.update({'random_start': False})

            # Evaluate on val
            val_result = evaluate_model(
                algo_train,
                model_name,

                run=run,
                run_config=eval_run_config,

                np_env_config = val_np_env_config,
                stoploss_data_df = val_stoploss_data_df,
                split_label='val',

                log_to_wandb=True,
                return_metrics=False,
            )

            # Finetune on current validation set (previous test set)
            algo_val, val_ckpt_path = train_rllib_model(
                run_config,
                model_name,
                model_config,

                train_np_env_config = val_np_env_config,
                val_np_env_config = None,

                train_stoploss_data_df = val_stoploss_data_df,
                val_stoploss_data_df = None,

                split_name = 'val',

                pretrained_ckpt_path = train_ckpt_path
            )

            # Evaluate on test
            val_result = evaluate_model(
                algo_val,
                model_name,

                run=run,
                run_config=eval_run_config,

                np_env_config = test_np_env_config,
                stoploss_data_df = test_stoploss_data_df,
                split_label='test',

                log_to_wandb=True,
                return_metrics=False,
            )


            pretrained_ckpt_paths[model_name] = val_ckpt_path

        else:
            print(f"Skipping {model_name.upper()} agent")

    return pretrained_ckpt_paths

In [48]:
#@title YahooDownloader (compatible with original FinRL)

from __future__ import annotations

import pandas as pd
import yfinance as yf


class YahooDownloader:
    """Provides methods for retrieving daily stock data from
    Yahoo Finance API

    Attributes
    ----------
        start_date : str
            start date of the data (modified from neofinrl_config.py)
        end_date : str
            end date of the data (modified from neofinrl_config.py)
        ticker_list : list
            a list of stock tickers (modified from neofinrl_config.py)

    Methods
    -------
    fetch_data()
        Fetches data from yahoo API

    """

    def __init__(self, start_date: str, end_date: str, ticker_list: list):
        self.start_date = start_date
        self.end_date = end_date
        self.ticker_list = ticker_list

    def fetch_data(self, proxy=None) -> pd.DataFrame:
        """Fetches data from Yahoo API
        Parameters
        ----------

        Returns
        -------
        `pd.DataFrame`
            7 columns: A date, open, high, low, close, volume and tick symbol
            for the specified stock ticker
        """
        # Download and save the data in a pandas DataFrame:
        data_df = pd.DataFrame()
        num_failures = 0
        for tic in self.ticker_list:
            temp_df = yf.download(
                tic, start=self.start_date, end=self.end_date, proxy=proxy
            )
            temp_df["tic"] = tic
            if len(temp_df) > 0:
                # data_df = data_df.append(temp_df)
                data_df = pd.concat([data_df, temp_df], axis=0)
            else:
                num_failures += 1
        if num_failures == len(self.ticker_list):
            raise ValueError("no data is fetched.")
        # reset the index, we want to use numbers as index instead of dates
        data_df = data_df.reset_index()

        try:
            # Convert wide to long format
            # print(f"DATA COLS: {data_df.columns}")
            data_df = data_df.sort_index(axis=1).set_index(['Date']).drop(columns=['tic']).stack(level='Ticker', future_stack=True)
            data_df.reset_index(inplace=True)
            data_df.columns.name = ''

            # convert the column names to standardized names
            data_df.rename(columns={'Ticker': 'Tic', 'Adj Close': 'Adjcp'}, inplace=True)
            data_df.rename(columns={col: col.lower() for col in data_df.columns}, inplace=True)

            columns = [
                "date",
                "tic",
                "open",
                "high",
                "low",
                "close",
                # "adjcp",
                "volume",
            ]

            data_df = data_df[columns]
            if 'adjcp' in data_df.columns:
                # use adjusted close price instead of close price
                data_df["close"] = data_df["adjcp"]
                # drop the adjusted close price column
                data_df = data_df.drop(labels="adjcp", axis=1)

        except NotImplementedError:
            print("the features are not supported currently")

        # create day of the week column (monday = 0)
        data_df["day"] = data_df["date"].dt.dayofweek
        # convert date to standard string format, easy to filter
        data_df["date"] = data_df.date.apply(lambda x: x.strftime("%Y-%m-%d"))
        # drop missing data
        data_df = data_df.dropna()
        data_df = data_df.reset_index(drop=True)
        print("Shape of DataFrame: ", data_df.shape)
        # print("Display DataFrame: ", data_df.head())

        data_df = data_df.sort_values(by=["date", "tic"]).reset_index(drop=True)

        return data_df

    def select_equal_rows_stock(self, df):
        df_check = df.tic.value_counts()
        df_check = pd.DataFrame(df_check).reset_index()
        df_check.columns = ["tic", "counts"]
        mean_df = df_check.counts.mean()
        equal_list = list(df.tic.value_counts() >= mean_df)
        names = df.tic.value_counts().index
        select_stocks_list = list(names[equal_list])
        df = df[df.tic.isin(select_stocks_list)]
        return df

In [49]:
#@title load DIJA

def load_djia(
        start_date,
        end_date,
        initial_amount
    ):

    """
    Load prices of index scaled by initial amount
    """

    # Download DIJA
    df_djia = YahooDownloader(
        start_date = start_date,
        end_date = end_date + pd.Timedelta(days=1), # include last day
        ticker_list=['^DJI'] # `dji` is delisted, `DJIA` is an ETF, not an index
    ).fetch_data()


    # Scale DJIA data
    df_djia = df_djia[['date','close']]
    fst_day = df_djia['close'].iloc[0]
    df_djia = pd.DataFrame({
        'date': df_djia['date'],
        'DJIA': df_djia['close'].div(fst_day).mul(initial_amount).values,
    })

    df_djia['date'] = pd.to_datetime(df_djia['date'])

    return df_djia

In [50]:
#@title eval_agg_dija
import wandb

def eval_agg_dija(first_run_or_sweep, last_run, split_label='test', silent=False):
    """
    Compute DJIA metrics aggregated over period
        from first_run `test_start_date`
        till last_run  `test_end_date`
    """

    if isinstance(first_run_or_sweep, wandb.apis.public.sweeps.Sweep):
        # sweep
        test_start_date = first_run_or_sweep.config['parameters']['min_test_start_date']['value']
        initial_amount = first_run_or_sweep.config['parameters']['initial_amount']['value']
    else:
        # first_run
        test_start_date = first_run_or_sweep.config['date_range']['test_start_date']
        initial_amount = first_run_or_sweep.config['initial_amount']

    df_djia = load_djia(
        start_date=pd.Timestamp(test_start_date),
        end_date=pd.Timestamp(last_run.config['date_range']['test_end_date']),
        initial_amount=initial_amount
    )

    djia_metrics = compute_metrics(df_djia)
    print(f'djia_metrics: {djia_metrics}')

    # assert isinstance(last_run, wandb.sdk.wandb_run.Run), "last_run should be activated"

    log_eval_results(
        'djia',
        djia_metrics,
        split_label=split_label,
        metric_prefix='agg',
        run=last_run
    )

    # wandb.finish()

    return df_djia, djia_metrics

# FILTER_HASH = '1ea39a67636c8e66b0064c4cac304a880b02405eff96aa2bc82017e78cd8461f'

# run_list = config_hash_to_runs[FILTER_HASH]
# last_run = get_last_finished_run(run_list)
# first_run = get_first_finished_run(run_list)
# df_djia, djia_metrics = eval_agg_dija(first_run, last_run, silent=True)
# df_djia.head()

In [51]:
#@title Download artifacts
from pathlib import Path

def load_artifacts(run_id, artifact_types, prepend_run_id=False, skip_cache=False):
    assert isinstance(artifact_types, list)

    # Initialize the W&B API
    api = wandb.Api()

    # Retrieve the run
    run = api.run(f"{ENTITY}/{PROJECT}/{run_id}")

    # Iterate over the artifacts used or logged by the run
    artifacts_cnt = {_type: 0 for _type in artifact_types}
    artifacts_folders_per_run = {_type: [] for _type in artifact_types}
    print(len(run.logged_artifacts()))
    for artifact in run.logged_artifacts():
        # print(f'Considering artifact {artifact.name}')

        if "latest" in artifact.aliases \
        and artifact.type in artifact_types:
            artifacts_cnt[artifact.type] += 1

            artifact_folder = Path(f'{artifact.name}')
            if prepend_run_id:
                artifact_folder = run_id / artifact_folder

            artifacts_folders_per_run[artifact.type].append(artifact_folder)

            if not os.path.exists(artifact_folder) or skip_cache:
                !rm -rf {str(artifact_folder)}

                print(f"Downloading artifacts to {artifact_folder}")
                artifact.download(artifact_folder)
            else:
                print(f"Using cached data at {artifact_folder}")


    print(f"Run id: {run_id}")
    print(f'Artifacts downloaded cnt per type: {artifacts_cnt}')
    print(f"Artifacts folders per type: {artifacts_folders_per_run}")
    print()
    return artifacts_folders_per_run

# artifacts_folders_per_run = load_artifacts(
#     run_id='rbug5oxn',
#     artifact_types=['trained_models'],
#     prepend_run_id=True
# )

# artifacts_folders_per_run

In [52]:
#@title get_env_end_state

def get_env_end_state(
    run,
    model_name,
    split_label,
    **fmt_metric_name_kwargs
):

    end_amount_fmt_name = get_formatted_metric_name(
        model_name = model_name,
        metric_name = 'end_amount',
        split_label = split_label,
        **fmt_metric_name_kwargs
    )

    end_stocks_fmt_name = get_formatted_metric_name(
        model_name = model_name,
        metric_name = 'end_stocks',
        split_label = split_label,
        **fmt_metric_name_kwargs
    )

    end_amount = run.summary.get(end_amount_fmt_name, None)
    end_stocks = run.summary.get(end_stocks_fmt_name, None)

    print(f"Env metrics for run {run.id} with test start date {run.config['date_range']['test_start_date']}")
    print(f"Env {end_amount_fmt_name}: {end_amount}")
    print(f"Env {end_stocks_fmt_name}: {end_stocks}")

    return end_amount, end_stocks

# end_amount, end_stocks = get_env_end_state(
#     wandb.run,
#     model_name='ppo',
#     split_label='val'
# )
# end_amount, end_stocks

In [53]:
#@title (func) evaluate SLIDING WINDOW aggregated metrics for quantile GRID (w/plots & DJIA & caching) # TODO

def eval_quantile_grid(
    model_name,
    rl_module,
    data,
    data_processor,
    current_run_id,
    sweep_api,

    SILENT = True,

    FILTER_HASH_LIST = None,
    # FILTER_HASH_LIST = FILTER_HASH_LIST_DICT[SWEEP_ID],

    # LOG_PLOTS = False,
    # LOG_PLOTS = 'all',
    LOG_PLOTS = 'last_only',

    LIMIT_RUNS = None,
    # LIMIT_RUNS = 1,

    #######################

    lookback_window = 63, # one quarter
    is_expanding_insample = True,
    is_sliding_window = True,

    # q_lower = 0.1,
    # q_upper = 0.9,
    # q_step = 0.2,

    q_lower = 0.5,
    q_upper = 0.5,
    q_step = 0.1,
):
    assert is_expanding_insample, "Only expanding in-sample currently supported for on-line sweep"
    assert isinstance(LOG_PLOTS, bool) or LOG_PLOTS in ['last_only', 'all']

    quantile_log_name = 'q'
    quantile_log_name += 'e' if is_expanding_insample else ''
    quantile_log_name += 's' if is_sliding_window else ''

    quantile_grid = np.arange(q_lower, q_upper + q_step, q_step).round(2)

    current_run_api = WANDB_API.run(f"{PROJECT}/{current_run_id}")
    prev_run_id = find_prev_run(sweep_api.id, current_run_id)
    prev_run_api = WANDB_API.run(f"{PROJECT}/{current_run_id}")

    is_last_run = (current_run_api.config['date_range']['test_end_date'] == sweep_api.config['parameters']['max_test_end_date']['value'])

    if_vix = current_run_api.config.get('if_vix', True)
    if if_vix:
        turbulence_log_name = 'vix'
        turbulence_label = 'vix'
    else:
        turbulence_log_name = 'ti'
        turbulence_label = 'turbulence'

    # print(f"\n\nConfig_hash: {config_hash}")

    turbulence_sma_col = f"{turbulence_label}_{lookback_window}_sma"

    # Init collections for expanding results
    agg_turbulence_thresh = {
        'test': {
            # quantile: pd.DataFrame()
            quantile: []
            for quantile in quantile_grid
        }
    }
    agg_asset_values = {
        'test': {
            # quantile: pd.DataFrame()
            quantile: []
            for quantile in quantile_grid
        }
    }
    agg_turbulence_series = {
        'test': {
            # quantile: pd.Series(name=turbulence_label)
            quantile: []
            for quantile in quantile_grid
        }
    }

    prev_run_api = None

    # load full train period for in-sample turbulence
    current_train_np_env_config = get_env_config(
        data,
        data_processor,
        use_extra_indicators=True if is_sliding_window else False, # TODO: actually dont need for train since for sliding window history only val/test turbulence is used
        start_date = sweep_config['parameters']['train_start_date']['value'],
        end_date = current_run_api.config['date_range']['val_start_date'],
        if_train=True,
    )

    # load val & test periods
    val_np_env_config = get_env_config(
        data,
        data_processor,
        use_extra_indicators=True if is_sliding_window else False,
        start_date=current_run_api.config['date_range']['val_start_date'],
        end_date=current_run_api.config['date_range']['test_start_date'],
        if_train=False
    )

    test_np_env_config = get_env_config(
        data,
        data_processor,
        use_extra_indicators=True if is_sliding_window else False,
        start_date=current_run_api.config['date_range']['test_start_date'],
        end_date=current_run_api.config['date_range']['test_end_date'],
        if_train=False,
    )

    # store figures and names, save later in a single artifact
    fig_list = {
        'val': [],
        'test': []
    }

    fig_names = {
        'val': [],
        'test': []
    }

    agg_metrics_per_quantiles = []
    for quantile in quantile_grid:
        if is_expanding_insample:
            insample_turbulence = current_train_np_env_config['turbulence_array']
        else:
            raise NotImplementedError
            # insample_turbulence = first_train_np_env_config['turbulence_array']

        insample_turbulence_threshold = np.quantile(
            insample_turbulence, quantile
        ).round(2)

        # evaluate model for each split
        for split_label, np_env_config in zip(
            ['val', 'test'],
            [val_np_env_config, test_np_env_config],
        ):

            if split_label not in agg_asset_values.keys():
                # aggregated metrics are computed only for test
                continue

            # compute historical turbulence based on current split data inside lookback window
            if is_sliding_window:
                historical_turbulence_mean = np_env_config[turbulence_sma_col]
            else:
                historical_turbulence_mean = np_env_config['turbulence_array'][-lookback_window:].mean()

            turbulence_thresh = np.where(
                historical_turbulence_mean > insample_turbulence_threshold,
                insample_turbulence_threshold, # high volatility -> limit by in-sample turbulence
                np.quantile(insample_turbulence, 1) # low volatility -> allow all actions
            ).round(2)

            print(
                f"Thresholds for quantile={quantile}: ",
                f"\thistorical_turbulence_mean: {historical_turbulence_mean}",
                f"\tinsample_turbulence_threshold: {insample_turbulence_threshold}",
                f"\tturbulence_threshold: {turbulence_thresh}",
                sep='\n',
                end='\n'
            )

            print("[ASSET VALUES] Running inference to obtain asset values.")

            # perform state transfer if previous run exists
            if split_label == 'test' and prev_run_api is not None:
                assert isinstance(prev_run_api, wandb.apis.public.runs.Run)
                prev_end_amount, prev_end_stocks = get_env_end_state(
                    run=prev_run_api,
                    model_name=model_name,
                    split_label=split_label,
                    quantile_log_name=quantile_log_name,
                    quantile_thresh=quantile,
                )
                print(f"Transfering previous run env state with amount: {prev_end_amount}")
            else:
                print("Creating brand new env.")
                prev_end_amount, prev_end_stocks = None, None

            # evaluate model under given quantile
            result, metrics = evaluate_model(
                turbulence_thresh=turbulence_thresh,
                model_name=model_name,
                algo_or_rl_module=rl_module,
                run_config=current_run_api.config,
                np_env_config=np_env_config,
                split_label=split_label,
                log_to_wandb=False,
                return_metrics=True,
                prev_end_amount=prev_end_amount,
                prev_end_stocks=prev_end_stocks,
            )

            metrics.update({
                f"threshold": turbulence_thresh,
                f"insample_threshold": insample_turbulence_threshold,
                f"historical_mean": historical_turbulence_mean,
                f"lookback_window": lookback_window,
            })

            assert turbulence_log_name == 'vix', "TODO: support other tubulence indices logging for quantiles"
            log_eval_results(
                model_name,
                metrics=metrics,
                split_label=split_label,
                turbulence_log_name=None,
                turbulence_thresh=None,
                run=current_run_api,
                metric_prefix=None,
                quantile_log_name=quantile_log_name,
                quantile_thresh=quantile,
            )

            # asset values for current run
            run_asset_values_per_split = result['account_value']

            print('\ntrain_start_date', current_run_api.config['date_range']['train_start_date'])
            print('val_start_date', current_run_api.config['date_range']['val_start_date'])
            print('test_start_date', current_run_api.config['date_range']['test_start_date'])
            print('test_end_date', current_run_api.config['date_range']['test_end_date'])


            # debug dates
            print("\nagg_asset_values dates:")
            if len(agg_asset_values[split_label][quantile]) > 0:
                display(
                    pd.concat(agg_asset_values[split_label][quantile])['date'].agg(['min', 'max'])
                )
            else:
                print("No dates (empty df)")
            print("\nrun_asset_values_per_split dates:")
            display(run_asset_values_per_split['date'].agg(['min', 'max']))


            ######### EXPAND COLLECTIONS TO STORE RESULTS

            agg_asset_values[split_label][quantile].append(
                run_asset_values_per_split
            )

            assert pd.concat(agg_asset_values[split_label][quantile])['date'].is_monotonic_increasing

            run_turbulence_series_per_split = result['turbulence_series']

            agg_turbulence_series[split_label][quantile].append(
                run_turbulence_series_per_split
            )

            agg_turbulence_thresh[split_label][quantile].append(
                turbulence_thresh
            )

            ############# LOG AGGREGATED METRICS

            agg_metrics_per_split = compute_metrics(pd.concat(agg_asset_values[split_label][quantile]))

            log_eval_results(
                model_name,
                agg_metrics_per_split,
                split_label=split_label,
                turbulence_log_name=None,
                turbulence_thresh=None,
                run=current_run_api,
                metric_prefix='agg',
                quantile_log_name=quantile_log_name,
                quantile_thresh=quantile,
            )

            agg_metrics_per_split.update({
                f"quantile": quantile,
            })
            agg_metrics_per_quantiles.append(agg_metrics_per_split)

            # compute DJIA metrics for split for dates from first run till current run
            full_df_dija, agg_djia_metrics = eval_agg_dija(sweep_api, current_run_api, split_label, silent=SILENT)
            current_df_djia = full_df_dija[
                full_df_dija['date'] >= current_run_api.config['date_range']['test_start_date']
            ]

            # add DJIA metrics to current agg result df
            agg_asset_values[split_label][quantile][-1] = pd.merge(
                agg_asset_values[split_label][quantile][-1],
                current_df_djia,
                on='date'
            )

            if not LOG_PLOTS or LOG_PLOTS == 'last_only' and is_last_run:
                fig = plot_results(
                    account_value=agg_asset_values[split_label][quantile],
                    turbulence_series=agg_turbulence_series[split_label][quantile],
                    turbulence_thresh=agg_turbulence_thresh[split_label][quantile],
                    turbulence_quantile=quantile,
                    figsize='small',
                    split_label='test',
                    metrics=agg_metrics_per_split,
                    index_metrics=agg_djia_metrics
                )

                # Create plot name
                fig_collection_name = f"agg_cum_return-{split_label}-{quantile_log_name}"
                fig_name = f"{fig_collection_name}_{quantile}"
                fig_list[split_label].append(fig)
                fig_names[split_label].append(fig_name)

    current_run_sdk = wandb.init(
        project=PROJECT, id=current_run_id, resume='must',
        settings=wandb.Settings(silent="true" if SILENT else "false")
    )

    agg_metrics_per_quantiles = pd.DataFrame(agg_metrics_per_quantiles)
    table = wandb.Table(dataframe=agg_metrics_per_quantiles)
    wandb.log({f'agg_metrics_per_quantiles-{split_label}-{quantile_log_name}': table})

    if not LOG_PLOTS or LOG_PLOTS == 'last_only' and is_last_run:
        for split_label in ['val', 'test']:
            fig_collection_name = f"agg_cum_return-{split_label}-{quantile_log_name}"
            batch_log_plots_as_artifact(
                fig_list[split_label],
                fig_names[split_label],
                artifact_name_prefix=fig_collection_name,
                run=current_run_sdk,
            )

In [85]:
#@title Train RLlib model

from functools import partial
from pathlib import Path
import torch
from ray.tune.registry import register_env
from time import perf_counter
from math import ceil
import ray

def train_rllib_model(
    run_config,
    algo_name,
    algo_cls_config,
    split_name,

    train_np_env_config=None,
    val_np_env_config=None,

    train_stoploss_data_df = None,
    val_stoploss_data_df = None,

    pretrained_ckpt_path=None,
    skip_training=False
):
    assert split_name in ['train', 'val']

    num_envs_per_env_runner = run_config['env_runners_params']['num_envs_per_env_runner']
    num_env_runners = run_config['env_runners_params']['num_env_runners']

    if run_config['env_class'] == 'np':
        train_env_config = {
            "np_env_config": train_np_env_config,
            "initial_amount": run_config['initial_amount'],
            "cost_pct": run_config['cost_pct'],
            "mode": 'train'
        }

        if val_np_env_config is None:
            val_env_config = None
        else:
            val_env_config = {
                "np_env_config": val_np_env_config,
                "initial_amount": run_config['initial_amount'],
                "cost_pct": run_config['cost_pct'],
                "mode": 'val'
            }
    elif run_config['env_class'] == 'stoploss':
        env_config_kwargs = dict(
            cost_pct = run_config['cost_pct'],
            initial_amount = run_config['initial_amount'],
            discrete_actions = run_config['discrete_actions'],
            cache_indicator_data =  run_config['cache_indicator_data'],
            patient = run_config['patient'],
            print_verbosity = run_config['print_verbosity'],
        )
        train_env_config = {
            **env_config_kwargs,
            "df": train_stoploss_data_df,
        }
        if val_stoploss_data_df is None:
            val_env_config = None
        else:
            val_env_config = {
                **env_config_kwargs,
                "df": val_stoploss_data_df,
            }
    else:
        raise NotImplementedError

    base_config = PPOConfig()

    # if pretrained_ckpt_path is not None:
    if pretrained_ckpt_path is not None:
        def on_algorithm_init(algorithm, **kwargs):
            # module_p0 = algorithm.get_module()
            # weight_before = convert_to_numpy(next(iter(module_p0.parameters())))

            algorithm.restore_from_path(pretrained_ckpt_path)
            algorithm.metrics.reset()
            print(f"Using a pretrained algo with {algorithm.iteration} iterations")

            # # Make sure weights were restored (changed).
            # weight_after = convert_to_numpy(next(iter(module_p0.parameters())))
            # check(weight_before, weight_after, false=True)

        config = (
            base_config
            .callbacks(on_algorithm_init=on_algorithm_init)
        )
    else:
        config = base_config

    training_params = run_config['training_params'][algo_name]
    config = (
        config
        .environment(
            env="stock_trading_env",
            env_config=train_env_config,
        )
        .env_runners(
            batch_mode="complete_episodes",
            num_envs_per_env_runner=num_envs_per_env_runner,
            num_env_runners=num_env_runners,
            num_cpus_per_env_runner= (2/num_env_runners) if num_env_runners > 2 else None,

            # gym_env_vectorize_mode=gym.envs.registration.VectorizeMode.ASYNC,
        )
        .training(
            train_batch_size = training_params['train_batch_size'],
            num_epochs = training_params['num_epochs'],
            minibatch_size = training_params['minibatch_size'],
        )
        .callbacks(partial(
                MetricsLoggerCallback,
                model_name=algo_name,
                env_class=run_config['env_class'],
                log_to_wandb=True
            )
        )
        .resources(
            num_gpus=1 if torch.cuda.is_available() else None
        )
    )

    if val_env_config is not None:
        config = (
            config
            .evaluation(
                # Set up the validation environment
                evaluation_interval=1,  # Specify evaluation frequency (1=after each training step)
                evaluation_config={
                    "explore": False,
                    "env": "stock_trading_env",
                    "env_config": val_env_config,
                },
            )
        )

    algo = config.build_algo()
    wandb.run.summary[f'{split_name}.num_pretrain_iters'] = algo.training_iteration

    if config.num_env_runners > 0:
        ray.shutdown()
        ray.init()
        register_env("stock_trading_env", create_np_stock_trading_env)

    # 3. .. train it ..
    @benchmark_exec_time
    def _train_rllib(total_timesteps):
        print('Started training.')
        results = []
        total_batches = ceil(total_timesteps / algo.config.train_batch_size)
        print(f"total_batches: {total_batches}")
        print(f"total_timesteps: {total_timesteps}")
        for _ in range(total_batches):
            result = algo.train()
            results.append(result)

        print_result(result)
        print('Training complete.')
        return results

    results, exec_time_sec = _train_rllib(run_config['training_params'][algo_name]['steps'])
    print(f"TRAINING DURATION: {exec_time_sec}")

    # Log train duration
    duration_minutes = round(exec_time_sec / 60, 1)
    wandb.run.summary[f"train.duration_minutes/{algo_name}"] = duration_minutes
    wandb.run.summary[f"train.duration_minutes/{algo_name}"] = duration_minutes

    # Save model
    ckpt_path = (Path(TRAINED_MODEL_DIR) / algo_name).resolve()
    algo.save(ckpt_path)
    update_model_artifacts(log_results_folder=False)

    return algo, ckpt_path

In [55]:
#@title RLLib_prediction
from ray.rllib.core.columns import Columns

def RLLib_prediction(
    model_name,
    rl_module,
    eval_env,
    seed=None,
):
    state, info = eval_env.reset(seed=seed)
    done = False

    # i = 0
    # while not done and i < 10:
    #     i += 1

    while not done:

        # Compute action using the RLlib trained agent
        input_dict = {Columns.OBS: torch.Tensor(state).unsqueeze(0)}

        # print("input_dict:")
        # for k, v in input_dict.items():
        #     print(k, v.shape)

        rl_module_out = rl_module.forward_inference(input_dict)
        logits = rl_module_out[Columns.ACTION_DIST_INPUTS]

        # Take mean of multivariate Gaussian distribution
        mean, log_std = logits.chunk(2, dim=-1)

        # action_distribution = TorchDiagGaussian.from_logits(logits)
        # action_distribution = action_distribution.to_deterministic()
        # assert np.allclose(mean, action_distribution.loc)
        # assert np.allclose(log_std.exp(), action_distribution._dist.scale)
        # action = action_distribution.sample()

        action = mean.detach().numpy().squeeze()

        # Clip the action to ensure it's within the action space bounds
        action = np.clip(action, eval_env.action_space.low, eval_env.action_space.high)

        # Perform action
        state, reward, terminated, truncated, _ = eval_env.step(action)

        done = terminated or truncated
        # break

    df_account_values = eval_env.save_asset_memory()
    # display(df_account_values)


    # HACK: hardcoded timezone
    # TODO: transfer timezone handling to FinRL
    # df_account_value['date'] = pd.to_datetime(df_account_value['date'], utc=True).dt.tz_convert(NY)
    return df_account_values

# env_config = {
#     "np_env_config": val_np_env_config,

#     "initial_amount": run_config['initial_amount'],
#     "initial_stocks": None,
#     "cost_pct": run_config['cost_pct'],

#     "mode": 'val',
#     'turbulence_threshold': 99
# }

# eval_env = create_stock_trading_env(env_config)
# rl_module = algo.env_runner.module
# df_account_value = RLLib_prediction(
#     model_name,
#     rl_module,
#     eval_env,
# )

# df_account_value.head()

In [56]:
#@title get_stoploss_env_metrics

def get_stoploss_env_metrics(env, prefix=None):
    state = env.state_memory[-1]

    # Get the initial state (from reset) where:
    # index 0 is cash and indices 1 to 1+N are stock amounts (for N assets)
    cash = state[0]
    stocks = state[1:1+len(env.assets)]

    # Obtain the closing prices for the initial date.
    # (Assuming get_date_vector() returns a vector of indicators when passed cols=["close"])
    initial_closings = np.array(env.get_date_vector(env.date_index, cols=["close"]))

    # Compute asset value: dot product of stock amounts and their closing prices.
    asset_value = np.dot(stocks, initial_closings)

    # Total asset is cash plus the asset value.
    total_asset = cash + asset_value

    # Assemble the metrics dictionary.
    env_metrics = {
        'cash': cash,
        'stocks': stocks,
        'asset_value': asset_value,
        'total_asset': total_asset
    }

    if prefix is not None:
        env_metrics = { f'{prefix}_{name}': value for name, value in env_metrics.items()}

    return env_metrics

In [82]:
#@title evaluate_model
from ray.rllib.algorithms.algorithm import Algorithm
from ray.rllib.utils.numpy import convert_to_numpy, softmax
from ray.rllib.models.torch.torch_distributions import TorchDiagGaussian
from ray.rllib.core.rl_module.rl_module import RLModule
import torch

GLOBAL_SEED = 2025

# Create the testing environment
def evaluate_model(
    algo_or_rl_module,
    model_name,
    split_label,

    run,
    run_config, # could be different from run.config

    np_env_config=None,
    stoploss_data_df=None,

    turbulence_thresh=None,
    log_to_wandb=False,
    return_metrics=False,
    seed=GLOBAL_SEED,

    prev_end_amount=None,
    prev_end_stocks=None,

    # insample_threshold_quantile = None,
    # is_expanding_insample = False
):
    """
        Use np_env_config for numpy env.
        Use stoploss_data_df for stoploss env.
    """

    assert (prev_end_amount is None) == (prev_end_stocks is None), (
        "Either both prev_end_amount and prev_end_stocks must be None, "
        "or both must be non-None."
    )

    if turbulence_thresh is None:
        turbulence_thresh = run_config.get('turbulence_thresh', 99)

    turbulence_name = 'turbulence' if not run_config.get('if_vix', None) else 'vix'
    print(f"\nEvaluating for `{split_label}` using `{turbulence_name}`: {turbulence_thresh}")

    if prev_end_stocks is not None:
        initial_amount = prev_end_amount
    else:
        initial_amount = run_config['initial_amount']
        initial_stocks = None
        print("Init with NEW env state:")

    print(f"\t initial_amount: {initial_amount}")
    print(f"\t initial_stocks: {initial_stocks}")

    if np_env_config is not None:
        assert stoploss_data_df is None, "Use either `np_env_config` for numpy env, or `stoploss_data_df` for stop loss env"
        env_config = {
            "np_env_config": np_env_config,

            "initial_amount": initial_amount,
            "initial_stocks": initial_stocks,
            "cost_pct": run_config['cost_pct'],

            "mode": split_label,
            'turbulence_threshold': turbulence_thresh
        }
        eval_env = create_np_stock_trading_env(env_config)
        _ = eval_env.reset(seed=seed)
        env_metrics = {
            'end_amount': eval_env.amount,
            'end_stocks': eval_env.stocks,
            'end_total_asset': eval_env.total_asset
        }


    elif stoploss_data_df is not None:
        assert np_env_config is None, "Use either `np_env_config` for numpy env, or `stoploss_data_df` for stop loss env"
        env_config = dict(
            cost_pct = run_config['cost_pct'],
            initial_amount = run_config['initial_amount'],
            discrete_actions = run_config['discrete_actions'],
            cache_indicator_data =  run_config['cache_indicator_data'],
            patient = run_config['patient'],
            print_verbosity = run_config['print_verbosity'],
            df = stoploss_data_df,
        )
        print("Initializing StopLossEnv")

        eval_env = create_stoploss_stock_trading_env(env_config)
        _ = eval_env.reset(seed=seed)
        env_metrics = get_stoploss_env_metrics(eval_env, prefix='start')

    assert eval_env is not None

    # Extract rl_module if using Algorithm
    if isinstance(algo_or_rl_module, Algorithm):
        rl_module = algo_or_rl_module.env_runner.module
    elif isinstance(algo_or_rl_module, RLModule):
        rl_module = algo_or_rl_module
    else:
        raise NotImplementedError

    df_account_value = RLLib_prediction(
        model_name,
        rl_module,
        eval_env,
        seed=seed,
    )

    if np_env_config is not None:
        df_account_value = df_account_value.rename(
            columns={'account_value': model_name.upper()}
        )
        end_env_metrics = {
            'start_amount': eval_env.amount,
            'start_stocks': eval_env.stocks,
            'start_total_asset': eval_env.total_asset
        }
    elif stoploss_data_df is not None:
        df_account_value = df_account_value[['date', 'total_assets']].rename(
            columns={'total_assets': model_name.upper()}
        )
        end_env_metrics = get_stoploss_env_metrics(eval_env, prefix='end')

    metrics = compute_metrics(df_account_value)

    env_metrics.update(end_env_metrics)

    # print("\nEnv metrics after evaluation:")
    # for k, v in env_metrics.items():
    #     print(f"{k}: {v}")

    metrics.update(env_metrics)

    if np_env_config is not None:
        turbulence_series = pd.Series(
            np_env_config['turbulence_array'][:len(df_account_value)],
            index=df_account_value['date'],
            name=turbulence_name
        )
    elif stoploss_data_df is not None:
        turbulence_series =\
            stoploss_data_df.set_index('timestamp') \
            .loc[stoploss_data_df['timestamp'].unique()]['vix']
        turbulence_series.index.name = 'date'
    else:
        raise NotImplementedError

    eval_result = {
        'account_value': df_account_value,
        'turbulence_series': turbulence_series,
        'turbulence_thresh': turbulence_thresh
    }

    # optionally log to wandb
    if log_to_wandb:
        # api = wandb.Api()
        # run_api = api.run(f"{PROJECT}/{wandb.run.id}")
        turbulence_log_name = 'ti' if not run_config.get('if_vix', None) else 'vix'
        log_eval_results(
            model_name,
            metrics,
            split_label,
            turbulence_log_name,
            turbulence_thresh,
            run=run,
        )

    # print("Shutting down ray... ", end='')
    # ray.shutdown()
    # print("Done.")

    if return_metrics:
        return eval_result, metrics
    else:
        return eval_result

In [58]:
#@title plot_results (enchanced | list input | list of arrays for thresholds )
import matplotlib.pyplot as plt
import matplotlib.ticker as mtick
import pandas as pd
import numpy as np
import matplotlib.dates as mdates

def plot_results(
        account_value,
        turbulence_series,
        turbulence_thresh,
        turbulence_quantile=None,
        figsize='small',
        split_label=None,
        metrics=None,
        index_metrics=None,
        index_name='DJIA',
        metrics_highlight_model_name='PPO',
        allowed_metrics=['sharpe_ratio', 'mdd', 'cum_return', 'ann_return'],
        ylim_padding=6_000,
        ylim_bottom=None,
        ylim_top=None
    ):
    """
    This function now supports two input modes:
      1. Single DataFrame/Series/scalar for the three main inputs
      2. Lists of DataFrames/Series/scalars for multiple periods
    """

    # --- 1) Check inputs, possibly concatenate if they are lists ---
    is_list_input = isinstance(account_value, list)
    if is_list_input:
        # If one is a list, all must be lists
        assert isinstance(turbulence_series, list), "If account_value is a list, turbulence_series must be a list."
        assert isinstance(turbulence_thresh, list), "If account_value is a list, turbulence_thresh must be a list."
        assert len(account_value) == len(turbulence_series) == len(turbulence_thresh), \
            "All three lists must have the same length."
        assert turbulence_series[0].name in ['turbulence', 'vix']
        assert metrics_highlight_model_name in account_value[0].columns
        assert turbulence_quantile is not None
        if index_metrics is not None:
            assert index_name in account_value[0].columns

        # Concatenate across periods
        df_concat = []
        ts_concat = []
        thresh_concat = []
        for df_period, ts_period, thr_period in zip(account_value, turbulence_series, turbulence_thresh):
            # Ensure date is index for each period's DataFrame
            if 'date' in df_period.columns:
                df_period = df_period.set_index('date')
            df_period.rename(columns={col: col.upper() for col in df_period.columns}, inplace=True)
            df_concat.append(df_period)

            # Make sure index matches or is alignable (assuming same index)
            ts_concat.append(ts_period)

            # Repeat threshold value for length of this period
            # and store it as a Series to match the index
            if isinstance(thr_period, (int, float)):
                # Scalar threshold: repeat for all timestamps
                thr_series = pd.Series([thr_period]*len(df_period), index=df_period.index)
            elif isinstance(thr_period, (list, np.ndarray)) and len(thr_period) == len(df_period):
                # Array threshold: create a Series from the array
                thr_series = pd.Series(thr_period, index=df_period.index)
            else:
                raise ValueError(
                    "Each turbulence_thresh entry must be either a scalar or "
                    "an array matching the period length."
                )

            thresh_concat.append(thr_series)

        # # Final single DataFrame, Series, Series for plotting
        # account_value = pd.concat(df_concat)
        # turbulence_series = pd.concat(ts_concat)
        # turbulence_thresh = pd.concat(thresh_concat)

        # Final single DataFrame, Series, Series for plotting
        account_value = pd.concat(df_concat).sort_index()
        turbulence_series = pd.concat(ts_concat).sort_index()
        turbulence_thresh = pd.concat(thresh_concat).sort_index()


    else:
        # Single input variant
        assert not isinstance(turbulence_series, list), "turbulence_series must not be a list if account_value is not a list."
        assert not isinstance(turbulence_thresh, list), "turbulence_thresh must not be a list if account_value is not a list."
        assert turbulence_series.name in ['turbulence', 'vix']
        assert metrics_highlight_model_name in account_value.columns
        if index_metrics is not None:
            assert index_name in account_value.columns

        # Ensure date is index
        if 'date' in account_value.columns:
            account_value.set_index('date', inplace=True)
        account_value.rename(columns={col: col.upper() for col in account_value.columns}, inplace=True)

    # --- 2) Original asserts (unchanged) ---
    assert split_label in ['val', 'test']
    assert figsize in ['small', 'medium']

    # --- 3) Prepare figure ---
    figsizes = {
        'medium': (14, 10),
        'small': (8.3, 8)
    }
    fig, (ax1, ax2) = plt.subplots(
        2, 1, figsize=figsizes[figsize], sharex=True, gridspec_kw={'height_ratios': [3, 1]}
    )

    # --- 4) Method styles ---
    method_styles = {
        'A2C': {'color': '#8c564b', 'linestyle': '--'},
        'DDPG': {'color': '#e377c2', 'linestyle': '-'},
        'PPO': {'color': 'blue', 'linestyle': '-'},
        'TD3': {'color': '#bcbd22', 'linestyle': '--'},
        'SAC': {'color': '#17becf', 'linestyle': '-'},
        'DJIA': {'color': '#000000', 'linestyle': '-'},
    }

    # --- 5) Plot account values ---
    for model_name in account_value.columns:
        style = method_styles.get(model_name, {'color': 'blue', 'linestyle': '-'})  # fallback style
        ax1.plot(account_value.index, account_value[model_name], label=model_name, **style)

    # --- 6) Title and subtitle logic ---
    turbulence_label = "Turbulence Index" if turbulence_series.name == 'turbulence' else "VIX Coefficient"
    split_label_name = ('validation' if split_label == 'val' else split_label).capitalize()

    # If threshold is still a single number, round it; if it's a Series, round all values
    if isinstance(turbulence_thresh, (int, float)):
        # single numeric
        rounded_thresh_text = str(round(turbulence_thresh))
        title = f"{split_label_name} split | {turbulence_label} threshold: {rounded_thresh_text}"
    else:
        # it's a Series, just mention "multiple" or no direct numeric
        turbulence_thresh = turbulence_thresh.round()
        title = f"{split_label_name} split | {turbulence_label} insample quantile: {turbulence_quantile}"

    fig.suptitle(title, fontsize=20, fontweight='bold')

    full_names = {
        'mdd': 'MDD',
        'ann_return': 'Annualized Return',
        'cum_return': 'Cumulative Return',
        'sharpe_ratio': 'Sharpe Ratio'
    }
    # Build subtitle lines
    subtitle_lines = []

    if metrics:
        metric_text = f"{metrics_highlight_model_name} | "
        metric_text += ", ".join(
            f"{full_names.get(name, name.replace('_', ' ').capitalize())}: {value:.2f}"
            for name, value in metrics.items()
            if name in allowed_metrics
        )
        subtitle_lines.append(metric_text)

    # NEW: optionally handle index_metrics in a similar way
    if index_metrics:
        index_metric_text = f"{index_name} | "
        index_metric_text += ", ".join(
            f"{full_names.get(name, name.replace('_', ' ').capitalize())}: {value:.2f}"
            for name, value in index_metrics.items()
            if name in allowed_metrics
        )
        subtitle_lines.append(index_metric_text)

    # If we have any lines to show, join them with newline(s) and set as subtitle
    if subtitle_lines:
        ax1.set_title("\n".join(subtitle_lines), fontsize=12, color='gray', ha='center')

    # --- 7) Y-axis (main) ---
    ax1.yaxis.set_major_formatter(mtick.FuncFormatter(lambda x, _: f"{x:,.0f}"))

    # Horizontal line at initial asset value
    initial_asset_value = account_value.iloc[0].mean()
    min_account_value = account_value.min().min()
    max_account_value = account_value.max().max()
    ax1.axhline(y=initial_asset_value, color='gray', linestyle='-.', linewidth=1.5, label="Initial Asset Value")

    ax1.set_ylabel("Total Asset Value ($)", fontsize=16, fontweight='bold')

    if ylim_bottom is None:
        ylim_bottom = min_account_value - ylim_padding
    if ylim_top is None:
        ylim_top = max_account_value + ylim_padding
    ax1.set_ylim(bottom=ylim_bottom, top=ylim_top)

    ax1.legend(loc='lower right')
    ax1.grid(True, linestyle='--', alpha=0.3)

    # --- 8) Second y-axis ticks ---
    ax1_right = ax1.twinx()
    ax1_right.set_ylim(ax1.get_ylim())
    left_ticks = ax1.get_yticks().tolist()
    ymin, ymax = ax1.get_ylim()

    start_value = account_value.iloc[0].mean()
    end_value = account_value.iloc[-1].mean()

    extra_ticks = []
    if ymin <= start_value <= ymax:
        extra_ticks.append(start_value)
    if ymin <= end_value <= ymax:
        extra_ticks.append(end_value)

    all_ticks = sorted(set(left_ticks + extra_ticks))
    labels = []
    for tick in all_ticks:
        if np.isclose(tick, start_value):
            labels.append(f"{tick:,.0f}")
        elif np.isclose(tick, end_value):
            labels.append(f"{tick:,.0f}")
        else:
            labels.append("")
    ax1_right.set_yticks(all_ticks)
    ax1_right.set_yticklabels(labels)
    ax1_right.set_ylabel('')
    ax1_right.grid(False)

    # --- 9) Turbulence plot ---
    ax2.plot(turbulence_series.index, turbulence_series, label=turbulence_label,
             color='red', linestyle='--', linewidth=2)

    # Plot threshold line or series
    if isinstance(turbulence_thresh, (int, float)):
        # single horizontal line
        ax2.axhline(y=turbulence_thresh, color='red', linestyle=':', label=f'Threshold = {round(turbulence_thresh)}')
        max_turbulence = max(turbulence_series.max(), turbulence_thresh)
    else:
        # threshold is a Series
        ax2.plot(turbulence_thresh.index, turbulence_thresh, color='red', linestyle=':',
                 label='Threshold')
        max_turbulence = max(turbulence_series.max(), turbulence_thresh.max())

    ax2.set_ylabel(turbulence_label, fontsize=16, fontweight='bold')
    ax2.legend(loc='upper left')
    ax2.grid(True, linestyle='--', alpha=0.3)
    ax2.set_ylim(0, max_turbulence + 10)
    ax2.set_xlabel("Date", fontsize=16, fontweight='bold')

    for ax in [ax1, ax2]:
        # Major ticks (quarter boundaries)
        ax.xaxis.set_major_locator(mdates.MonthLocator(bymonth=[1,4,7,10]))
        ax.xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m'))

        # Minor ticks (monthly)
        ax.xaxis.set_minor_locator(mdates.MonthLocator(interval=1))

        # Draw major grid lines (more pronounced)
        ax.grid(which='major', axis='x', linestyle='-', linewidth=1.2, color='black', alpha=0.5)

        # Draw minor grid lines (finer)
        ax.grid(which='minor', axis='x', linestyle='--', linewidth=0.5, color='gray', alpha=0.3)

    # # ------------------
    # # On ax2, enable labeling of minor ticks
    # ax2.xaxis.set_minor_formatter(mdates.DateFormatter('%Y-%m'))  # or '%m/%Y' etc.
    # ax2.tick_params(axis='x', which='minor', labelsize=8, labelrotation=45)

    ax2.tick_params(axis='x', which='major', labelsize=8, labelrotation=45)

    return fig

In [59]:
#@title batch_log_plots_as_artifact

import gc
import os
import shutil
import wandb
import matplotlib.pyplot as plt

def batch_log_plots_as_artifact(
        fig_list, fig_names, artifact_name_prefix, artifact_type="plot",
        run=None
    ):

    """
    Save a list of Matplotlib figures to a folder, log the folder as a W&B artifact,
    and delete the folder after logging.

    Parameters:
        fig_list (list): List of Matplotlib figure objects.
        folder_name (str): Name of the folder to store plots.
        artifact_name_prefix (str): Prefix for the artifact name.
        artifact_type (str): The type of the artifact (default is "plot").
    """
    if run is None:
        assert wandb.run is not None, "If run is not provided, it should be active in the background"

    assert isinstance(run, wandb.sdk.wandb_run.Run)

    # Ensure the folder exists
    os.makedirs(artifact_name_prefix, exist_ok=True)

    try:
        # Save all figures in the folder
        for i, (fig, fig_name) in enumerate(zip(fig_list, fig_names)):
            filename = os.path.join(artifact_name_prefix, f"{fig_name}.png")
            fig.savefig(filename, bbox_inches='tight', pad_inches=0.1, dpi=300)
            plt.close(fig)  # Close the figure to free up memory
        fig_list.clear()  # remove references
        gc.collect()

        # Create and log the W&B artifact
        artifact_name = f"{artifact_name_prefix}-{run.id}"
        artifact = wandb.Artifact(artifact_name, type=artifact_type)
        artifact.add_dir(artifact_name_prefix, skip_cache=True)
        run.log_artifact(artifact)
    finally:
        # Ensure the folder is deleted after use
        if os.path.exists(artifact_name_prefix):
            shutil.rmtree(artifact_name_prefix)

        for fig in fig_list:
            ax = fig.gca()  # Get the current axis of the figure

            for txt in ax.texts:  # Remove all text objects
                txt.remove()

            for line in ax.lines:
                line.remove()  # Remove previous plot lines

            fig.clf()   # Clear the figure
            plt.close(fig)  # Close the figure to free memory


In [60]:
#@title log_plot_as_artifact

import os
import matplotlib.pyplot as plt
import wandb

def log_plot_as_artifact(fig, artifact_name_prefix, artifact_type="plot"):
    """
    Save a Matplotlib figure without clipping and log it as a W&B artifact.

    Parameters:
        fig (matplotlib.figure.Figure): The Matplotlib figure to save and log.
        artifact_name (str): The name of the W&B artifact.
        artifact_type (str): The type of the artifact (default is "plot").
        filename (str): The filename to save the plot as (default is "plot.png").
    """
    assert wandb.run.id

    # Get full artifact name
    artifact_name = f'{artifact_name_prefix}-{wandb.run.id}'
    filename = artifact_name + '.png'

    try:
        # Save the figure with tight layout and proper padding
        fig.savefig(filename, bbox_inches='tight', pad_inches=0.1, dpi=300)
        plt.close(fig)  # Close the figure to free up memory

        # Create and log the W&B artifact
        artifact = wandb.Artifact(artifact_name, type=artifact_type)
        artifact.add_file(filename, skip_cache=True, overwrite=True)
        wandb.log_artifact(artifact)
    finally:
        # Ensure the file is deleted after use
        if os.path.exists(filename):
            os.remove(filename)

In [61]:
#@title log_eval_results

def log_eval_results(
        model_name,
        metrics,
        split_label,
        turbulence_log_name=None,
        turbulence_thresh=None,
        run=None,
        turbulence_thresh_postfix=None,
        metric_prefix=None,
        quantile_log_name=None,
        quantile_thresh = None
    ):

    # if run is None:
    #     run = wandb.run
    #     assert run is not None, "If no run is provided, wandb should contain an active run"

    for metric_name, metric_value in metrics.items():
        formatted_name = get_formatted_metric_name(
            model_name,
            metric_name,
            split_label,
            turbulence_log_name,
            turbulence_thresh,
            metric_prefix=metric_prefix,
            quantile_log_name=quantile_log_name,
            quantile_thresh = quantile_thresh
        )

        run.summary[formatted_name] = np.round(metric_value, decimals=2).tolist()  # Use formatted_name instead


    if isinstance(run, wandb.apis.public.runs.Run):
        run.summary.update() # for API only


# log_eval_results(
#     model_name,
#     metrics={'sharpe_ratio': 1.5},
#     split_label='val',
#     # turbulence_log_name='vix',
#     # turbulence_thresh=99,
#     run=None,
#     metric_prefix='agg',
#     quantile_log_name='q',
#     quantile_thresh = 0.25
# )

In [62]:
#@title get_formatted_metric_name

def get_formatted_metric_name(
    model_name,
    metric_name,
    split_label,
    turbulence_log_name=None,
    turbulence_thresh=None,
    turbulence_thresh_postfix=None,
    metric_prefix=None,
    quantile_log_name=None,
    quantile_thresh = None
):
    assert turbulence_thresh_postfix in ['best', 'chosen', None]
    # assert quantile_log_name in ['q', 'qe', None]

    if turbulence_log_name is not None:
        assert turbulence_thresh is not None

    if quantile_log_name is not None:
        assert turbulence_thresh is None and turbulence_log_name is None
        assert quantile_thresh is not None and 0 <= quantile_thresh <= 1

    formatted_name = (
        f"{split_label}"

        # Quantile name + thresh
        f"{'.' + quantile_log_name if quantile_log_name else ''}"
        f"{'_' + str(round(quantile_thresh * 100)) if quantile_thresh else ''}"

        # Turbulence name + thresh
        f"{'.' + turbulence_log_name if turbulence_log_name else ''}"
        f"{'_' + str(turbulence_thresh) if turbulence_thresh else ''}"
        f"{'_' + turbulence_thresh_postfix if turbulence_thresh_postfix else ''}"

        f".{metric_prefix + '_' if metric_prefix else ''}{metric_name}/"
        f"{model_name}"
    )
    return formatted_name


# formatted_name = get_formatted_metric_name(
#     model_name,
#     metric_name='sharpe_ratio',
#     split_label='val',
#     # turbulence_log_name='vix',
#     # turbulence_thresh=99,
#     metric_prefix='agg',
#     quantile_log_name='q',
#     quantile_thresh = 0.25
# )

# formatted_name

In [63]:
#@title evaluate_threshold_grid

import numpy as np
import pandas as pd
import wandb

def evaluate_threshold_grid(
    algo_or_rl_module,
    model_name,
    run_config,
    np_env_config,
    num_grid_points=10,
    split_label='val',
    chosen_th=None,
    plot_padding=500
):

    if run_config.get('if_vix', True):
        turbulence_log_name = 'vix'
    else:
        turbulence_log_name = 'ti'
        raise NotImplementedError

    # Calculate threshold grid
    turb_ary = np_env_config['turbulence_array']
    threshold_grid = np.linspace(turb_ary.min(), turb_ary.max(), num_grid_points)
    threshold_grid = np.ceil(threshold_grid).astype(int).tolist()
    threshold_grid.append(max(threshold_grid) + 1)
    if chosen_th is not None:
        threshold_grid.append(chosen_th)

    th_metrics = []  # List to store metrics for all thresholds
    th_results = {}  # Dictionary to store results per threshold

    for th in threshold_grid:
        result, metrics = evaluate_model(
            turbulence_thresh=th,
            model_name=model_name,
            algo_or_rl_module=algo_or_rl_module,
            run_config=run_config,
            np_env_config=np_env_config,
            split_label=split_label,
            log_to_wandb=True,
            return_metrics=True
        )

        metrics = {turbulence_log_name: th, **metrics}
        th_metrics.append(metrics)
        th_results[th] = result

    # Convert metrics to a DataFrame and log to WandB
    df = pd.DataFrame(th_metrics).astype({turbulence_log_name: int})
    metric_cols = df.drop(columns=[turbulence_log_name]).columns # take only true metric columns

    # Identify the protected row where turbulence_log_name == chosen_th
    protected_row = df[df[turbulence_log_name] == chosen_th].tail(1)  # Keep only the last occurrence

    # Drop duplicate rows (except the protected one)
    df = df.drop_duplicates(metric_cols, keep="last")

    # Drop rows where all metric columns are 0, excluding the protected row
    df = df[~((df[metric_cols] == 0).all(axis=1) & ~df.index.isin(protected_row.index))]


    wandb_table = wandb.Table(dataframe=df)
    wandb.log({f"threshold_grid_metrics-{split_label}": wandb_table})

    # Log metrics for best threshold
    best_idx = df['sharpe_ratio'].idxmax()  # Use idxmax() instead of argmax()
    best_metrics = df.loc[best_idx].to_dict()
    best_metrics[turbulence_log_name] = int(best_metrics[turbulence_log_name])
    log_eval_results(model_name, best_metrics, split_label, turbulence_log_name, 'best')

    # Log metrics for chosen threshold (if any)
    if chosen_th is not None:
        chosen_metrics = df[df[turbulence_log_name] == chosen_th].iloc[0].to_dict()
        log_eval_results(model_name, chosen_metrics, split_label, turbulence_log_name, 'chosen')

    # Compute min_account_value and max_account_value
    min_account_value = float('+inf')
    max_account_value = float('-inf')
    for th_value, result in th_results.items():
        min_account_value = min(min_account_value, result['account_value'][model_name].min())
        max_account_value = max(max_account_value, result['account_value'][model_name].max())

    min_account_value -= plot_padding
    max_account_value += plot_padding

    # Plot returns
    figs = []
    fig_names = []
    fig_collection_name = f"cum_return-{split_label}-{turbulence_log_name}"
    best_th_value = best_metrics[turbulence_log_name]

    for th_value in df[turbulence_log_name].values:
        result = th_results[th_value]
        metrics = df.drop(columns=[turbulence_log_name])[
            df[turbulence_log_name] == th_value].iloc[0].to_dict()

        fig = plot_results(
            **result,
            figsize='small',
            split_label=split_label,
            metrics=metrics,
            ylim_bottom=min_account_value,
            ylim_top=max_account_value
        )
        fig_name = f"{fig_collection_name}_{th_value}"
        if th_value == best_th_value:
            fig_name += "_best"
        if chosen_th is not None and th_value == chosen_th:
            fig_name += "_chosen"

        figs.append(fig)
        fig_names.append(fig_name)

    batch_log_plots_as_artifact(
        figs,
        fig_names,
        artifact_name_prefix=fig_collection_name
    )

    best_th = best_metrics[turbulence_log_name]
    return best_th

# Sweep Runner

In [78]:
#@title load_model

import ray
from ray.rllib.algorithms.ppo import PPO

AVAILABLE_MODELS_CLASSES = {
    'ppo': PPO
}

def load_model(model_name, trained_model_dir = TRAINED_MODEL_DIR):
    model_class = AVAILABLE_MODELS_CLASSES[model_name]
    checkpoint_path = os.path.abspath(f"{trained_model_dir}/{model_name}")
    algo = model_class.from_checkpoint(checkpoint_path)
    return algo

In [79]:
#@title Sweep Runner
from copy import deepcopy
from datetime import datetime
import random
import string
import json


def dict_to_canonical_string(cfg_dict):
    return json.dumps(cfg_dict, sort_keys=True)

def set_run_name(prefix, n=5):
    run_name = f"{prefix} | {wandb.run.id}"
    wandb.run.name = run_name
    wandb.run.save()

class SweepRunner:
    def __init__(self, sweep_id):
        self.sweep_id = sweep_id
        self.pretrained_val_models = {}

    def main(self, run_config=None):
        run_timer_start = perf_counter()

        current_time = datetime.now()
        print("START TIME:", current_time)

        with wandb.init(config=run_config):
            run_config = wandb.config
            # print('Debug copy run_config...', end=' ')
            # _ = deepcopy(run_config)
            # print("Done.")

            if run_config['dataset_type'] == 'yearly_train_test':
                raise NotImplementedError
            elif run_config['dataset_type'] == 'quarterly_train_val_test':
                (
                    train_np_env_config,
                    val_np_env_config,
                    test_np_env_config
                ) = build_quarterly_train_val_test(run_config)
                set_run_name(run_config['dataset_name'])

                # Extract pretrained model from previous run
                # for a given unique hyperparameter combination (e.g. (seed, steps) pair).
                # Ignore `date_range` parameter since pretrained models should persist
                # across the whole date range
                config_seed = run_config['seed']
                config_training_params = run_config['training_params']
                # config_hash = hash(str(config_seed) + str(config_training_params))
                config_hash = hash(dict_to_canonical_string({
                    'seed': config_seed,
                    'training_params': config_training_params
                }))

                # run_config_copy = run_config._as_dict().copy()
                # run_config_copy.pop('date_range')
                # config_hash = hash(str(run_config_copy))


                # extract pretrained model from previous run for a given seed
                pretrained_val_models = self.pretrained_val_models.get(config_hash, {})
                wandb.run.summary['config_hash'] = config_hash

                if pretrained_val_models:
                    run_config.update({'finetune': True})
                else:
                    run_config.update({'finetune': False})

                pretrained_val_models = train_eval_rllib_models(
                    run_config,
                    train_np_env_config,
                    val_np_env_config,
                    test_np_env_config,
                    pretrained_val_models=pretrained_val_models
                )

                self.pretrained_val_models[config_hash] = pretrained_val_models

            ray.shutdown()

            run_timer_end = perf_counter()
            run_duration_minutes = round( (run_timer_end - run_timer_start) / 60, 1)
            wandb.run.summary[f"run.duration_minutes"] = run_duration_minutes
            print(f"RUN DURATION: {run_duration_minutes}")

        current_time = datetime.now()
        print("END TIME:", current_time)

In [80]:
#@title Sweep Runner (load model weights + continue = async agent)
from copy import deepcopy
from datetime import datetime
from hashlib import sha256
import random
import string

def set_run_name(prefix, n=5):
    run_name = f"{prefix} | {wandb.run.id}"
    wandb.run.name = run_name
    wandb.run.save()

TECH_INDICATOR_MAX_SLIDING_WINDOW = 60

class SweepRunner:
    def __init__(self, sweep_id):
        self.sweep_id = sweep_id
        self.pretrained_ckpt_paths = {}

        sweep_api = WANDB_API.sweep(f"{ENTITY}/{PROJECT}/{sweep_id}")
        self.sweep_api = sweep_api

        # if_vix = sweep_api.config['parameters']['if_vix']['value'] # assume all runs with same config hash use same volatility index
        # if if_vix:
        #     turbulence_label = 'vix'
        # else:
        #     turbulence_label = 'turbulence'
        # lookback_window = sweep_api.config['parameters']['lookback_window']['value']
        # turbulence_sma_col = f"{turbulence_label}_{lookback_window}_sma"

        self.data, self.data_processor = load_cached_data(
            start_date = sweep_api.config['parameters']['train_start_date']['value'],
            end_date = sweep_api.config['parameters']['max_test_end_date']['value'],
            tech_indicator_padding = TECH_INDICATOR_MAX_SLIDING_WINDOW,
            if_vix = sweep_api.config['parameters']['if_vix']['value'], # should be fixed for a particular hash during training
            technical_indicator_list = INDICATORS,
            # extra_indicator_list = [turbulence_sma_col]
        )

    def main(self, run_config=None):
        run_timer_start = perf_counter()

        current_time = datetime.now()
        print("START TIME:", current_time)

        with wandb.init(config=run_config) as run:
            run_config = wandb.config
            # print('Debug copy run_config...', end=' ')
            # _ = deepcopy(run_config)
            # print("Done.")

            if run_config['dataset_type'] == 'yearly_train_test':
                raise NotImplementedError
            elif run_config['dataset_type'] == 'quarterly_train_val_test':
                date_range = {
                    key: pd.Timestamp(date)
                    for key, date in run_config['date_range'].items()
                }

                dataset_name = get_quarterly_dataset_name(
                    run_config['stock_index_name'],
                    date_range['train_start_date'],
                    date_range['val_start_date'],
                    date_range['test_start_date'],
                )
                run_config.update({"dataset_name": dataset_name})

                # get data splits based on env class
                if run_config['env_class'] == 'np':
                    env_config_kwargs = build_quarterly_train_val_test(
                        self.data,
                        self.data_processor,
                        date_range,
                        run_config,
                    )
                else:
                    data_splits = split_data(self.data, date_range, date_col_name='timestamp')
                    env_config_kwargs = dict(
                        train_stoploss_data_df = data_splits['train'],
                        val_stoploss_data_df = data_splits['val'],
                        test_stoploss_data_df = data_splits['test'],
                    )

                set_run_name(run_config['dataset_name'])

                # Extract pretrained model from previous run
                # for a given unique hyperparameter combination (e.g. (seed, steps) pair).
                # Ignore `date_range` parameter since pretrained models should persist
                # across the whole date range

                config_hash = get_config_hash(run_config)

                # extract pretrained model from previous run for a given seed
                pretrained_ckpt_paths = self.pretrained_ckpt_paths.get(config_hash, {})
                wandb.run.summary['config_hash'] = config_hash

                if not pretrained_ckpt_paths:
                    print("Looking for models from previous run for this date range...")
                    prev_run = find_prev_run(self.sweep_id, wandb.run.id)
                    if prev_run is not None:
                        download_artifacts(prev_run.id, artifact_types=['trained_models'])

                        create_env_fn = CREATE_ENV_FN[run_config['env_class']]
                        register_env("stock_trading_env", create_env_fn)

                        pretrained_ckpt_paths = {
                            model_name: load_model(model_name)
                            for model_name in run_config['models_used']
                        }
                        print("Models downloaded.")
                    else:
                        print("No models found.")

                if pretrained_ckpt_paths:
                    run_config.update({'finetune': True})
                else:
                    run_config.update({'finetune': False})

                pretrained_ckpt_paths = train_eval_rllib_models(
                    run,

                    self.data,
                    self.data_processor,
                    run.id,
                    self.sweep_api,
                    pretrained_ckpt_paths=pretrained_ckpt_paths,

                    **env_config_kwargs,
                )

                self.pretrained_ckpt_paths[config_hash] = pretrained_ckpt_paths

            ray.shutdown()

            run_timer_end = perf_counter()
            run_duration_minutes = round( (run_timer_end - run_timer_start) / 60, 1)
            wandb.run.summary[f"run.duration_minutes"] = run_duration_minutes
            print(f"RUN DURATION: {run_duration_minutes}")

        current_time = datetime.now()
        print("END TIME:", current_time)

# Run sweep

In [86]:
#@title RUN SWEEP

def run_sweep(n_runs, sweep_config=None, sweep_id=None):
    wandb.finish()
    if sweep_id is None:
        assert sweep_config is not None
        sweep_id = wandb.sweep(sweep_config, project=PROJECT)
    else:
        assert sweep_config is None

    !rm -rf {TRAINED_MODEL_DIR}/*

    sweep_runner = SweepRunner(sweep_id)
    wandb.agent(sweep_id, sweep_runner.main, project=PROJECT, count=n_runs)

run_sweep(
    # sweep_id='yj19ygra',
    sweep_config=sweep_config,

    # n_runs=None,
    n_runs=1,
)

Create sweep with ID: od06puvz
Sweep URL: https://wandb.ai/overfit1010/finrl-dt-replicate/sweeps/od06puvz
Using cached data: cache/2014-10-07_2016-10-01_1d_c93df9314227db48fcf165579b5d9a62f68742c52cfedc23d2756e007b0fdd77.csv


[34m[1mwandb[0m: Agent Starting Run: ru9afj14 with config:
[34m[1mwandb[0m: 	cache_indicator_data: True
[34m[1mwandb[0m: 	cost_pct: 0.001
[34m[1mwandb[0m: 	dataset_type: quarterly_train_val_test
[34m[1mwandb[0m: 	date_range: {'test_end_date': '2016-04-01 00:00:00', 'test_start_date': '2016-01-01 00:00:00', 'train_start_date': '2015-01-01 00:00:00', 'val_start_date': '2015-10-01 00:00:00'}
[34m[1mwandb[0m: 	discrete_actions: True
[34m[1mwandb[0m: 	env_class: stoploss
[34m[1mwandb[0m: 	env_runners_params: {'num_env_runners': 0, 'num_envs_per_env_runner': 1}
[34m[1mwandb[0m: 	eval_turbulence_thresh: 25
[34m[1mwandb[0m: 	if_using_a2c: True
[34m[1mwandb[0m: 	if_using_ddpg: False
[34m[1mwandb[0m: 	if_using_ppo: True
[34m[1mwandb[0m: 	if_using_sac: False
[34m[1mwandb[0m: 	if_using_td3: False
[34m[1mwandb[0m: 	if_vix: True
[34m[1mwandb[0m: 	initial_amount: 50000
[34m[1mwandb[0m: 	max_test_end_date: 2016-10-01
[34m[1mwandb[0m: 	min_test_star

START TIME: 2025-02-28 19:09:50.566619


Looking for models from previous run for this date range...




No models found.


2025-02-28 19:09:56,220	INFO worker.py:1832 -- Started a local Ray instance. View the dashboard at [1m[32mhttp://127.0.0.1:8265 [39m[22m
  logger.warn(f"Overriding environment {new_spec.id} already in registry.")


Creating Stop Loss env from 2015-01-02 00:00:00 to 2015-09-30 00:00:00
caching data
data cached!


  logger.warn(f"Overriding environment {new_spec.id} already in registry.")


Creating Stop Loss env from 2015-10-01 00:00:00 to 2015-12-31 00:00:00
caching data
data cached!


2025-02-28 19:10:12,950	INFO trainable.py:160 -- Trainable.setup took 19.982 seconds. If your trainable is slow to initialize, consider setting reuse_actors=True to reduce actor creation overheads.
  logger.warn(
  logger.warn(f"{pre} is not within the observation space.")
  logger.warn(
  gym.logger.warn("Casting input x to numpy array.")


Started training.
total_batches: 1
total_timesteps: 1


  logger.warn(
  logger.warn(f"{pre} is not within the observation space.")
  logger.warn(
  gym.logger.warn("Casting input x to numpy array.")
[34m[1mwandb[0m: Adding directory to artifact (./trained_models)... 


----------------------------------------
train/ann_return: -21.75
train/ann_return_EMA_0.2: -19.84
train/ann_return_MA_20: -15.14
train/mdd: -6.28
train/mdd_EMA_0.2: -6.08
train/mdd_MA_20: -5.59
train/sharpe_ratio: -1.46
train/sharpe_ratio_EMA_0.2: -1.34
train/sharpe_ratio_MA_20: -1.03
****************************************
val/ann_return: -3.01
val/ann_return_EMA_0.2: 3.04
val/ann_return_MA_20: 3.89
val/mdd: -1.52
val/mdd_EMA_0.2: -1.79
val/mdd_MA_20: -1.78
val/sharpe_ratio: -1.28
val/sharpe_ratio_EMA_0.2: 1.07
val/sharpe_ratio_MA_20: 1.32
----------------------------------------

Training complete.
TRAINING DURATION: 30.267517697000585


Done. 0.1s


Artifact 'trained_models-ru9afj14' has been updated and uploaded.

Evaluating for `val` using `vix`: 99
Init with NEW env state:
	 initial_amount: 50000
	 initial_stocks: None
Initializing StopLossEnv
Creating Stop Loss env from 2015-10-01 00:00:00 to 2015-12-31 00:00:00
caching data
data cached!


  logger.warn(f"Overriding environment {new_spec.id} already in registry.")


Creating Stop Loss env from 2015-10-01 00:00:00 to 2015-12-31 00:00:00
caching data
data cached!
Using a pretrained algo with 1 iterations
Started training.
total_batches: 1
total_timesteps: 1


  logger.warn(
  logger.warn(f"{pre} is not within the observation space.")
  logger.warn(
  gym.logger.warn("Casting input x to numpy array.")
[34m[1mwandb[0m: Adding directory to artifact (./trained_models)... Done. 0.0s



----------------------------------------
train/ann_return: 7.53
train/ann_return_EMA_0.2: 8.24
train/ann_return_MA_20: 9.48
train/mdd: -2.34
train/mdd_EMA_0.2: -2.32
train/mdd_MA_20: -2.28
train/sharpe_ratio: 1.19
train/sharpe_ratio_EMA_0.2: 1.41
train/sharpe_ratio_MA_20: 1.81
Training complete.
TRAINING DURATION: 6.604626036001719
Artifact 'trained_models-ru9afj14' has been updated and uploaded.

Evaluating for `test` using `vix`: 99
Init with NEW env state:
	 initial_amount: 50000
	 initial_stocks: None
Initializing StopLossEnv
Creating Stop Loss env from 2016-01-04 00:00:00 to 2016-03-31 00:00:00
caching data
data cached!
RUN DURATION: 1.2


0,1
train.ann_return/ppo,▁▅▇██
train.cum_return/ppo,▁▆█▇█
train.mdd/ppo,▁▄▇█▇
train.sharpe_ratio/ppo,▁▃▅█▆
val.ann_return/ppo,▁▂██▂▆▄█▂▄
val.cum_return/ppo,▁▂█▇▁▆▂▇▁▃
val.mdd/ppo,▆█▁▁▅▁▃▁▄▃
val.sharpe_ratio/ppo,▁▂▇█▆▅▅█▂▅

0,1
config_hash,cab99dd9dde5ca0d6ee6...
run.duration_minutes,1.2
test.vix_99.ann_return/ppo,67.51
test.vix_99.cum_return/ppo,6.27
test.vix_99.end_asset_value/ppo,52510.37
test.vix_99.end_cash/ppo,562.53
test.vix_99.end_total_asset/ppo,53072.89
test.vix_99.mdd/ppo,-0.99
test.vix_99.sharpe_ratio/ppo,5.17
test.vix_99.start_asset_value/ppo,0


END TIME: 2025-02-28 19:11:04.784102
