In [None]:
import ray
from ray import tune

from ray.tune.logger import DEFAULT_LOGGERS
from ray.tune.integration.wandb import WandbLoggerCallback
import wandb
from abides_gym.envs.markets_execution_custom_metrics import MyCallbacks

api_key = wandb.api.api_key

# Example with custom callbacks and WandB
# Import to register environments
import abides_gym

from ray.tune.registry import register_env

# import env
from abides_gym.envs.markets_execution_environment_v0 import (
    SubGymMarketsExecutionEnv_v0,
)

In [None]:
register_env(
    "markets-execution-v0",
    lambda config: SubGymMarketsExecutionEnv_v0(**config),
)

ray.shutdown()
# limit memory to 0.5Gb, 10x == 5GB
# https://stackoverflow.com/a/57796710
#max_memory = ( (10**9) / 2 )
ray.init(num_cpus=8)

In [None]:
ray.shutdown()
ray.init()

"""
DQN's default:
train_batch_size=32, sample_batch_size=4, timesteps_per_iteration=1000 -> workers collect chunks of 4 ts and add these to the replay buffer (of size buffer_size ts), then at each train call, at least 1000 ts are pulled altogether from the buffer (in batches of 32) to update the network.
"""
register_env(
    "markets-execution-v0",
    lambda config: SubGymMarketsExecutionEnv_v0(**config),
)


name_xp = "dqn_execution_demo_3" #change to your convenience

tune.run(
    "DQN",
    name=name_xp,
    resume=False,
    stop={"training_iteration": 100},  
    checkpoint_at_end=True,
    checkpoint_freq=5,
    config={
        "env": "markets-execution-v0",
        "env_config": {"background_config":"rmsc04",
                        "timestep_duration":"10S",
                        "execution_window": "04:00:00",
                        "parent_order_size": 20000,
                        "order_fixed_size": 50,
                        "not_enough_reward_update":-100,#penalty
            },
        "seed": tune.grid_search([1, 2, 3]),
        "num_gpus": 0,
        "num_workers": 0,
        "hiddens": [50, 20],
        "gamma": 1,
        "lr": tune.grid_search([0.001,0.0001, 0.01]),
        "framework": "torch",
        "observation_filter": "MeanStdFilter",
    },
    callbacks=[
        WandbLoggerCallback(
            project="abides_markets_execution_ABM",
            group=name_xp,
            api_key=api_key,
            log_config=False,
        )
    ],
)