In [None]:
from fractal_zero.trainers.muzero_discriminator import FractalMuZeroDiscriminatorTrainer, FMZGModel
from fractal_zero.models.joint_model import JointModel

from fractal_zero.data.expert_dataset import ExpertDatasetGenerator
from fractal_zero.vectorized_environment import load_environment


import torch

In [None]:
env = load_environment("CartPole-v0")

In [None]:
# get_expert_action = lambda x: env.action_space.sample()  # random policy

_expert_policy_model = torch.load("models/best_cartpole_policy.pth")
def get_expert_action(x):
    embedded_actions = _expert_policy_model.forward(x)
    return _expert_policy_model.parse_actions(embedded_actions)

# play with the expert for a bit
# obs = env.reset()
# total_reward = 0
# for _ in range(200):
#     action = get_expert_action(obs)
#     obs, reward, done, info = env.step(action)
#     total_reward += reward
#     env.render()
#     if done:
#         break

# env.close()
# print(f"total_reward={total_reward}")

In [None]:
rep_model = torch.nn.Sequential(
    torch.nn.Linear(4, 4),
    torch.nn.ReLU(),
    torch.nn.Linear(4, 4),
    torch.nn.ReLU(),
)
# rep_model = torch.nn.Identity()

dyn_model = torch.nn.Sequential(
    torch.nn.Linear(5, 16),  # obs space + action embedding = 4 + 1 = 5
    torch.nn.ReLU(),
    torch.nn.Linear(16, 4),
    torch.nn.ReLU(),
)

# the disc model receives the output embedding from the representation model
disc_model = torch.nn.Sequential(
    torch.nn.Linear(5, 1),
    torch.nn.Sigmoid(),  # between 0 and 1 confidence values.
)

action_vec = lambda x: torch.tensor(x).flatten().int().item()

model = FMZGModel(
    rep_model, 
    dyn_model, 
    disc_model, 
    num_walkers=256, 
    action_vectorizer=action_vec,
)

expert_dataset = ExpertDatasetGenerator(
    get_expert_action,
    env,
    action_vectorizer=action_vec,
)

discriminator_optimizer = torch.optim.Adam([
    *rep_model.parameters(), 
    *dyn_model.parameters(),
    *disc_model.parameters(),
], lr=0.1)

trainer = FractalMuZeroDiscriminatorTrainer(env, model, expert_dataset, discriminator_optimizer)

In [None]:
max_steps = 64
for _ in range(100):
    print(trainer.train_step(max_steps))