In [None]:
!pip install -U kaggle_environments cpprb

In [None]:
import gc
from multiprocessing import set_start_method, cpu_count, Process, Event, SimpleQueue
import time

import numpy as np
import tensorflow as tf
import cpprb # Replay Buffer Library: https://ymd_h.gitlab.io/cpprb/
from tqdm.notebook import tqdm

from kaggle_environments.envs.hungry_geese.hungry_geese import Observation, Configuration, Action, row_col
from kaggle_environments import make

# %load_ext tensorboard
# %tensorboard --logdir logs

In [None]:
# Global config
RIGHT = 0
GO = 1
LEFT = 2

GOOSE = -1
FOOD = 1

act_shape = 3

WIDTH = 11
HEIGHT = 7

xc = WIDTH//2 + 1
yc = HEIGHT//2 + 1

code2dir = {0:'EAST', 1:'NORTH', 2:'WEST', 3:'SOUTH'}

dir2code = {"EAST":0, "NORTH": 1, "WEST":2, "SOUTH": 3}

In [None]:
def create_model():
    model = tf.keras.Sequential([tf.keras.layers.Dense(100,activation="relu",input_shape=(WIDTH*HEIGHT,)),
                                 tf.keras.layers.Dense(100,activation="relu"),
                                 tf.keras.layers.Dense(100,activation="relu"),
                                 tf.keras.layers.Dense(act_shape)])
    return model

In [None]:
def Q_func(model,obs,act):
    return tf.reduce_sum(model(obs) * tf.one_hot(act,depth=3), axis=1)

def Q1_func(model,next_obs,rew,done):
    gamma = 0.99
    return gamma*tf.reduce_max(model(next_obs),axis=1)*(1.0-done) + rew

#@tf.function
def train_then_absTD(model,target,obs,act,rew,next_obs,done,weights):
    with tf.GradientTape() as tape:
        tape.watch(model.trainable_weights)
        Q = Q_func(model,obs,act)
        yQ1_r = Q1_func(target,next_obs,rew,done)
        TD_square = tf.square(Q - yQ1_r)
        weighted_loss = tf.reduce_mean(TD_square * weights)

    grad = tape.gradient(weighted_loss,model.trainable_weights)
    opt.apply_gradients(zip(grad,model.trainable_weights))

    Qnew = Q_func(model,obs,act)
    return tf.abs(Qnew - yQ1_r)

#@tf.function
def abs_TD(model,target,obs,act,rew,next_obs,done):
    Q = Q_func(model,obs,act)
    yQ1_r = Q1_func(target,next_obs,rew,done)
    return tf.abs(Q - yQ1_r)   

In [None]:
def pos(index):
    return index%WIDTH, index//WIDTH

def centering(z,dz,Z):
    z += dz
    if z < 0:
        z += Z
    elif Z >= Z:
        z -= Z
    return z
    

def encode_board(obs,act="NORTH",idx=0):
    """
    Player goose is always set at the center
    """
    board = np.zeros((WIDTH,HEIGHT))

    if len(obs["geese"][idx]) == 0:
        return board
        
    x0, y0 = pos(obs["geese"][idx][0])
    dx = xc - x0
    dy = yc - y0
    
    for goose in obs["geese"]:
        for g in goose:
            x, y = pos(g)
            
            x = centering(x,dx,WIDTH)
            y = centering(y,dy,HEIGHT)
                
            board[x,y] = GOOSE
            
    for food in obs["food"]:
        x, y = pos(food)
        
        x = centering(x,dx,WIDTH)
        y = centering(y,dy,HEIGHT)
        
        board[x,y] = FOOD
        
    board[xc,yc] = dir2code[act]
    
    return board

In [None]:
def get_obs_action(model,states,idx=0, train=False):
    eps = 0.3

    act = states[idx]["action"]

    if states[idx]["status"] != "ACTIVE":
        return None, act
    
    board = encode_board(states[0]["observation"],act=act,idx=idx)
    
    # e-greedy
    if train and np.random.random() < eps:
        turn = np.random.randint(3)
    else:
        turn = int(tf.math.argmax(tf.squeeze(model(board.reshape(1,-1))))) - 1

    new_act = dir2code[act] + turn
    if new_act < 0:
        new_act += 4
    elif new_act >= 4:
        new_act -= 4
    return board, code2dir[new_act]

In [None]:
def create_buffer(buffer_size,env_dict,alpha):
    return cpprb.MPPrioritizedReplayBuffer(buffer_size,env_dict,alpha=alpha)

In [None]:
def explorer(global_rb,env_dict,is_training_done,queue):
    local_buffer_size = int(1e+2)
    local_rb = cpprb.ReplayBuffer(local_buffer_size+4,env_dict)

    model = create_model()
    target = tf.keras.models.clone_model(model)
    env = make("hungry_geese", debug=False)
    
    states = env.reset(4)
    while not is_training_done.is_set():
        if not queue.empty():
            w,wt = queue.get()
            model.set_weights(w)
            target.set_weights(wt)

        board_act = [get_obs_action(model,states,i,train=True) for i in range(4)]

        states = env.step([a for b,a in board_act])

        for i, (b, a) in enumerate(board_act):
            if b is None:
                continue

            local_rb.add(obs=b.ravel(),
                         act=dir2code[a],
                         next_obs=encode_board(states[0]["observation"],act=a,idx=i).ravel(),
                         rew=states[i]["reward"],
                         done=(states[i]["status"] == "ACTIVE"))

        if all(s["status"] != "ACTIVE" for s in states):
            states = env.reset(4)
            local_rb.on_episode_end()

        if local_rb.get_stored_size() >= local_buffer_size:
            sample = local_rb.get_all_transitions()
            global_rb.add(**sample,
                          priorities=abs_TD(model,target,
                                            tf.constant(sample["obs"]),
                                            tf.constant(sample["act"].ravel()),
                                            tf.constant(sample["rew"].ravel()),
                                            tf.constant(sample["next_obs"]),
                                            tf.constant(sample["done"].ravel())))
            local_rb.clear()            

In [None]:
%%time

# Training
n_warming = 100
n_train_step = int(1e+5)
batch_size = 64

writer = tf.summary.create_file_writer("./logs")

# Replay Buffer 
buffer_size = 10e+5
env_dict = {"obs": {"shape": (WIDTH*HEIGHT)},
            "act": {"dtype": int},
            "next_obs": {"shape": (WIDTH*HEIGHT)},
            "rew": {},
            "done": {}}
alpha = 0.5
rb = create_buffer(buffer_size, env_dict,alpha)

# Model
target_update = 50


model = create_model()
target = tf.keras.models.clone_model(model)

opt = tf.keras.optimizers.Adam()

# Ape-X
explorer_update_freq = 100
n_explorer = cpu_count() - 1


is_training_done = Event()
is_training_done.clear()

qs = [SimpleQueue() for _ in range(n_explorer)]
ps = [Process(target=explorer,
              args=[rb,env_dict,is_training_done,q])
      for q in qs]

for p in ps:
    p.start()

print("warm-up")
while rb.get_stored_size() < n_warming:
    time.sleep(1)


print("training")
    
epoch = 0
for i in tqdm(range(n_train_step)):        
    sample = rb.sample(batch_size,beta=0.4)
    
    absTD = train_then_absTD(model,target,
                             tf.constant(sample["obs"]),
                             tf.constant(sample["act"].ravel()),
                             tf.constant(sample["rew"].ravel()),
                             tf.constant(sample["next_obs"]),
                             tf.constant(sample["done"].ravel()),
                             tf.constant(sample["weights"].ravel()))
    rb.update_priorities(sample["indexes"],absTD)
        
    if i % target_update == 0:
        target.set_weights(model.get_weights())
        
    if i % explorer_update_freq == 0:
        w = model.get_weights()
        wt = target.get_weights()
        for q in qs:
            q.put((w,wt))

    
is_training_done.set()

model.save("model")

for p in ps:
    p.join()

In [None]:
test_env = make("hungry_geese", debug=True)

states = test_env.reset(4)

while any(s["status"] == "ACTIVE" for s in states):
    board_act = [get_obs_action(model,states,i) for i in range(4)]
    states = test_env.step([a for b,a in board_act])

test_env.render(mode='ipython')