In [None]:
%matplotlib inline

# These are useful for debugging, but make code slower:
%load_ext autoreload
%autoreload 2

import logging
import sys

import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
import jax
import jax.numpy as jnp
import jax.experimental.optimizers as jopt
import jax.random as jrandom
import numpy as np
import copy

from pref_bootstrap.envs import gridworld, mdp_interface, vase_world
from pref_bootstrap.algos import mce_irl
import pref_bootstrap.feedback_learner_blind_irl as fbl_blind_irl
import pref_bootstrap.feedback_learner_paired_comparisons as fbl_paired_comp
import pref_bootstrap.reward_models as r_models
import pref_bootstrap.expert_base as experts

sns.set(context='notebook', style='darkgrid')
logging.basicConfig(level=logging.INFO)
np.set_printoptions(precision=4, linewidth=100)

In [None]:
def backtracking_line_search(point, direction, eval_fn, *, init_step=1.0, armijo_c=0.05, armijo_tau=0.1, min_step=1e-5):
    """Performs backtracking line search until Armijo condition is satisfied.
    Assumes function is continuously differentiable (which should be true
    as long as we stay away from L1 penalties, hard maxes, etc.)."""
    value_at_point = eval_fn(point)
    armijo_satisfied = False
    t = -armijo_c * jnp.dot(point, direction)
    num_steps = 0
    best_step_size = init_step
    value_for_best_step_size = float('inf')
    while not armijo_satisfied:
        # compute step size and update num_steps
        step_size = init_step * armijo_tau ** num_steps
        num_steps += 1
        
        # break early if we're running for suspiciously long
        if step_size < min_step:
            print(
                f"Backtracked to step size {step_size} after {num_steps} iterations, "
                "but Armijo is still not satisfied; exiting early")
            break

        # compute new value & Armijo condition
        value_at_step = eval_fn(point - step_size * direction)
        armijo_satisfied = value_at_point - value_at_step >= step_size * t
        
        # we keep the 'best' step size in case we need to exit early (this is a bit of hack, but oh well)
        if value_at_step < value_for_best_step_size:
            value_for_best_step_size = value_at_step
            best_step_size = step_size

    return best_step_size

# generate some random trajectories & compare a random subset of them
def generate_comparison_dataset(pc_ntraj, env, pc_expert):
    """Generates a paired comparisons dataset by generating n_traj trajectories
    and then doing a single comparison (with a random other trajectory) for
    each generated trajectory."""
    pc_trajectories = mce_irl.mce_irl_sample(env, pc_ntraj, R=np.ones((env.n_states, )))
    to_compare_first = np.arange(len(pc_trajectories['states']))
    comparisons = []
    for first_idx in range(pc_ntraj):
        second_idx = np.random.randint(pc_ntraj - 1)
        if second_idx >= first_idx:
            second_idx += 1
        traj1_is_better = pc_expert.interact(
            dict(states=pc_trajectories['states'][first_idx]),
            dict(states=pc_trajectories['states'][second_idx]))
        if traj1_is_better:
            # the better trajectory comes before the worse one
            comparisons.append((first_idx, second_idx))
        else:
            comparisons.append((second_idx, first_idx))
    return {
        'trajectories': pc_trajectories,
        'comparisons': np.asarray(comparisons),
    }

def multi_optimize(model_list, data_list, rmodel, bias_list, env, names, lr=0.1, steps=1000, *, optimize_reward=True, debug=False): 
    """Joint optimization of several models at once."""
    loss_prev = float('Inf')

    for step in range(steps):
        rew_grad = jnp.zeros_like(rmodel.get_params())
        if debug: print('rew_grad:', rew_grad)  # XXX
        
        if optimize_reward: 
            for model, data, bias_params in zip(model_list, data_list, bias_list):
                rew_grad += model.log_likelihood_grad_rew(data, rmodel, bias_params)
                if debug: print('rew_grad:', rew_grad)  # XXX

        new_r = rmodel.get_params() + lr * rew_grad
        rmodel.set_params(new_r)
        
        bias_grad_norms = []
        for k, (model, data, bias_params) in enumerate(zip(model_list, data_list, bias_list)):
            bias_grad = model.log_likelihood_grad_bias(data, rmodel, bias_params)
            # bias prior
            bias_prior_grad = model.bias_prior.log_prior_grad(bias_params)
            bias_grad = bias_grad + bias_prior_grad
            bias_grad_norms.append(np.linalg.norm(np.asarray(bias_grad)))
            bias_list[k] = model.bias_prior.project_to_support(bias_params + lr * bias_grad)
            if debug: print('bias_grad:', bias_grad)  # XXX

        for k, (model, data, bias_params, name, bias_grad_norm) in enumerate(zip(model_list, data_list, bias_list, names, bias_grad_norms)):
            log_likelihood = model.log_likelihood(data, rmodel, bias_params)
            if step % 100 == 0:
                print('[%s] step %d: ll %.3f, |bias_grad| %.3f' % (name, step, log_likelihood, bias_grad_norm))
                print('[%s] biases' % (name, ), bias_params)
        if step % 100 == 0:
            print(f'[reward] params {np.asarray(rmodel.get_params())}, |grad| {np.linalg.norm(rew_grad):.3f}')
            print()

    return model_list, rmodel, bias_list

def multi_optimize_line(model_list, data_list, rmodel, bias_list, env, names, lr=0.01, steps=1000, *, optimize_reward=True, debug=False): 
    """Like multi_optimize but does line search."""
    loss_prev = float('Inf')
    for step in range(steps):
        rew_grad = jnp.zeros_like(rmodel.get_params())
        if debug: print('rew_grad:', rew_grad)  # XXX
        
        if optimize_reward: 
            for model, data, bias_params in zip(model_list, data_list, bias_list):
                rew_grad += model.log_likelihood_grad_rew(data, rmodel, bias_params)
                if debug: print('rew_grad:', rew_grad)  # XXX

        def rew_log_likelihood(rparams):
            # FIXME(sam): this is probably inefficient
            rmodel2 = copy.deepcopy(rmodel)
            rmodel2.set_params(rparams)
            total = 0.0
            for model, data, bias_params in zip(model_list, data_list, bias_list):
                total += model.log_likelihood(data, rmodel2, bias_params)
            return total

        rew_step_size = backtracking_line_search(point=rmodel.get_params(), direction=rew_grad, eval_fn=rew_log_likelihood)
        new_r = rmodel.get_params() + rew_step_size * rew_grad
        rmodel.set_params(new_r)
        raise NotImplementedError("Need to finish the rest of htis")
        
        bias_grad_norms = []
        for k, (model, data, bias_params) in enumerate(zip(model_list, data_list, bias_list)):
            bias_grad = model.log_likelihood_grad_bias(data, rmodel, bias_params)
            # bias prior
            bias_prior_grad = model.bias_prior.log_prior_grad(bias_params)
            bias_grad = bias_grad + bias_prior_grad
            bias_grad_norms.append(np.linalg.norm(np.asarray(bias_grad)))
            bias_list[k] = model.bias_prior.project_to_support(bias_params + lr * bias_grad)
            if debug: print('bias_grad:', bias_grad)  # XXX

        for k, (model, data, bias_params, name, bias_grad_norm) in enumerate(zip(model_list, data_list, bias_list, names, bias_grad_norms)):
            log_likelihood = model.log_likelihood(data, rmodel, bias_params)
            if step % 100 == 0:
                print('[%s] step %d: ll %.3f, |bias_grad| %.3f' % (name, step, log_likelihood, bias_grad_norm))
                print('[%s] biases' % (name, ), bias_params)
        if step % 100 == 0:
            print(f'[reward] params {np.asarray(rmodel.get_params())}, |grad| {np.linalg.norm(rew_grad):.3f}')
            print()

    return model_list, rmodel, bias_list

def init_models(env, irl_ntraj, pc_ntraj, seed=42, *,
                dummy_use_opt_reward=False):
    """Initialize paired comparisons & IRL models."""
    irl_feedback_model = fbl_blind_irl.BlindIRLFeedbackModel(env)
    rng = jrandom.PRNGKey(seed)
    rng, irl_bias_params = irl_feedback_model.init_bias_params(rng)
    rng, true_irl_bias_params_jax = irl_feedback_model.init_bias_params(rng)
    # round the actual parameters to 0/1
    true_irl_bias_params = np.round(np.asarray(true_irl_bias_params_jax))
    # these weird hex values are here so we have different seeds in
    # each place where we need to seed a new RNG
    irl_expert = experts.BlindMEDemonstratorExpert(
        env=env, feature_weights=true_irl_bias_params, seed=seed ^ 0x7364181f)
    # we'll do IRL based on 10 trajectories
    irl_dataset = irl_expert.interact(irl_ntraj)

    pc_feedback_model = fbl_paired_comp.PairedCompFeedbackModel(env)
    rng, pc_bias_params = pc_feedback_model.init_bias_params(rng)
    rng, true_pc_bias_params_jax = pc_feedback_model.init_bias_params(rng)
    true_pc_bias_params = float(true_pc_bias_params_jax)
    pc_expert = experts.PairedComparisonExpert(
        env, boltz_temp=true_pc_bias_params_jax, seed=seed ^ 0x35e0251f)
    assert pc_expert.boltz_temp == true_pc_bias_params
    comparison_dataset = generate_comparison_dataset(pc_ntraj=pc_ntraj,
                                                     env=env,
                                                     pc_expert=pc_expert)

    model_list = [pc_feedback_model, irl_feedback_model]
    data_list = [comparison_dataset, irl_dataset]
    bias_list = [pc_bias_params, irl_bias_params]
    names = ['paired_comparisons', 'blind_irl']
    biases_actual = [pc_expert.boltz_temp, true_irl_bias_params]

    rmodel = r_models.LinearRewardModel(env.obs_dim, seed=seed ^ 0x6b2a8d53)
    if dummy_use_opt_reward:
        print("WARNING: Initializing with optimal reward parameters!", file=sys.stderr)
        rmodel.set_params(env.reward_matrix)
        
    return model_list, data_list, rmodel, bias_list, names, biases_actual

In [None]:
# making the "vase world" environment
vw_env = vase_world.VaseWorld()
plt.figure(figsize=(10, 4))
plt.subplot(1, 2, 1)
sns.heatmap(vw_env.reward_matrix.reshape((vw_env.n_rows, vw_env.n_cols)))
plt.title('Reward values')
plt.subplot(1, 2, 2)
opt_om_ts, opt_om = mce_irl.mce_occupancy_measures(vw_env)
sns.heatmap(opt_om.reshape((vw_env.n_rows, vw_env.n_cols)))
plt.title('Occupancy measure')
plt.show()

In [None]:
model_list, data_list, rmodel, bias_list, names, biases_actual = init_models(vw_env, irl_ntraj=5, pc_ntraj=30, seed=45)
print('Reward model params:', rmodel.get_params())
print('True reward weights:', vw_env.reward_weights)
print('Bias params:', bias_list)
print('Real biases:', biases_actual)

In [None]:
final_model_list, final_rmodel, final_bias_list = multi_optimize(model_list=model_list, data_list=data_list, rmodel=rmodel, bias_list=bias_list, env=vw_env, names=names, lr=0.01, steps=5000)