In [None]:
import gym
import numpy as np
from tqdm import tqdm
import torch
import wandb

from fractal_zero.config import FMCConfig
from fractal_zero.search.fmc import FMC
from fractal_zero.models.prediction import FullyConnectedPredictionModel
from fractal_zero.vectorized_environment import (
    RayVectorizedEnvironment,
    VectorizedDynamicsModelEnvironment,

)
from fractal_zero.trainers.online import OnlineFMCPolicyTrainer

from fractal_zero.tests.test_vectorized_environment import build_test_joint_model

In [None]:
NUM_WALKERS = 64

class CartpolePolicy(torch.nn.Module):
    def __init__(self):
        super().__init__()

        self.net = torch.nn.Sequential(
            torch.nn.Linear(4, 16),
            torch.nn.ReLU(),
            torch.nn.Linear(16, 16),
            torch.nn.ReLU(),
            torch.nn.Linear(16, 2),
        )

    def forward(self, observations, with_randomness: bool = False, argmax: bool = True):
        observations = torch.tensor(observations).float()

        y = self.net(observations)

        if with_randomness:
            # center = embeddings.std()
            center = y.var()
            centered_uniform_noise = (torch.rand_like(y) * center) - (center / 2)
            y += centered_uniform_noise

        # TODO: refac
        if argmax:
            return torch.argmax(y, dim=-1)
        return y

    def parse_actions(self, actions):
        return actions.tolist()

policy_model = CartpolePolicy()

In [None]:
# optimizer = torch.optim.SGD(policy_model.parameters(), lr=0.01, weight_decay=1e-4)
optimizer = torch.optim.Adam(policy_model.parameters(), lr=0.01, weight_decay=1e-4)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.5)
policy_trainer = OnlineFMCPolicyTrainer("CartPole-v0", policy_model, optimizer, NUM_WALKERS)

In [None]:
wandb.init(project="fz-policy-trainer-epsilon-greedy")

In [None]:
num_episodes = 20
train_steps_per_episode = 1000
eval_every = 20
max_steps = 32

for _ in range(num_episodes):
    policy_trainer.generate_episode_data(max_steps)

    for i in range(train_steps_per_episode):
        policy_trainer.train_on_latest_episode()

        if i % eval_every == 0:
            policy_trainer.evaluate_policy(max_steps)

    lr_scheduler.step()

In [None]:
policy_trainer.fmc.clone_receives