In [None]:
import gym
import numpy as np
from tqdm import tqdm
import torch
import wandb
from copy import deepcopy

from fractal_zero.config import FMCConfig
from fractal_zero.search.fmc import FMC
from fractal_zero.models.prediction import FullyConnectedPredictionModel
from fractal_zero.models.policies.cartpole_policy import CartpolePolicy
from fractal_zero.vectorized_environment import (
    RayVectorizedEnvironment,
    VectorizedDynamicsModelEnvironment,

)
from fractal_zero.trainers.online import OfflineFMCPolicyTrainer


In [None]:
NUM_WALKERS = 256

class CartpolePolicy(torch.nn.Module):
    def __init__(self):
        super().__init__()

        self.net = torch.nn.Sequential(
            torch.nn.Linear(4, 16),
            torch.nn.ReLU(),
            torch.nn.Linear(16, 2),
            # torch.nn.Sigmoid(),  # keep MSE from exploding
        )

    def forward(self, observations, with_randomness: bool = False):
        observations = torch.tensor(observations).float()

        y = self.net(observations)

        if with_randomness:
            # center = embeddings.std()
            # center = y.var()
            # centered_uniform_noise = (torch.rand_like(y) * center) - (center / 2)
            # y += centered_uniform_noise
            raise NotImplementedError

        return y

    def parse_action(self, actions):
        # actions = torch.where(actions > 0.5, 1, 0).flatten()
        actions = actions.argmax(-1)
        l = actions.tolist()
        return l

policy_model = CartpolePolicy()

In [None]:
# optimizer = torch.optim.SGD(policy_model.parameters(), lr=0.01, weight_decay=1e-4)
optimizer = torch.optim.Adam(policy_model.parameters(), lr=0.01)
# lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.5)

loss_func = torch.nn.functional.cross_entropy
policy_trainer = OfflineFMCPolicyTrainer("CartPole-v0", policy_model, optimizer, NUM_WALKERS, loss_spec=loss_func, use_ray=False)

In [None]:
wandb.init(project="fz-policy-trainer-game-tree")

In [None]:
num_episodes = 100
train_steps_per_episode = 5
eval_every = 20
max_steps = 200

best_total_rewards = float("-inf")
best_model = None

for _ in range(num_episodes):
    policy_trainer.generate_episode_data(max_steps)
    print("best path reward", policy_trainer.fmc.tree.best_path.total_reward)

    for i in range(train_steps_per_episode):
        policy_trainer.train_on_latest_episode()

        if i % eval_every == 0:
            total_rewards = policy_trainer.evaluate_policy(max_steps)

            if total_rewards > best_total_rewards:
                best_total_rewards = total_rewards
                best_model = deepcopy(policy_model)
                
                # torch.save(best_model, "models/best_cartpole_policy.pth")

    # lr_scheduler.step()

In [None]:
policy_trainer.fmc.clone_receives