In [1]:
import gym
import numpy as np
from tqdm import tqdm
import torch
import wandb

from fractal_zero.config import FMCConfig
from fractal_zero.search.fmc import FMC
from fractal_zero.models.prediction import FullyConnectedPredictionModel
from fractal_zero.vectorized_environment import (
    RayVectorizedEnvironment,
    VectorizedDynamicsModelEnvironment,

)
from fractal_zero.trainers.online import OnlineFMCPolicyTrainer

from fractal_zero.tests.test_vectorized_environment import build_test_joint_model



In [2]:
NUM_WALKERS = 64

class CartpolePolicy(torch.nn.Module):
    def __init__(self):
        super().__init__()

        self.net = torch.nn.Sequential(
            torch.nn.Linear(4, 16),
            torch.nn.ReLU(),
            torch.nn.Linear(16, 32),
            torch.nn.ReLU(),
            torch.nn.Linear(32, 16),
            torch.nn.ReLU(),
            torch.nn.Linear(16, 2),
        )

    def forward(self, observations, with_randomness: bool = False, argmax: bool = True):
        observations = torch.tensor(observations).float()

        y = self.net(observations)

        if with_randomness:
            # center = embeddings.std()
            center = y.var()
            centered_uniform_noise = (torch.rand_like(y) * center) - (center / 2)
            y += centered_uniform_noise

        # TODO: refac
        if argmax:
            return torch.argmax(y, dim=-1)
        return y

    def parse_actions(self, actions):
        return actions.tolist()

policy_model = CartpolePolicy()

In [3]:
# optimizer = torch.optim.SGD(policy_model.parameters(), lr=0.01, weight_decay=1e-4)
optimizer = torch.optim.Adam(policy_model.parameters(), lr=0.01, weight_decay=1e-4)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.5)
policy_trainer = OnlineFMCPolicyTrainer("CartPole-v0", policy_model, optimizer, NUM_WALKERS)

  logger.warn(
  deprecation(
  deprecation(
2022-09-27 02:21:30,890	INFO worker.py:1518 -- Started a local Ray instance.


In [4]:
wandb.init(project="fz-policy-trainer-game-tree")

  return LooseVersion(v) >= LooseVersion(check)
[34m[1mwandb[0m: Currently logged in as: [33mdyllan[0m. Use [1m`wandb login --relogin`[0m to force relogin
[2m[36m(_RayWrappedEnvironment pid=2077806)[0m   logger.warn(
[2m[36m(_RayWrappedEnvironment pid=2077806)[0m   deprecation(
[2m[36m(_RayWrappedEnvironment pid=2077806)[0m   deprecation(
  from IPython.core.display import HTML, display  # type: ignore


In [5]:
num_episodes = 100
train_steps_per_episode = 2
eval_every = 20
max_steps = 200

for _ in range(num_episodes):
    policy_trainer.generate_episode_data(max_steps)

    for i in range(train_steps_per_episode):
        policy_trainer.train_on_latest_episode()

        if i % eval_every == 0:
            policy_trainer.evaluate_policy(max_steps)

    lr_scheduler.step()

[2m[36m(_RayWrappedEnvironment pid=2078985)[0m   logger.warn(
[2m[36m(_RayWrappedEnvironment pid=2078985)[0m   deprecation(
[2m[36m(_RayWrappedEnvironment pid=2078985)[0m   deprecation(
[2m[36m(_RayWrappedEnvironment pid=2078015)[0m   logger.warn(
[2m[36m(_RayWrappedEnvironment pid=2078015)[0m   deprecation(
[2m[36m(_RayWrappedEnvironment pid=2078015)[0m   deprecation(
[2m[36m(_RayWrappedEnvironment pid=2078019)[0m   logger.warn(
[2m[36m(_RayWrappedEnvironment pid=2078019)[0m   deprecation(
[2m[36m(_RayWrappedEnvironment pid=2078019)[0m   deprecation(
[2m[36m(_RayWrappedEnvironment pid=2077793)[0m   logger.warn(
[2m[36m(_RayWrappedEnvironment pid=2077793)[0m   deprecation(
[2m[36m(_RayWrappedEnvironment pid=2077793)[0m   deprecation(
[2m[36m(_RayWrappedEnvironment pid=2078014)[0m   logger.warn(
[2m[36m(_RayWrappedEnvironment pid=2078610)[0m   logger.warn(
[2m[36m(_RayWrappedEnvironment pid=2078610)[0m   deprecation(
[2m[36m(_RayWrappedEnvi

In [None]:
policy_trainer.fmc.clone_receives

In [None]:
policy_trainer.fmc.tree.render()

In [None]:
str(policy_trainer.fmc.tree.g)