In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from gridworld import GridworldMdp
from agents import OptimalAgent, MyopicAgent
from fast_agents import FastMyopicAgent, FastOptimalAgent
from mdp_interface import Mdp
from agent_runner import get_reward_from_trajectory, run_agent
import numpy as np
from maia_chess_backend.maia.tfprocess import get_tfp
import tensorflow as tf
from multiprocessing import Pool
import tqdm

In [3]:
np.set_printoptions(precision=5, linewidth=200)

In [4]:
def gen_gridworld_arr(gridworld):
    arr = np.zeros((3,width,width), dtype=np.int8)
    arr[0] = np.array(gridworld.walls)
    
    for (x,y) in gridworld.rewards:
        arr[1,x,y] = gridworld.rewards[(x,y)]
        
    (x,y) = gridworld.get_start_state()
    arr[2,x,y] = 1
    
    return arr

In [5]:
def gen_random_connected(height, width, num_rewards):
    for _ in range(5):
        try:
            return GridworldMdp.generate_random_connected(height=height,width=width,num_rewards=num_rewards,noise=0)
        except:
            pass
    raise ValueError('Could not generate Gridworld')

In [6]:
width=6
num_rewards=4

gamma = 0.9

num_start_states = 1
episode_length = 5

def gen_data(num_grids):
    agent = FastMyopicAgent(horizon=2)
    optimal_agent = FastMyopicAgent(horizon=episode_length)
    data = np.zeros((num_grids,4,width,width))

    for i in range(num_grids):
        gridworld = gen_random_connected(width, width, num_rewards)
        mdp = Mdp(gridworld)

        start_state = gridworld.get_random_start_state()
        mdp.gridworld.start_state = start_state

        agent.set_mdp(gridworld)
        optimal_agent.set_mdp(gridworld)

        agent_action = agent.get_action(start_state)
        optimal_action = optimal_agent.get_action(start_state)

        r1,r2 = 0.0,0.0

        if agent_action != optimal_action:
            agent_trajectory = run_agent(agent,mdp,episode_length=episode_length)
            r1 = get_reward_from_trajectory(agent_trajectory)
            intervened_trajectory = run_agent(agent,mdp,episode_length=episode_length, first_optimal=optimal_agent)
            r2 = get_reward_from_trajectory(intervened_trajectory)
        data[i,:3] = gen_gridworld_arr(gridworld)
        data[i,3] = r2 - r1
    return data

In [15]:
def get_train_data(width, num_rewards, episode_length, horizon, cost, test=False):
    data_path = f'/scratch1/fs1/chien-ju.ho/RIS/518/scripts/{width}_{num_rewards}_{episode_length}_{horizon}.npz'
    all_data = np.load(data_path)
    x, y = all_data['x'], all_data['y']
    y = (y>cost).astype(int)
    pos, neg = (y==1), (y==0)
    train_n, eval_n = 80000, 20000
    n = train_n + eval_n
    xpos, xneg = x[pos][:n], x[neg][:n]
    ypos, yneg = y[pos][:n], y[neg][:n]
    
    if test:
        xeval = np.concatenate([xpos[-eval_n:],xneg[-eval_n:]])
        yeval = np.concatenate([ypos[-eval_n:],yneg[-eval_n:]])
        return xeval, yeval
        
    xtrain = np.concatenate([xpos[:train_n],xneg[:train_n]])
    ytrain = np.concatenate([ypos[:train_n],yneg[:train_n]])
    return xtrain, ytrain

In [17]:
%%time
for width in [6]:
    for num_rewards in [4]:
        for episode_length in [5,6,7,8]:
            for cost in [0,5,10]:
                for train_horizon in [1, 2, 3]:
                    xtrain, ytrain = get_train_data(width, num_rewards, episode_length, train_horizon, cost, test=False)
                    tfp = get_tfp(filters=64, blocks=6, regularizer=False, input_size=3, board_size=width, output_size=1)
                    optimizer = tfp.optimizer
                    loss = tf.keras.losses.BinaryCrossentropy(from_logits=False)
                    metrics = ['accuracy',tf.keras.metrics.AUC()]
                    tfp.model.compile(optimizer, loss, metrics)

                    tfp.model.fit(xtrain,ytrain, verbose=0)
                    for test_horizon in [1, 2, 3]:
                        xeval, yeval = get_train_data(width, num_rewards, episode_length, test_horizon, cost, test=True)
                        res = tfp.model.evaluate(xeval,yeval, verbose=0)
                        print(f'{episode_length}\t{cost}\t{train_horizon}\t{test_horizon}\t{res[1]}')
                print()

KeyboardInterrupt: 