<a href="https://colab.research.google.com/github/tony-pitchblack/finrl-dt/blob/custom-backtesting/finrl_dt_replicate_eval.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Installs

In [None]:
%%capture
!pip install stable-baselines3
!pip install finrl
!pip install alpaca_trade_api
!pip install exchange_calendars
!pip install stockstats
!pip install wrds

# Imports

In [None]:
import os
import wandb
import pandas as pd

In [None]:
os.environ["WANDB_API_KEY"] = "aee284a72205e2d6787bd3ce266c5b9aefefa42c"

PROJECT = 'finrl-dt-replicate'
ENTITY = "overfit1010"

# Main

In [None]:
RUN_ID = '1chsdmof'

In [None]:
# !rm -rf ./*

In [None]:
wandb.init(entity=ENTITY, project=PROJECT, id=RUN_ID, resume='must')
config = wandb.run.config

In [None]:
#@title Download artifacts

def download_artifacts(run_id):
    # Initialize the W&B API
    api = wandb.Api()

    # Retrieve the run
    run = api.run(f"{ENTITY}/{PROJECT}/{run_id}")

    # Iterate over the artifacts used or logged by the run
    for artifact in run.logged_artifacts():
        artifact.download(f'./{artifact.type}')

download_artifacts(RUN_ID)

True


[34m[1mwandb[0m: \ 1 of 3 files downloaded...[34m[1mwandb[0m:   3 of 3 files downloaded.  
[34m[1mwandb[0m:   1 of 1 files downloaded.  


In [None]:
# %load_ext tensorboard
# %tensorboard --logdir ./results/

In [None]:
#@title Load data

def load_data(config):
    train_df = pd.read_csv('./dataset/train_data.csv')
    train_df.set_index('date', inplace=True)

    test_df = pd.read_csv('./dataset/test_data.csv')
    test_df.set_index('date', inplace=True)

    if wandb.config['dataset_type'] == 'quarterly_train_val_test':
        val_df = pd.read_csv('./dataset/val_data.csv')
        val_df.set_index('date', inplace=True)
    else:
        val_df = None

    return train_df, val_df, test_df

train, valid, trade = load_data(config)

In [None]:
#@title imports

import sys
import os
import re
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import pickle
from finrl.meta.preprocessor.yahoodownloader import YahooDownloader
from finrl.meta.env_stock_trading.env_stocktrading import StockTradingEnv
from finrl.agents.stablebaselines3.models import DRLAgent
from stable_baselines3 import A2C, DDPG, PPO, SAC, TD3

from finrl.config import INDICATORS, TRAINED_MODEL_DIR

import warnings
warnings.filterwarnings("ignore", category=DeprecationWarning)
warnings.simplefilter(action='ignore', category=FutureWarning)

In [None]:
#@title Define metric functions

def calculate_mdd(asset_values):
    """
    Calculate the Maximum Drawdown (MDD) of a portfolio.
    """
    running_max = asset_values.cummax()
    drawdown = (asset_values - running_max) / running_max
    mdd = drawdown.min() * 100  # Convert to percentage
    return mdd

def calculate_sharpe_ratio(asset_values, risk_free_rate=0.0):
    """
    Calculate the Sharpe Ratio of a portfolio.
    """
    # Calculate daily returns
    returns = asset_values.pct_change().dropna()
    excess_returns = returns - risk_free_rate / 252  # Assuming 252 trading days
    if excess_returns.std() == 0:
        return 0.0
    sharpe_ratio = excess_returns.mean() / excess_returns.std() * np.sqrt(252)  # Annualized
    return sharpe_ratio

def calculate_annualized_return(asset_values):
    """
    Calculate the annualized return of a portfolio.
    """
    # Assume `asset_values` is indexed by date or trading day
    total_return = (asset_values.iloc[-1] / asset_values.iloc[0] - 1) * 100
    num_days = (asset_values.index[-1] - asset_values.index[0]).days
    annualized_return = (1 + total_return) ** (365 / num_days) - 1
    return annualized_return

## Evaluate (train / test)

In [None]:
#@title construct_daily_index
def make_daily_index(data_df, date_column='date', new_index_name='date_index'):
    # Get unique dates and create a mapping to daily indices
    total_dates = data_df[date_column].unique()
    date_to_index = {date: idx for idx, date in enumerate(sorted(total_dates))}
    return data_df[date_column].map(date_to_index)

def set_daily_index(data_df, date_column='date', new_index_name='date_index'):
    """
    Constructs a daily index from unique dates in the specified column.

    Parameters:
        data_df (pd.DataFrame): The input DataFrame.
        date_column (str): The name of the column containing dates.
        new_index_name (str): The name for the new index.

    Returns:
        pd.DataFrame: DataFrame with a daily index.
    """

    # Map dates to daily indices and set as index
    data_df[new_index_name] = make_daily_index(data_df, date_column='date', new_index_name='date_index')

    data_df.set_index(new_index_name, inplace=True)
    data_df.index.name = ''  # Remove the index name for simplicity

    return data_df

def fix_daily_index(df):
    if df.index.name == 'date':
        df.reset_index(inplace=True)

    daily_index = make_daily_index(df, date_column='date', new_index_name='date_index')
    if (df.index.values != daily_index.values).any():

        df.index = daily_index
        df.index.name = ''

    return df

# trade = fix_daily_index(trade)
# trade.index

In [None]:
#@title init env

def init_env(trade, config):
    # Define environment parameters
    stock_dimension = len(trade.tic.unique())
    state_space = 1 + 2 * stock_dimension + len(INDICATORS) * stock_dimension
    print(f"Stock Dimension: {stock_dimension}, State Space: {state_space}")

    num_stock_shares = [0] * stock_dimension

    env_kwargs = {
        "hmax": 100,
        "initial_amount": config['initial_amount'],
        "num_stock_shares": num_stock_shares,
        "buy_cost_pct": config['cost_pct'],
        "sell_cost_pct": config['cost_pct'],
        "state_space": state_space,
        "stock_dim": stock_dimension,
        "tech_indicator_list": INDICATORS,
        "action_space": stock_dimension,
        "reward_scaling": 1e-4
    }

    # Initialize trading environment
    e_trade_gym = StockTradingEnv(
        df=trade,
        turbulence_threshold=70, risk_indicator_col='vix',
        **env_kwargs
    )

    return e_trade_gym

e_trade_gym = init_env(trade, config)

Stock Dimension: 29, State Space: 291


In [None]:
#@title get metrics
import wandb

def get_account_value_metrics(df_account_value: pd.DataFrame):
    """
    Takes a DataFrame with account value (total asset value) per each date.
    Should contain two columns - 'date' and name of algo, e.g. 'a2c'.
    """

    assert isinstance(df_account_value, pd.DataFrame)
    assert 'date' in df_account_value.columns
    assert len(df_account_value.columns) == 2

    account_values = df_account_value.dropna().set_index('date').iloc[:, 0]
    sharpe = calculate_sharpe_ratio(account_values)
    mdd = calculate_mdd(account_values)
    cum_ret = (account_values.iloc[-1] - account_values.iloc[0]) / account_values.iloc[0] * 100
    num_days = (account_values.index[-1] - account_values.index[0]).days
    ann_ret = ((1 + cum_ret / 100) ** (365 / num_days) - 1) * 100

    return {
            f'sharpe_ratio': sharpe,
            f'mdd': mdd,
            f'ann_return': ann_ret,
            f'cum_return': cum_ret,
        }

def get_env_metrics(env):
    end_total_asset = env.state[0] + sum(
        np.array(env.state[1 : (env.stock_dim + 1)])
        * np.array(env.state[(env.stock_dim + 1) : (env.stock_dim * 2 + 1)])
    )

    return {
        'begin_total_asset': env.asset_memory[0],
        'end_total_asset': end_total_asset,
        'total_cost': env.cost,
        'total_trades': env.trades,
    }

In [None]:
#@title log metrics

def log_metrics(metrics, model_name, split_label, best_model=False, step=None):
    assert model_name in ['a2c', 'ddpg', 'sac', 'ppo', 'td3']

    rename_metrics = lambda model_name: {
        f"{key}/{model_name}": value for key, value in metrics.items()
    }

    if best_model:
        renamed_metrics = rename_metrics('best_model')

        # step_metric_name = f"{split_label}/step"
        # wandb.define_metric(step_metric_name)
        # for key in renamed_metrics:
        #     wandb.define_metric(key, step_metric=step_metric_name)

        wandb.log({split_label: renamed_metrics}, step=step)
        wandb.log({'best_model_name': model_name})
    else:
        wandb.log({split_label: rename_metrics(model_name)}, step=step)

    # wandb.run.save()

def update_best_model_metrics(metrics, model_name, split_label):
    print(f"DEBUG ({split_label}): {wandb.run.id}")
    if 'sharpe_ratios' not in wandb.run.config:
        wandb.run.config['sharpe_ratios'] = {}

    if split_label not in wandb.run.config['sharpe_ratios']:
        wandb.run.config['sharpe_ratios'][split_label] = {}

    sharpe_ratios = wandb.run.config['sharpe_ratios'][split_label]

    print(f"DEBUG ({split_label}): updating best model based on sharpe_ratios: {sharpe_ratios}")
    if len(sharpe_ratios) > 0:
        best_model_name = max(sharpe_ratios, key=sharpe_ratios.get)
        if metrics['sharpe_ratio'] > sharpe_ratios[best_model_name]:
            print(
                f"DEBUG ({split_label}): {round(metrics['sharpe_ratio'], 2)} ({model_name})"
                f" > {round(sharpe_ratios[best_model_name], 2)} ({best_model_name})"
                f". New best model: {model_name}."
            )
            log_metrics(metrics, model_name, split_label, best_model=True)
        else:
            print(
                f"DEBUG ({split_label}): {round(metrics['sharpe_ratio'], 2)} ({model_name})"
                f" <= {round(sharpe_ratios[best_model_name], 2)} ({best_model_name})"
                ". Not updating best model."
            )
    else:
        print(f"DEBUG ({split_label}): no models logged yet, new best model is current one: {model_name}")
        log_metrics(metrics, model_name, split_label, best_model=True)

    sharpe_ratios[model_name] = metrics['sharpe_ratio']

In [None]:
#@title result_df_to_metrics_dict

def result_df_to_metrics_dict(metrics_summary_df):
    cols_to_rename = {
        'Method': 'model_name',
        'Cumulative Return (%)': 'cum_return',
        'Annualized Return (%)': 'ann_return',
        'MDD (%)': 'mdd',
        'Sharpe Ratio': 'sharpe_ratio'
    }

    extracted_metrics = metrics_summary_df.rename(columns=cols_to_rename).T.to_dict()
    metrics_all = {model_data['model_name'].lower(): model_data for idx, model_data in extracted_metrics.items()}
    for model_name, metrics in metrics_all.items():
        metrics.pop('model_name')
        for metric, value in metrics.items():
            metrics_all[model_name][metric] = float(value)

    return metrics_all

result_df_to_metrics_dict(metrics_summary_df)

{'a2c': {'cum_return': 2.27,
  'ann_return': 9.52,
  'mdd': -3.61,
  'sharpe_ratio': 0.78}}

In [None]:
#@title Custom DRLAgent (custom prediction)
from finrl.agents.stablebaselines3.models import DRLAgent, TensorboardCallback
from stable_baselines3.common.callbacks import CallbackList
import wandb

class DRLAgent(DRLAgent):
    def DRL_prediction(model, environment, deterministic=True):
        """make a prediction and get results"""
        test_env, test_obs = environment.get_sb_env()
        account_memory = None  # This help avoid unnecessary list creation
        actions_memory = None  # optimize memory consumption
        # state_memory=[] #add memory pool to store states

        test_env.reset()
        max_steps = len(environment.df.index.unique()) - 1

        for i in range(len(environment.df.index.unique())):
            action, _states = model.predict(test_obs, deterministic=deterministic)
            # account_memory = test_env.env_method(method_name="save_asset_memory")
            # actions_memory = te    @staticmethodst_env.env_method(method_name="save_action_memory")
            test_obs, rewards, dones, info = test_env.step(action)

            if (
                i == max_steps - 1
            ):  # more descriptive condition for early termination to clarify the logic
                account_memory = test_env.env_method(method_name="save_asset_memory")
                actions_memory = test_env.env_method(method_name="save_action_memory")
            # add current state to state memory
            # state_memory=test_env.env_method(method_name="save_state_memory")

            if dones[0]:
                print("hit end!")
                break

        # env_metrics = get_env_metrics(test_env.envs[0])
        # print(test_env.envs[0].state)
        # print(env_metrics)

        return account_memory[0], actions_memory[0]

In [None]:
#@title get predictions

def get_predictions(e_trade_gym, config):
    # Load variables from the config
    if_using_a2c = config["if_using_a2c"]
    if_using_ddpg = config["if_using_ddpg"]
    if_using_ppo = config["if_using_ppo"]
    if_using_td3 = config["if_using_td3"]
    if_using_sac = config["if_using_sac"]

    # Ensure at least one algorithm is enabled
    if not any([if_using_a2c, if_using_ddpg, if_using_ppo, if_using_td3, if_using_sac]):
        raise ValueError("At least one algorithm must be set to True for the script to run.")

    # Load trained models
    trained_a2c = A2C.load(os.path.join(TRAINED_MODEL_DIR, "agent_a2c")) if if_using_a2c else None
    trained_ddpg = DDPG.load(os.path.join(TRAINED_MODEL_DIR, "agent_ddpg")) if if_using_ddpg else None
    trained_ppo = PPO.load(os.path.join(TRAINED_MODEL_DIR, "agent_ppo")) if if_using_ppo else None
    trained_td3 = TD3.load(os.path.join(TRAINED_MODEL_DIR, "agent_td3")) if if_using_td3 else None
    trained_sac = SAC.load(os.path.join(TRAINED_MODEL_DIR, "agent_sac")) if if_using_sac else None

    # Reset the result DataFrame
    result = pd.DataFrame()

    # Predict and store results for all enabled algorithms
    for algo_name, trained_model, is_enabled in [
        ("A2C", trained_a2c, if_using_a2c),
        ("DDPG", trained_ddpg, if_using_ddpg),
        ("PPO", trained_ppo, if_using_ppo),
        ("TD3", trained_td3, if_using_td3),
        ("SAC", trained_sac, if_using_sac)
    ]:
        if is_enabled:
            df_account_value, df_actions = DRLAgent.DRL_prediction(
                model=trained_model, environment=e_trade_gym
            )
            df_result = df_account_value.set_index('date')

            # env_metrics = get_env_metrics(e_trade_gym)
            # print(env_metrics)

            # display(df_result)
            # break

            df_result.columns = [f"{algo_name}_{col}" for col in df_result.columns]
            result = pd.merge(result, df_result, how="outer", left_index=True, right_index=True)

    # Create a dictionary with the mapping of old columns to model names
    rename_dict = {
        'A2C_account_value': 'A2C',
        'DDPG_account_value': 'DDPG',
        'PPO_account_value': 'PPO',
        'TD3_account_value': 'TD3',
        'SAC_account_value': 'SAC'
    }

    # Rename columns using the dictionary and ignore errors for columns not in the dictionary
    result = result.rename(columns=rename_dict, errors='ignore')
    result.index = pd.to_datetime(result.index)
    return result

# result = get_predictions(e_trade_gym, config)
# result.head()

In [None]:
#@title add DIJA for test period

def add_djia_test(trade, result, config):
    # Define test period
    TEST_START_DATE = trade['date'].iloc[0]
    TEST_END_DATE = trade['date'].iloc[-1]

    print(TEST_START_DATE)
    print(TEST_END_DATE)

    # Fetch DJIA data for the test period
    df_dji = YahooDownloader(
        start_date=TEST_START_DATE,
        end_date=TEST_END_DATE,
        ticker_list=['^DJI'] # `dji` is delisted, `DJIA` is an ETF, not an index
    ).fetch_data()
    df_dji['date'] = pd.to_datetime(df_dji['date'])

    # Merge DJIA data
    df_dji = df_dji[['date','close']]
    fst_day = df_dji['close'].iloc[0]
    dji = pd.DataFrame({
        'DJIA': df_dji['close'].div(fst_day).mul(config['initial_amount']).values
        # 'DJIA': df_dji['close'].div(fst_day).values
    }, index=df_dji['date'])

    # Merge DJIA data using inner join to ensure alignment
    result = pd.merge(result, dji, how='inner', left_index=True, right_index=True).fillna(method='bfill')

    return result

# result = add_djia_test(result)
# result

In [None]:
#@title Calculate metrics for individual algorithms (w/annualized returns)

def calculate_metrics(result):
    label_mapping = {
        # 'DT_LoRA_GPT2': 'DT-LoRA-GPT2',
        # 'DT_LoRA_Random_Weight_GPT2': 'DT-LoRA-Random-GPT2',
        # 'CQL': 'Conservative Q-Learning',
        # 'IQL': 'Implicit Q-Learning',
        # 'BC': 'Behavior Cloning',
        'A2C': 'A2C',
        'DDPG': 'DDPG',
        'PPO': 'PPO',
        'TD3': 'TD3',
        'SAC': 'SAC',
        'DJIA': 'Dow Jones Index'
    }

    metrics_dict = {
        'Method': [],
        'Cumulative Return Mean (%)': [],
        # 'Cumulative Return Std (%)': [],
        'Annualized Return Mean (%)': [],
        # 'Annualized Return Std (%)': [],
        'MDD Mean (%)': [],
        'MDD Std (%)': [],
        'Sharpe Ratio Mean': [],
        'Sharpe Ratio Std': []
    }

    experiment_stats = {}
    individual_algos = ['A2C', 'DDPG', 'TD3', 'SAC', 'PPO']
    for algo in individual_algos:
        if algo in result.columns:
            # Check if this algorithm is already part of experiment_groups
            if label_mapping.get(algo, algo) in experiment_stats:
                print(f"Info: '{algo}' is already included in experiment groups. Skipping individual plotting to avoid duplication.")
                continue  # Skip to prevent duplicate plotting

            account_values = result[algo].dropna()
            if account_values.empty:
                print(f"Warning: No valid asset values for individual algorithm '{algo}'. Skipping metrics calculation.")
                continue

            cum_ret = (account_values.iloc[-1] - account_values.iloc[0]) / account_values.iloc[0] * 100

            # Handle potential division by zero or invalid calculations
            if np.isinf(cum_ret) or np.isnan(cum_ret):
                cum_ret = np.nan

            # Calculate annualized return
            num_days = (account_values.index[-1] - account_values.index[0]).days
            ann_ret = ( (1 + cum_ret / 100) ** (365 / num_days) - 1 ) * 100

            # MDD
            mdd = calculate_mdd(account_values)
            # Sharpe Ratio
            sharpe = calculate_sharpe_ratio(account_values)
            # Append to metrics_dict with mapped label
            mapped_algo = label_mapping.get(algo, algo)
            metrics_dict['Method'].append(mapped_algo)
            metrics_dict['Cumulative Return Mean (%)'].append(cum_ret)
            # metrics_dict['Cumulative Return Std (%)'].append(0.00)  # Single run, std is 0
            metrics_dict['Annualized Return Mean (%)'].append(ann_ret)
            # metrics_dict['Annualized Return Std (%)'].append(0.00)  # Single run, std is 0
            metrics_dict['MDD Mean (%)'].append(mdd)
            metrics_dict['MDD Std (%)'].append(0.00)  # Single run, std is 0
            metrics_dict['Sharpe Ratio Mean'].append(sharpe)
            metrics_dict['Sharpe Ratio Std'].append(0.00)  # Single run, std is 0

            # Store in experiment_stats for plotting
            experiment_stats[mapped_algo] = {'mean': account_values, 'std': pd.Series([0]*len(account_values), index=account_values.index)}

    # Convert metrics_dict to DataFrame
    metrics_df = pd.DataFrame(metrics_dict)

    # Drop any rows with NaN metrics to ensure clean tables
    metrics_df = metrics_df.dropna(subset=['Cumulative Return Mean (%)', 'Annualized Return Mean (%)', 'MDD Mean (%)', 'Sharpe Ratio Mean'])

    # Create summary DataFrame with formatted strings
    metrics_summary_df = metrics_df.copy()
    metrics_summary_df['Cumulative Return (%)'] = metrics_df['Cumulative Return Mean (%)'].round(2).astype(str) \
        #  + " ± " + metrics_df['Cumulative Return Std (%)'].round(2).astype(str)
    metrics_summary_df['Annualized Return (%)'] = metrics_df['Annualized Return Mean (%)'].round(2).astype(str) \
        #  + " ± " + metrics_df['Annualized Return Std (%)'].round(2).astype(str)
    metrics_summary_df['MDD (%)'] = metrics_df['MDD Mean (%)'].round(2).astype(str) \
        #  + " ± " + metrics_df['MDD Std (%)'].round(2).astype(str)
    metrics_summary_df['Sharpe Ratio'] = metrics_df['Sharpe Ratio Mean'].round(2).astype(str) \
        #  + " ± " + metrics_df['Sharpe Ratio Std'].round(2).astype(str)
    metrics_summary_df = metrics_summary_df[['Method', 'Cumulative Return (%)', 'Annualized Return (%)', 'MDD (%)', 'Sharpe Ratio']]

    return metrics_df, metrics_summary_df, experiment_stats

metrics_df, metrics_summary_df, experiment_stats = calculate_metrics(result)

In [None]:
#@title Print the comparison table (w/annualized returns)

def print_comparison_table(metrics_summary_df, metrics_df):
    print(f"\n=== Metrics Comparison ===")
    print(metrics_summary_df.to_string(index=False))
    print("\n")

    # Create separate DataFrames for rankings
    ranking_cum_ret = metrics_df[['Method', 'Cumulative Return Mean (%)']].copy()
    ranking_cum_ret = ranking_cum_ret.sort_values(by='Cumulative Return Mean (%)', ascending=False)

    ranking_annualized_ret = metrics_df[['Method', 'Annualized Return Mean (%)']].copy()
    ranking_annualized_ret = ranking_annualized_ret.sort_values(by='Annualized Return Mean (%)', ascending=False)

    ranking_mdd = metrics_df[['Method', 'MDD Mean (%)']].copy()
    ranking_mdd = ranking_mdd.sort_values(by='MDD Mean (%)', ascending=False)  # Lower abs(MDD) is better

    ranking_sharpe = metrics_df[['Method', 'Sharpe Ratio Mean']].copy()
    ranking_sharpe = ranking_sharpe.sort_values(by='Sharpe Ratio Mean', ascending=False)

    # Print rankings
    print(f"=== Rankings ===")

    print("\nCumulative Return (%):")
    for idx, row in ranking_cum_ret.iterrows():
        print(f"{row['Method']}: {row['Cumulative Return Mean (%)']:.2f}%")

    print("\nAnnualized Return (%):")
    for idx, row in ranking_annualized_ret.iterrows():
        print(f"{row['Method']}: {row['Annualized Return Mean (%)']:.2f}%")

    print("\nMaximum Drawdown (MDD %) [Lower absolute values is Better]:")
    for idx, row in ranking_mdd.iterrows():
        print(f"{row['Method']}: {row['MDD Mean (%)']:.2f}%")

    print("\nSharpe Ratio [Higher is Better]:")
    for idx, row in ranking_sharpe.iterrows():
        print(f"{row['Method']}: {row['Sharpe Ratio Mean']:.2f}")

    print("\n")

    # # Debugging: Check if all means align with result.index
    # for exp_name, stats in experiment_stats.items():
    #     mean_length = len(stats['mean'])
    #     result_length = len(result.index)
    #     if mean_length != result_length:
    #         print(f"Warning: Mean length for '{exp_name}' ({mean_length}) does not match result index length ({result_length}). Reindexing.")
    #         experiment_stats[exp_name]['mean'] = stats['mean'].reindex(result.index).fillna(method='ffill')
    #         experiment_stats[exp_name]['std'] = stats['std'].reindex(result.index).fillna(0)

# print_comparison_table(metrics_summary_df, metrics_df)

In [None]:
#@title log_test_metrics

def log_test_metrics(metrics_summary_df):
    # Identify the best model based on maximum Sharpe Ratio
    best_model_row = metrics_summary_df.loc[metrics_summary_df['Sharpe Ratio'].idxmax()]
    best_model_name = best_model_row['Method'].lower()  # Convert the model name to lowercase

    # Log metrics for each model
    for _, row in metrics_summary_df.iterrows():
        model_name = row['Method'].lower()  # Convert the model name to lowercase

        # Log per-model metrics
        wandb.run.log({
            f'test/cum_return/{model_name}': row['Cumulative Return (%)'],
            f'test/ann_return/{model_name}': row['Annualized Return (%)'],
            f'test/mdd/{model_name}': row['MDD (%)'],
            f'test/sharpe_ratio/{model_name}': row['Sharpe Ratio']
        })

    # Log metrics for the best model
    wandb.run.log({
        'test/cum_return/best_model': best_model_row['Cumulative Return (%)'],
        'test/ann_return/best_model': best_model_row['Annualized Return (%)'],
        'test/mdd/best_model': best_model_row['MDD (%)'],
        'test/sharpe_ratio/best_model': best_model_row['Sharpe Ratio'],
        'test/best_model_name': best_model_name
    })

    print(wandb.run.config)

# Example usage:
# log_test_metrics(metrics_summary_df)


In [None]:
#@title log_metrics
def log_metrics(metrics, model_name, split_label, best_model=False, step=None):
    assert model_name in ['a2c', 'ddpg', 'sac', 'ppo', 'td3']

    rename_metrics = lambda model_name: {
        f"{key}/{model_name}": value for key, value in metrics.items()
    }

    if best_model:
        renamed_metrics = rename_metrics('best_model')

        # step_metric_name = f"{split_label}/step"
        # wandb.define_metric(step_metric_name)
        # for key in renamed_metrics:
        #     wandb.define_metric(key, step_metric=step_metric_name)

        wandb.log({split_label: renamed_metrics}, step=step)
        wandb.log({'best_model_name': model_name})
    else:
        wandb.log({split_label: rename_metrics(model_name)}, step=step)

    # wandb.run.save()

def update_best_model_metrics(metrics, model_name, split_label):
    # api = wandb.Api()
    # run = api.run(f"{wandb.run.project}/{wandb.run.id}")

    # sharpe_ratios = {
    #     full_metric_name.split('/')[1]: sharpe_ratio_value
    #     for full_metric_name, sharpe_ratio_value in run.summary.items()
    #     if full_metric_name.startswith(f'{split_label}.sharpe_ratio/')
    # } # { new_model_name: sharpe_ratio_value }

    if 'sharpe_ratios' not in wandb.run.config:
        wandb.run.config['sharpe_ratios'] = {}

    if split_label not in wandb.run.config['sharpe_ratios']:
        wandb.run.config['sharpe_ratios'][split_label] = {}

    sharpe_ratios = wandb.run.config['sharpe_ratios'][split_label]

    print(f"DEBUG ({split_label}): updating best model based on sharpe_ratios: {sharpe_ratios}")
    if len(sharpe_ratios) > 0:
        best_model_name = max(sharpe_ratios, key=sharpe_ratios.get)
        if metrics['sharpe_ratio'] > sharpe_ratios[best_model_name]:
            print(
                f"DEBUG ({split_label}): {round(metrics['sharpe_ratio'], 2)} ({model_name})"
                f" > {round(sharpe_ratios[best_model_name], 2)} ({best_model_name})"
                f". New best model: {model_name}."
            )
            log_metrics(metrics, model_name, split_label, best_model=True)
        else:
            print(
                f"DEBUG ({split_label}): {round(metrics['sharpe_ratio'], 2)} ({model_name})"
                f" <= {round(sharpe_ratios[best_model_name], 2)} ({best_model_name})"
                ". Not updating best model."
            )
    else:
        print(f"DEBUG ({split_label}): no models logged yet, new best model is current one: {model_name}")
        wandb.run.config['sharpe_ratios'][split_label][model_name] = metrics['sharpe_ratio']
        log_metrics(metrics, model_name, split_label, best_model=True)

In [None]:
#@title plot results
%matplotlib inline

def plot_results(result, experiment_stats):
    # Plotting section
    fig = plt.figure(figsize=(16, 9))  # Increased figure size for better readability
    method_styles = {
        # 'CQL': {'color': '#1f77b4', 'linestyle': '-'},           # Blue solid
        # 'IQL': {'color': '#ff7f0e', 'linestyle': '--'},          # Orange dashed
        # 'BC': {'color': '#2ca02c', 'linestyle': '-.'},           # Green dash-dot
        # 'DT LoRA GPT2': {'color': '#d62728', 'linestyle': ':'},  # Red dotted
        # 'DT LoRA Random Weight GPT2': {'color': '#9467bd', 'linestyle': '-'},  # Purple solid
        'A2C': {'color': '#8c564b', 'linestyle': '--'},          # Brown dashed
        'DDPG': {'color': '#e377c2', 'linestyle': '-'},          # Pink solid
        'PPO': {'color': '#7f7f7f', 'linestyle': '-'},           # Gray solid
        'TD3': {'color': '#bcbd22', 'linestyle': '--'},          # Olive dashed
        'SAC': {'color': '#17becf', 'linestyle': '-'},           # Cyan solid
        'DJIA': {'color': '#000000', 'linestyle': '-'},          # Black solid
        # Add more methods here if needed
    }
    # Plot DJIA
    plt.plot(result.index, result['DJIA'], label="Dow Jones Index", linestyle=method_styles['DJIA']['linestyle'], color=method_styles['DJIA']['color'])

    # Define color palette and line styles
    color_palette = plt.get_cmap('tab10').colors  # Colorblind-friendly palette
    line_styles = ['-', '--', '-.', ':']  # Different line styles

    # Plot experiment groups
    for idx, (exp_name, stats) in enumerate(experiment_stats.items()):
        min_date = stats['mean'].index.min()
        max_date = stats['mean'].index.max()

        mean = stats['mean']
        std = stats['std']

        # Ensure mean and std are aligned with result.index
        mean = mean.reindex(result.index).fillna(method='ffill')
        std = std.reindex(result.index).fillna(0)

        # Assign colors and line styles
        color = color_palette[idx % len(color_palette)]
        linestyle = line_styles[idx % len(line_styles)]

        def exp_name_formatter(exp_name):
            exp_names = exp_name.split('_')
            if len(exp_names) == 1:
                return exp_name
            elif len(exp_names) == 2:
                return exp_names[1].upper()
            elif len(exp_names) == 3:
                return None
            elif len(exp_names) == 4:
                return exp_names[1].upper() + ' LoRA ' + 'GPT2'
            elif len(exp_names) == 6:
                return exp_names[1].upper() + ' LoRA ' + 'Random Weight ' + 'GPT2'
            else:
                return exp_name

        # Plot mean
        line, = plt.plot(
            result.index,
            mean,
            label=exp_name_formatter(exp_name),
            linestyle=method_styles[exp_name_formatter(exp_name)]['linestyle'],
            color=method_styles[exp_name_formatter(exp_name)]['color']
        )

        # Plot error bands (mean ± 1 std)
        plt.fill_between(
            result.index,
            mean - std,
            mean + std,
            color=method_styles[exp_name_formatter(exp_name)]['color'],
            alpha=0.2
        )

    # Enhance layout and aesthetics
    plt.tight_layout()
    plt.grid(True, linestyle='--', alpha=0.3)

    # Set title and labels with enhanced formatting
    plt.title(f"Performance Comparison of DRL agents", fontsize=20, fontweight='bold')
    plt.xlabel("Date", fontsize=16, fontweight='bold')
    plt.ylabel("Total Asset Value ($)", fontsize=16, fontweight='bold')
    plt.xticks(result.index[0::30])

    # Add 'Test Phase' annotation with date range
    plt.text(0.5, 0.95, f'Test Phase: {min_date.date()} to {max_date.date()}',
    transform=plt.gca().transAxes, fontsize=14, ha='center',
    bbox=dict(boxstyle="round,pad=0.3", facecolor="white", alpha=0.5))

    plt.legend(loc='lower right')
    return fig

# fig = plot_results(result, experiment_stats)
# plt.show()

In [None]:
#@title log_plot_as_artifact

import os
import matplotlib.pyplot as plt
import wandb

def log_plot_as_artifact(fig, artifact_name_prefix, artifact_type="plot"):
    """
    Save a Matplotlib figure without clipping and log it as a W&B artifact.

    Parameters:
        fig (matplotlib.figure.Figure): The Matplotlib figure to save and log.
        artifact_name (str): The name of the W&B artifact.
        artifact_type (str): The type of the artifact (default is "plot").
        filename (str): The filename to save the plot as (default is "plot.png").
    """
    try:
        # Get full artifact name
        artifact_name = f'{artifact_name_prefix}-{wandb.run.id}'
        filename = artifact_name + '.png'

        # Save the figure with tight layout and proper padding
        fig.savefig(filename, bbox_inches='tight', pad_inches=0.1, dpi=300)
        plt.close(fig)  # Close the figure to free up memory

        # Create and log the W&B artifact
        artifact = wandb.Artifact(artifact_name, type=artifact_type)
        artifact.add_file(filename)
        wandb.log_artifact(artifact)
    finally:
        # Ensure the file is deleted after use
        if os.path.exists(filename):
            os.remove(filename)


In [None]:
#@title run_prediction_and_log_metrics

def run_prediction_and_log_metrics(run_id):
    with wandb.init(project=PROJECT, entity=ENTITY, id=run_id, resume='must'):
        config = wandb.run.config

        download_artifacts(run_id)
        train, val, trade = load_data(config)
        trade = fix_daily_index(trade)
        print(f"DEBUG trade end:{trade['date'].max()}")

        e_trade_gym = init_env(trade, config)
        result = get_predictions(e_trade_gym, config)
        result = add_djia_test(trade, result, config)

        metrics_df, metrics_summary_df, experiment_stats = calculate_metrics(result)
        log_test_metrics(metrics_summary_df)

        metrics_all = result_df_to_metrics_dict(metrics_summary_df)
        for model_name, metrics in metrics_all.items():
            update_best_model_metrics(metrics, model_name, 'test')

        fig = plot_results(result, experiment_stats)
        log_plot_as_artifact(fig, artifact_name_prefix="performance_comparison_DRL_agents", artifact_type="plot")

# run_prediction_and_log_metrics(RUN_ID)

In [None]:
SWEEP_ID = 'k43io2uh'
# SWEEP_ID = 'l9jr39py'

In [None]:
from tqdm.notebook import tqdm

##################

api = wandb.Api()
sweep = api.sweep(f'{ENTITY}/{PROJECT}/{SWEEP_ID}')

for run in tqdm(sweep.runs):
    run_prediction_and_log_metrics(run.id)
    pass

True
True
True
True
True
True
True
True
True
True


  0%|          | 0/17 [00:00<?, ?it/s]

True


[34m[1mwandb[0m:   3 of 3 files downloaded.  
[34m[1mwandb[0m:   5 of 5 files downloaded.  
[34m[1mwandb[0m:   5 of 5 files downloaded.  
[34m[1mwandb[0m:   5 of 5 files downloaded.  


DEBUG trade end:2020-03-31
Stock Dimension: 29, State Space: 291
hit end!
hit end!
hit end!
hit end!


[*********************100%***********************]  1 of 1 completed

hit end!
2020-01-02
2020-03-31
Shape of DataFrame:  (61, 8)
{'dataset_name': 'DOW-30 | 2009-01 | 2019 Q4 | 2020 Q1', 'train_start_date': '2009-01-01', 'REFERENCE_PRICE_END_DATE': '2024-12-21', 'cost_pct': [0.010314179718673032, 0.009076606561755468, 0.008307067100246136, 0.015516293581809112, 0.006395481989312903, 0.0072110283262062785, 0.04249566955190291, 0.015979716445954528, 0.021738725278835313, 0.004222453318177087, 0.005968143721526647, 0.010896151366589463, 0.01093735188446354, 0.11420864277075855, 0.01665255165637118, 0.010252194206940846, 0.03947813656400475, 0.008462643895002689, 0.01919729692524921, 0.024719550845968045, 0.005734449144229167, 0.03210832438433615, 0.014379031318205482, 0.00982580254909786, 0.0044498520553265045, 0.007984135151425499, 0.05881990501271046, 0.270946766931955, 0.026832260273111815], 'dataset_type': 'quarterly_train_val_test', 'if_using_sac': True, 'if_using_td3': True, 'min_test_start_date': '2016-01-01', 'if_using_ddpg': True, 'REFERNCE_PRICE_W




0,1
best_model_name,ppo
test/ann_return/a2c,-68.06
test/ann_return/best_model,-68.06
test/ann_return/ddpg,-62.82
test/ann_return/ppo,-30.44
test/ann_return/sac,-75.24
test/ann_return/td3,-61.1
test/best_model_name,a2c
test/cum_return/a2c,-24.05
test/cum_return/best_model,-24.05


True


[34m[1mwandb[0m:   3 of 3 files downloaded.  
[34m[1mwandb[0m:   5 of 5 files downloaded.  
[34m[1mwandb[0m:   5 of 5 files downloaded.  
[34m[1mwandb[0m:   5 of 5 files downloaded.  
[34m[1mwandb[0m:   5 of 5 files downloaded.  
[34m[1mwandb[0m:   5 of 5 files downloaded.  


DEBUG trade end:2019-12-31
Stock Dimension: 29, State Space: 291
hit end!
hit end!
hit end!
hit end!


[*********************100%***********************]  1 of 1 completed

hit end!
2019-10-01
2019-12-31
Shape of DataFrame:  (63, 8)
{'finetune': True, 'stock_index_name': 'DOW-30', 'cost_pct': [0.010314179718673032, 0.009076606561755468, 0.008307067100246136, 0.015516293581809112, 0.006395481989312903, 0.0072110283262062785, 0.04249566955190291, 0.015979716445954528, 0.021738725278835313, 0.004222453318177087, 0.005968143721526647, 0.010896151366589463, 0.01093735188446354, 0.11420864277075855, 0.01665255165637118, 0.010252194206940846, 0.03947813656400475, 0.008462643895002689, 0.01919729692524921, 0.024719550845968045, 0.005734449144229167, 0.03210832438433615, 0.014379031318205482, 0.00982580254909786, 0.0044498520553265045, 0.007984135151425499, 0.05881990501271046, 0.270946766931955, 0.026832260273111815], 'train_start_date': '2009-01-01', 'max_test_end_date': '2020-08-05', 'REFERNCE_PRICE_WINDOW_DAYS': 30, 'date_range': {'test_end_date': '2020-01-01 00:00:00', 'val_start_date': '2019-07-01 00:00:00', 'test_start_date': '2019-10-01 00:00:00', 'train_s




0,1
best_model_name,ppo
test/ann_return/a2c,23.69
test/ann_return/best_model,60.99
test/ann_return/ddpg,41.26
test/ann_return/ppo,60.99
test/ann_return/sac,-23.35
test/ann_return/td3,53.12
test/best_model_name,ppo
test/cum_return/a2c,5.38
test/cum_return/best_model,12.46


True


[34m[1mwandb[0m:   3 of 3 files downloaded.  
[34m[1mwandb[0m:   5 of 5 files downloaded.  
[34m[1mwandb[0m:   5 of 5 files downloaded.  
[34m[1mwandb[0m:   5 of 5 files downloaded.  
[34m[1mwandb[0m:   5 of 5 files downloaded.  
[34m[1mwandb[0m:   5 of 5 files downloaded.  


DEBUG trade end:2019-09-30
Stock Dimension: 29, State Space: 291
hit end!
hit end!
hit end!
hit end!


[*********************100%***********************]  1 of 1 completed

hit end!
2019-07-01
2019-09-30
Shape of DataFrame:  (63, 8)
{'cost_abs': 2.5, 'REFERENCE_PRICE_END_DATE': '2024-12-21', 'if_using_a2c': True, 'train_params': {'ddpg': {'steps': 50000}, 'a2c': {'steps': 50000}, 'ppo': {'steps': 200000}, 'sac': {'steps': 70000}, 'td3': {'steps': 50000}}, 'if_using_ppo': True, 'train_start_date': '2009-01-01', 'min_test_start_date': '2016-01-01', 'REFERNCE_PRICE_WINDOW_DAYS': 30, 'max_test_end_date': '2020-08-05', 'dataset_type': 'quarterly_train_val_test', 'dataset_name': 'DOW-30 | 2009-01 | 2019 Q2 | 2019 Q3', 'cost_pct': [0.010314179718673032, 0.009076606561755468, 0.008307067100246136, 0.015516293581809112, 0.006395481989312903, 0.0072110283262062785, 0.04249566955190291, 0.015979716445954528, 0.021738725278835313, 0.004222453318177087, 0.005968143721526647, 0.010896151366589463, 0.01093735188446354, 0.11420864277075855, 0.01665255165637118, 0.010252194206940846, 0.03947813656400475, 0.008462643895002689, 0.01919729692524921, 0.024719550845968045, 0.0




0,1
best_model_name,sac
test/ann_return/a2c,10.63
test/ann_return/best_model,22.49
test/ann_return/ddpg,2.79
test/ann_return/ppo,-5.3
test/ann_return/sac,22.49
test/ann_return/td3,9.49
test/best_model_name,sac
test/cum_return/a2c,2.46
test/cum_return/best_model,5.01


True


[34m[1mwandb[0m:   3 of 3 files downloaded.  
[34m[1mwandb[0m:   5 of 5 files downloaded.  
[34m[1mwandb[0m:   5 of 5 files downloaded.  
[34m[1mwandb[0m:   5 of 5 files downloaded.  
[34m[1mwandb[0m:   5 of 5 files downloaded.  
[34m[1mwandb[0m:   5 of 5 files downloaded.  


DEBUG trade end:2019-06-28
Stock Dimension: 29, State Space: 291
hit end!
hit end!
hit end!
hit end!


[*********************100%***********************]  1 of 1 completed

hit end!
2019-04-01
2019-06-28
Shape of DataFrame:  (62, 8)
{'max_test_end_date': '2020-08-05', 'date_range': {'test_end_date': '2019-07-01 00:00:00', 'val_start_date': '2019-01-01 00:00:00', 'test_start_date': '2019-04-01 00:00:00', 'train_start_date': '2009-01-01 00:00:00'}, 'stock_index_name': 'DOW-30', 'REFERNCE_PRICE_WINDOW_DAYS': 30, 'cost_pct': [0.010314179718673032, 0.009076606561755468, 0.008307067100246136, 0.015516293581809112, 0.006395481989312903, 0.0072110283262062785, 0.04249566955190291, 0.015979716445954528, 0.021738725278835313, 0.004222453318177087, 0.005968143721526647, 0.010896151366589463, 0.01093735188446354, 0.11420864277075855, 0.01665255165637118, 0.010252194206940846, 0.03947813656400475, 0.008462643895002689, 0.01919729692524921, 0.024719550845968045, 0.005734449144229167, 0.03210832438433615, 0.014379031318205482, 0.00982580254909786, 0.0044498520553265045, 0.007984135151425499, 0.05881990501271046, 0.270946766931955, 0.026832260273111815], 'initial_amount'




0,1
best_model_name,ddpg
test/ann_return/a2c,11.98
test/ann_return/best_model,17.96
test/ann_return/ddpg,17.96
test/ann_return/ppo,11.98
test/ann_return/sac,3.58
test/ann_return/td3,8.06
test/best_model_name,ddpg
test/cum_return/a2c,2.73
test/cum_return/best_model,4.01


True


[34m[1mwandb[0m:   3 of 3 files downloaded.  
[34m[1mwandb[0m:   5 of 5 files downloaded.  
[34m[1mwandb[0m:   5 of 5 files downloaded.  
[34m[1mwandb[0m:   5 of 5 files downloaded.  
[34m[1mwandb[0m:   5 of 5 files downloaded.  
[34m[1mwandb[0m:   5 of 5 files downloaded.  


DEBUG trade end:2019-03-29
Stock Dimension: 29, State Space: 291
hit end!
hit end!
hit end!
hit end!


[*********************100%***********************]  1 of 1 completed

hit end!
2019-01-02
2019-03-29
Shape of DataFrame:  (60, 8)
{'cost_pct': [0.010314179718673032, 0.009076606561755468, 0.008307067100246136, 0.015516293581809112, 0.006395481989312903, 0.0072110283262062785, 0.04249566955190291, 0.015979716445954528, 0.021738725278835313, 0.004222453318177087, 0.005968143721526647, 0.010896151366589463, 0.01093735188446354, 0.11420864277075855, 0.01665255165637118, 0.010252194206940846, 0.03947813656400475, 0.008462643895002689, 0.01919729692524921, 0.024719550845968045, 0.005734449144229167, 0.03210832438433615, 0.014379031318205482, 0.00982580254909786, 0.0044498520553265045, 0.007984135151425499, 0.05881990501271046, 0.270946766931955, 0.026832260273111815], 'if_using_sac': True, 'if_using_ddpg': True, 'train_start_date': '2009-01-01', 'if_using_td3': True, 'max_test_end_date': '2020-08-05', 'REFERNCE_PRICE_WINDOW_DAYS': 30, 'date_range': {'test_end_date': '2019-04-01 00:00:00', 'val_start_date': '2018-10-01 00:00:00', 'test_start_date': '2019-01-01 




0,1
best_model_name,ddpg
test/ann_return/a2c,27.92
test/ann_return/best_model,58.62
test/ann_return/ddpg,58.62
test/ann_return/ppo,38.34
test/ann_return/sac,41.45
test/ann_return/td3,27.63
test/best_model_name,ddpg
test/cum_return/a2c,5.9
test/cum_return/best_model,11.34


True


[34m[1mwandb[0m:   3 of 3 files downloaded.  
[34m[1mwandb[0m:   5 of 5 files downloaded.  
[34m[1mwandb[0m:   5 of 5 files downloaded.  
[34m[1mwandb[0m:   5 of 5 files downloaded.  
[34m[1mwandb[0m:   5 of 5 files downloaded.  
[34m[1mwandb[0m:   5 of 5 files downloaded.  


DEBUG trade end:2018-12-31
Stock Dimension: 29, State Space: 291
hit end!
hit end!
hit end!
hit end!


[*********************100%***********************]  1 of 1 completed

hit end!
2018-10-01
2018-12-31
Shape of DataFrame:  (62, 8)
{'REFERNCE_PRICE_WINDOW_DAYS': 30, 'if_using_a2c': True, 'min_test_start_date': '2016-01-01', 'stock_index_name': 'DOW-30', 'if_using_ddpg': True, 'dataset_name': 'DOW-30 | 2009-01 | 2018 Q3 | 2018 Q4', 'train_start_date': '2009-01-01', 'REFERENCE_PRICE_END_DATE': '2024-12-21', 'cost_pct': [0.010314179718673032, 0.009076606561755468, 0.008307067100246136, 0.015516293581809112, 0.006395481989312903, 0.0072110283262062785, 0.04249566955190291, 0.015979716445954528, 0.021738725278835313, 0.004222453318177087, 0.005968143721526647, 0.010896151366589463, 0.01093735188446354, 0.11420864277075855, 0.01665255165637118, 0.010252194206940846, 0.03947813656400475, 0.008462643895002689, 0.01919729692524921, 0.024719550845968045, 0.005734449144229167, 0.03210832438433615, 0.014379031318205482, 0.00982580254909786, 0.0044498520553265045, 0.007984135151425499, 0.05881990501271046, 0.270946766931955, 0.026832260273111815], 'if_using_td3': Tru




0,1
best_model_name,ppo
test/ann_return/a2c,-50.36
test/ann_return/best_model,-67.15
test/ann_return/ddpg,-44.06
test/ann_return/ppo,-33.73
test/ann_return/sac,-67.15
test/ann_return/td3,-37.03
test/best_model_name,sac
test/cum_return/a2c,-15.54
test/cum_return/best_model,-23.54


True


[34m[1mwandb[0m:   3 of 3 files downloaded.  
[34m[1mwandb[0m:   5 of 5 files downloaded.  
[34m[1mwandb[0m:   5 of 5 files downloaded.  
[34m[1mwandb[0m:   5 of 5 files downloaded.  
[34m[1mwandb[0m:   5 of 5 files downloaded.  
[34m[1mwandb[0m:   5 of 5 files downloaded.  


DEBUG trade end:2018-09-28
Stock Dimension: 29, State Space: 291
hit end!
hit end!
hit end!
hit end!


[*********************100%***********************]  1 of 1 completed

hit end!
2018-07-02
2018-09-28
Shape of DataFrame:  (62, 8)
{'REFERENCE_PRICE_END_DATE': '2024-12-21', 'REFERNCE_PRICE_WINDOW_DAYS': 30, 'max_test_end_date': '2020-08-05', 'min_test_start_date': '2016-01-01', 'train_start_date': '2009-01-01', 'if_using_sac': True, 'if_using_ppo': True, 'train_params': {'a2c': {'steps': 50000}, 'ppo': {'steps': 200000}, 'sac': {'steps': 70000}, 'td3': {'steps': 50000}, 'ddpg': {'steps': 50000}}, 'if_using_ddpg': True, 'finetune': True, 'date_range': {'test_end_date': '2018-10-01 00:00:00', 'val_start_date': '2018-04-01 00:00:00', 'test_start_date': '2018-07-01 00:00:00', 'train_start_date': '2009-01-01 00:00:00'}, 'cost_pct': [0.010314179718673032, 0.009076606561755468, 0.008307067100246136, 0.015516293581809112, 0.006395481989312903, 0.0072110283262062785, 0.04249566955190291, 0.015979716445954528, 0.021738725278835313, 0.004222453318177087, 0.005968143721526647, 0.010896151366589463, 0.01093735188446354, 0.11420864277075855, 0.01665255165637118, 0.010




0,1
best_model_name,ppo
test/ann_return/a2c,43.77
test/ann_return/best_model,55.87
test/ann_return/ddpg,33.81
test/ann_return/ppo,55.87
test/ann_return/sac,36.07
test/ann_return/td3,47.46
test/best_model_name,ppo
test/cum_return/a2c,9.04
test/cum_return/best_model,11.16


True


[34m[1mwandb[0m:   3 of 3 files downloaded.  
[34m[1mwandb[0m:   5 of 5 files downloaded.  
[34m[1mwandb[0m:   5 of 5 files downloaded.  
[34m[1mwandb[0m:   5 of 5 files downloaded.  
[34m[1mwandb[0m:   5 of 5 files downloaded.  
[34m[1mwandb[0m:   5 of 5 files downloaded.  


DEBUG trade end:2018-06-29
Stock Dimension: 29, State Space: 291
hit end!
hit end!
hit end!
hit end!


[*********************100%***********************]  1 of 1 completed

hit end!
2018-04-02
2018-06-29
Shape of DataFrame:  (63, 8)
{'REFERNCE_PRICE_WINDOW_DAYS': 30, 'initial_amount': 50000, 'REFERENCE_PRICE_END_DATE': '2024-12-21', 'cost_abs': 2.5, 'max_test_end_date': '2020-08-05', 'finetune': True, 'date_range': {'test_end_date': '2018-07-01 00:00:00', 'val_start_date': '2018-01-01 00:00:00', 'test_start_date': '2018-04-01 00:00:00', 'train_start_date': '2009-01-01 00:00:00'}, 'if_using_ddpg': True, 'stock_index_name': 'DOW-30', 'cost_pct': [0.010314179718673032, 0.009076606561755468, 0.008307067100246136, 0.015516293581809112, 0.006395481989312903, 0.0072110283262062785, 0.04249566955190291, 0.015979716445954528, 0.021738725278835313, 0.004222453318177087, 0.005968143721526647, 0.010896151366589463, 0.01093735188446354, 0.11420864277075855, 0.01665255165637118, 0.010252194206940846, 0.03947813656400475, 0.008462643895002689, 0.01919729692524921, 0.024719550845968045, 0.005734449144229167, 0.03210832438433615, 0.014379031318205482, 0.00982580254909786,




0,1
best_model_name,ppo
test/ann_return/a2c,8.97
test/ann_return/best_model,13.71
test/ann_return/ddpg,-15.69
test/ann_return/ppo,13.71
test/ann_return/sac,-25.97
test/ann_return/td3,12.51
test/best_model_name,ppo
test/cum_return/a2c,2.07
test/cum_return/best_model,3.11


True


[34m[1mwandb[0m:   3 of 3 files downloaded.  
[34m[1mwandb[0m:   1 of 1 files downloaded.  
[34m[1mwandb[0m:   2 of 2 files downloaded.  
[34m[1mwandb[0m:   3 of 3 files downloaded.  
[34m[1mwandb[0m:   4 of 4 files downloaded.  
[34m[1mwandb[0m:   5 of 5 files downloaded.  


DEBUG trade end:2018-03-29
Stock Dimension: 29, State Space: 291
hit end!
hit end!
hit end!
hit end!


[*********************100%***********************]  1 of 1 completed

hit end!
2018-01-02
2018-03-29
Shape of DataFrame:  (60, 8)
{'finetune': False, 'date_range': {'test_end_date': '2018-04-01 00:00:00', 'val_start_date': '2017-10-01 00:00:00', 'test_start_date': '2018-01-01 00:00:00', 'train_start_date': '2009-01-01 00:00:00'}, 'min_test_start_date': '2016-01-01', 'dataset_type': 'quarterly_train_val_test', 'if_using_a2c': True, 'max_test_end_date': '2020-08-05', 'if_using_ppo': True, 'REFERNCE_PRICE_WINDOW_DAYS': 30, 'REFERENCE_PRICE_END_DATE': '2024-12-21', 'dataset_name': 'DOW-30 | 2009-01 | 2017 Q4 | 2018 Q1', 'train_params': {'a2c': {'steps': 50000}, 'ppo': {'steps': 200000}, 'sac': {'steps': 70000}, 'td3': {'steps': 50000}, 'ddpg': {'steps': 50000}}, 'if_using_sac': True, 'if_using_ddpg': True, 'cost_abs': 2.5, 'stock_index_name': 'DOW-30', 'initial_amount': 50000, 'train_start_date': '2009-01-01', 'if_using_td3': True, 'cost_pct': [0.010314179718673032, 0.009076606561755468, 0.008307067100246136, 0.015516293581809112, 0.006395481989312903, 0.007




0,1
best_model_name,td3
test/ann_return/a2c,-10.87
test/ann_return/best_model,-23.72
test/ann_return/ddpg,-23.72
test/ann_return/ppo,-22.05
test/ann_return/sac,-15.79
test/ann_return/td3,-4.05
test/best_model_name,ddpg
test/cum_return/a2c,-2.64
test/cum_return/best_model,-6.11


True


[34m[1mwandb[0m:   3 of 3 files downloaded.  
[34m[1mwandb[0m:   1 of 1 files downloaded.  
[34m[1mwandb[0m:   2 of 2 files downloaded.  
[34m[1mwandb[0m:   3 of 3 files downloaded.  
[34m[1mwandb[0m:   4 of 4 files downloaded.  
[34m[1mwandb[0m:   5 of 5 files downloaded.  


DEBUG trade end:2017-12-29
Stock Dimension: 29, State Space: 291
hit end!
hit end!
hit end!
hit end!
hit end!
2017-10-02
2017-12-29


[*********************100%***********************]  1 of 1 completed


Shape of DataFrame:  (62, 8)
{'train_params': {'td3': {'steps': 50000}, 'ddpg': {'steps': 50000}, 'a2c': {'steps': 50000}, 'ppo': {'steps': 200000}, 'sac': {'steps': 70000}}, 'if_using_sac': True, 'finetune': False, 'date_range': {'test_end_date': '2018-01-01 00:00:00', 'val_start_date': '2017-07-01 00:00:00', 'test_start_date': '2017-10-01 00:00:00', 'train_start_date': '2009-01-01 00:00:00'}, 'max_test_end_date': '2020-08-05', 'cost_abs': 2.5, 'if_using_ppo': True, 'if_using_ddpg': True, 'dataset_type': 'quarterly_train_val_test', 'REFERENCE_PRICE_END_DATE': '2024-12-21', 'REFERNCE_PRICE_WINDOW_DAYS': 30, 'train_start_date': '2009-01-01', 'min_test_start_date': '2016-01-01', 'if_using_td3': True, 'cost_pct': [0.010314179718673032, 0.009076606561755468, 0.008307067100246136, 0.015516293581809112, 0.006395481989312903, 0.0072110283262062785, 0.04249566955190291, 0.015979716445954528, 0.021738725278835313, 0.004222453318177087, 0.005968143721526647, 0.010896151366589463, 0.0109373518844

0,1
best_model_name,sac
test/ann_return/a2c,27.08
test/ann_return/best_model,39.31
test/ann_return/ddpg,23.05
test/ann_return/ppo,14.03
test/ann_return/sac,39.31
test/ann_return/td3,33.54
test/best_model_name,sac
test/cum_return/a2c,5.88
test/cum_return/best_model,8.22


True
True
True
True
True
True
True


True


[34m[1mwandb[0m:   3 of 3 files downloaded.  
[34m[1mwandb[0m:   1 of 1 files downloaded.  
[34m[1mwandb[0m:   5 of 5 files downloaded.  


DEBUG trade end:2017-09-29
Stock Dimension: 29, State Space: 291
hit end!
hit end!
hit end!
hit end!


[*********************100%***********************]  1 of 1 completed

hit end!
2017-07-03
2017-09-29
Shape of DataFrame:  (62, 8)
{'max_test_end_date': '2020-08-05', 'if_using_a2c': True, 'cost_pct': [0.010314179718673032, 0.009076606561755468, 0.008307067100246136, 0.015516293581809112, 0.006395481989312903, 0.0072110283262062785, 0.04249566955190291, 0.015979716445954528, 0.021738725278835313, 0.004222453318177087, 0.005968143721526647, 0.010896151366589463, 0.01093735188446354, 0.11420864277075855, 0.01665255165637118, 0.010252194206940846, 0.03947813656400475, 0.008462643895002689, 0.01919729692524921, 0.024719550845968045, 0.005734449144229167, 0.03210832438433615, 0.014379031318205482, 0.00982580254909786, 0.0044498520553265045, 0.007984135151425499, 0.05881990501271046, 0.270946766931955, 0.026832260273111815], 'if_using_ddpg': True, 'if_using_ppo': True, 'dataset_type': 'quarterly_train_val_test', 'finetune': True, 'REFERENCE_PRICE_END_DATE': '2024-12-21', 'cost_abs': 2.5, 'if_using_sac': True, 'dataset_name': 'DOW-30 | 2009-01 | 2017 Q2 | 2017 Q




0,1
best_model_name,ddpg
test/ann_return/a2c,15.42
test/ann_return/best_model,71.61
test/ann_return/ddpg,71.61
test/ann_return/ppo,16.7
test/ann_return/sac,12.02
test/ann_return/td3,-14.36
test/best_model_name,ddpg
test/cum_return/a2c,3.48
test/cum_return/best_model,13.74


True


[34m[1mwandb[0m:   3 of 3 files downloaded.  
[34m[1mwandb[0m:   5 of 5 files downloaded.  
[34m[1mwandb[0m:   5 of 5 files downloaded.  
[34m[1mwandb[0m:   5 of 5 files downloaded.  
[34m[1mwandb[0m:   5 of 5 files downloaded.  
[34m[1mwandb[0m:   5 of 5 files downloaded.  
[34m[1mwandb[0m:   1 of 1 files downloaded.  


DEBUG trade end:2017-06-30
Stock Dimension: 29, State Space: 291
hit end!
hit end!
hit end!
hit end!


[*********************100%***********************]  1 of 1 completed

hit end!
2017-04-03
2017-06-30
Shape of DataFrame:  (62, 8)
{'cost_abs': 2.5, 'dataset_type': 'quarterly_train_val_test', 'if_using_ppo': True, 'dataset_name': 'DOW-30 | 2009-01 | 2017 Q1 | 2017 Q2', 'initial_amount': 50000, 'REFERNCE_PRICE_WINDOW_DAYS': 30, 'if_using_td3': True, 'if_using_ddpg': True, 'min_test_start_date': '2016-01-01', 'date_range': {'test_start_date': '2017-04-01 00:00:00', 'train_start_date': '2009-01-01 00:00:00', 'test_end_date': '2017-07-01 00:00:00', 'val_start_date': '2017-01-01 00:00:00'}, 'stock_index_name': 'DOW-30', 'train_start_date': '2009-01-01', 'if_using_sac': True, 'REFERENCE_PRICE_END_DATE': '2024-12-21', 'finetune': True, 'train_params': {'ddpg': {'steps': 50000}, 'a2c': {'steps': 50000}, 'ppo': {'steps': 200000}, 'sac': {'steps': 70000}, 'td3': {'steps': 50000}}, 'if_using_a2c': True, 'max_test_end_date': '2020-08-05', 'cost_pct': [0.010314179718673032, 0.009076606561755468, 0.008307067100246136, 0.015516293581809112, 0.006395481989312903, 0.0072




0,1
best_model_name,ppo
test/ann_return/a2c,31.24
test/ann_return/best_model,36.91
test/ann_return/ddpg,25.86
test/ann_return/ppo,36.91
test/ann_return/sac,4.88
test/ann_return/td3,-30.21
test/best_model_name,ppo
test/cum_return/a2c,6.7
test/cum_return/best_model,7.78


True


[34m[1mwandb[0m:   3 of 3 files downloaded.  
[34m[1mwandb[0m:   5 of 5 files downloaded.  
[34m[1mwandb[0m:   5 of 5 files downloaded.  
[34m[1mwandb[0m:   5 of 5 files downloaded.  
[34m[1mwandb[0m:   5 of 5 files downloaded.  
[34m[1mwandb[0m:   5 of 5 files downloaded.  
[34m[1mwandb[0m:   1 of 1 files downloaded.  


DEBUG trade end:2017-03-31
Stock Dimension: 29, State Space: 291
hit end!
hit end!
hit end!
hit end!


[*********************100%***********************]  1 of 1 completed

hit end!
2017-01-03
2017-03-31
Shape of DataFrame:  (61, 8)
{'min_test_start_date': '2016-01-01', 'REFERNCE_PRICE_WINDOW_DAYS': 30, 'cost_abs': 2.5, 'train_params': {'a2c': {'steps': 50000}, 'ppo': {'steps': 200000}, 'sac': {'steps': 70000}, 'td3': {'steps': 50000}, 'ddpg': {'steps': 50000}}, 'if_using_ppo': True, 'stock_index_name': 'DOW-30', 'cost_pct': [0.010314179718673032, 0.009076606561755468, 0.008307067100246136, 0.015516293581809112, 0.006395481989312903, 0.0072110283262062785, 0.04249566955190291, 0.015979716445954528, 0.021738725278835313, 0.004222453318177087, 0.005968143721526647, 0.010896151366589463, 0.01093735188446354, 0.11420864277075855, 0.01665255165637118, 0.010252194206940846, 0.03947813656400475, 0.008462643895002689, 0.01919729692524921, 0.024719550845968045, 0.005734449144229167, 0.03210832438433615, 0.014379031318205482, 0.00982580254909786, 0.0044498520553265045, 0.007984135151425499, 0.05881990501271046, 0.270946766931955, 0.026832260273111815], 'if_using_dd




0,1
best_model_name,ppo
test/ann_return/a2c,14.29
test/ann_return/best_model,38.41
test/ann_return/ddpg,30.33
test/ann_return/ppo,38.41
test/ann_return/sac,3.37
test/ann_return/td3,-2.16
test/best_model_name,ppo
test/cum_return/a2c,3.2
test/cum_return/best_model,7.96


True


[34m[1mwandb[0m:   3 of 3 files downloaded.  
[34m[1mwandb[0m:   5 of 5 files downloaded.  
[34m[1mwandb[0m:   5 of 5 files downloaded.  
[34m[1mwandb[0m:   5 of 5 files downloaded.  
[34m[1mwandb[0m:   5 of 5 files downloaded.  
[34m[1mwandb[0m:   5 of 5 files downloaded.  
[34m[1mwandb[0m:   1 of 1 files downloaded.  


DEBUG trade end:2016-12-30
Stock Dimension: 29, State Space: 291
hit end!
hit end!
hit end!
hit end!


[*********************100%***********************]  1 of 1 completed

hit end!
2016-10-03
2016-12-30
Shape of DataFrame:  (62, 8)
{'REFERNCE_PRICE_WINDOW_DAYS': 30, 'finetune': True, 'if_using_ddpg': True, 'cost_abs': 2.5, 'min_test_start_date': '2016-01-01', 'cost_pct': [0.010314179718673032, 0.009076606561755468, 0.008307067100246136, 0.015516293581809112, 0.006395481989312903, 0.0072110283262062785, 0.04249566955190291, 0.015979716445954528, 0.021738725278835313, 0.004222453318177087, 0.005968143721526647, 0.010896151366589463, 0.01093735188446354, 0.11420864277075855, 0.01665255165637118, 0.010252194206940846, 0.03947813656400475, 0.008462643895002689, 0.01919729692524921, 0.024719550845968045, 0.005734449144229167, 0.03210832438433615, 0.014379031318205482, 0.00982580254909786, 0.0044498520553265045, 0.007984135151425499, 0.05881990501271046, 0.270946766931955, 0.026832260273111815], 'dataset_name': 'DOW-30 | 2009-01 | 2016 Q3 | 2016 Q4', 'if_using_ppo': True, 'train_start_date': '2009-01-01', 'if_using_a2c': True, 'initial_amount': 50000, 'max_test




0,1
best_model_name,sac
test/ann_return/a2c,5.86
test/ann_return/best_model,57.72
test/ann_return/ddpg,28.89
test/ann_return/ppo,31.57
test/ann_return/sac,57.72
test/ann_return/td3,24.12
test/best_model_name,sac
test/cum_return/a2c,1.37
test/cum_return/best_model,11.47


True


[34m[1mwandb[0m:   3 of 3 files downloaded.  
[34m[1mwandb[0m:   5 of 5 files downloaded.  
[34m[1mwandb[0m:   5 of 5 files downloaded.  
[34m[1mwandb[0m:   5 of 5 files downloaded.  
[34m[1mwandb[0m:   5 of 5 files downloaded.  
[34m[1mwandb[0m:   5 of 5 files downloaded.  
[34m[1mwandb[0m:   1 of 1 files downloaded.  


DEBUG trade end:2016-09-30
Stock Dimension: 29, State Space: 291
hit end!
hit end!
hit end!
hit end!


[*********************100%***********************]  1 of 1 completed

hit end!
2016-07-01
2016-09-30
Shape of DataFrame:  (63, 8)
{'dataset_name': 'DOW-30 | 2009-01 | 2016 Q2 | 2016 Q3', 'cost_pct': [0.010314179718673032, 0.009076606561755468, 0.008307067100246136, 0.015516293581809112, 0.006395481989312903, 0.0072110283262062785, 0.04249566955190291, 0.015979716445954528, 0.021738725278835313, 0.004222453318177087, 0.005968143721526647, 0.010896151366589463, 0.01093735188446354, 0.11420864277075855, 0.01665255165637118, 0.010252194206940846, 0.03947813656400475, 0.008462643895002689, 0.01919729692524921, 0.024719550845968045, 0.005734449144229167, 0.03210832438433615, 0.014379031318205482, 0.00982580254909786, 0.0044498520553265045, 0.007984135151425499, 0.05881990501271046, 0.270946766931955, 0.026832260273111815], 'stock_index_name': 'DOW-30', 'if_using_td3': True, 'REFERENCE_PRICE_END_DATE': '2024-12-21', 'cost_abs': 2.5, 'finetune': True, 'date_range': {'val_start_date': '2016-04-01 00:00:00', 'test_start_date': '2016-07-01 00:00:00', 'train_start_d




0,1
best_model_name,sac
test/ann_return/a2c,4.11
test/ann_return/best_model,4.84
test/ann_return/ddpg,-3.25
test/ann_return/ppo,-12.25
test/ann_return/sac,4.84
test/ann_return/td3,-23.11
test/best_model_name,sac
test/cum_return/a2c,1.0
test/cum_return/best_model,1.17


True


[34m[1mwandb[0m:   3 of 3 files downloaded.  
[34m[1mwandb[0m:   5 of 5 files downloaded.  
[34m[1mwandb[0m:   5 of 5 files downloaded.  
[34m[1mwandb[0m:   5 of 5 files downloaded.  
[34m[1mwandb[0m:   5 of 5 files downloaded.  
[34m[1mwandb[0m:   5 of 5 files downloaded.  
[34m[1mwandb[0m:   1 of 1 files downloaded.  


DEBUG trade end:2016-06-30
Stock Dimension: 29, State Space: 291
hit end!
hit end!
hit end!
hit end!


[*********************100%***********************]  1 of 1 completed

hit end!
2016-04-01
2016-06-30
Shape of DataFrame:  (63, 8)
{'if_using_ddpg': True, 'cost_pct': [0.010314179718673032, 0.009076606561755468, 0.008307067100246136, 0.015516293581809112, 0.006395481989312903, 0.0072110283262062785, 0.04249566955190291, 0.015979716445954528, 0.021738725278835313, 0.004222453318177087, 0.005968143721526647, 0.010896151366589463, 0.01093735188446354, 0.11420864277075855, 0.01665255165637118, 0.010252194206940846, 0.03947813656400475, 0.008462643895002689, 0.01919729692524921, 0.024719550845968045, 0.005734449144229167, 0.03210832438433615, 0.014379031318205482, 0.00982580254909786, 0.0044498520553265045, 0.007984135151425499, 0.05881990501271046, 0.270946766931955, 0.026832260273111815], 'initial_amount': 50000, 'dataset_name': 'DOW-30 | 2009-01 | 2016 Q1 | 2016 Q2', 'dataset_type': 'quarterly_train_val_test', 'max_test_end_date': '2020-08-05', 'min_test_start_date': '2016-01-01', 'if_using_a2c': True, 'cost_abs': 2.5, 'finetune': True, 'date_range': {'test




0,1
best_model_name,ppo
test/ann_return/a2c,-7.25
test/ann_return/best_model,-13.69
test/ann_return/ddpg,-2.82
test/ann_return/ppo,-1.4
test/ann_return/sac,-13.69
test/ann_return/td3,-12.49
test/best_model_name,sac
test/cum_return/a2c,-1.82
test/cum_return/best_model,-3.53


True


[34m[1mwandb[0m:   3 of 3 files downloaded.  
[34m[1mwandb[0m:   1 of 1 files downloaded.  
[34m[1mwandb[0m:   2 of 2 files downloaded.  
[34m[1mwandb[0m:   3 of 3 files downloaded.  
[34m[1mwandb[0m:   4 of 4 files downloaded.  
[34m[1mwandb[0m:   5 of 5 files downloaded.  
[34m[1mwandb[0m:   1 of 1 files downloaded.  


DEBUG trade end:2016-03-31
Stock Dimension: 29, State Space: 291
hit end!
hit end!
hit end!
hit end!


[*********************100%***********************]  1 of 1 completed

hit end!
2016-01-04
2016-03-31
Shape of DataFrame:  (60, 8)
{'cost_pct': [0.010314179718673032, 0.009076606561755468, 0.008307067100246136, 0.015516293581809112, 0.006395481989312903, 0.0072110283262062785, 0.04249566955190291, 0.015979716445954528, 0.021738725278835313, 0.004222453318177087, 0.005968143721526647, 0.010896151366589463, 0.01093735188446354, 0.11420864277075855, 0.01665255165637118, 0.010252194206940846, 0.03947813656400475, 0.008462643895002689, 0.01919729692524921, 0.024719550845968045, 0.005734449144229167, 0.03210832438433615, 0.014379031318205482, 0.00982580254909786, 0.0044498520553265045, 0.007984135151425499, 0.05881990501271046, 0.270946766931955, 0.026832260273111815], 'dataset_type': 'quarterly_train_val_test', 'initial_amount': 50000, 'train_params': {'a2c': {'steps': 50000}, 'ppo': {'steps': 200000}, 'sac': {'steps': 70000}, 'td3': {'steps': 50000}, 'ddpg': {'steps': 50000}}, 'if_using_ddpg': True, 'dataset_name': 'DOW-30 | 2009-01 | 2015 Q4 | 2016 Q1', 'if_




0,1
best_model_name,a2c
test/ann_return/a2c,20.46
test/ann_return/best_model,20.46
test/ann_return/ddpg,-16.13
test/ann_return/ppo,-0.68
test/ann_return/sac,3.3
test/ann_return/td3,-7.96
test/best_model_name,a2c
test/cum_return/a2c,4.48
test/cum_return/best_model,4.48


# Fix metrics

## Compute cumulative metrics

In [None]:
SWEEP_ID = 'l9jr39py'

api = wandb.Api()
sweep = api.sweep(f'{ENTITY}/{PROJECT}/{SWEEP_ID}')

In [None]:
for run in sweep.runs:
    run = api.run(f'{ENTITY}/{PROJECT}/{run.id}')
    break
    # with wandb.init(project=PROJECT, entity=ENTITY, id=run_id, resume='must'):
        # config = wandb.run.config



True


In [None]:
for each run in sweep.runs:
    # for each model in models:
        # update_cumulative_metrics()

In [None]:
#@title run_prediction_and_log_metrics (cumulative)

class CumulativeMetricsLogger():
    def __init__(self, sweep_id):
        # TODO: assert every run in sweep has same models enabled?
        self.cum_result = pd.DataFrame() # start with empty results

    def run_prediction_and_log_metrics(self, run_id):
        with wandb.init(project=PROJECT, entity=ENTITY, id=run_id, resume='must'):
            config = wandb.run.config

            download_artifacts(run_id)
            train, val, trade = load_data(config)
            trade = fix_daily_index(trade)
            print(f"DEBUG trade end:{trade['date'].max()}")

            e_trade_gym = init_env(trade, config)
            result = get_predictions(e_trade_gym, config)
            result = add_djia_test(trade, result, config)

            # Calculate and log normal metrics
            metrics_df, metrics_summary_df, experiment_stats = calculate_metrics(result)
            log_test_metrics(metrics_summary_df)
            metrics_all = result_df_to_metrics_dict(metrics_summary_df)
            for model_name, metrics in metrics_all.items():
                update_best_model_metrics(metrics, model_name, 'test')

            # Calculate and log cumulative metrics
            self.cum_result = pd.concat([self.cum_result, result])
            metrics_df, metrics_summary_df, experiment_stats = calculate_metrics(self.cum_result)
            log_test_metrics(metrics_summary_df)
            metrics_all = result_df_to_metrics_dict(metrics_summary_df)
            for model_name, metrics in metrics_all.items():
                update_best_model_metrics(metrics, model_name, 'test')

            fig = plot_results(result, experiment_stats)
            log_plot_as_artifact(fig, artifact_name_prefix="performance_comparison_DRL_agents", artifact_type="plot")

In [None]:
from tqdm.notebook import tqdm

##################

api = wandb.Api()
sweep = api.sweep(f'{ENTITY}/{PROJECT}/{SWEEP_ID}')

for run in tqdm(sweep.runs):
    run_prediction_and_log_metrics(run.id)
    pass

## Un-annualize Sharpe Ratio

In [None]:
#@title unannualize_sharpe_ratio

import numpy as np

def unannualize_sharpe_ratio(annual_sharpe_ratio, days_in_period, days_in_year=252):
    return annual_sharpe_ratio / np.sqrt(days_in_year / days_in_period)

# Example usage
annual_sharpe = 2.5
days_in_period = 63  # e.g., for a quarter
un_annualized_sharpe = unannualize_sharpe_ratio(annual_sharpe, days_in_period)
print(f"Un-annualized Sharpe Ratio: {un_annualized_sharpe}")

Un-annualized Sharpe Ratio: 1.25


In [None]:
#@title fix_sharpe_ratios

def log_nested_metrics(run, key, new_value):
    # Fetch existing metrics for the split
    old_value = run.summary.get(key, {})

    # Merge with new metrics
    updated_metrics = {**old_value, **new_value}

    # Log the merged metrics
    wandb.log({key: updated_metrics})

def fix_sharpe_ratios(run):
    config = run.config
    date_range = config['date_range']
    days_per_split = {}

    # Calculate days for the train split
    days_per_split['train'] = (
        pd.Timestamp(date_range['val_start_date']) - pd.Timestamp(date_range['train_start_date'])
    ).days

    # Calculate days for the validation split
    days_per_split['val'] = (
        pd.Timestamp(date_range['test_start_date']) - pd.Timestamp(date_range['val_start_date'])
    ).days

    # Calculate days for the test split
    days_per_split['test'] = (
        pd.Timestamp(date_range['test_end_date']) - pd.Timestamp(date_range['test_start_date'])
    ).days

    # # Add dummy data

    # days_per_split['dummy'] = (
    #     pd.Timestamp(date_range['test_end_date']) - pd.Timestamp(date_range['test_start_date'])
    # ).days

    # wandb.log({
    #     'dummy': {
    #         'sharpe_ratio/best_model': -1.135835450802152,
    #         'sharpe_ratio/a2c': -1.135835450802152,
    #     }
    # })

    # print(days_per_split)

    for split_label, days_in_split in days_per_split.items():
        if split_label in run.summary.keys():
            sharpe_ratios = {key: value for key, value in run.summary._as_dict()[split_label].items() if key.startswith(f'sharpe_ratio/')}

            # print()
            # print(split_label)
            # print(sharpe_ratios)
            # display(run.summary._as_dict()[split_label])

            for ann_key, sharpe_ratio in sharpe_ratios.items():
                unann_sharpe_ratio = unannualize_sharpe_ratio(sharpe_ratio, days_in_split)
                unann_key = f"unann_{ann_key}"

                # Log the updated metrics to W&B
                log_nested_metrics(run, split_label, {unann_key: unann_sharpe_ratio})
        else:
            print(f"Skipping '{split_label}' split")

In [None]:
#@title fix_sharpe_ratios_all (per RUN)

# RUN_ID = 'zyokpqqt'

# # if wandb.run is not None and RUN_ID != wandb.run.id:
#     # wandb.finish()

# wandb.finish()

# run = wandb.init(id=RUN_ID, project=PROJECT, entity=ENTITY, resume="must")

# fix_sharpe_ratios(run)

In [None]:
#@title fix_sharpe_ratios_all (per SWEEP)

def fix_sharpe_ratios_all(sweep_id):
    api = wandb.Api()
    sweep = api.sweep(f"{ENTITY}/{PROJECT}/{sweep_id}")

    for run in sweep.runs:
        with wandb.init(id=run.id, project=PROJECT, entity=ENTITY, resume="must") as run:
            fix_sharpe_ratios(run)

SWEEP_ID = 'k43io2uh'
fix_sharpe_ratios_all(SWEEP_ID)

True
True
True
True
True
True
True
True
True
True


Skipping 'dummy' split


0,1
best_model_name,ppo
test/ann_return/a2c,-68.06
test/ann_return/best_model,-68.06
test/ann_return/ddpg,-62.82
test/ann_return/ppo,-30.44
test/ann_return/sac,-75.24
test/ann_return/td3,-61.1
test/best_model_name,a2c
test/cum_return/a2c,-24.05
test/cum_return/best_model,-24.05


Skipping 'dummy' split


0,1
best_model_name,ppo
test/ann_return/a2c,23.69
test/ann_return/best_model,60.99
test/ann_return/ddpg,41.26
test/ann_return/ppo,60.99
test/ann_return/sac,-23.35
test/ann_return/td3,53.12
test/best_model_name,ppo
test/cum_return/a2c,5.38
test/cum_return/best_model,12.46


Skipping 'dummy' split


0,1
best_model_name,sac
test/ann_return/a2c,10.63
test/ann_return/best_model,22.49
test/ann_return/ddpg,2.79
test/ann_return/ppo,-5.3
test/ann_return/sac,22.49
test/ann_return/td3,9.49
test/best_model_name,sac
test/cum_return/a2c,2.46
test/cum_return/best_model,5.01


Skipping 'dummy' split


0,1
best_model_name,ddpg
test/ann_return/a2c,11.98
test/ann_return/best_model,17.96
test/ann_return/ddpg,17.96
test/ann_return/ppo,11.98
test/ann_return/sac,3.58
test/ann_return/td3,8.06
test/best_model_name,ddpg
test/cum_return/a2c,2.73
test/cum_return/best_model,4.01


Skipping 'dummy' split


0,1
best_model_name,ddpg
test/ann_return/a2c,27.92
test/ann_return/best_model,58.62
test/ann_return/ddpg,58.62
test/ann_return/ppo,38.34
test/ann_return/sac,41.45
test/ann_return/td3,27.63
test/best_model_name,ddpg
test/cum_return/a2c,5.9
test/cum_return/best_model,11.34


Skipping 'dummy' split


0,1
best_model_name,ppo
test/ann_return/a2c,-50.36
test/ann_return/best_model,-67.15
test/ann_return/ddpg,-44.06
test/ann_return/ppo,-33.73
test/ann_return/sac,-67.15
test/ann_return/td3,-37.03
test/best_model_name,sac
test/cum_return/a2c,-15.54
test/cum_return/best_model,-23.54


Skipping 'dummy' split


0,1
best_model_name,ppo
test/ann_return/a2c,43.77
test/ann_return/best_model,55.87
test/ann_return/ddpg,33.81
test/ann_return/ppo,55.87
test/ann_return/sac,36.07
test/ann_return/td3,47.46
test/best_model_name,ppo
test/cum_return/a2c,9.04
test/cum_return/best_model,11.16


Skipping 'dummy' split


0,1
best_model_name,ppo
test/ann_return/a2c,8.97
test/ann_return/best_model,13.71
test/ann_return/ddpg,-15.69
test/ann_return/ppo,13.71
test/ann_return/sac,-25.97
test/ann_return/td3,12.51
test/best_model_name,ppo
test/cum_return/a2c,2.07
test/cum_return/best_model,3.11


Skipping 'dummy' split


0,1
best_model_name,td3
test/ann_return/a2c,-10.87
test/ann_return/best_model,-23.72
test/ann_return/ddpg,-23.72
test/ann_return/ppo,-22.05
test/ann_return/sac,-15.79
test/ann_return/td3,-4.05
test/best_model_name,ddpg
test/cum_return/a2c,-2.64
test/cum_return/best_model,-6.11


Skipping 'dummy' split


0,1
best_model_name,sac
test/ann_return/a2c,27.08
test/ann_return/best_model,39.31
test/ann_return/ddpg,23.05
test/ann_return/ppo,14.03
test/ann_return/sac,39.31
test/ann_return/td3,33.54
test/best_model_name,sac
test/cum_return/a2c,5.88
test/cum_return/best_model,8.22


True
True
True
True
True
True
True


Skipping 'dummy' split


0,1
best_model_name,ddpg
test/ann_return/a2c,15.42
test/ann_return/best_model,71.61
test/ann_return/ddpg,71.61
test/ann_return/ppo,16.7
test/ann_return/sac,12.02
test/ann_return/td3,-14.36
test/best_model_name,ddpg
test/cum_return/a2c,3.48
test/cum_return/best_model,13.74


Skipping 'dummy' split


0,1
best_model_name,ppo
test/ann_return/a2c,31.24
test/ann_return/best_model,36.91
test/ann_return/ddpg,25.86
test/ann_return/ppo,36.91
test/ann_return/sac,4.88
test/ann_return/td3,-30.21
test/best_model_name,ppo
test/cum_return/a2c,6.7
test/cum_return/best_model,7.78


Skipping 'dummy' split


0,1
best_model_name,ppo
test/ann_return/a2c,14.29
test/ann_return/best_model,38.41
test/ann_return/ddpg,30.33
test/ann_return/ppo,38.41
test/ann_return/sac,3.37
test/ann_return/td3,-2.16
test/best_model_name,ppo
test/cum_return/a2c,3.2
test/cum_return/best_model,7.96


Skipping 'dummy' split


0,1
best_model_name,sac
test/ann_return/a2c,5.86
test/ann_return/best_model,57.72
test/ann_return/ddpg,28.89
test/ann_return/ppo,31.57
test/ann_return/sac,57.72
test/ann_return/td3,24.12
test/best_model_name,sac
test/cum_return/a2c,1.37
test/cum_return/best_model,11.47


Skipping 'dummy' split


0,1
best_model_name,sac
test/ann_return/a2c,4.11
test/ann_return/best_model,4.84
test/ann_return/ddpg,-3.25
test/ann_return/ppo,-12.25
test/ann_return/sac,4.84
test/ann_return/td3,-23.11
test/best_model_name,sac
test/cum_return/a2c,1.0
test/cum_return/best_model,1.17


Skipping 'dummy' split


0,1
best_model_name,ppo
test/ann_return/a2c,-7.25
test/ann_return/best_model,-13.69
test/ann_return/ddpg,-2.82
test/ann_return/ppo,-1.4
test/ann_return/sac,-13.69
test/ann_return/td3,-12.49
test/best_model_name,sac
test/cum_return/a2c,-1.82
test/cum_return/best_model,-3.53


Skipping 'dummy' split


0,1
best_model_name,a2c
test/ann_return/a2c,20.46
test/ann_return/best_model,20.46
test/ann_return/ddpg,-16.13
test/ann_return/ppo,-0.68
test/ann_return/sac,3.3
test/ann_return/td3,-7.96
test/best_model_name,a2c
test/cum_return/a2c,4.48
test/cum_return/best_model,4.48


# Rename metrics

In [None]:
#@title rename_metrics_in_sweep
import wandb

def rename_metrics_in_sweep(sweep_id, metric_rename_map):
    """
    Rename metrics in all runs of a given sweep.

    Parameters:
        sweep_id (str): The ID of the sweep.
        metric_rename_map (dict): A dictionary where keys are old metric names and values are the new names.
    """
    api = wandb.Api()
    sweep = api.sweep(f"{ENTITY}/{PROJECT}/{sweep_id}")

    for run in sweep.runs:
        try:
            # Activate the run with context manager
            with wandb.init(id=run.id, project=PROJECT, entity=ENTITY, resume="allow"):
                # Get the run's summary
                summary = run.summary

                # Rename metrics based on the provided mapping
                for old_metric, new_metric in metric_rename_map.items():
                    if old_metric in summary:
                        # Convert the metric to float if possible
                        try:
                            metric_value = float(summary[old_metric])
                        except ValueError:
                            metric_value = summary[old_metric]  # Keep the value if it's not convertible to float

                        # Log the renamed metric to the current run
                        wandb.log({new_metric: metric_value})

                # Save the updated metrics
                print(f"Updated metrics for run: {run.name}")
        except Exception as e:
            print(f"Failed to update run {run.name}: {e}")

In [None]:
#@title rename default metrics

# metric_rename_map = {
#     "max_sharpe_ratio": "train.sharpe_ratio/best_model",
#     "max_sharpe_ratio_model": "train.best_model_name"
# }

# metric_rename_map = {
#     "test.sharpe_ratio/best_model": "test/sharpe_ratio/best_model",
#     "test.best_model_name": "test/best_model_name",
#     "test.cum_return/best_model": "test/cum_return/best_model",
#     "test.ann_return/best_model": "test/ann_return/best_model",
#     "test.mdd/best_model": "test/mdd/best_model"
# }

rename_metrics_in_sweep(sweep_id=SWEEP_ID, metric_rename_map=metric_rename_map)

In [None]:
wandb.finish()