In [None]:
import pandas as pd
from ray.rllib.agents import ppo
from ray.tune.logger import pretty_print

from src.rllib_gym_trade_environment import prepare_dict
from src.rllib_gym_trade_environment import CryptoEnv

In [None]:
stop_iters = 200
n_workers = 5
n_envs_per_worker = 20
r_fragment_length = 2000
train_batch_size = n_workers * n_envs_per_worker * r_fragment_length
sgd_minibatch_size = int(train_batch_size / 10)
df_train = pd.read_parquet('data/df_train.parquet')
train_dict = prepare_dict(df_train)
config = {
    "env": CryptoEnv,
    "env_config": {
        "price_array": train_dict['price_array'],
        "observations": train_dict['observations'],
        "initial_capital": 1e4,
        "gamma": 0.99,
        "max_steps": r_fragment_length,
    },
    "num_gpus": 1,
    "model": {
        "vf_share_layers": False,
    },
    "num_workers": n_workers,
    "num_envs_per_worker": n_envs_per_worker,
    "rollout_fragment_length": r_fragment_length,
    "train_batch_size": train_batch_size,
    "sgd_minibatch_size": sgd_minibatch_size,
    "batch_mode": "complete_episodes",
    "framework": "tf",
}

In [None]:
ppo_config = ppo.DEFAULT_CONFIG.copy()
ppo_config.update(config)
ppo_config["lr"] = 1e-5
trainer = ppo.PPOTrainer(config=ppo_config, env=CryptoEnv)

for _ in range(stop_iters):
    result = trainer.train()
    print(pretty_print(result))
    checkpoint = trainer.save()
    print("Checkpoint saved at", checkpoint, "\n")