In [None]:
!pip install gym


In [None]:

import tensorflow as tf
import random
import numpy as np
from statistics import mean 


In [None]:

from mcts import Node, run_mcts
from gamecomponents import Policy
from game import Game
from replaybuffer import ReplayBuffer
from network import Network, SharedStorage
from helpers import KnownBounds
from muzero_core import play_game, train_network, MuZeroConfig, make_aigym_config

In [None]:
c = make_aigym_config('CartPole-v0')
s = SharedStorage()
r = ReplayBuffer()

# tweaks to the configuration
c.num_simulations = 60
c.training_steps = 150
c.batch_size = 32
c.td_steps = 25
c.discount = 1. 
c.lr_init = 0.005
c.lr_decay_steps = 100e3
c.max_moves = 2000
c.num_unroll_steps = 4

c.epsilon = 0.001

# a little bit of help - starting point for value fn bounds 
c.known_bounds = KnownBounds(max = 20, min = -20)

# Set up the network parameters
Network.N = 4

# tweak the MCTS parameter
Node.root_exploration_fraction = 0.30


In [None]:
# Generate some game runs to get us started
for _ in range(30):
  n = s.latest_network()
  g1 = play_game(c, n)
  r.save_game(g1)
  print(sum(g1.rewards))

In [None]:


for iter in range(400):

  train_network(c,s,r)

  n = s.latest_network()

  # play some games
  # if the model has improved significantly then repeat for 
  # up to 25 times to add new information to the replay buffer 
  tot = 0
  N = 0
  better = True
  while better:
    g1 = play_game(c, n)
    print(g1.length())
    tot += (g1.length())
    N += 1

    better = tot/N > max(r.game_len)*0.5 and N < 25
    r.save_game(g1)
    
  print((iter, tot/N))
  print('----')




