<a href="https://colab.research.google.com/github/tony-pitchblack/finrl-dt/blob/custom-backtesting/finrl_dt_replicate_sweep.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#Installs

In [None]:
!pip install -q yfinance==0.2.50

In [None]:
%%capture
!pip install stable-baselines3
!pip install finrl
!pip install alpaca_trade_api
!pip install exchange_calendars
!pip install stockstats
!pip install wrds

In [None]:
import numpy as np

if np.__version__ != '1.26.4':
    !pip install -q numpy==1.26.4 --force-reinstall

'1.26.4'

In [None]:
%%capture
import pandas as pd

if pd.__version__ != '2.2.2':
    !pip install -q pandas==2.2.2 --force-reinstall

In [None]:
# %load_ext autoreload
# %autoreload 2

# Imports

In [1]:
import pandas as pd

from finrl.meta.env_stock_trading.env_stocktrading import StockTradingEnv
from finrl.agents.stablebaselines3.models import DRLAgent
from stable_baselines3.common.logger import configure
from finrl import config_tickers
from finrl.main import check_and_make_directories
from finrl.config import INDICATORS, TRAINED_MODEL_DIR, RESULTS_DIR

import warnings
warnings.filterwarnings("ignore", category=DeprecationWarning)

In [2]:
import os
from pathlib import Path
import pandas as pd
import warnings
warnings.filterwarnings("ignore", category=DeprecationWarning)

In [3]:
import os

os.environ["WANDB_API_KEY"] = "aee284a72205e2d6787bd3ce266c5b9aefefa42c"
PROJECT_NAME = 'finrl-dt-replicate'

# General funcs

In [4]:
#@title YahooDownloader

"""Contains methods and classes to collect data from
Yahoo Finance API
"""

from __future__ import annotations

import pandas as pd
import yfinance as yf


class YahooDownloader:
    """Provides methods for retrieving daily stock data from
    Yahoo Finance API

    Attributes
    ----------
        start_date : str
            start date of the data (modified from neofinrl_config.py)
        end_date : str
            end date of the data (modified from neofinrl_config.py)
        ticker_list : list
            a list of stock tickers (modified from neofinrl_config.py)

    Methods
    -------
    fetch_data()
        Fetches data from yahoo API

    """

    def __init__(self, start_date: str, end_date: str, ticker_list: list):
        self.start_date = start_date
        self.end_date = end_date
        self.ticker_list = ticker_list

    def fetch_data(self, proxy=None) -> pd.DataFrame:
        """Fetches data from Yahoo API
        Parameters
        ----------

        Returns
        -------
        `pd.DataFrame`
            7 columns: A date, open, high, low, close, volume and tick symbol
            for the specified stock ticker
        """
        # Download and save the data in a pandas DataFrame:
        data_df = pd.DataFrame()
        num_failures = 0
        for tic in self.ticker_list:
            temp_df = yf.download(
                tic, start=self.start_date, end=self.end_date, proxy=proxy
            )
            temp_df["tic"] = tic
            if len(temp_df) > 0:
                # data_df = data_df.append(temp_df)
                data_df = pd.concat([data_df, temp_df], axis=0)
            else:
                num_failures += 1
        if num_failures == len(self.ticker_list):
            raise ValueError("no data is fetched.")
        # reset the index, we want to use numbers as index instead of dates
        data_df = data_df.reset_index()

        try:
            # Convert wide to long format
            # print(f"DATA COLS: {data_df.columns}")
            data_df = data_df.sort_index(axis=1).set_index(['Date']).drop(columns=['tic']).stack(level='Ticker', future_stack=True)
            data_df.reset_index(inplace=True)
            data_df.columns.name = ''

            # convert the column names to standardized names
            data_df.rename(columns={'Ticker': 'Tic', 'Adj Close': 'Adjcp'}, inplace=True)
            data_df.rename(columns={col: col.lower() for col in data_df.columns}, inplace=True)

            columns = [
                "date",
                "tic",
                "open",
                "high",
                "low",
                "close",
                "adjcp",
                "volume",
            ]

            data_df = data_df[columns]
            # use adjusted close price instead of close price
            data_df["close"] = data_df["adjcp"]
            # drop the adjusted close price column
            data_df = data_df.drop(labels="adjcp", axis=1)

        except NotImplementedError:
            print("the features are not supported currently")

        # create day of the week column (monday = 0)
        data_df["day"] = data_df["date"].dt.dayofweek
        # convert date to standard string format, easy to filter
        data_df["date"] = data_df.date.apply(lambda x: x.strftime("%Y-%m-%d"))
        # drop missing data
        data_df = data_df.dropna()
        data_df = data_df.reset_index(drop=True)
        print("Shape of DataFrame: ", data_df.shape)
        # print("Display DataFrame: ", data_df.head())

        data_df = data_df.sort_values(by=["date", "tic"]).reset_index(drop=True)

        return data_df

    def select_equal_rows_stock(self, df):
        df_check = df.tic.value_counts()
        df_check = pd.DataFrame(df_check).reset_index()
        df_check.columns = ["tic", "counts"]
        mean_df = df_check.counts.mean()
        equal_list = list(df.tic.value_counts() >= mean_df)
        names = df.tic.value_counts().index
        select_stocks_list = list(names[equal_list])
        df = df[df.tic.isin(select_stocks_list)]
        return df


In [5]:
#@title construct_daily_index
def construct_daily_index(data_df, date_column='date', new_index_name='date_index'):
    """
    Constructs a daily index from unique dates in the specified column.

    Parameters:
        data_df (pd.DataFrame): The input DataFrame.
        date_column (str): The name of the column containing dates.
        new_index_name (str): The name for the new index.

    Returns:
        pd.DataFrame: DataFrame with a daily index.
    """
    # Get unique dates and create a mapping to daily indices
    total_dates = data_df[date_column].unique()
    date_to_index = {date: idx for idx, date in enumerate(sorted(total_dates))}

    # Map dates to daily indices and set as index
    data_df[new_index_name] = data_df[date_column].map(date_to_index)
    data_df.set_index(new_index_name, inplace=True)
    data_df.index.name = ''  # Remove the index name for simplicity

    return data_df

In [6]:
#@title add_dataset

def get_dataset_name(prefix, train_start, test_start, test_end):
    # Extract year and month
    train_start_str = f"{train_start.year}-{train_start.month:02}"
    test_start_str = f"{test_start.year}-{test_start.month:02}"
    test_end_str = f"{test_end.year}-{test_end.month:02}"

    # Construct the dataset name
    dataset_name = f"{prefix} | {train_start_str} | {test_start_str} | {test_end_str}"
    return dataset_name

def add_dataset(stock_index_name, train_df, test_df):
    if 'datasets' not in globals():
        global datasets
        datasets = {}

    # Ensure datetime format
    if 'date' in train_df.columns:
        train_df.set_index('date', inplace=True)
    train_df.index = pd.to_datetime(train_df.index)

    if 'date' in test_df.columns:
        test_df.set_index('date', inplace=True)
    test_df.index = pd.to_datetime(test_df.index)

    train_start_date = train_df.index[0]
    test_start_date = test_df.index[0]
    test_end_date = test_df.index[-1]

    dataset_name = get_dataset_name(
        stock_index_name,
        train_start_date, test_start_date, test_end_date
    )

    train_df.reset_index(inplace=True)
    test_df.reset_index(inplace=True)

    train_df = construct_daily_index(train_df)
    test_df = construct_daily_index(test_df)

    ticker_list = train_df.tic.unique().tolist()

    datasets[dataset_name] = {
        'train': train_df,
        'test': test_df,
        'metadata': dict(
            stock_index_name = stock_index_name,
            train_start_date = train_start_date,
            test_start_date = test_start_date,
            test_end_date = test_end_date,
            num_tickers = len(ticker_list),
            ticker_list = ticker_list,
        )
    }

# Load data

## DATA: DOW-30 (rolling yearly windows)

In [7]:
#@title download full data
# %%capture

min_test_start_year = 2020
max_test_start_year = 2025

train_years_count = 10
test_years_count = 1.5

min_date = \
    pd.Timestamp(year=min_test_start_year, month=1, day=1) - \
    pd.Timedelta(days=int(train_years_count * 365.2425))

max_date = \
    pd.Timestamp(year=max_test_start_year, month=1, day=1) + \
    pd.Timedelta(days=int(test_years_count * 365.2425))

data_df = YahooDownloader(
    start_date=min_date,
    end_date=max_date,
    ticker_list=config_tickers.DOW_30_TICKER
).fetch_data()

data_df['date'] = pd.to_datetime(data_df['date'])

# clip max year w.r.t. to available data
max_data_date = data_df['date'].max()
max_test_start_year = min(max_test_start_year, max_data_date.year)

[*********************100%***********************]  1 of 1 completed
[*********************100%***********************]  1 of 1 completed
[*********************100%***********************]  1 of 1 completed
[*********************100%***********************]  1 of 1 completed
[*********************100%***********************]  1 of 1 completed
[*********************100%***********************]  1 of 1 completed
[*********************100%***********************]  1 of 1 completed
[*********************100%***********************]  1 of 1 completed
[*********************100%***********************]  1 of 1 completed
[*********************100%***********************]  1 of 1 completed
[*********************100%***********************]  1 of 1 completed
[*********************100%***********************]  1 of 1 completed
[*********************100%***********************]  1 of 1 completed
[*********************100%***********************]  1 of 1 completed
[*********************100%********

Shape of DataFrame:  (110753, 8)


In [8]:
#@title add features

from finrl.meta.preprocessor.preprocessors import FeatureEngineer

fe = FeatureEngineer(use_turbulence=True, use_vix=True)
preprocessed_data_df = fe.preprocess_data(data_df.astype({'date': str}))
preprocessed_data_df.head()

Successfully added technical indicators


[*********************100%***********************]  1 of 1 completed


Shape of DataFrame:  (3768, 8)
Successfully added vix
Successfully added turbulence index


Unnamed: 0,date,tic,open,high,low,close,volume,day,macd,boll_ub,boll_lb,rsi_30,cci_30,dx_30,close_30_sma,close_60_sma,vix,turbulence
0,2010-01-04,AAPL,7.6225,7.660714,7.585,6.447412,493729600.0,0,0.0,6.468749,6.437222,100.0,66.666667,100.0,6.447412,6.447412,21.68,0.0
1,2010-01-04,AMGN,56.630001,57.869999,56.560001,40.915901,5277400.0,0,0.0,6.468749,6.437222,100.0,66.666667,100.0,40.915901,40.915901,21.68,0.0
2,2010-01-04,AXP,40.810001,41.099998,40.389999,32.906174,6894300.0,0,0.0,6.468749,6.437222,100.0,66.666667,100.0,32.906174,32.906174,21.68,0.0
3,2010-01-04,BA,55.720001,56.389999,54.799999,43.77755,6186700.0,0,0.0,6.468749,6.437222,100.0,66.666667,100.0,43.77755,43.77755,21.68,0.0
4,2010-01-04,CAT,57.650002,59.189999,57.509998,39.883911,7325600.0,0,0.0,6.468749,6.437222,100.0,66.666667,100.0,39.883911,39.883911,21.68,0.0


In [9]:
#@title get_train_test_dates
def get_train_test_dates(train_years_count, test_years_count, test_start_year):
    test_start_date = pd.Timestamp(year=test_start_year, month=1, day=1)

    train_start_date = \
        test_start_date - \
        pd.Timedelta(days=int(train_years_count * 365.2425))

    test_end_date = \
        test_start_date + \
        pd.Timedelta(days=int(test_years_count * 365.2425))

    return train_start_date, test_start_date, test_end_date

In [10]:
preprocessed_data_df['date'] = pd.to_datetime(preprocessed_data_df['date'])

for test_start_year in range(min_test_start_year, max_test_start_year + 1):
    train_start_date, test_start_date, test_end_date = get_train_test_dates(
        train_years_count, test_years_count, test_start_year
    )

    # Filter using the 'date' column
    train_df = preprocessed_data_df[(preprocessed_data_df['date'] >= train_start_date) & (preprocessed_data_df['date'] < test_start_date)]
    test_df = preprocessed_data_df[(preprocessed_data_df['date'] >= test_start_date) & (preprocessed_data_df['date'] < test_end_date)]

    # add_dataset('DOW_30', train_df, test_df)

    print(f"Train start: {train_df['date'].min()}, Train end: {train_df['date'].max()}")
    print(f"Test start: {test_df['date'].min()}, Test end: {test_df['date'].max()}")
    print()

# print(*list(datasets.keys()), sep='\n')

Train start: 2010-01-04 00:00:00, Train end: 2019-12-31 00:00:00
Test start: 2020-01-02 00:00:00, Test end: 2021-06-30 00:00:00

Train start: 2011-01-03 00:00:00, Train end: 2020-12-31 00:00:00
Test start: 2021-01-04 00:00:00, Test end: 2022-07-01 00:00:00

Train start: 2012-01-03 00:00:00, Train end: 2021-12-31 00:00:00
Test start: 2022-01-03 00:00:00, Test end: 2023-06-30 00:00:00

Train start: 2013-01-02 00:00:00, Train end: 2022-12-30 00:00:00
Test start: 2023-01-03 00:00:00, Test end: 2024-06-28 00:00:00

Train start: 2014-01-02 00:00:00, Train end: 2023-12-29 00:00:00
Test start: 2024-01-02 00:00:00, Test end: 2024-12-20 00:00:00



# Main

## Config

In [11]:
#@title init
parameters_dict = {}
sweep_config = {
    'method': 'grid',
    'metric': {
        'name': 'max_sharpe_ratio'
    },
    'parameters': parameters_dict
}

In [12]:
#@title used models

parameters_dict.update({
    'if_using_a2c': {'value': True},
    'if_using_ddpg': {'value': True},
    'if_using_ppo': {'value': True},
    'if_using_td3': {'value': True},
    'if_using_sac': {'value': True}
})

In [13]:
#@title date range
parameters_dict.update(dict(
    stock_index_name = {'value': 'DOW-30'},
    train_years_count = {'value': 10},
    test_years_count = {'value': 1},
    test_start_year = {
        # 'value': 2020,
        'values': list(range(2020, 2025))
    }
))

In [14]:
#@title env params
parameters_dict.update(dict(
    env_params = {
        'parameters': dict(
            cost_abs = {'value': 2.5},
            initial_amount = {'value': 50_000},
        )
    },
    REFERENCE_PRICE_END_DATE = {'value': '2024-12-21'},
    REFERNCE_PRICE_WINDOW_DAYS = {'value': 30}
))

## Wandb artifacts

In [15]:
#@title update_artifact

def update_artifact(folder_path, name_prefix, type):
    """
    Create or update a W&B artifact consisting of a folder.

    Args:
        run: The current W&B run.
        folder_path (str): Path to the folder to upload.
        artifact_name (str): Name of the artifact.
        artifact_type (str): Type of the artifact.
    """
    run = wandb.run
    artifact_name = f'{name_prefix}-{wandb.run.id}'

    # Create a new artifact
    artifact = wandb.Artifact(name=artifact_name, type=type)

    # Add the folder to the artifact
    artifact.add_dir(folder_path)

    # Log the artifact to W&B
    run.log_artifact(artifact)
    print(f"Artifact '{artifact_name}' has been updated and uploaded.")

In [16]:
#@title update_model_artifacts

def update_model_artifacts():
    update_artifact(
        folder_path = RESULTS_DIR,
        name_prefix = 'results',
        type = 'results'
    )

    update_artifact(
        folder_path = TRAINED_MODEL_DIR,
        name_prefix = 'trained_models',
        type = 'trained_models'
    )

In [17]:
#@title update_dataset_artifact

from pathlib import Path

def update_dataset_artifact(train, test, config):
    DATASET_DIR = Path('./dataset')
    os.makedirs(DATASET_DIR, exist_ok=True)

    train.to_csv(DATASET_DIR / 'train_data.csv')
    test.to_csv(DATASET_DIR / 'test_data.csv')

    update_artifact(
        folder_path = DATASET_DIR,
        name_prefix = 'dataset',
        type = 'dataset'
    )

## Build funcs

In [18]:
#@title build_dataset
def build_dataset(config):
    # Assuming 'config' is a dictionary or a similar object that contains the necessary parameters

    train_start_date, test_start_date, test_end_date = get_train_test_dates(
        config['train_years_count'],
        config['test_years_count'],
        config['test_start_year']
    )

    train_df = preprocessed_data_df[(preprocessed_data_df['date'] >= train_start_date) & (preprocessed_data_df['date'] < test_start_date)]
    test_df = preprocessed_data_df[(preprocessed_data_df['date'] >= test_start_date) & (preprocessed_data_df['date'] < test_end_date)]

    train_df = construct_daily_index(train_df)
    test_df = construct_daily_index(test_df)

    dataset_name = get_dataset_name(
        config['stock_index_name'], train_start_date, test_start_date, test_end_date
    )

    config.update(dict(
        train_start_date=train_start_date,
        test_start_date=test_start_date,
        test_end_date=test_end_date,
        dataset_name=dataset_name
    ))

    update_dataset_artifact(train_df, test_df, config)
    return train_df, test_df

In [19]:
#@title Calculate fee percent based on average price for past N days

def cost_pct_from_avg_price(df, cost_abs, price_avg_days, verbose=False):
    df['date'] = pd.to_datetime(df['date'])
    avg_price_dict = {}
    for tic, _df in df.groupby('tic'):
        last_date = _df['date'].max()
        _df = _df[_df.date >= last_date - pd.Timedelta(days=price_avg_days)]
        avg_price = ((_df.high + _df.low) / 2).mean()
        avg_price_dict.update({tic: avg_price})

    avg_price_df = pd.DataFrame(avg_price_dict, index=[f'cost_avg']).T
    cost_pct_df = (cost_abs / avg_price_df).rename(columns={'cost_avg': 'cost_pct'})

    if verbose:
        display(avg_price_df.head())
        print()
        display(cost_pct_df.head())

    return cost_pct_df.values.flatten().tolist()

In [20]:
#@title set_cost_pct

def set_cost_pct(train, config):
    # Calculate reference price interval
    REFERENCE_PRICE_END_DATE = config['REFERENCE_PRICE_END_DATE']
    REFERNCE_PRICE_WINDOW_DAYS = config['REFERNCE_PRICE_WINDOW_DAYS']

    ref_price_start_date = \
        pd.Timestamp(REFERENCE_PRICE_END_DATE) \
        - pd.Timedelta(days=REFERNCE_PRICE_WINDOW_DAYS)

    ref_price_df = YahooDownloader(
            start_date=ref_price_start_date,
            end_date=REFERENCE_PRICE_END_DATE,
            ticker_list=train.tic.unique().tolist(),
            # ticker_list=config_tickers.DOW_30_TICKER
        ).fetch_data()

    # Calculate cost
    COST_PCT = cost_pct_from_avg_price(
        df=ref_price_df,
        cost_abs=config['env_params']['cost_abs'],
        price_avg_days=config['REFERNCE_PRICE_WINDOW_DAYS'],
        # verbose=False
    )

    config['env_params'].update({ 'cost_pct': COST_PCT })

    # print(config)

In [21]:
#@title Init env
def init_env(train, config):
    stock_dimension = len(train.tic.unique())
    state_space = 1 + 2*stock_dimension + len(INDICATORS)*stock_dimension
    print(f"Stock Dimension: {stock_dimension}, State Space: {state_space}")

    cost_pct = config['env_params']['cost_pct']
    if isinstance(cost_pct, list):
        assert len(cost_pct) == stock_dimension
        buy_cost_pct = sell_cost_pct = cost_pct
    elif isinstance(cost_pct, (int, float)):
        buy_cost_pct = sell_cost_pct = [ config['COST_PCT'] ] * stock_dimension
    else:
        raise ValueError

    num_stock_shares = [0] * stock_dimension

    env_kwargs = {
        "hmax": 100,
        "initial_amount": config['env_params']['initial_amount'],
        "num_stock_shares": num_stock_shares,
        "buy_cost_pct": buy_cost_pct,
        "sell_cost_pct": sell_cost_pct,
        "state_space": state_space,
        "stock_dim": stock_dimension,
        "tech_indicator_list": INDICATORS,
        "action_space": stock_dimension,
        "reward_scaling": 1e-4,

        "print_verbosity": 1,
        # "make_plots": True
    }

    e_train_gym = StockTradingEnv(df = train, **env_kwargs)
    return e_train_gym

## Train models

In [22]:
#@title SharpeRatioCallback
from stable_baselines3.common.callbacks import BaseCallback
import pandas as pd
import numpy as np

class SharpeRatioCallback(BaseCallback):
    def __init__(self, model_name, verbose=0):
        self.model_name = model_name
        super(SharpeRatioCallback, self).__init__(verbose)
        self.sharpe_ratios = []

    def _on_step(self) -> bool:
        # print(f"LOGGING {self.model_name} sharpe ratio")
        # Access the environment
        env = self.training_env.envs[0]

        # Check if the episode is terminal
        env.terminal = env.day >= len(env.df.index.unique()) - 1
        if env.terminal:
            # breakpoint()
            df_total_value = pd.DataFrame(env.asset_memory, columns=["account_value"])
            df_total_value["date"] = env.date_memory
            df_total_value["daily_return"] = df_total_value["account_value"].pct_change(1)

            # Calculate the Sharpe ratio if standard deviation is non-zero
            if df_total_value["daily_return"].std() != 0:
                sharpe = (
                    (252**0.5)
                    * df_total_value["daily_return"].mean()
                    / df_total_value["daily_return"].std()
                )
                self.sharpe_ratios.append(sharpe)

                # Log the Sharpe ratio
                if self.verbose > 0:
                    print(f"Episode: {env.episode}, Sharpe Ratio: {sharpe:.3f}")

                # Update WandB config with distinct Sharpe ratios
                wandb.log({
                    f'sharpe_ratio/{self.model_name}': sharpe,
                }, step=env.episode)

                # Add to config for instant acess
                if "sharpe_ratios" not in wandb.config:
                    wandb.config.sharpe_ratios = {}

                wandb.config.sharpe_ratios[self.model_name] = sharpe

            print(wandb.run.summary.keys())
        return True

In [23]:
#@title MaxSharpeRatioCallback

import re
import wandb
from stable_baselines3.common.callbacks import BaseCallback

class MaxSharpeRatioCallback(BaseCallback):
    def __init__(self, verbose=0):
        super(MaxSharpeRatioCallback, self).__init__(verbose)

    def _on_step(self) -> None:
        # Access the environment
        env = self.training_env.envs[0]

        # Check if the episode is terminal
        env.terminal = env.day >= len(env.df.index.unique()) - 1
        if env.terminal:
            # print("LOGGING max sharpe ratio")
            # Get the summary from wandb run
            # summary = wandb.run.summary

            # # Create the sharpe_dict by extracting model names and Sharpe ratios
            # sharpe_ratios = {
            #     re.match('sharpe_ratio/(.*)', key).group(1): summary.get(key)
            #     for key in summary.keys() if re.match('sharpe_ratio/(.*)', key)
            # }

            sharpe_ratios = wandb.config.get("sharpe_ratios", {})
            max_sharpe_ratio_model = max(sharpe_ratios, key=sharpe_ratios.get)
            max_sharpe_ratio = sharpe_ratios[max_sharpe_ratio_model]

            # Log the max Sharpe ratio and the corresponding model name
            wandb.log({
                'max_sharpe_ratio': max_sharpe_ratio,
                'max_sharpe_ratio_model': max_sharpe_ratio_model
            })

            # # Optionally print if verbose is enabled
            # if self.verbose > 0:
            #     print(f"Max Sharpe Ratio: {max_sharpe_ratio} from Model: {max_sharpe_ratio_model}")

        return True

In [24]:
#@title Custom DRLAgent (3 callbacks)
from finrl.agents.stablebaselines3.models import DRLAgent, TensorboardCallback
from stable_baselines3.common.callbacks import CallbackList
import wandb

class DRLAgent(DRLAgent):
    @staticmethod
    def train_model(
        model,
        tb_log_name,
        total_timesteps=5000,
        callback=None,  # Allow custom callbacks to be passed
    ):
        # Ensure TensorboardCallback is always included
        tensorboard_callback = TensorboardCallback()

        # Initialize default callbacks
        sharpe_ratio_callback = SharpeRatioCallback(model_name=tb_log_name, verbose=1)
        max_sharpe_ratio_ratio_callback = MaxSharpeRatioCallback(verbose=1)

        # Combine all callbacks (always include Tensorboard, SharpeRatio, and MaxSharpeRatio by default)
        callbacks_to_use = [
            tensorboard_callback,
            sharpe_ratio_callback,
            max_sharpe_ratio_ratio_callback
        ]

        # Add any custom callback passed by the user
        if callback is not None:
            if isinstance(callback, BaseCallback):
                callbacks_to_use.append(callback)
            elif isinstance(callback, list):
                callbacks_to_use.extend(callback)
            else:
                raise ValueError("callback must be None, a BaseCallback, or a list of BaseCallback instances.")

        # Wrap all callbacks into a CallbackList
        combined_callback = CallbackList(callbacks_to_use)

        # Train the model with the combined callbacks
        model = model.learn(
            total_timesteps=total_timesteps,
            tb_log_name=tb_log_name,
            callback=combined_callback,
        )
        return model


In [25]:
#@title train models

def train_models(e_train_gym, config):
    check_and_make_directories([TRAINED_MODEL_DIR])

    env_train, _ = e_train_gym.get_sb_env()
    print(type(env_train))

    # Set the corresponding values to 'True' for the algorithms that you want to use

    # Load variables from the config
    if_using_a2c = config["if_using_a2c"]
    if_using_ddpg = config["if_using_ddpg"]
    if_using_ppo = config["if_using_ppo"]
    if_using_td3 = config["if_using_td3"]
    if_using_sac = config["if_using_sac"]

    if if_using_a2c:
        print("training A2C agent")
        agent = DRLAgent(env = env_train)
        model_a2c = agent.get_model("a2c")

        # set up logger
        tmp_path = RESULTS_DIR + '/a2c'
        !rm -rf {tmp_path}/*
        new_logger_a2c = configure(tmp_path, ["stdout", "csv", "tensorboard"])
        # Set new logger
        model_a2c.set_logger(new_logger_a2c)

        trained_a2c = agent.train_model(model=model_a2c,
                                        tb_log_name='a2c',
                                        total_timesteps=50000) if if_using_a2c else None

        trained_a2c.save(TRAINED_MODEL_DIR + "/agent_a2c") if if_using_a2c else None
        update_model_artifacts()

    if if_using_ddpg:
        print("training DDPG agent")
        agent = DRLAgent(env = env_train)
        model_ddpg = agent.get_model("ddpg")

        # set up logger
        tmp_path = RESULTS_DIR + '/ddpg'
        !rm -rf {tmp_path}/*
        new_logger_ddpg = configure(tmp_path, ["stdout", "csv", "tensorboard"])
        # Set new logger
        model_ddpg.set_logger(new_logger_ddpg)

        trained_ddpg = agent.train_model(model=model_ddpg,
                                tb_log_name='ddpg',
                                total_timesteps=50000) if if_using_ddpg else None

        trained_ddpg.save(TRAINED_MODEL_DIR + "/agent_ddpg") if if_using_ddpg else None
        update_model_artifacts()

    if if_using_td3:
        print("training TD3 agent")
        agent = DRLAgent(env = env_train)
        TD3_PARAMS = {"batch_size": 100,
                    "buffer_size": 1000000,
                    "learning_rate": 0.001}

        model_td3 = agent.get_model("td3",model_kwargs = TD3_PARAMS)

        # set up logger
        tmp_path = RESULTS_DIR + '/td3'
        !rm -rf {tmp_path}/*
        new_logger_td3 = configure(tmp_path, ["stdout", "csv", "tensorboard"])
        # Set new logger
        model_td3.set_logger(new_logger_td3)

        trained_td3 = agent.train_model(model=model_td3,
                                tb_log_name='td3',
                                total_timesteps=50000) if if_using_td3 else None

        trained_td3.save(TRAINED_MODEL_DIR + "/agent_td3") if if_using_td3 else None
        update_model_artifacts()

    if if_using_sac:
        print("training SAC agent")
        agent = DRLAgent(env = env_train)
        SAC_PARAMS = {
            "batch_size": 128,
            "buffer_size": 100000,
            "learning_rate": 0.0001,
            "learning_starts": 100,
            "ent_coef": "auto_0.1",
        }

        model_sac = agent.get_model("sac",model_kwargs = SAC_PARAMS)

        # set up logger
        tmp_path = RESULTS_DIR + '/sac'
        !rm -rf {tmp_path}/*
        new_logger_sac = configure(tmp_path, ["stdout", "csv", "tensorboard"])
        # Set new logger
        model_sac.set_logger(new_logger_sac)

        trained_sac = agent.train_model(model=model_sac,
                                tb_log_name='sac',
                                total_timesteps=70000) if if_using_sac else None
        trained_sac.save(TRAINED_MODEL_DIR + "/agent_sac") if if_using_sac else None
        update_model_artifacts()

    if if_using_ppo:
        agent = DRLAgent(env = env_train)
        PPO_PARAMS = {
            "n_steps": 2048,
            "ent_coef": 0.01,
            "learning_rate": 0.00025,
            "batch_size": 128,
        }
        model_ppo = agent.get_model("ppo",model_kwargs = PPO_PARAMS)
        # set up logger
        tmp_path = RESULTS_DIR + '/ppo'
        !rm -rf {tmp_path}/*
        new_logger_ppo = configure(tmp_path, ["stdout", "csv", "tensorboard"])
        # Set new logger
        model_ppo.set_logger(new_logger_ppo)

        trained_ppo = agent.train_model(model=model_ppo,
                                tb_log_name='ppo',
                                total_timesteps=200000) if if_using_ppo else None

        trained_ppo.save(TRAINED_MODEL_DIR + "/agent_ppo") if if_using_ppo else None
        update_model_artifacts()

In [26]:
#@title train
from pprint import pprint
import wandb

import random
import string

def generate_run_name(prefix, n=5):
    random_str = ''.join(random.choices(string.ascii_letters + string.digits, k=n))
    return f"{prefix} | {random_str}"

def train(config=None):
    # Initialize a new wandb run using the context manager
    with wandb.init(config=config):
        config = wandb.config

        # Build the dataset
        train_df, test_df = build_dataset(config)

        wandb.run.name = generate_run_name(config['dataset_name'])
        wandb.run.save()

        # Set the cost percentage (or any other constants you need to set)
        set_cost_pct(train_df, config)

        # Initialize the training environment
        e_train_gym = init_env(train_df, config)

        # Train the models (assuming this function will handle the model training)
        train_models(e_train_gym, config)

In [27]:
sweep_id = wandb.sweep(sweep_config, project=PROJECT_NAME)
wandb.agent(sweep_id, train, count=1)

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


Create sweep with ID: pknx8zsp
Sweep URL: https://wandb.ai/overfit1010/finrl-dt-replicate/sweeps/pknx8zsp


[34m[1mwandb[0m: Agent Starting Run: 2qo4lebf with config:
[34m[1mwandb[0m: 	REFERENCE_PRICE_END_DATE: 2024-12-21
[34m[1mwandb[0m: 	REFERNCE_PRICE_WINDOW_DAYS: 30
[34m[1mwandb[0m: 	env_params: {'cost_abs': 2.5, 'initial_amount': 50000}
[34m[1mwandb[0m: 	if_using_a2c: True
[34m[1mwandb[0m: 	if_using_ddpg: True
[34m[1mwandb[0m: 	if_using_ppo: True
[34m[1mwandb[0m: 	if_using_sac: True
[34m[1mwandb[0m: 	if_using_td3: True
[34m[1mwandb[0m: 	stock_index_name: DOW-30
[34m[1mwandb[0m: 	test_start_year: 2020
[34m[1mwandb[0m: 	test_years_count: 1
[34m[1mwandb[0m: 	train_years_count: 10
[34m[1mwandb[0m: Currently logged in as: [33mtony-pitchblack[0m ([33moverfit1010[0m). Use [1m`wandb login --relogin`[0m to force relogin


[34m[1mwandb[0m: Adding directory to artifact (./dataset)... Done. 0.1s
[*********************100%***********************]  1 of 1 completed

Artifact 'dataset-2qo4lebf' has been updated and uploaded.



[*********************100%***********************]  1 of 1 completed
[*********************100%***********************]  1 of 1 completed
[*********************100%***********************]  1 of 1 completed
[*********************100%***********************]  1 of 1 completed
[*********************100%***********************]  1 of 1 completed
[*********************100%***********************]  1 of 1 completed
[*********************100%***********************]  1 of 1 completed
[*********************100%***********************]  1 of 1 completed
[*********************100%***********************]  1 of 1 completed
[*********************100%***********************]  1 of 1 completed
[*********************100%***********************]  1 of 1 completed
[*********************100%***********************]  1 of 1 completed
[*********************100%***********************]  1 of 1 completed
[*********************100%***********************]  1 of 1 completed
[*********************100%*******

Shape of DataFrame:  (609, 8)
{'REFERENCE_PRICE_END_DATE': '2024-12-21', 'REFERNCE_PRICE_WINDOW_DAYS': 30, 'env_params': {'cost_abs': 2.5, 'initial_amount': 50000, 'cost_pct': [0.010314179718673032, 0.009076606561755468, 0.008307067100246136, 0.015516293581809113, 0.006395481989312903, 0.0072110283262062785, 0.04249566955190291, 0.015979716445954528, 0.021738725278835313, 0.004222453318177087, 0.005968143721526647, 0.010896151366589465, 0.01093735188446354, 0.11420864277075855, 0.01665255165637118, 0.010252194206940846, 0.03947813656400475, 0.008462643895002689, 0.019197296925249206, 0.024719550845968045, 0.005734449144229167, 0.03210832438433615, 0.014379031318205482, 0.009825802549097859, 0.0044498520553265045, 0.007984135151425499, 0.05881990501271046, 0.270946766931955, 0.026832260273111818]}, 'if_using_a2c': True, 'if_using_ddpg': True, 'if_using_ppo': True, 'if_using_sac': True, 'if_using_td3': True, 'stock_index_name': 'DOW-30', 'test_start_year': 2020, 'test_years_count': 1, 't

[34m[1mwandb[0m: Ctrl + C detected. Stopping sweep.
