In [1]:
import gym
import wandb
from tqdm.notebook import tqdm_notebook

from fractal_zero.models import (
    FullyConnectedDynamicsModel, 
    FullyConnectedRepresentationModel, 
    FullyConnectedPredictionModel,
    JointModel,
)

from fractal_zero.data.data_handler import DataHandler
from fractal_zero.fractal_zero import FractalZero
from fractal_zero.trainer import FractalZeroTrainer

from fractal_zero.config import FractalZeroConfig




In [2]:
env = gym.make("CartPole-v0")

  logger.warn(
  deprecation(
  deprecation(


In [3]:
embedding_size = 16
out_features = 1

representation_model = FullyConnectedRepresentationModel(env, embedding_size)
dynamics_model = FullyConnectedDynamicsModel(
    env, embedding_size, out_features=out_features
)
prediction_model = FullyConnectedPredictionModel(env, embedding_size)
joint_model = JointModel(representation_model, dynamics_model, prediction_model)

In [6]:
config = FractalZeroConfig(
    env,
    joint_model,
    max_replay_buffer_size=8,
    replay_buffer_pop_strategy="balanced",
    num_games=128,
    max_game_steps=200,
    max_batch_size=128,
    unroll_steps=16,
    learning_rate=0.003,
    optimizer="SGD",
    weight_decay=1e-4,
    momentum=0.9,  # only if optimizer is SGD
    num_walkers=64,
    balance=1.0,
    lookahead_steps=64,
    evaluation_lookahead_steps=64,
    # wandb_config={"project": "fractal_zero_cartpole"},
)

# TODO: make this logic automatic in config somehow?
config.joint_model = config.joint_model.to(config.device)

In [7]:
config.asdict()

{'joint_model': JointModel(
   (representation_model): FullyConnectedRepresentationModel(
     (net): Sequential(
       (0): Linear(in_features=4, out_features=16, bias=True)
       (1): ReLU()
       (2): Linear(in_features=16, out_features=16, bias=True)
       (3): ReLU()
       (4): Linear(in_features=16, out_features=16, bias=True)
     )
   )
   (dynamics_model): FullyConnectedDynamicsModel(
     (embedding_net): Sequential(
       (0): Linear(in_features=17, out_features=16, bias=True)
       (1): ReLU()
       (2): Linear(in_features=16, out_features=16, bias=True)
       (3): ReLU()
       (4): Linear(in_features=16, out_features=16, bias=True)
       (5): ReLU()
     )
     (auxiliary_net): Sequential(
       (0): Linear(in_features=16, out_features=1, bias=True)
     )
   )
   (prediction_model): FullyConnectedPredictionModel(
     (policy_head): Linear(in_features=16, out_features=1, bias=True)
     (value_head): Linear(in_features=16, out_features=1, bias=True)
   )
 ),
 

In [8]:
%%prun -s cumtime -T cartpole_profile.txt -q -l 50

# TODO: move into config?
train_every = 1
train_batches = 2
evaluate_every = 16
eval_steps = 16

data_handler = DataHandler(config)
fractal_zero = FractalZero(config)
trainer = FractalZeroTrainer(
    fractal_zero,
    data_handler,
)

for i in tqdm_notebook(
    range(config.num_games),
    desc="Playing games and training",
    total=config.num_games,
):
    fractal_zero.train()
    game_history = fractal_zero.play_game()
    data_handler.replay_buffer.append(game_history)

    if i % train_every == 0:
        for _ in range(train_batches):
            trainer.train_step()

    if i % evaluate_every == 0:
        trainer.evaluate(eval_steps)

Playing games and training:   0%|          | 0/128 [00:00<?, ?it/s]