# LSTM Training with Ray RLlib

This notebook demonstrates training a PPO agent with LSTM memory using Ray RLlib and Optuna hyperparameter optimization.

## 📚 Related Tutorials

| Tutorial | Description |
|----------|-------------|
| [Ray RLlib Deep Dive](../docs/tutorials/04-training/02-ray-rllib.md) | Distributed training configuration |
| [Optuna Optimization](../docs/tutorials/04-training/03-optuna.md) | Hyperparameter tuning guide |
| [First Training](../docs/tutorials/04-training/01-first-training.md) | Training fundamentals |
| [Common Failures](../docs/tutorials/02-domains/track-b-rl-for-traders/02-common-failures.md) | **Critical pitfalls to avoid** |
| [Walk-Forward Validation](../docs/tutorials/05-advanced/03-walk-forward.md) | Proper evaluation methodology |

### 📊 Best Hyperparameters from Experiments

From our Optuna experiments, the best hyperparameters found:
```python
{
    "lr": 3.29e-05,
    "gamma": 0.992,
    "entropy_coeff": 0.015,
    "clip_param": 0.123,
    "hidden_layers": [128, 128],
}
```

See [EXPERIMENTS.md](../docs/EXPERIMENTS.md) for full results.

---

The results are cached and the training can be resumed if not deleted.

In [2]:
!rm -rf ~/ray_results/PPO ~/ray_results/

In [None]:
!pip install --upgrade ray[default,rllib,tune,serve]==2.37.0 optuna tensortrade ta



In [1]:
from tensortrade_platform.data.cdd import CryptoDataDownload

import pandas as pd
import numpy as np

def prepare_data(df):
    df['volume'] = np.int64(df['volume'])
    df['date'] = pd.to_datetime(df['date'])
    df.sort_values(by='date', ascending=True, inplace=True)
    df.reset_index(drop=True, inplace=True)
    df['date'] = df['date'].dt.strftime('%Y-%m-%d %I:%M %p')
    return df

def fetch_data():
    cdd = CryptoDataDownload()
    bitfinex_data = cdd.fetch("Bitfinex", "BTC", "USD", "1h")
    bitfinex_data = bitfinex_data[['date', 'open', 'high', 'low', 'close', 'volume']]
    bitfinex_data = prepare_data(bitfinex_data)
    return bitfinex_data

def load_csv(filename):
    df = pd.read_csv('data/' + filename, skiprows=1)
    df.drop(columns=['symbol', 'volume_btc'], inplace=True)

    # Fix timestamp from "2019-10-17 09-AM" to "2019-10-17 09-00-00 AM"
    df['date'] = df['date'].str[:14] + '00-00 ' + df['date'].str[-2:]

    return prepare_data(df)

In [None]:
import ta

def rsi(price: 'pd.Series[pd.Float64Dtype]', period: float) -> 'pd.Series[pd.Float64Dtype]':
    r = price.diff()
    upside = np.minimum(r, 0).abs()
    downside = np.maximum(r, 0).abs()
    rs = upside.ewm(alpha=1 / period).mean() / downside.ewm(alpha=1 / period).mean()
    return 100*(1 - (1 + rs) ** -1)

def macd(price: 'pd.Series[pd.Float64Dtype]', fast: float, slow: float, signal: float) -> 'pd.Series[pd.Float64Dtype]':
    fm = price.ewm(span=fast, adjust=False).mean()
    sm = price.ewm(span=slow, adjust=False).mean()
    md = fm - sm
    signal = md - md.ewm(span=signal, adjust=False).mean()
    return signal

def generate_features(data):
    # Naming convention across most technical indicator libraries
    data = data.rename(columns={'date': 'Date', 
                                'open': 'Open', 
                                'high': 'High', 
                                'low': 'Low', 
                                'close': 'Close', 
                                'volume': 'Volume'})
    data = data.set_index('Date')

    # Custom indicators
    features = pd.DataFrame.from_dict({
        'dfast': df['Close'].rolling(window=10).std().abs(),
        'dmedium': df['Close'].rolling(window=50).std().abs(),
        'dslow': df['Close'].rolling(window=100).std().abs(),
        'fast': df['Close'].rolling(window=10).mean(),
        'medium': df['Close'].rolling(window=50).mean(),
        'slow': df['Close'].rolling(window=100).mean(),
        'ema_fast': ta.trend.ema_indicator(df['Close'], window=5, fillna=True),
        'ema_medium': ta.trend.ema_indicator(df['Close'], window=10, fillna=True),
        'ema_slow': ta.trend.ema_indicator(df['Close'], window=64, fillna=True),
        'lr': np.log(df['Close']).diff().fillna(0),
        'rsi_5': rsi(df['Close'], period=5),
        'rsi_10': rsi(df['Close'], period=10),
        'rsi_100': rsi(df['Close'], period=100),
        'rsi_7': rsi(df['Close'], period=7),
        'rsi_14': rsi(df['Close'], period=14),
        'rsi_28': rsi(df['Close'], period=28),
        'macd_normal': macd(df['Close'], fast=12, slow=26, signal=9),
        'macd_short': macd(df['Close'], fast=10, slow=50, signal=5),
        'macd_long': macd(df['Close'], fast=200, slow=100, signal=50),
    })

    # Generate all default indicators from ta library
    df = ta.add_all_ta_features(data, 'Open', 'High', 'Low', 'Close', 'Volume', fillna=True)

    # Concatenate both manually and automatically generated features
    data = pd.concat([df, features], axis='columns').ffill()

    # Remove potential column duplicates
    data = data.loc[:,~data.columns.duplicated()]

    # Revert naming convention
    data = data.rename(columns={'Date': 'date', 
                                'Open': 'open', 
                                'High': 'high', 
                                'Low': 'low', 
                                'Close': 'close', 
                                'Volume': 'volume'})

    # A lot of indicators generate NaNs at the beginning of DataFrames, so remove them
    data = data.iloc[200:]
    data = data.reset_index(drop=True)

    return data

In [3]:
from sklearn.model_selection import train_test_split

def split_data(data):
    X = data.copy()
    y = X['close'].pct_change()

    X_train_test, X_valid, y_train_test, y_valid = \
        train_test_split(data, data['close'].pct_change(), train_size=0.67, test_size=0.33, shuffle=False)

    X_train, X_test, y_train, y_test = \
        train_test_split(X_train_test, y_train_test, train_size=0.50, test_size=0.50, shuffle=False)

    return X_train, X_test, X_valid, y_train, y_test, y_valid

In [4]:
data = fetch_data()
data

Unnamed: 0,date,open,high,low,close,volume
0,2018-05-15 06:00 AM,8723.8,8793.0,8714.9,8739.0,8988053
1,2018-05-15 07:00 AM,8739.0,8754.8,8719.3,8743.0,2288904
2,2018-05-15 08:00 AM,8743.0,8743.1,8653.2,8723.7,8891773
3,2018-05-15 09:00 AM,8723.7,8737.8,8701.2,8708.1,2054868
4,2018-05-15 10:00 AM,8708.1,8855.7,8695.8,8784.4,17309722
...,...,...,...,...,...,...
65042,2025-10-15 09:00 PM,111320.0,111500.0,110710.0,110960.0,5257176
65043,2025-10-15 10:00 PM,110970.0,111500.0,110830.0,110830.0,3081976
65044,2025-10-15 11:00 PM,110820.0,111180.0,110770.0,110920.0,3135737
65045,2025-10-16 12:00 AM,110950.0,111010.0,110540.0,110660.0,5250807


In [5]:
#dataset = generate_features(data)
dataset = data.copy()
dataset

Unnamed: 0,date,open,high,low,close,volume
0,2018-05-15 06:00 AM,8723.8,8793.0,8714.9,8739.0,8988053
1,2018-05-15 07:00 AM,8739.0,8754.8,8719.3,8743.0,2288904
2,2018-05-15 08:00 AM,8743.0,8743.1,8653.2,8723.7,8891773
3,2018-05-15 09:00 AM,8723.7,8737.8,8701.2,8708.1,2054868
4,2018-05-15 10:00 AM,8708.1,8855.7,8695.8,8784.4,17309722
...,...,...,...,...,...,...
65042,2025-10-15 09:00 PM,111320.0,111500.0,110710.0,110960.0,5257176
65043,2025-10-15 10:00 PM,110970.0,111500.0,110830.0,110830.0,3081976
65044,2025-10-15 11:00 PM,110820.0,111180.0,110770.0,110920.0,3135737
65045,2025-10-16 12:00 AM,110950.0,111010.0,110540.0,110660.0,5250807


In [6]:
X_train, X_test, X_valid, y_train, y_test, y_valid = \
    split_data(data)

import os
cwd = os.getcwd()
train_csv = os.path.join(cwd, 'train.csv')
test_csv = os.path.join(cwd, 'test.csv')
valid_csv = os.path.join(cwd, 'valid.csv')
X_train.to_csv(train_csv, index=False)
X_test.to_csv(test_csv, index=False)
X_valid.to_csv(valid_csv, index=False)

In [7]:
# Things to understand here:
# Writing a Renderer

import matplotlib.pyplot as plt

from tensortrade.env.generic import Renderer


class PositionChangeChart(Renderer):
    def __init__(self, color: str = "orange"):
        self.color = "orange"

    def render(self, env, **kwargs):
        history = pd.DataFrame(env.observer.renderer_history)

        actions = list(history.action)
        price = list(history.close)

        buy = {}
        sell = {}

        for i in range(len(actions) - 1):
            a1 = actions[i]
            a2 = actions[i + 1]

            if a1 != a2:
                if a1 == 0 and a2 == 1:
                    buy[i] = price[i]
                else:
                    sell[i] = price[i]

        buy = pd.Series(buy)
        sell = pd.Series(sell)

        fig, axs = plt.subplots(1, 2, figsize=(15, 5))

        fig.suptitle("Performance")

        axs[0].plot(np.arange(len(price)), price, label="price", color=self.color)
        axs[0].scatter(buy.index, buy.values, marker="^", color="green")
        axs[0].scatter(sell.index, sell.values, marker="^", color="red")
        axs[0].set_title("Trading Chart")

        performance_df = pd.DataFrame().from_dict(env.action_scheme.portfolio.performance, orient='index')
        performance_df.plot(ax=axs[1])
        axs[1].set_title("Net Worth")

        plt.show()

In [None]:
# Things to understand here:
# execution_order
# Types of execution logic
# Exchange
# DataFeed
# renderer_feed
# default (env)

import ray
import numpy as np
import pandas as pd

from ray import tune
from ray.tune.registry import register_env

import tensortrade.env.default as default

from tensortrade.env.default.rewards import PBR, RiskAdjustedReturns
from tensortrade.env.default.rewards import SimpleProfit
from tensortrade.env.default.actions import BSH, ManagedRiskOrders
from tensortrade.feed.core import DataFeed, Stream
from tensortrade.feed.core.base import NameSpace
from tensortrade.oms.exchanges import Exchange, ExchangeOptions
from tensortrade.oms.instruments import USD, BTC
from tensortrade.oms.services.execution.simulated import execute_order
from tensortrade.oms.wallets import Wallet, Portfolio

def create_env(config):
    data = pd.read_csv(filepath_or_buffer=config["csv_filename"], 
                       parse_dates=['date']).bfill().ffill()

    # TODO: adjust according to your commission percentage, if present
    commission = 0.001
    price = Stream.source(list(data["close"]), 
                          dtype="float").rename("USD-BTC")
    bitstamp_options = ExchangeOptions(commission=commission)
    bitstamp = Exchange("bitstamp", 
                        service=execute_order, 
                        options=bitstamp_options)(price)

    cash = Wallet(bitstamp, 100000 * USD)
    asset = Wallet(bitstamp, 0 * BTC)

    portfolio = Portfolio(USD, [cash, asset])

    # Custom indicators
    features = pd.DataFrame.from_dict({
        'dfast': data['close'].rolling(window=10).std().abs(),
        'dmedium': data['close'].rolling(window=50).std().abs(),
        'dslow': data['close'].rolling(window=100).std().abs(),
        'fast': data['close'].rolling(window=10).mean(),
        'medium': data['close'].rolling(window=50).mean(),
        'slow': data['close'].rolling(window=100).mean(),
        'ema_fast': ta.trend.ema_indicator(data['close'], window=5, fillna=True),
        'ema_medium': ta.trend.ema_indicator(data['close'], window=10, fillna=True),
        'ema_slow': ta.trend.ema_indicator(data['close'], window=64, fillna=True),
        'lr': np.log(data['close']).diff().fillna(0),
        'rsi_5': rsi(data['close'], period=5),
        'rsi_10': rsi(data['close'], period=10),
        'rsi_100': rsi(data['close'], period=100),
        'rsi_7': rsi(data['close'], period=7),
        'rsi_14': rsi(data['close'], period=14),
        'rsi_28': rsi(data['close'], period=28),
        'macd_normal': macd(data['close'], fast=12, slow=26, signal=9),
        'macd_short': macd(data['close'], fast=10, slow=50, signal=5),
        'macd_long': macd(data['close'], fast=200, slow=100, signal=50),
    })

    ta.add_all_ta_features(data, 
                           'open', 
                           'high', 
                           'low', 
                           'close', 
                           'volume', 
                           fillna=True)

    with NameSpace("bitstamp"):
        data = pd.concat([data, features], axis='columns')
        automatic_features = [
            Stream.source(list(data[c]), 
                          dtype="float").rename(c) for c in data.columns[1:]
        ]

    feed = DataFeed(automatic_features)
    feed.compile()

    reward_scheme = PBR(price=price)

    action_scheme = BSH(
        cash=cash,
        asset=asset
    ).attach(reward_scheme)

    renderer_feed = DataFeed([
        Stream.source(list(data["date"])).rename("date"),
        Stream.source(list(data["open"]), dtype="float").rename("open"),
        Stream.source(list(data["high"]), dtype="float").rename("high"),
        Stream.source(list(data["low"]), dtype="float").rename("low"),
        Stream.source(list(data["close"]), dtype="float").rename("close"), 
        Stream.source(list(data["volume"]), dtype="float").rename("volume"), 
        Stream.sensor(action_scheme, 
                      lambda s: s.action, dtype="float").rename("action")
    ])

    environment = default.create(
        feed=feed,
        portfolio=portfolio,
        action_scheme=action_scheme,
        reward_scheme=reward_scheme,
        renderer_feed=renderer_feed,
        renderer=[
            PositionChangeChart(),
            default.renderers.PlotlyTradingChart(),
        ],
        window_size=config["window_size"],
        max_allowed_loss=0.9
    )
    return environment

ray.init(num_cpus=6,
         include_dashboard=True,
         address=None,  # set `address=None` to train on laptop
         ignore_reinit_error=True)

register_env("TradingEnv", create_env)

In [None]:
from ray.tune.schedulers import ASHAScheduler
from ray.tune.search import ConcurrencyLimiter
from ray.tune.search.optuna import OptunaSearch
from ray.tune import TuneConfig, RunConfig
from ray.train import CheckpointConfig

LR = tune.loguniform(1e-5, 1e-2)
GAMMA = tune.uniform(0.8, 0.9999)
LAMBDA = tune.uniform(0.1, 0.8)
VF_LOSS_COEFF = tune.uniform(0.01, 1.0)
ENTROPY_COEFF = tune.uniform(1e-8, 1e-1)

checkpoint_metric = 'env_runners/episode_reward_mean'

# Specific configuration keys that will be used during training
env_config_training = {
    "window_size": 14,  # The number of past samples to look at (hours)
    "reward_window_size": 7,  # And calculate reward based on the actions taken in the next n hours
    "max_allowed_loss": 0.10,  # If it goes past 90% loss during the iteration, we don't want to waste time on a "loser".
    "csv_filename": train_csv  # The variable that will be used to differentiate training and validation datasets
}
# Specific configuration keys that will be used during evaluation (only the overridden ones)
env_config_evaluation = {
    "max_allowed_loss": 1.00,  # During validation runs we want to see how bad it would go. Even up to 100% loss.
    "csv_filename": test_csv,  # The variable that will be used to differentiate training and validation datasets
}

search_alg = OptunaSearch()
search_alg = ConcurrencyLimiter(search_alg, max_concurrent=4)

scheduler = ASHAScheduler(
    max_t=35,  # Max training iterations per trial
    grace_period=5,  # Min iterations before early stopping
)

import time
start = time.time()

# Ray 2.x API: use tune.Tuner instead of tune.run
tuner = tune.Tuner(
    "PPO",
    param_space={
        "env": "TradingEnv",
        "env_config": env_config_training,
        "log_level": "ERROR",
        "framework": "torch",
        "enable_rl_module_and_learner": False,  # Use old API stack for model/lr_schedule support
        "enable_env_runner_and_connector_v2": False,
        "ignore_env_runner_failures": True,
        "num_env_runners": 2,  # Ray 2.x: num_workers -> num_env_runners
        "num_gpus": 0,
        "clip_rewards": True,
        "lr": LR,
        "lr_schedule": [
            [0, 1e-1],
            [int(1e2), 1e-2],
            [int(1e3), 1e-3],
            [int(1e4), 1e-4],
            [int(1e5), 1e-5],
            [int(1e6), 1e-6],
            [int(1e7), 1e-7]
        ],
        "model": {
            "use_lstm": True,
            "lstm_cell_size": 512
        },
        "gamma": GAMMA,
        "observation_filter": "MeanStdFilter",
        "lambda_": LAMBDA,  # Ray 2.x: "lambda" -> "lambda_"
        "vf_share_layers": True,
        "vf_loss_coeff": VF_LOSS_COEFF,
        "entropy_coeff": ENTROPY_COEFF,
        "evaluation_interval": 1,  # Run evaluation on every iteration
        "evaluation_config": {
            "env_config": env_config_evaluation,  # The dictionary we built before
            "explore": False,  # We don't want to explore during evaluation
        },
    },
    tune_config=TuneConfig(
        search_alg=search_alg,
        scheduler=scheduler,
        num_samples=10,  # Samples per hyperparameter combination
        metric=checkpoint_metric,
        mode="max",
    ),
    run_config=RunConfig(
        checkpoint_config=CheckpointConfig(
            checkpoint_score_attribute=checkpoint_metric,
            num_to_keep=10,
        ),
    ),
)

# Execute the tuning
results = tuner.fit()

taken = time.time() - start
print(f"Time taken: {taken:.2f} seconds.")

# Get best result
best_result = results.get_best_result(metric=checkpoint_metric, mode="max")
print(f"Best config: {best_result.config}")

In [None]:
# Plot episode reward mean across all trials
ax = None
for result in results:
    if result.metrics_dataframe is not None:
        df = result.metrics_dataframe
        if 'env_runners/episode_reward_mean' in df.columns:
            ax = df['env_runners/episode_reward_mean'].plot(ax=ax, legend=False)

In [None]:
# Ray 2.x API: Use PPO.from_checkpoint() instead of PPOTrainer
from ray.rllib.algorithms.ppo import PPO

# Get best checkpoint from results
best_result = results.get_best_result(metric=checkpoint_metric, mode='max')
checkpoint_path = best_result.checkpoint.path

env_config_validation = {
    "window_size": 14,  # The number of past samples to look at (hours)
    "reward_window_size": 7,  # And calculate reward based on the actions taken in the next n hours
    "max_allowed_loss": 1.0,  # Allow 100% loss during evaluation
    "csv_filename": valid_csv  # The variable that will be used to differentiate training and validation datasets
}

# Restore algorithm from checkpoint
algo = PPO.from_checkpoint(checkpoint_path)

# Get config and update env_config for validation
config = best_result.config.copy()
config['env_config'] = env_config_validation

In [None]:
# See how the model is wrapped by LSTM
algo.get_policy().model

In [None]:
# Instantiate the environment
env = create_env(env_config_validation)

# Run until episode ends
# Gymnasium API: reset() returns (obs, info) tuple
obs, info = env.reset()

# Initialize hidden_state variable that will correspond to lstm_cell_size
lstm_cell_size = config['model']['lstm_cell_size']
hidden_state = [np.zeros(lstm_cell_size), np.zeros(lstm_cell_size)]

done = truncated = False
total_reward = 0

while not done and not truncated:
    # In order for use_lstm to work we set full_fetch to True
    # This changes the output of compute action to a tuple (action, hidden_state, info)
    # We also pass in the previous hidden state in order for the model to use correctly use the LSTM
    action, hidden_state, _ = algo.compute_single_action(obs, state=hidden_state, full_fetch=True)
    # Gymnasium API: step() returns (obs, reward, terminated, truncated, info)
    obs, reward, done, truncated, info = env.step(action)
    total_reward += reward

print(f"Total reward: {total_reward}")
env.render()