In [10]:
import gym
import numpy as np
from tqdm import tqdm
import torch
import wandb

from fractal_zero.config import FMCConfig
from fractal_zero.search.fmc import FMC
from fractal_zero.models.prediction import FullyConnectedPredictionModel
from fractal_zero.vectorized_environment import (
    RayVectorizedEnvironment,
    VectorizedDynamicsModelEnvironment,

)
from fractal_zero.trainers.online import OnlineFMCPolicyTrainer

from fractal_zero.tests.test_vectorized_environment import build_test_joint_model

In [11]:
NUM_WALKERS = 64

class CartpolePolicy(torch.nn.Module):
    def __init__(self):
        super().__init__()

        self.net = torch.nn.Sequential(
            torch.nn.Linear(4, 16),
            torch.nn.ReLU(),
            torch.nn.Linear(16, 32),
            torch.nn.ReLU(),
            torch.nn.Linear(32, 16),
            torch.nn.ReLU(),
            torch.nn.Linear(16, 2),
        )

    def forward(self, observations, with_randomness: bool = False, argmax: bool = True):
        observations = torch.tensor(observations).float()

        y = self.net(observations)

        if with_randomness:
            # center = embeddings.std()
            center = y.var()
            centered_uniform_noise = (torch.rand_like(y) * center) - (center / 2)
            y += centered_uniform_noise

        # TODO: refac
        if argmax:
            return torch.argmax(y, dim=-1)
        return y

    def parse_actions(self, actions):
        return actions.tolist()

policy_model = CartpolePolicy()

In [12]:
# optimizer = torch.optim.SGD(policy_model.parameters(), lr=0.01, weight_decay=1e-4)
optimizer = torch.optim.Adam(policy_model.parameters(), lr=0.01, weight_decay=1e-4)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.5)
policy_trainer = OnlineFMCPolicyTrainer("CartPole-v0", policy_model, optimizer, NUM_WALKERS)

  logger.warn(
  deprecation(
  deprecation(


In [13]:
wandb.init(project="fz-policy-trainer-game-tree")

  from IPython.core.display import HTML, display  # type: ignore


VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

[2m[36m(pid=2067124)[0m  0xffffffff
[2m[36m(_RayWrappedEnvironment pid=2066785)[0m   logger.warn(
[2m[36m(_RayWrappedEnvironment pid=2066785)[0m   deprecation(
[2m[36m(_RayWrappedEnvironment pid=2066785)[0m   deprecation(
[2m[36m(_RayWrappedEnvironment pid=2066790)[0m   logger.warn(
[2m[36m(_RayWrappedEnvironment pid=2066790)[0m   deprecation(
[2m[36m(_RayWrappedEnvironment pid=2066790)[0m   deprecation(


0,1
eval/total_rewards,▁▆██████████████████████████████████████
train/epsiode_reward,███████████▆█████████████████▁██████████
train/loss,▃▃▂▁▃▃▅▅▆▁▂▃▄▁▁▃▃▁▂▃▃▄▂▂▂▄▄█▂▃▂▁▂▅▂▂▃▁▅▃

0,1
eval/total_rewards,32.0
train/epsiode_reward,32.0
train/loss,0.69525


[2m[36m(_RayWrappedEnvironment pid=2066784)[0m   logger.warn(
[2m[36m(_RayWrappedEnvironment pid=2066784)[0m   deprecation(
[2m[36m(_RayWrappedEnvironment pid=2066784)[0m   deprecation(
[2m[36m(_RayWrappedEnvironment pid=2066774)[0m   logger.warn(
[2m[36m(_RayWrappedEnvironment pid=2066774)[0m   deprecation(
[2m[36m(_RayWrappedEnvironment pid=2066774)[0m   deprecation(
[2m[36m(_RayWrappedEnvironment pid=2066787)[0m   logger.warn(
[2m[36m(_RayWrappedEnvironment pid=2066787)[0m   deprecation(
[2m[36m(_RayWrappedEnvironment pid=2066787)[0m   deprecation(
[2m[36m(_RayWrappedEnvironment pid=2066973)[0m   logger.warn(
[2m[36m(_RayWrappedEnvironment pid=2066973)[0m   deprecation(
[2m[36m(_RayWrappedEnvironment pid=2066973)[0m   deprecation(
[2m[36m(_RayWrappedEnvironment pid=2066775)[0m   logger.warn(
[2m[36m(_RayWrappedEnvironment pid=2066775)[0m   deprecation(
[2m[36m(_RayWrappedEnvironment pid=2066775)[0m   deprecation(
  from IPython.core.displ

VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.033337100346883135, max=1.0…

[2m[36m(_RayWrappedEnvironment pid=2066819)[0m   logger.warn(
[2m[36m(_RayWrappedEnvironment pid=2066819)[0m   deprecation(
[2m[36m(_RayWrappedEnvironment pid=2066819)[0m   deprecation(
[2m[36m(_RayWrappedEnvironment pid=2067023)[0m   logger.warn(
[2m[36m(_RayWrappedEnvironment pid=2067023)[0m   deprecation(
[2m[36m(_RayWrappedEnvironment pid=2067023)[0m   deprecation(
[2m[36m(_RayWrappedEnvironment pid=2067143)[0m   logger.warn(
[2m[36m(_RayWrappedEnvironment pid=2067143)[0m   deprecation(
[2m[36m(_RayWrappedEnvironment pid=2067143)[0m   deprecation(
[2m[36m(_RayWrappedEnvironment pid=2066777)[0m   logger.warn(
[2m[36m(_RayWrappedEnvironment pid=2066777)[0m   deprecation(
[2m[36m(_RayWrappedEnvironment pid=2066777)[0m   deprecation(
[2m[36m(_RayWrappedEnvironment pid=2067469)[0m   logger.warn(
[2m[36m(_RayWrappedEnvironment pid=2067469)[0m   deprecation(
[2m[36m(_RayWrappedEnvironment pid=2067469)[0m   deprecation(
[2m[36m(_RayWrappedEnvi

In [14]:
num_episodes = 100
train_steps_per_episode = 2
eval_every = 20
max_steps = 200

for _ in range(num_episodes):
    policy_trainer.generate_episode_data(max_steps)

    for i in range(train_steps_per_episode):
        policy_trainer.train_on_latest_episode()

        if i % eval_every == 0:
            policy_trainer.evaluate_policy(max_steps)

    lr_scheduler.step()

[2m[36m(_RayWrappedEnvironment pid=2066973)[0m   logger.warn(
[2m[36m(_RayWrappedEnvironment pid=2066775)[0m   logger.warn(
[2m[36m(_RayWrappedEnvironment pid=2067611)[0m   logger.warn(
[2m[36m(_RayWrappedEnvironment pid=2067068)[0m   logger.warn(
[2m[36m(_RayWrappedEnvironment pid=2067785)[0m   logger.warn(
[2m[36m(_RayWrappedEnvironment pid=2067313)[0m   logger.warn(
[2m[36m(_RayWrappedEnvironment pid=2067688)[0m   logger.warn(
[2m[36m(_RayWrappedEnvironment pid=2067881)[0m   logger.warn(
[2m[36m(_RayWrappedEnvironment pid=2066789)[0m   logger.warn(
[2m[36m(_RayWrappedEnvironment pid=2066886)[0m   logger.warn(
[2m[36m(_RayWrappedEnvironment pid=2067109)[0m   logger.warn(
[2m[36m(_RayWrappedEnvironment pid=2067338)[0m   logger.warn(
[2m[36m(_RayWrappedEnvironment pid=2067071)[0m   logger.warn(
[2m[36m(_RayWrappedEnvironment pid=2067993)[0m   logger.warn(
[2m[36m(_RayWrappedEnvironment pid=2067899)[0m   logger.warn(
[2m[36m(_RayWrappedEnvi

In [None]:
policy_trainer.fmc.clone_receives

In [None]:
policy_trainer.fmc.tree.render()

In [None]:
str(policy_trainer.fmc.tree.g)