In [None]:
%matplotlib inline

# These are useful for debugging, but make code slower:
%load_ext autoreload
%autoreload 2

import logging
import sys

import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
import jax
import jax.numpy as jnp
import jax.experimental.optimizers as jopt
import jax.random as jrandom
import numpy as np

from pref_bootstrap.envs import gridworld, mdp_interface, avse_world
from pref_bootstrap.algos import mce_irl
import pref_bootstrap.feedback_learner_blind_irl as fbl_blind_irl
import pref_bootstrap.feedback_learner_paired_comparisons as fbl_paired_comp
import pref_bootstrap.reward_models as r_models
import pref_bootstrap.expert_base as experts

sns.set(context='notebook', style='darkgrid')
logging.basicConfig(level=logging.INFO)
np.set_printoptions(precision=4, linewidth=100)

In [None]:
def backtracking_line_search(point, direction, eval_fn, *, init_step=1.0, armijo_c=0.05, armijo_tau=0.1, min_step=1e-5):
    """Performs backtracking line search until Armijo condition is satisfied.
    Assumes function is continuously differentiable (which should be true
    as long as we stay away from L1 penalties, hard maxes, etc.)."""
    value_at_point = eval_fn(point)
    armijo_satisfied = False
    t = -armijo_c * jnp.dot(point, direction)
    num_steps = 0
    best_step_size = init_step
    value_for_best_step_size = float('inf')
    while not armijo_satisfied:
        # compute step size and update num_steps
        step_size = init_step * armijo_tau ** num_steps
        num_steps += 1
        
        # break early if we're running for suspiciously long
        if step_size < min_step:
            print(
                f"Backtracked to step size {step_size} after {num_steps} iterations, "
                "but Armijo is still not satisfied; exiting early")
            break

        # compute new value & Armijo condition
        value_at_step = eval_fn(point - step_size * direction)
        armijo_satisfied = value_at_point - value_at_step >= step_size * t
        
        # we keep the 'best' step size in case we need to exit early (this is a bit of hack, but oh well)
        if value_at_step < value_for_best_step_size:
            value_for_best_step_size = value_at_step
            best_step_size = step_size

    return best_step_size

# generate some random trajectories & compare a random subset of them
def generate_comparison_dataset(pc_ntraj, env, pc_expert):
    """Generates a paired comparisons dataset by generating n_traj trajectories
    and then doing a single comparison (with a random other trajectory) for
    each generated trajectory."""
    pc_trajectories = mce_irl.mce_irl_sample(env, pc_ntraj, R=np.ones((env.n_states, )))
    to_compare_first = np.arange(len(pc_trajectories['states']))
    comparisons = []
    for first_idx in range(pc_ntraj):
        second_idx = np.random.randint(pc_ntraj - 1)
        if second_idx >= first_idx:
            second_idx += 1
        traj1_is_better = pc_expert.interact(
            dict(states=pc_trajectories['states'][first_idx]),
            dict(states=pc_trajectories['states'][second_idx]))
        if traj1_is_better:
            # the better trajectory comes before the worse one
            comparisons.append((first_idx, second_idx))
        else:
            comparisons.append((second_idx, first_idx))
    return {
        'trajectories': pc_trajectories,
        'comparisons': np.asarray(comparisons),
    }

def multi_optimize(model_list, data_list, rmodel, bias_list, *, optimize_reward=True): 
    """Joint optimization of several models at once."""
    lr = 1
    steps = 1000
    loss_prev = float('Inf')
    delta = 100
    step = 0
    
    for step in range(steps):
        grew = jnp.zeros_like(env.reward_matrix)
        
        if optimize_reward: 
            for model, data, bias_params in zip(model_list, data_list, bias_list):
                grew += model.log_likelihood_grad_rew(data, rmodel, bias_params)

        new_r = rmodel.get_params() + lr*grew
        rmodel.set_params(new_r)
        
        for k, (model, data, bias_params) in enumerate(zip(model_list, data_list, bias_list)):
            gbias = model.log_likelihood_grad_bias(data, rmodel, bias_params)
            # bias prior
            bias_prior_grad = model.bias_prior.log_prior_grad(bias_params)
            gbias = gbias + bias_prior_grad
        
            bias_list[k] = bias_params + lr*(gbias)
            bias_list[k] = model.bias_prior.project_to_support(bias_list[k])

        for k, (model, data, bias_params) in enumerate(zip(model_list, data_list, bias_list)):
            loss = model.log_likelihood(data, rmodel, bias_params)
          
            if step % 100 == 0:
                print('step %d loss %.3f model %d' % (step, loss, k))
                print('---', bias_params)
        if step > 0: 
            delta = loss - loss_prev
        loss_prev = loss

    return model_list, rmodel, bias_list

def init_models(env, irl_ntraj, pc_ntraj, *, dummy_use_opt_reward=False):
    """Initialize paired comparisons & IRL models."""
    irl_feedback_model = fbl_blind_irl.BlindIRLFeedbackModel(env)
    rmodel = r_models.LinearRewardModel(env.obs_dim)
    rng = jrandom.PRNGKey(42)
    rng, irl_bias_params = irl_feedback_model.init_bias_params(rng)
    irl_expert = experts.MEDemonstratorExpert(env, np.random.randint((1 << 31) - 1))
    # we'll do IRL based on 10 trajectories
    irl_dataset = irl_expert.interact(irl+ntraj)

    pc_feedback_model = fbl_paired_comp.PairedCompFeedbackModel(env)
    rng, pc_bias_params = pc_feedback_model.init_bias_params(rng)
    pc_expert = experts.PairedComparisonExpert(env, boltz_temp=1.0, seed=42)

    comparison_dataset = generate_comparison_dataset(pc_ntraj)

    rmodel = r_models.LinearRewardModel(env.obs_dim)

    rng = jrandom.PRNGKey(23)

    model_list = [pc_feedback_model, irl_feedback_model]
    data_list = [comparison_dataset, irl_dataset]
    bias_list = [pc_bias_params, irl_bias_params]
    use_bias_list = [True, True]
    rmodel = r_models.LinearRewardModel(env.obs_dim)
    if dummy_use_opt_reward:
        print("WARNING: Initializing with optimal reward parameters!", file=sys.stderr)
        rmodel.set_params(env.reward_matrix)

    names = ['paired_comparisons', 'blind_irl']

    biases_actual = [pc_expert.boltz_temp, irl_bias_params]

    # pc actual bias
    biases_actual.append(pc_expert.boltz_temp)
    
    # blind irl actual bias
    biases_actual.append((irl_bias_params))
        
    return model_list, data_list, rmodel, bias_list, use_bias_list, names, biases_actual

In [None]:
# making the "vase world" environment
vw_env = vase_world.VaseWorld()
sns.heatmap(vw_env.reward_matrix.reshape((vw_env.n_rows, vw_env.n_cols)))
plt.title('Reward values')
plt.show()
opt_om_ts, opt_om = mce_irl.mce_occupancy_measures(vw_env)
sns.heatmap(opt_om.reshape((vw_env.n_rows, vw_env.n_cols)))
plt.title('Occupancy measure')
plt.show()