In [None]:
import gym
import numpy as np
from tqdm import tqdm
import torch

from fractal_zero.config import FMCConfig
from fractal_zero.search.fmc import FMC
from fractal_zero.models.prediction import FullyConnectedPredictionModel
from fractal_zero.vectorized_environment import (
    RayVectorizedEnvironment,
    VectorizedDynamicsModelEnvironment,

)
from fractal_zero.trainers.online import OnlineFMCPolicyTrainer

from fractal_zero.tests.test_vectorized_environment import build_test_joint_model

In [None]:
NUM_WALKERS = 2

class CartpolePolicy(torch.nn.Module):
    def __init__(self):
        super().__init__()

        self.embedding = torch.nn.Sequential(
            torch.nn.Linear(4, 4),
        )
        self.action_head = torch.nn.Sequential(
            torch.nn.ReLU(),
            torch.nn.Linear(4, 2),
        )

    def forward(self, observations, with_randomness: bool = False):
        observations = torch.tensor(observations).float()

        embeddings = self.embedding(observations)

        if with_randomness:
            x = 0.1
            zero_centered_uniform_noise = (torch.rand_like(embeddings) * x) - (x / 2)
            embeddings += zero_centered_uniform_noise

        return self.action_head(embeddings)

    def parse_actions(self, actions):
        return torch.argmax(actions, dim=-1).tolist()

policy_model = CartpolePolicy()
policy_trainer = OnlineFMCPolicyTrainer("CartPole-v0", policy_model, NUM_WALKERS)

In [None]:
policy_trainer.train_epsiode(8)