# SARL for Portfolio Management on DJ30
This tutorial is to demonstrate an example of using SARL to do portfolio management on DJ30

## Step1: Import Packages

In [1]:
import warnings
import argparse

warnings.filterwarnings("ignore")
import sys
from pathlib import Path
import os
import torch

ROOT = os.path.dirname(os.path.abspath("."))
sys.path.append(ROOT)

from trademaster.utils import plot
import argparse
import os.path as osp
from mmengine.config import Config
from trademaster.utils import replace_cfg_vals
from trademaster.datasets.builder import build_dataset
from trademaster.trainers.builder import build_trainer
from trademaster.utils import set_seed
set_seed(2023)

## Step2: Import Configs

In [2]:
parser = argparse.ArgumentParser(description='Download Alpaca Datasets')
parser.add_argument("--config", default=osp.join(ROOT, "configs", "portfolio_management", "portfolio_management_dj30_sarl_sarl_adam_mse.py"),
                    help="download datasets config file path")
parser.add_argument("--task_name", type=str, default="train")
args, _ = parser.parse_known_args()

In [3]:
task_name = "portfolio_management"
dataset_name = "dj30"
net_name = "sarl"
agent_name = "ppo"  # <- Updated to reflect PPO usage
optimizer_name = "adam"
loss_name = "mse"

work_dir = f"work_dir/{task_name}_{dataset_name}_{net_name}_{agent_name}_{optimizer_name}_{loss_name}"

cfg_dict = {
    '_base_': [
        f"../_base_/datasets/{task_name}/{dataset_name}.py",
        f"../_base_/environments/{task_name}/env.py",
        f"../_base_/trainers/{task_name}/sarl_trainer.py",
        f"../_base_/losses/{loss_name}.py",
        f"../_base_/optimizers/{optimizer_name}.py",
    ],
    'data': {
        'type': "PortfolioManagementDataset",
        'data_path': "data/portfolio_management/dj30",
        'train_path': "data/portfolio_management/dj30/train.csv",
        'valid_path': "data/portfolio_management/dj30/valid.csv",
        'test_path': "data/portfolio_management/dj30/test.csv",
        'test_dynamic_path': "data/portfolio_management/dj30/test_with_label.csv",
        'tech_indicator_list': [
            "high", "low", "open", "close", "adjcp",
            "zopen", "zhigh", "zlow", "zadjcp", "zclose",
            "zd_5", "zd_10", "zd_15", "zd_20", "zd_25", "zd_30"
        ],
        'length_day': 5,
        'initial_amount': 10000,
        'transaction_cost_pct': 0.001,
    },
    'environment': {
        'type': "PortfolioManagementSARLEnvironment",
    },
    'trainer': {
        'type': "PortfolioManagementSARLTrainer",
        'agent_name': "ppo",  # <- Make sure this matches internally
        'if_remove': False,
        'work_dir': work_dir,
        'epochs': 2,
        'configs': {
            'dataset': {
                'type': "PortfolioManagementDataset",
                'data_path': "data/portfolio_management/dj30",
                'train_path': "data/portfolio_management/dj30/train.csv",
                'valid_path': "data/portfolio_management/dj30/valid.csv",
                'test_path': "data/portfolio_management/dj30/test.csv",
                'test_dynamic_path': "data/portfolio_management/dj30/test_with_label.csv",
                'tech_indicator_list': [
                    "high", "low", "open", "close", "adjcp",
                    "zopen", "zhigh", "zlow", "zadjcp", "zclose",
                    "zd_5", "zd_10", "zd_15", "zd_20", "zd_25", "zd_30"
                ],
                'length_day': 5,
                'initial_amount': 10000,
                'transaction_cost_pct': 0.001,
            },
            'work_dir': work_dir,
            'num_workers': 0,          # Optional: can increase for parallel rollout
            'num_gpus': 0,             # Optional: 1 if you want GPU
            'lr': 5e-5,
            'train_batch_size': 4000,
            'sgd_minibatch_size': 128,
            'num_sgd_iter': 30,
            'gamma': 0.99,
            'model': {
                # Optional: Model config, e.g., custom FC sizes
                "fcnet_hiddens": [256, 256],
                "fcnet_activation": "relu",
            },
        }
    }
}


cfg = Config(cfg_dict)

In [4]:
cfg

Config (path: None): {'_base_': ['../_base_/datasets/portfolio_management/dj30.py', '../_base_/environments/portfolio_management/env.py', '../_base_/trainers/portfolio_management/sarl_trainer.py', '../_base_/losses/mse.py', '../_base_/optimizers/adam.py'], 'data': {'type': 'PortfolioManagementDataset', 'data_path': 'data/portfolio_management/dj30', 'train_path': 'data/portfolio_management/dj30/train.csv', 'valid_path': 'data/portfolio_management/dj30/valid.csv', 'test_path': 'data/portfolio_management/dj30/test.csv', 'test_dynamic_path': 'data/portfolio_management/dj30/test_with_label.csv', 'tech_indicator_list': ['high', 'low', 'open', 'close', 'adjcp', 'zopen', 'zhigh', 'zlow', 'zadjcp', 'zclose', 'zd_5', 'zd_10', 'zd_15', 'zd_20', 'zd_25', 'zd_30'], 'length_day': 5, 'initial_amount': 10000, 'transaction_cost_pct': 0.001}, 'environment': {'type': 'PortfolioManagementSARLEnvironment'}, 'trainer': {'type': 'PortfolioManagementSARLTrainer', 'agent_name': 'ppo', 'if_remove': False, 'work

## Step3: Build Dataset

In [5]:
dataset = build_dataset(cfg)

## Step4: Build Trainer

In [6]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
work_dir = os.path.join(ROOT,"Sarl", cfg.trainer.work_dir)

if not os.path.exists(work_dir):
    os.makedirs(work_dir)
cfg.dump(osp.join(work_dir, osp.basename(args.config)))

In [7]:
trainer = build_trainer(cfg, default_args=dict(dataset=dataset, device = device))

2025-06-14 03:26:28,157	INFO worker.py:1917 -- Started a local Ray instance.


## Step5: Train, Valid and Test

In [8]:
trainer.train_and_valid()

`UnifiedLogger` will be removed in Ray 2.7.
  return UnifiedLogger(config, logdir, loggers=None)
The `JsonLogger interface is deprecated in favor of the `ray.tune.json.JsonLoggerCallback` interface and will be removed in Ray 2.7.
  self._loggers.append(cls(self.config, self.logdir, self.trial))
The `CSVLogger interface is deprecated in favor of the `ray.tune.csv.CSVLoggerCallback` interface and will be removed in Ray 2.7.
  self._loggers.append(cls(self.config, self.logdir, self.trial))
The `TBXLogger interface is deprecated in favor of the `ray.tune.tensorboardx.TBXLoggerCallback` interface and will be removed in Ray 2.7.
  self._loggers.append(cls(self.config, self.logdir, self.trial))


Training Epoch 1/2


`UnifiedLogger` will be removed in Ray 2.7.
  return UnifiedLogger(config, logdir, loggers=None)
The `JsonLogger interface is deprecated in favor of the `ray.tune.json.JsonLoggerCallback` interface and will be removed in Ray 2.7.
  self._loggers.append(cls(self.config, self.logdir, self.trial))
The `CSVLogger interface is deprecated in favor of the `ray.tune.csv.CSVLoggerCallback` interface and will be removed in Ray 2.7.
  self._loggers.append(cls(self.config, self.logdir, self.trial))
The `TBXLogger interface is deprecated in favor of the `ray.tune.tensorboardx.TBXLoggerCallback` interface and will be removed in Ray 2.7.
  self._loggers.append(cls(self.config, self.logdir, self.trial))


Validating Epoch 1/2


AttributeError: 'function' object has no attribute 'local_worker'

In [None]:
import ray
from ray.tune.registry import register_env
from trademaster.environments.portfolio_management.sarl_environment import PortfolioManagementSARLEnvironment
def env_creator(env_name):
    if env_name == 'portfolio_management_sarl':
        env = PortfolioManagementSARLEnvironment
    else:
        raise NotImplementedError
    return env
ray.init(ignore_reinit_error=True)
register_env("portfolio_management_sarl", lambda config: env_creator("portfolio_management_sarl")(config))
trainer.test();

In [None]:
plot(trainer.test_environment.save_asset_memory(),alg="SARL")