In [None]:
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from matplotlib.offsetbox import OffsetImage, AnnotationBbox
from matplotlib.colors import LinearSegmentedColormap
from mpl_toolkits.axes_grid1 import make_axes_locatable

from tqdm import tqdm
from typing import Any, Sequence, Union, List, Dict

import numpy as np
from torch import nn
import torch as tr

import copy

from grid_env import GridEnv, GridSize, GridObservation
from envs import FourRooms
from grid_templates import GridTemplate
from agents import GeneralQ, LambdaQ
from runners import run_nn_experiment_episodic
from utils import (
    save_results,
    load_results,
    gif_from_frames,
    set_seed_everywhere,
    create_neuronav_gif,
    plot_neuronav_frame,
    errorfill,
    plot_values,
    plot_action_values,)

# display options
np.set_printoptions(precision=4, suppress=1)

colors = plt.rcParams['axes.prop_cycle'].by_key()['color']
# Define the color segments for the colormap
segments = [(i/(len(colors)-1), colors[i]) for i in range(len(colors))]
# Create a LinearSegmentedColormap from the color segments
cmap = LinearSegmentedColormap.from_list(name='my_colormap', colors=segments)

%load_ext autoreload
%autoreload 2

In [None]:
agent_lambda_ = 0.5

goals = [12, 20, 67]
goal_rewards = [10, 5, 5]
env_lambda_ = 0.5
discount = 0.99
max_ep_len = 100

env = GridEnv(
    template=GridTemplate.two_rooms,
    size=GridSize.small,
    use_noop=True,
    lambda_=env_lambda_,
    obs_type=GridObservation.index,
    im_size=128)

goal_xys = [env.idx_to_state_coords(g) for g in goals]
objects = {"rewards": dict(zip(goal_xys, goal_rewards))}

agent = LambdaQ(
  env.state_size,
  env.action_space.n,
  env.reset(objects=objects),
  method='q',
  double=False,
  step_size=0.1,
  use_ez_greedy=False,
  epsilon=0.2,
  optimistic_init=True,
  decay_explore=None,
  lambda_=agent_lambda_,)


results = run_nn_experiment_episodic(
  env, agent, 500, objects=objects, discount=discount,
  terminate_on_reward=False, random_start=False, use_underlying_pos=False,
  display_eps=10, respect_done=True, max_ep_len=max_ep_len, record=True,
  goals_always_available=True, eval_every=20)


lim = max_ep_len
create_neuronav_gif(results['frames'][:lim], results['rewards_remaining_hist'][:lim],
                  f"{agent_lambda_}.gif", return_sequence=results['ep_return_hist'][:lim],
                  show_map=False, contains_map=False, num_goals=len(goals),
                  max_reward=np.max(goal_rewards), max_value=200, duration=300,
                  show_values=True, value_sequence=results['value_hist'][:lim])




In [None]:
# actor critic
from utils import DmEnvWrapper, smooth
from runners import jax_run
from lambda_rac import default_lambda_agent

In [None]:
agent_lambda_ = 0.5
print(f"\nTraining with lambda = {agent_lambda_} ==========================")
goals = [12, 20, 67]
goal_rewards = [10, 5, 5]
env_lambda_ = 0.5
discount = 0.99
max_ep_len = 100
num_episodes = 7_500

env = GridEnv(
    template=GridTemplate.two_rooms,
    size=GridSize.small,
    use_noop=True,
    lambda_=env_lambda_,
    obs_type=GridObservation.onehot,
    im_size=128)

goal_xys = [env.idx_to_state_coords(g) for g in goals]
objects = {"rewards": dict(zip(goal_xys, goal_rewards))}

env = DmEnvWrapper(env, objects=objects, terminate_on_reward=False, random_start=False, discount=discount)

# agent = default_agent(env.observation_spec(), env.action_spec(), seed=n)
agent = default_lambda_agent(
    env.observation_spec(),
    env.action_spec(),
    lf_lambda=agent_lambda_,
    lf_wt=0.1,
    hidden_size=128,
    feature_dim=121,
    entropy_cost=0.01)


results = jax_run(agent, env, num_episodes, log_every=5, max_episode_len=max_ep_len)

