In [1]:
import argparse
import errno
import os
import random
from importlib.metadata import requires
from timeit import timeit
import dill as pickle
import numpy as np
import scipy
import torch
import wandb
import yaml
from sympy import Matrix, MatrixSymbol, derive_by_array, symarray
from torch.distributions import Categorical

from subrl.utils.environment import GridWorld
from subrl.utils.network import append_state
from subrl.utils.network import policy as agent_net
from visualization import Visu

# TODO: 1. remove dependence from matrix and could run multiple times in parallel, .sh script, run it on the server, check how to plot multiple on wb,
# can it plot different kappa's on the same with grouping
# apply a policy gradient algorithm, since the policy is deterministic, use policy iteration/value iteration since we know the dynamics
#
workspace = "subrl"

# 1) Load the config file
with open(workspace + "/params/" + "two_rooms/subrl_NM" + ".yaml") as file:
    params = yaml.load(file, Loader=yaml.FullLoader)
print(params)

# 2) Set the path and copy params from file
env_load_path = workspace + \
    "/environments/" + params["env"]["node_weight"]+ "/env_" + \
    str(1)

params['env']['num'] = 1 

epochs = params["alg"]["epochs"]

H = params["env"]["horizon"]
MAX_Ret = 2*(H+1)
if params["env"]["disc_size"] == "large":
    MAX_Ret = 3*(H+2)

# 3) Setup the environement
env = GridWorld(
    env_params=params["env"], common_params=params["common"], visu_params=params["visu"], env_file_path=env_load_path)
node_size = params["env"]["shape"]['x']*params["env"]["shape"]['y']

if params["env"]["node_weight"] == "entropy" or params["env"]["node_weight"] == "steiner_covering" or params["env"]["node_weight"] == "GP": 
    a_file = open(env_load_path +".pkl", "rb")
    data = pickle.load(a_file)
    a_file.close()

if params["env"]["node_weight"] == "entropy":
    env.cov = data
if params["env"]["node_weight"] == "steiner_covering":
    env.items_loc = data
if params["env"]["node_weight"] == "GP":
    env.weight = data

visu = Visu(env_params=params["env"])

env.get_horizon_transition_matrix()


  from .autonotebook import tqdm as notebook_tqdm


{'env': {'start': 1, 'step_size': 0.1, 'shape': {'x': 7, 'y': 14}, 'horizon': 40, 'node_weight': 'constant', 'disc_size': 'small', 'n_players': 3, 'Cx_lengthscale': 2, 'Cx_noise': 0.001, 'Fx_lengthscale': 1, 'Fx_noise': 0.001, 'Cx_beta': 1.5, 'Fx_beta': 1.5, 'generate': False, 'env_file_name': 'env_data.pkl', 'cov_module': 'Matern', 'stochasticity': 0.0, 'domains': 'two_room'}, 'alg': {'gamma': 1, 'type': 'NM', 'ent_coef': 0.0, 'epochs': 140, 'lr': 0.02}, 'common': {'a': 1, 'subgrad': 'greedy', 'grad': 'pytorch', 'algo': 'both', 'init': 'deterministic', 'batch_size': 3000}, 'visu': {'wb': 'disabled', 'a': 1}}
x_ticks [-0.5001, -0.4999, 0.4999, 0.5001, 1.4999, 1.5001, 2.4999, 2.5001, 3.4999, 3.5001, 4.4999, 4.5001, 5.4999, 5.5001, 6.4999, 6.5001, 7.4999, 7.5001, 8.4999, 8.5001, 9.4999, 9.5001, 10.4999, 10.5001, 11.4999, 11.5001, 12.4999, 12.5001, 13.4999, 13.5001]
y_ticks [-0.5001, -0.4999, 0.4999, 0.5001, 1.4999, 1.5001, 2.4999, 2.5001, 3.4999, 3.5001, 4.4999, 4.5001, 5.4999, 5.5001, 6

In [2]:
def train():
    # Agent's policy
    if params["alg"]["type"]=="M" or params["alg"]["type"]=="SRL":
        agent = agent_net(2, env.action_dim)
    else:
        agent = agent_net(H-1, env.action_dim)
    optim = torch.optim.Adam(agent.parameters(), lr=params["alg"]["lr"])

    for t_eps in range(epochs):
        mat_action = []
        mat_state = []
        mat_return = []
        marginal_return = []
        mat_done = []
        # print(t_eps)
        env.initialize()
        mat_state.append(env.state)
        init_state = env.state
        list_batch_state = []
        for h_iter in range(H-1):
            if params["alg"]["type"]=="M" or params["alg"]["type"]=="SRL":
                batch_state = mat_state[-1].reshape(-1, 1).float()
                # append time index to the state
                batch_state = torch.cat(
                    [batch_state, h_iter*torch.ones_like(batch_state)], 1)
            else:
                batch_state = append_state(mat_state, H-1)
            action_prob = agent(batch_state)
            policy_dist = Categorical(action_prob)
            actions = policy_dist.sample()
            mat_action.append(actions)
            env.step(h_iter, actions)
            mat_state.append(env.state)  # s+1
            mat_return.append(env.weighted_traj_return(mat_state, type = params["alg"]["type"]))
            if h_iter ==0:
                marginal_return.append(mat_return[h_iter])
            else:
                marginal_return.append(mat_return[h_iter] - mat_return[h_iter-1])
            list_batch_state.append(batch_state)

        ###################
        # Compute gradients
        ###################

        states_visited = torch.vstack(list_batch_state).float()
    
        policy_dist = Categorical(agent(states_visited))
        log_prob = policy_dist.log_prob(torch.hstack(mat_action))
        batch_return = torch.hstack(marginal_return)/MAX_Ret

        # - 2*policy_dist.entropy().mean()
        J_obj = -1*(torch.mean(log_prob*batch_return) + params["alg"]["ent_coef"] *
                    policy_dist.entropy().mean()/(t_eps+1))
        optim.zero_grad()
        J_obj.backward()
        optim.step()

        obj = env.weighted_traj_return(mat_state).float()
        print(visu.JPi_optimal, " mean ", obj.mean(), " max ",
            obj.max(), " median ", obj.median(), " min ", obj.min(), " ent ", policy_dist.entropy().mean().detach())
    return agent

In [3]:
min_return = []
max_return = []
mean_return = []
median_return = []
for iter in range(10):
    # params["common"]["batch_size"]=1000
    agent = train()
    mat_state = []
    mat_return = []
    env.initialize()
    mat_state.append(env.state)
    init_state = env.state
    for h_iter in range(H-1):
        batch_state = append_state(mat_state, H-1)
        action_prob = agent(batch_state)
        policy_dist = Categorical(action_prob)
        actions = policy_dist.sample()
        env.step(h_iter, actions)
        mat_state.append(env.state)  # s+1

    returns = env.weighted_traj_return(mat_state, type = params["alg"]["type"]).float()
    min_return.append(returns.min())
    max_return.append(returns.max())
    mean_return.append(returns.mean())
    median_return.append(returns.median())
mean_min_return = np.mean(min_return)
std_min_return = np.std(min_return)
mean_max_return = np.mean(max_return)
std_max_return = np.std(max_return)
mean_mean_return = np.mean(mean_return)
std_mean_return = np.std(mean_return)
mean_median_return = np.mean(median_return)
std_median_return = np.std(median_return)
print(f"min: {mean_min_return:.2f}±{std_min_return:.2f}, max: {mean_max_return:.2f}±{std_max_return:.2f}, mean: {mean_mean_return:.2f}±{std_mean_return:.2f}, median: {mean_median_return:.2f}±{std_median_return:.2f}")

None  mean  tensor(18.8360)  max  tensor(33.)  median  tensor(20.)  min  tensor(6.)  ent  tensor(1.5493)
None  mean  tensor(22.3720)  max  tensor(34.)  median  tensor(22.)  min  tensor(14.)  ent  tensor(1.2196)
None  mean  tensor(24.1693)  max  tensor(35.)  median  tensor(24.)  min  tensor(11.)  ent  tensor(1.2411)
None  mean  tensor(25.8920)  max  tensor(36.)  median  tensor(26.)  min  tensor(8.)  ent  tensor(1.2206)
None  mean  tensor(27.7177)  max  tensor(50.)  median  tensor(27.)  min  tensor(14.)  ent  tensor(1.1722)
None  mean  tensor(31.2220)  max  tensor(49.)  median  tensor(30.)  min  tensor(14.)  ent  tensor(1.0043)
None  mean  tensor(32.5220)  max  tensor(48.)  median  tensor(34.)  min  tensor(14.)  ent  tensor(0.8255)
None  mean  tensor(33.1043)  max  tensor(48.)  median  tensor(35.)  min  tensor(14.)  ent  tensor(0.7269)
None  mean  tensor(34.6487)  max  tensor(49.)  median  tensor(36.)  min  tensor(14.)  ent  tensor(0.7169)
None  mean  tensor(36.8723)  max  tensor(50.)  m