In [1]:
# import matplotlib as mpl
# mpl.use('Agg')

import os
import sys
import pickle
import numpy as np
from collections import defaultdict

# Utils
sys.path.append("../../irl/")
from IRLProblem import IRLProblem

# GeoLife
sys.path.append("../../dataset/GeolifeTrajectories1.3/")
import geolife_data as GLData

# MLIRL
from mlirl_parallel import *

# Torch
import torch

# Plotting
import matplotlib.pyplot as plt
%load_ext autoreload
%autoreload 2

In [2]:
gmaps_api_key = 'AIzaSyCEQD8gqNXicATEstzFwCtFOlRWWWRrX4k'

# IRL Problem

In [4]:
# dataset
GEOLIFE_DIR = "../../dataset/GeolifeTrajectories1.3/"
IMAGE_FEATURE_DIR = "./features/satellite_rgb/"
traj_split = [0.6, 0.3, 0.1]
rand_split = False
max_traj_length = 30 # for debugging, -1 to disable
recompute_envelope = True
state_expl_cost_bound = 1.1 # only works if above flag is true

In [5]:
class GeoLife2(IRLProblem):
    
    def __init__(self, data_dir="../../dataset/GeolifeTrajectories1.3/",
                 satimg_dir="./features/satellite_rgb/",
                 satimg_zoom=19,
                 satimg_dim=(1,128,128), 
                 feature_mu=0., 
                 feature_std=1.,
                 lat_resolution=0.00010952786669320442,
                 lng_resolution=0.00014278203208415482,
                 st_fp_precision=20):
        
        with open(os.path.join(data_dir, "chopped_trajectories/turn_trajectories/turn_trajectories_v1.pk"), "rb") as f:
            self.trajectories = pickle.load(f)
            
        self.satimg_dir = satimg_dir
        self.satimg_zoom = satimg_zoom
        self.satimg_dim = satimg_dim
        self.feature_mu = feature_mu
        self.feature_std = feature_std
        self.lat_resolution = lat_resolution
        self.lng_resolution = lng_resolution
        self.st_fp_precision = st_fp_precision
    
    def features(self, state, feature_dim, gmaps_api_key=None, verbose=False):
        
        cc, hh, ww = feature_dim
        _, h, w = self.satimg_dim
        
        crop_top, crop_bottom = int(np.floor((h - hh) / 2)), int(np.ceil((h - hh) / 2))
        crop_left, crop_right = int(np.floor((w - ww) / 2)), int(np.ceil((w - ww) / 2))
        
        mode = "L" if cc == 1 else "RGB"
        img_size = str(h) +"x" + str(w)
        
        img = GLData.feature(state, img_size=img_size,
                             zoom=self.satimg_zoom, mode=mode,
                             api_key=gmaps_api_key, verbose=verbose,
                             store_dir=self.satimg_dir)
        if mode == "L":
            img_crop = img[crop_top:-crop_bottom,crop_left:-crop_right]
        else:
            img_crop = img[crop_top:-crop_bottom,crop_left:-crop_right,:]
            
        return torch.FloatTensor(
            (img_crop/255. - self.feature_mu) / self.feature_std).view(-1, cc, hh, ww)
    
    def sample_trajectories(self):
        return self.trajectories

    def get_dynamics(self):
        return self._envelope_gridded_dynamics
    
    def _envelope_gridded_dynamics(self, s, a, lat_to_lngs):
        
        s_prime = GLData.trans_func(
            s, a, 
            lat_resolution=self.lat_resolution, 
            lng_resolution=self.lng_resolution, 
            st_fp_precision=self.st_fp_precision)
    
        lat, lng = s_prime
        # next state can't go far from the cell center by this amount, if it goes we're outside the border
        lat_eps, lng_eps = self.lat_resolution/2., self.lng_resolution/2.

        grid_lats = np.asarray(list(lat_to_lngs.keys())).reshape(-1,1)
        grid_lat_dists = np.linalg.norm(grid_lats - lat, axis=1)
        best_lat_idx = np.argmin(grid_lat_dists)

        if grid_lat_dists[best_lat_idx] < lat_eps:        
            best_lat = grid_lats[best_lat_idx][0]
            lngs_list = lat_to_lngs[best_lat]

            if len(lngs_list) == 0:
                return None
            else:
                _lng_dists = np.linalg.norm(np.asarray(lngs_list).reshape(-1,1) - lng, axis=1)
                best_lng_idx = np.argmin(_lng_dists)
                if _lng_dists[best_lng_idx] < lng_eps:
                    return best_lat, lngs_list[best_lng_idx]
                else:
                    return None
        else:
            return None
        
def get_geolife_problem(geolife, image_dim, max_traj_length, gmaps_api_key):
    
    A = ["E", "N", "W", "S"] # TODO: Need 8 actions or at least 2 turns.
    T = geolife.get_dynamics()
    
    trajectories = geolife.sample_trajectories()
    # Trim trajectories (for debugging only)
    if max_traj_length != -1:
        for idx, t in enumerate(trajectories):
            trajectories[idx] = (t[0][:max_traj_length], t[1][:max_traj_length])
            
    S_list = []
    phi = lambda s: geolife.features(s, image_dim, gmaps_api_key=gmaps_api_key)
    S_lat_to_lngs_list = []
    for trajectory in trajectories:
        S, expert_cost = GLData.find_enevelope_a_star(
                    trajectory, A, GLData.trans_func, geolife.lat_resolution, geolife.lng_resolution, 
                        cost_ubound=1.1,
                        g_fn=lambda p1, p2, lat_res, lng_res: GLData.heuristic_l1(p1, p2, lat_res, lng_res),
                        h_fn=lambda p1, p2, lat_res, lng_res: GLData.heuristic_l2(p1, p2, lat_res, lng_res),)
        S_list.append(S)
        S_lat_to_lngs = defaultdict(lambda: [])
        for (lat, lng) in S:
            S_lat_to_lngs[lat].append(lng)
        S_lat_to_lngs_list.append(S_lat_to_lngs)
    return trajectories, S_list, phi, A, T, S_lat_to_lngs_list


In [6]:
image_dim=(1,64,64)
trajectories, S_list, phi, A, T, S_lat_to_lngs_list = get_geolife_problem(
    GeoLife2(GEOLIFE_DIR, IMAGE_FEATURE_DIR), image_dim, max_traj_length, gmaps_api_key)

In [7]:
np.random.seed(0)
torch.manual_seed(0)

N = len(trajectories)
N_tr = int(traj_split[0] * N)
N_val = int(traj_split[1] * N)
N_te = N - N_tr - N_val

if rand_split:
    rand_idxs = np.random.permutation(len(trajectories))
    tr_idxs = rand_idxs[:N_tr]
    val_idxs = rand_idxs[N_tr: N_tr+N_val]
    te_idxs = rand_idxs[N_tr+N_val:]
else:
    tr_idxs = [4, 25, 10, 31, 27, 11, 36, 28, 20, 38, 2, 40, 18, 15, 22, 16, 37, 8, 13, 5, 17]
    val_idxs = [14, 34, 7, 33, 1, 26, 12, 32, 24, 6, 23, 21]
    te_idxs = [19, 9, 39, 3, 0]
    
print("Train traj idxs", tr_idxs)
print("Val traj idxs", val_idxs)
print("Test traj idxs", te_idxs)

train_trajectories = [trajectories[idx] for idx in tr_idxs]
val_trajectories = [trajectories[idx] for idx in val_idxs]
test_trajectories = [trajectories[idx] for idx in te_idxs]

train_S_lat_to_lngs_list = [S_lat_to_lngs_list[idx] for idx in tr_idxs]
val_S_lat_to_lngs_list = [S_lat_to_lngs_list[idx] for idx in val_idxs]
test_S_lat_to_lngs_list = [S_lat_to_lngs_list[idx] for idx in te_idxs]

train_S_list  = [S_list[idx] for idx in tr_idxs]
val_S_list  = [S_list[idx] for idx in val_idxs]
test_S_list  = [S_list[idx] for idx in te_idxs]

Train traj idxs [4, 25, 10, 31, 27, 11, 36, 28, 20, 38, 2, 40, 18, 15, 22, 16, 37, 8, 13, 5, 17]
Val traj idxs [14, 34, 7, 33, 1, 26, 12, 32, 24, 6, 23, 21]
Test traj idxs [19, 9, 39, 3, 0]


# Model

In [8]:
np.random.seed(0)
torch.manual_seed(0)

<torch._C.Generator at 0x7f7f36871cd0>

In [9]:
import copy

# ConvAE
sys.path.append("../../models/")
import ConvAE as ConvAE
from pprint import pprint

# arch
nw_depth = 6
conv_k_size = 16
fc_latent_multiplier = 4
code_size = 128
strided_conv_freq = 3
use_convae_phi_priors = True
convae_model_state = "../../models/EXP_GEOLIFE_FEATURES_CONVAE_STRIDED_FINE_STATES_CHOPPED_TRAJ/"\
    "states=100x100$img_size=128x128$conv_k_size=16$fc_latent_multiplier=4$"\
    "code_size=128$strided_conv_freq=3$lr=0.0001$weight_decay=1e-09/results/model_state_ae.pt"
model_restore_file = None # to restore weights

# training
lr = 0.0005
weight_decay = 1e-4
optimizer_fn = lambda params, lr, weight_decay: optim.Adam(params, lr=lr, weight_decay=weight_decay)

def weights_init(m):
    if isinstance(m, nn.Linear):
        torch.nn.init.normal_(m.weight, mean=0., std=0.1)
        
# Architecture
print("\nPreparing Net..")
CONV_BLOCK = [("conv1", conv_k_size), ("relu1", None)]
CONV_LAYERS = ConvAE.create_network(CONV_BLOCK, nw_depth, pooling_freq=1e100,
                                    strided_conv_freq=strided_conv_freq, strided_conv_channels=conv_k_size)
CONV_NW = CONV_LAYERS + [("flatten1", None),
                         ("linear1", fc_latent_multiplier * code_size), ("linear1", code_size), ]

print("NW Config (depth={}):\n\tBlock: ".format(len(CONV_NW)), end="")
pprint(CONV_BLOCK)
print("\tNet: \n", end="")
pprint(CONV_NW)

print("\nCreating model..")
# https://discuss.pytorch.org/t/how-can-we-release-gpu-memory-cache/14530
# torch.cuda.empty_cache() # Needed for repeated experiments
if use_convae_phi_priors:
    convae = ConvAE.ConvAE(image_dim, enc_config=CONV_NW, states_file=convae_model_state)
else:
    convae = ConvAE.ConvAE(input_dim, 
                      enc_config=CONV_NW)
encoder = nn.Sequential(*convae.encoder)
encoder2 = copy.deepcopy(encoder)
decoder = nn.Sequential(*convae.decoder)

# device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# print("\n\nUsing device {}".format(device))
# encoder.to(device)
# decoder.to(device)
    
phi_dim = code_size
if model_restore_file is not None:
    with open(model_restore_file, "rb") as mlirl_model:
        print("Restoring model from {}".format(model_restore_file))
        M = pickle.load(mlirl_model)
else:
    M = LinearRewardModel(phi_dim)
    M.apply(weights_init)

# Feature Encoding  + Linear Model 
# (Using two separate models so that features can be pretrained using ConvAE. 
# Encoder: 1x64x64 -> code size, Linear: code_size -> 1)
encoder.add_module("R", M)

def dec_model(encoding): 
    with torch.no_grad():
        return decoder(encoding)
    
# ConvAE features fixed, only scalar R mapping is learned
#     for param in M.parameters():
#         param.grad = Variable(torch.zeros(param.shape))
#     M.share_memory() 

# End to end: ConvAE features are learned end to end
for param in encoder.parameters():
    param.grad = Variable(torch.zeros(param.shape))

# This is required for the ``fork`` method to work (https://pytorch.org/docs/master/notes/multiprocessing.html)
encoder.share_memory()

# Optimization params
r_optimizer = optimizer_fn(encoder.parameters(), lr, weight_decay)

# Test
x = phi(S_list[0][0])
assert torch.allclose(-nn.ReLU()(-M.w(encoder2(x)[0])), encoder(x))


Preparing Net..
NW Config (depth=17):
	Block: [('conv1', 16), ('relu1', None)]
	Net: 
[('conv1', 16),
 ('relu1', None),
 ('conv1', 16),
 ('relu1', None),
 ('conv1', 16),
 ('relu1', None),
 ('conv-strided1', 16),
 ('conv1', 16),
 ('relu1', None),
 ('conv1', 16),
 ('relu1', None),
 ('conv1', 16),
 ('relu1', None),
 ('conv-strided1', 16),
 ('flatten1', None),
 ('linear1', 512),
 ('linear1', 128)]

Creating model..
Loading states from: ../../models/EXP_GEOLIFE_FEATURES_CONVAE_STRIDED_FINE_STATES_CHOPPED_TRAJ/states=100x100$img_size=128x128$conv_k_size=16$fc_latent_multiplier=4$code_size=128$strided_conv_freq=3$lr=0.0001$weight_decay=1e-09/results/model_state_ae.pt


In [10]:
from IPython import display

def iter_handler(_iter, loss_history, r_model, results_dir=None, dec_model=None, 
                 render=False, redraw=False, params_inspection_fn=None, training_interrupted=False):
    
    os.makedirs(results_dir, exist_ok=True)
    if training_interrupted: # i.e., training iterrupted
        print("Iter: {:04d} Training interrupted! "
              "Returning current states..".format(_iter))
        with open(os.path.join(results_dir, "./interrupted_mlirl_iter_model.pkl"), "wb") as model_file:
            pickle.dump(r_model, model_file)
        return
    
    if params_inspection_fn:
        params_inspection_fn(_iter, r_model)
    
    if results_dir:
        with open(os.path.join(results_dir, "./mlirl_iter_{:03d}_model.pkl".format(
                _iter)), "wb") as model_file:
            pickle.dump(r_model, model_file)            
        # Remove previous backup
        if _iter > 0:
            os.remove(os.path.join(results_dir, "./mlirl_iter_{:03d}_model.pkl".format(
                _iter-1)))
            
    if render and redraw:
        plt.gca().cla()
    
    fig = plt.figure(figsize=(20, 18))
    plt.subplot(211)

    if dec_model is not None:
        with torch.no_grad():
            out = dec_model(torch.FloatTensor(r_model[-1].w.weight.detach().data.numpy())).detach().data
            plt.imshow(out[0, 0].data, cmap="gray")
            plt.title("Max likelihood image")

    plt.subplot(212)
    plt.plot(np.exp(-np.asarray(loss_history)))
    plt.title("Linear MLIRL (GeoLife Trajectories)")
    plt.xlabel("Iterations")
    plt.ylabel("Likelihood")
    
    if results_dir:
        plt.savefig(os.path.join(results_dir, "mlirl_iter_{:03d}_plot.png".format(_iter)))
        if _iter > 0:
            os.remove(os.path.join(results_dir, "mlirl_iter_{:03d}_plot.png".format(_iter-1)))
            
    if render:     
        if redraw:
            display.clear_output(wait=True)
            display.display(plt.gcf())
        else:
            plt.show()
        
def params_inspection_fn(_iter, r_model):
    
    for p in r_model.parameters():
        print("\n\t\t\t Params @ Iter {}"
              "\n\t\t\t\t     w: [{}]"
              "\n\t\t\t\t    dw: [{}]".format(
                  _iter,
                  ' '.join("{:+012.7f}".format(v) for v in p[0, 0].flatten()),
                  ' '.join("{:+012.7f}".format(v) for v in p.grad[0, 0].flatten()))
              )
        break

# Deep MLIRL

In [None]:
RESULTS_DIR = "./__mlirl_results/"
r_model, loss_history, Pi_list, V_list, Q_list = MLIRL(
    train_trajectories, train_S_list, phi, A, T, train_S_lat_to_lngs_list, 
    r_model=encoder, r_optimizer=r_optimizer,
    n_iter=20,
    n_vi_iter=200, 
    gamma=0.95, 
    boltzmann_temp=1.,
    loss_eps=1e-3,
    max_goals=4,
    dtype=torch.float64,
    debug=True, 
    verbose=True, 
    perf_debug=True, 
    max_processes=1,
    iter_handler=iter_handler,
    results_dir=RESULTS_DIR,
    dec_model=dec_model)

	Forking process #0..
		Running process: 0, traj len: 30, states: 422
