In [None]:
import gym
import wandb
from tqdm.notebook import tqdm_notebook

from fractal_zero.models import (
    FullyConnectedDynamicsModel, 
    FullyConnectedRepresentationModel, 
    FullyConnectedPredictionModel,
    JointModel,
)

from fractal_zero.data.data_handler import DataHandler
from fractal_zero.fractal_zero import FractalZero
from fractal_zero.trainer import FractalZeroTrainer

from fractal_zero.config import FractalZeroConfig

In [None]:
env = gym.make("CartPole-v0")

In [None]:
embedding_size = 8
out_features = 1

representation_model = FullyConnectedRepresentationModel(env, embedding_size)
dynamics_model = FullyConnectedDynamicsModel(
    env, embedding_size, out_features=out_features
)
prediction_model = FullyConnectedPredictionModel(env, embedding_size)
joint_model = JointModel(representation_model, dynamics_model, prediction_model)

In [None]:
config = FractalZeroConfig(
    env,
    joint_model,
    max_replay_buffer_size=512,
    replay_buffer_pop_strategy="balanced",
    num_games=5_000,
    max_game_steps=200,                                         # cartpole maxes out at 200 steps.
    max_batch_size=128,
    unroll_steps=16,
    learning_rate=0.01,
    optimizer="SGD",
    weight_decay=1e-4,
    momentum=0.9,                                               # only if optimizer is SGD
    num_walkers=64,
    balance=1.0,
    lookahead_steps=16,
    evaluation_lookahead_steps=16,
    wandb_config={"project": "fractal_zero_cartpole"},
)

# TODO: make this logic automatic in config somehow?
config.joint_model = config.joint_model.to(config.device)

In [None]:
config.asdict()

In [None]:
# %%prun -s cumtime -T cartpole_profile.txt -q -l 100

# TODO: move into config?
train_every = 1
train_batches = 1
evaluate_every = 16
eval_steps = 16

data_handler = DataHandler(config)
fractal_zero = FractalZero(config)
trainer = FractalZeroTrainer(
    fractal_zero,
    data_handler,
)

for i in tqdm_notebook(
    range(config.num_games),
    desc="Playing games and training",
    total=config.num_games,
):
    trainer.play_game_store_history()

    if (i + 1) % train_every == 0:
        for _ in range(train_batches):
            trainer.train_step()

    if (i + 1) % evaluate_every == 0:
        trainer.evaluate(eval_steps)