In [None]:
from pathlib import Path
import pickle
import numpy as np
from matplotlib import pyplot as plt
import matplotlib.cm as cm
from ucb.value_functions import FiniteHorizonQRegressor
from ucb.models import RBFGP, GPTrainer
import ucb.envs
import gym
from sklearn.metrics import explained_variance_score
from functools import partial
%matplotlib inline

In [None]:
base_path = Path('../experiments')

In [None]:
env_name = "Uniform Beta + Rotation"

- try uniform p_0 (implemented, running on beta tracking)
- run FOVI + TQRL longer to see what happens at the end (going)
- think about whether we can justify theoretically the current acquisition function * the occupancy distribution
- Bias the prior in something like cartpole to make the prior mean low
- try a lower beta (going)
- run with the mean policy (going)

In [None]:
mc_exps = {'GP All data': 'rbf_mountaincar_2022-09-08/17-06-15',
           'CB All data': 'cb_mountaincar_2022-09-08/14-18-29',
           "GP TQRL All data": 'default_2022-09-09/14-06-25',
           'CB TQRL All data': 'cb_mc_tqrl_2022-09-09/14-17-11/',
           'CB TQRL New': 'mc_cb_tqrl_2022-09-12/11-03-51/',
           'GP TQRL New': 'mc_rbf_tqrl_2022-09-12/11-03-09/',
           'CB TQRL LCB': 'mc_cb_tqrl_lcb_eval_2022-09-12/11-21-52/',
           'GP TQRL LCB': 'mc_rbf_tqrl_q_lcb_eval_2022-09-12/15-11-46/',
          }
cp_exps = {"CB All Data": "cb_cartpole_2022-09-09/10-30-58/",
           "RBF All Data": "rbf_cartpole_2022-09-08/17-06-23/",
           "TQRL LCB Visited": "TQRL_LCB_visited_cartpole_2022-09-19/14-06-29/",
            }
dense_cp_exps = {
    # 'GP All data': 'dense_cp_rbf_2022-09-12/15-11-17/',
    # 'GP All data 2': 'dense_cp_rbf_2022-09-13/08-43-50/',
    # 'GP All data TQRL LCB': 'dense_cp_rbf_lcb_tqrl_2022-09-14/09-35-28/',
    # 'GP All data TQRL': 'dense_cp_rbf_tqrl_2022-09-14/23-02-12/',
    'TQRL (seeds)': 'TQRL_dense_cartpole_2022-09-19/02-15-39/',
    'US LCB': 'US_LCB_dense_cartpole_2022-09-19/05-37-27/',
    'TQRL LCB': 'TQRL_LCB_dense_cartpole_2022-09-18/22-45-27/',
    'FOVI': 'FOVI_dense_cartpole_2022-09-18/20-23-07/',
    'Random': 'RANDOM_LCB_dense_cartpole_2022-09-19/07-52-35/',
    'TQRL LCB Visited': 'TQRL_LCB_visited_dense_cartpole_2022-09-19/14-06-20/',
    'Greedy': 'greedy_dense_cartpole_2022-09-25/08-54-32/',
}

uniform_dense_cp_exps = {
    'AE-LSVI': 'AE_LSVI_dense_uniform_cartpole_2022-09-23/14-41-37/',
    'LSVI': 'LSVI_dense_uniform_cartpole_2022-09-26/15-44-56/',
    'AE-LSVI (visited)': 'AE_LSVI_visited_dense_uniform_cartpole_2022-09-23/14-41-26/',
    'Random': 'RANDOM_LCB_dense_uniform_cartpole_2022-09-26/16-01-15/',
    'US': 'US_LCB_dense_uniform_cartpole_2022-09-25/10-37-02/',
    'Greedy': 'greedy_dense_uniform_cartpole_2022-09-25/10-39-49/',
}


beta_tracking_exps = {
    # "GP All Data": 'beta_tracking_rbf_2022-09-13/16-22-03/',
    # 'GP All data TQRL LCB': 'beta_tracking_rbf_lcb_tqrl_2022-09-14/09-42-37/',
    # "GP All Data TQRL": 'beta_tracking_rbf_tqrl_2022-09-14/22-40-34/',
    # 'TQRL LCB visited': 'beta_tracking_rbf_tqrl_lcb_visited_2022-09-18/13-39-42/',
    'AE-LSVI (mean)': 'TQRL_beta_tracking_2022-09-18/22-14-46/',
    'US': 'US_LCB_beta_tracking_2022-09-18/23-21-08/',
    'AE-LSVI': 'TQRL_LCB_beta_tracking_2022-09-18/21-04-07/',
    'LSVI': 'FOVI_beta_tracking_2022-09-18/20-22-36/',
    'Random': 'RANDOM_LCB_beta_tracking_2022-09-19/00-05-14/',
    'TQRL LCB Visited': 'TQRL_LCB_visited_beta_tracking_2022-09-19/14-05-44/',
    'Greedy': 'greedy_beta_tracking_2022-09-20/20-35-26/',
}

uniform_beta_tracking_exps = {
    'AE-LSVI': 'AE_LSVI_uniform_beta_tracking_2022-09-23/14-32-37/',
    'LSVI': "LSVI_uniform_beta_tracking_eval_2022-09-27/14-46-59/",
    'Random': 'RANDOM_LCB_uniform_beta_tracking_2022-09-24/20-20-09/',
    'US': 'US_LCB_uniform_beta_tracking_2022-09-23/14-37-29/',
    'Greedy': 'greedy_uniform_beta_tracking_2022-09-21/07-19-51/',
}

beta_rotation_exps = {
    # 'GP All Data': 'beta_rotation_rbf_2022-09-13/16-34-16/',
    # 'GP All data TQRL LCB': 'beta_rotation_rbf_lcb_tqrl_2022-09-14/09-47-07/',
    # 'GP All Data TQRL': 'beta_rotation_rbf_tqrl_2022-09-14/22-40-14/',
    # 'TQRL LCB visited': 'beta_rotation_rbf_lcb_tqrl_visited_2022-09-17/12-20-38/',
    # 'TQRL (seeds)': 'TQRL_beta_rotation_2022-09-19/00-01-29/',
    # 'US LCB': 'US_LCB_beta_rotation_2022-09-19/02-28-00/',
    'AE-LSVI': 'TQRL_LCB_beta_rotation_2022-09-18/21-23-37/',
    'LSVI': 'FOVI_beta_rotation_2022-09-18/19-42-28/',
    'Random': 'RANDOM_LCB_beta_rotation_2022-09-19/02-28-46/',
    'TQRL LCB Visited': 'TQRL_LCB_visited_beta_rotation_2022-09-19/22-58-26/',
    'US': 'US_LCB_beta_rotation_2022-09-25/06-51-24/',
    'Greedy': 'greedy_beta_rotation_2022-09-25/16-36-30/',
}

uniform_beta_rotation_expts = {
    'AE-LSVI': 'TQRL_LCB_uniform_beta_rotation_2022-09-24/17-56-33/',
    'LSVI': 'FOVI_uniform_beta_rotation_2022-09-23/06-58-33/',
    'Random': 'RANDOM_LCB_uniform_beta_rotation_2022-09-25/23-43-55/',
    'US': 'US_LCB_uniform_beta_rotation_2022-09-25/23-43-45/',
    'Greedy': 'greedy_uniform_beta_rotation_2022-09-23/16-09-10/',
}

uniform_weird_gain_expts = {
    'LSVI': 'LSVI_uniform_weird_gain_2022-09-21/19-19-40/',
    'AE-LSVI Visited': 'AE-LSVI_LCB_visited_uniform_weird_gain_2022-09-21/19-19-30/',
}

navigation_expts = {
    'LSVI': 'FOVI_navigation_2022-09-23/09-44-50/',
    'AE-LSVI': 'TQRL_LCB_navigation_2022-09-23/09-48-49/',

}
    
nav_easy_expts = {
    'LSVI': 'FOVI_nav_easy_2022-09-23/20-03-00/seed_0/',
    'AE-LSVI visited': 'TQRL_LCB_visited_nav_easy_2022-09-23/20-02-42',
    'AE-LSVI': 'TQRL_LCB_nav_easy_2022-09-24/21-53-19/',
    'US': 'US_LCB_nav_easy_2022-09-24/13-16-01/',
    'Greedy': 'GREEDY_LCB_nav_easy_2022-09-25/06-51-09/',
    'Random': 'RANDOM_LCB_nav_easy_2022-09-26/00-18-21/',
}    

uniform_nav_easy_expts = {
    'AE-LSVI': 'TQRL_LCB_uniform_nav_easy_2022-09-27/05-55-00/',
    'LSVI': 'FOVI_uniform_nav_easy_2022-09-25/17-57-38/',
    'Greedy': 'GREEDY_LCB_uniform_nav_easy_2022-09-27/15-31-09/'
}
    
gym_env_names = {"Mountain Car": 'densemountaincar-dt10-v0', 
                 "Cartpole": "cartpoleswingup-v0",
                 "Dense Cartpole": 'cartpoleswingup-dense-v0',
                 "Beta Tracking": 'betatracking-v0',
                 "Beta + Rotation": 'betarotation-v0',
                 'Navigation': "navigation-v0"
                }
horizons = {"Mountain Car": 25, "Cartpole": 25, "Dense Cartpole": 25, "Uniform Dense Cartpole": 25, 
            "Beta Tracking": 15, "Uniform Beta Tracking": 15, "Beta + Rotation": 20, "Uniform Beta + Rotation": 20,
            "Weird Gain": 30,
            "Navigation": 30, "Easy Navigation": 30, "Uniform Easy Navigation": 30}

In [None]:
all_exps = {"Mountain Car": mc_exps,
            "Cartpole": cp_exps,
            "Dense Cartpole": dense_cp_exps,
            "Uniform Dense Cartpole": uniform_dense_cp_exps,
            "Beta Tracking": beta_tracking_exps,
            "Uniform Beta Tracking": uniform_beta_tracking_exps,
            "Beta + Rotation": beta_rotation_exps,
            "Uniform Beta + Rotation": uniform_beta_rotation_expts,
            "Weird Gain": uniform_weird_gain_expts,
            "Navigation": navigation_expts,
            "Easy Navigation": nav_easy_expts,
            "Uniform Easy Navigation": uniform_nav_easy_expts,
           }
exps = all_exps[env_name]

In [None]:
def process_single_expt(path):
    with path.open('rb') as f:
        data = pickle.load(f)
    eval_ndata = data['Eval ndata']
    eval_returns = np.array(data['Eval Returns'])
    mean_returns = np.mean(eval_returns, axis=1)
    stderr_returns = np.std(eval_returns, axis=1) / np.sqrt(eval_returns.shape[1])
    return {"Eval ndata": eval_ndata, "Mean Returns": mean_returns, "Stderr Returns": stderr_returns,
            "Exploration Returns": data['Exploration Returns']}

def process_seeds(path):
    seed_data = []
    nseeds = 5 if (path / f'seed_1').exists() else 1
    for i in range(nseeds):
        seed_path = path / f'seed_{i}' / 'info.pkl'
        seed_data.append(process_single_expt(seed_path))
    min_length = min([len(dat["Mean Returns"]) for dat in seed_data])
    means = np.array([dat['Mean Returns'][:min_length] for dat in seed_data])
    seed_mean = np.mean(means, axis=0)
    stderr = np.std(means, axis=0) / np.sqrt(means.shape[0])
    min_length = min([len(dat["Exploration Returns"]) for dat in seed_data])
    expl_means = np.array([dat['Exploration Returns'][:min_length] for dat in seed_data])
    expl_seed_mean = np.mean(expl_means, axis=0)
    expl_stderr = np.std(expl_means, axis=0) / np.sqrt(expl_means.shape[0])
    return {"Eval ndata": seed_data[0]["Eval ndata"], "Mean Returns": seed_mean, "Stderr Returns": stderr,
            "Exploration Returns": expl_seed_mean, "Expl Stderr Returns": expl_stderr}

def process_expt(path):
    single_expt_path = path / 'info.pkl'
    if single_expt_path.exists():
        return process_single_expt(single_expt_path)
    else:
        return process_seeds(path)

In [None]:
cutoffs = {
    "Mountain Car": 25, 
    "Cartpole": 500,
    "Dense Cartpole": 500,
    "Uniform Dense Cartpole": 500,
    "Beta Tracking": 300,
    "Beta + Rotation": 400,
    "Weird Gain": 30,
    "Navigation": 30,
    "Easy Navigation": 850}
# cutoff = cutoffs[env_name]

In [None]:
expt_data = {}
plot_cutoffs = True
cutoff=1000

In [None]:
plt.rcParams["figure.figsize"] = (10,6)
fig, [ax1, ax2] = plt.subplots(1,2)
for expt_name, expt_path in exps.items():
    this_data = expt_data[expt_name] = process_expt(base_path / expt_path)
    eval_ndata = this_data['Eval ndata']
    mean_returns = this_data['Mean Returns']
    stderr_returns = this_data['Stderr Returns']
    # eval_eps = (np.arange(len(mean_returns)) + 1) * env.horizon
    expl_eps = np.arange(len(this_data["Exploration Returns"])) * horizons[env_name]
    ax1.plot(eval_ndata, mean_returns, label=expt_name)
    ax1.fill_between(eval_ndata, mean_returns - stderr_returns, mean_returns + stderr_returns, alpha=0.2)
    ax2.plot(expl_eps, this_data["Exploration Returns"], label=expt_name)
if plot_cutoffs:
    ax1.axvline(cutoff, color='red')
ax1.set_title("Test Returns")
ax2.set_title("Exploration Returns")
ax1.legend()
fig.suptitle(f"Performance on {env_name}")


In [None]:
for expt_name, data in expt_data.items():
    ndata = np.array(data['Eval ndata'])
    mean_returns = data['Mean Returns']
    stderr_returns = data['Stderr Returns']
    idx = np.argmin(np.abs(ndata - cutoff))
    ret = mean_returns[idx]
    std_ret = stderr_returns[idx]
    print(f"{expt_name}: {ret:.2f} +- {std_ret:.2f}")

In [None]:
eval_returns = np.array(data['Eval Returns'])
mean_returns = np.mean(eval_returns, axis=1)
std_returns = np.std(eval_returns, axis=1)
eval_eps = (np.arange(len(mean_returns)) + 1) * env.horizon
expl_eps = np.arange(len(mean_returns) + 2) * env.horizon

In [None]:
plt.plot(eval_eps, mean_returns, label="Eval Returns")
plt.fill_between(eval_eps, mean_returns - std_returns, mean_returns + std_returns, alpha=0.2)
plt.plot(expl_eps, data["Exploration Returns"], label="Expl Returns")
plt.legend()
plt.xlabel('Number of Datapoints')
plt.ylabel("Returns")


In [None]:
Xtrain = data['Xtrain']
Ytrain = data['Ytrain']

In [None]:
for t, row in enumerate(reversed(Ytrain)):
    x = np.sort(row)
    y = np.arange(len(x)) / float(len(x))
    plt.plot(x, y, label=f"t = {t}")
plt.legend()
plt.ylabel("Cumulative Density")
plt.xlabel("Ytrain Value")
plt.show()

In [None]:
X1 = Xtrain[0, ...]
Y1 = Ytrain[0, ...]
Y1

In [None]:
plt.scatter(X1[:, 0], Y1 * 2)
plt.xlabel("X position")
plt.ylabel("Reward x 2")

In [None]:
all_data = data['all_data']
all_obs = all_data.next_obs
all_obs[:, 0] - (all_data.rewards * 2.2 -  1.2)

In [None]:
obs_by_time = all_obs.reshape((-1, horizon, all_obs.shape[1]))
print(obs_by_time.shape)
for ep, ep_data in enumerate(obs_by_time):
    plt.scatter(ep_data[:, 0], ep_data[:, 1], color=cm.hot(ep / len(obs_by_time)), label=f"Episode {ep}")
# plt.scatter(6, 9, s=100, color="green", label="goal")
plt.legend()

In [None]:
model = MaternGP(noise=0.05, jitter=0.3)
trainer = GPTrainer(lr=0.01, num_iters=1, seed=0, weight_decay=0.001, constrain_gd=True, load_params=False)

In [None]:
train_state = trainer.train(model, Xtrain[0, ...], Ytrain[0, ...], 24)

In [None]:
train_state

In [None]:
train_state = train_state.replace(params=train_state.params.copy({'params': {
        'log_rho': np.array([-3, -3, -3]), 
        'log_sigma': 1.34}}))

In [None]:
pred_fn = partial(trainer._pred, train_state=train_state, Xtrain=Xtrain[0, ...], Ytrain=Ytrain[0, ...], train_diag=None)

In [None]:
z_scores = (Ytrain[0, ...] - mean[:, 0]) / np.sqrt(var[:, 0])

In [None]:
Y1

In [None]:
explained_variance_score(Y1, mean)

In [None]:
# force xdot and action to zero and see whether the GP can fit that from R->R

In [None]:
# plot what the GP says and whether it can fit the GT data for those values

In [None]:
# try forcing the length scales to be tiny and see if that forces an overfit to train data

In [None]:
data = process_expt(base_path / 'greedy_dense_cartpole_2022-09-20/20-36-44/')

In [None]:
data