# Train agent recurrent with GPU

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [2]:
!pip install pyro-ppl &> /dev/null

In [3]:
import sys
cwd = "drive/Shareddrives/Active_Inference_Interaction/"
sys.path.append(cwd)

In [4]:
import os
import glob
import json
import pickle
import datetime
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch 
import torch.nn as nn

# setup imports
from src.data.train_utils import load_data, update_record_id
from src.data.train_utils import train_test_split, count_parameters
from src.data.data_filter import filter_segment_by_length
from src.data.ego_dataset import RelativeDataset, collate_fn

# model imports
from src.agents.rule_based import IDM
from src.agents.nn_agents import MLPAgent, RNNAgent
from src.agents.vin_agent import VINAgent
from src.agents.hyper_vin_agent import HyperVINAgent
from src.algo.bc import RecurrentBehaviorCloning
from src.algo.hyper_bc import HyperBehaviorCloning

# training imports
from src.algo.utils import train, SaveCallback

import warnings
warnings.filterwarnings("ignore")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"device: {device}")

device: cuda


In [5]:
def main(arglist):
    torch.manual_seed(arglist["seed"])
    
    # load data files
    df_track = []
    for filename in arglist["filenames"]:
        df_track.append(load_data(arglist["data_path"], arglist["scenario"], filename))
    df_track = pd.concat(df_track, axis=0)
    df_track = update_record_id(df_track)
    
    # filter invalid lanes
    valid_lanes = [int(l) for l in arglist["valid_lanes"]]
    is_train = df_track["is_train"].values
    is_train[df_track["ego_lane_id"].isin(valid_lanes) == False] = np.nan
    df_track = df_track.assign(is_train=is_train)

    # filter episode length
    eps_id, df_track["eps_len"] = filter_segment_by_length(
        df_track["eps_id"].values, arglist["min_eps_len"]
    )
    df_track = df_track.assign(eps_id=eps_id.astype(float))    

    df_track = df_track.loc[(df_track["is_train"] == 1) & (df_track["eps_id"] != np.nan)]

    feature_set = arglist["feature_set"]
    action_set = arglist["action_set"]

    # compute obs and ctl mean and variance stats
    obs_mean = torch.from_numpy(df_track.loc[df_track["is_train"] == 1][feature_set].mean().values).to(torch.float32)
    obs_var = torch.from_numpy(df_track.loc[df_track["is_train"] == 1][feature_set].var().values).to(torch.float32)
    ctl_mean = torch.from_numpy(df_track.loc[df_track["is_train"] == 1][action_set].mean().values).to(torch.float32)
    ctl_var = torch.from_numpy(df_track.loc[df_track["is_train"] == 1][action_set].var().values).to(torch.float32)
    
    # init dataset
    dataset = RelativeDataset(
        df_track, feature_set, action_set, train_labels_col="is_train",
        max_eps_len=arglist["max_eps_len"], max_eps=10000, state_action=False,
        seed=arglist["seed"]
    )
    train_loader, test_loader = train_test_split(
        dataset, arglist["train_ratio"], arglist["batch_size"], 
        collate_fn=collate_fn, seed=arglist["seed"]
    )
    obs_dim, ctl_dim = len(feature_set), len(action_set)

    print(f"feature set: {feature_set}")
    print(f"action set: {action_set}")
    print(f"train size: {len(train_loader.dataset)}, test size: {len(test_loader.dataset)}")

    # init agent
    if arglist["agent"] == "vin":
        agent = VINAgent(
            arglist["state_dim"], arglist["act_dim"], obs_dim, ctl_dim, arglist["hmm_rank"],
            arglist["horizon"], alpha=arglist["alpha"], beta=arglist["beta"], obs_model=arglist["obs_model"],
            obs_cov=arglist["obs_cov"], ctl_cov=arglist["ctl_cov"], rwd=arglist["rwd"], detach=arglist["detach"],
        )
        agent.obs_model.init_batch_norm(obs_mean, obs_var)
        agent.ctl_model.init_batch_norm(ctl_mean, ctl_var)

    elif arglist["agent"] == "hvin":
        agent = HyperVINAgent(
            arglist["state_dim"], arglist["act_dim"], obs_dim, ctl_dim, arglist["hmm_rank"],
            arglist["horizon"], arglist["hyper_dim"], arglist["hidden_dim"], arglist["num_hidden"], 
            arglist["gru_layers"], arglist["activation"], alpha=arglist["alpha"], beta=arglist["beta"], 
            obs_model=arglist["obs_model"], obs_cov=arglist["obs_cov"], ctl_cov=arglist["ctl_cov"], rwd=arglist["rwd"],
            hyper_cov=arglist["hyper_cov"], train_prior=arglist["train_prior"]
        )
    
    elif arglist["agent"] == "rnn":
        agent = RNNAgent(
            obs_dim, ctl_dim, arglist["act_dim"], arglist["hidden_dim"],
            arglist["num_hidden"], arglist["gru_layers"], arglist["activation"]
        )

    elif arglist["agent"] == "mlp":
        agent = MLPAgent(
            obs_dim, ctl_dim, arglist["act_dim"], arglist["hidden_dim"], 
            arglist["num_hidden"], arglist["activation"]
        )

    elif arglist["agent"] == "idm":
        agent = IDM(feature_set)
    
    # preload stats
    if hasattr(agent, "obs_mean") and arglist["norm_obs"]:
        agent.obs_mean.data = obs_mean
        agent.obs_variance.data = obs_var

    if hasattr(agent, "ctl_model") and (arglist["action_set"] == ["dds"] or arglist["action_set"] == ["dds_smooth"]):
        # load ctl gmm parameters
        with open(os.path.join(arglist["exp_path"], "agents", "ctl_model", "model.p"), "rb") as f:
            [ctl_means, ctl_covs, weights] = pickle.load(f)

        agent.ctl_model.init_params(ctl_means, ctl_covs)
        print("action model loaded")

    # init trainer
    if arglist["agent"] == "hvin":
        model = HyperBehaviorCloning(
            agent, arglist["train_mode"], arglist["detach"], arglist["bptt_steps"], arglist["pred_steps"],
            arglist["bc_penalty"], arglist["obs_penalty"], arglist["pred_penalty"], arglist["reg_penalty"], 
            arglist["post_obs_penalty"], arglist["kl_penalty"], 
            lr=arglist["lr"], lr_flow=arglist["lr_flow"], lr_post=arglist["lr_post"],
            decay=arglist["decay"], grad_clip=arglist["grad_clip"], decay_steps=arglist["decay_steps"],
            decay_rate=arglist["decay_rate"],
        )
    
    else:
        model = RecurrentBehaviorCloning(
            agent, arglist["bptt_steps"], arglist["pred_steps"], arglist["bc_penalty"], 
            arglist["obs_penalty"], arglist["pred_penalty"], arglist["reg_penalty"], 
            lr=arglist["lr"], lr_flow=arglist["lr_flow"], decay=arglist["decay"], 
            grad_clip=arglist["grad_clip"], decay_steps=arglist["decay_steps"],
            decay_rate=arglist["decay_rate"],

        )

    model.to(device)    
    print(f"num parameters: {count_parameters(model)}")
    print(model)

    # load from check point
    cp_history = None
    if arglist["checkpoint_path"] != "none":
        cp_path = os.path.join(
            arglist["exp_path"], "agents", 
            arglist["agent"], arglist["checkpoint_path"]
        )

        # load state dict
        cp_model_path = glob.glob(os.path.join(cp_path, "model/*.pt"))
        cp_model_path.sort(key=lambda x: int(os.path.basename(x).replace(".pt", "").split("_")[-1]))
        
        state_dict = torch.load(cp_model_path[-1])
        model.load_state_dict(state_dict["model_state_dict"], strict=False)
        model.optimizer.load_state_dict(state_dict["optimizer_state_dict"])
        model.scheduler.load_state_dict(state_dict["scheduler_state_dict"])

        # load history
        cp_history = pd.read_csv(os.path.join(cp_path, "history.csv"))
        print(f"loaded checkpoint from {cp_path}")

    callback = None
    if arglist["save"]:
        callback = SaveCallback(arglist, model, cp_history)

    model, df_history = train(
        model, train_loader, test_loader, arglist["epochs"], callback=callback, verbose=1
    )
    if arglist["save"]:
        callback.save_checkpoint(model)
        callback.save_history(df_history)
    return model.cpu(), df_history

In [14]:
# train config
arglist = {
    "data_path": os.path.join(cwd, "interaction-dataset-master"),
    "exp_path": os.path.join(cwd, "exp"),
    "scenario": "DR_CHN_Merging_ZS",
    "filenames": ["vehicle_tracks_000.csv", "vehicle_tracks_003.csv", "vehicle_tracks_007.csv"],
    "valid_lanes": ["3", "4"],
    "checkpoint_path": "none",
    "feature_set": ["lv_s_rel", "lv_ds_rel", "lv_inv_tau"],
    "action_set": ["dds_smooth"],
    # agent args
    "agent": "vin",
    "state_dim": 20,
    "act_dim": 15,
    "horizon": 30,
    "obs_model": "flow",
    "obs_cov": "tied",
    "ctl_cov": "diag",
    "hmm_rank": 0,
    "alpha": 1., 
    "beta": 0., 
    "rwd": "efe",
    "detach": False,
    "hyper_dim": 4,    
    "hyper_cov": True, 
    "train_prior": False,
    # nn args
    "hidden_dim": 64,
    "num_hidden": 2,
    "gru_layers": 2, # use 2 gru layers for inference
    "activation": "relu",
    "norm_obs": True,
    # algo args
    "train_mode": "marginal",
    "bptt_steps": 500,
    "pred_steps": 5,
    "bc_penalty": 1.,
    "obs_penalty": 1.,
    "pred_penalty": 0.2,
    "reg_penalty": 0.1,
    "post_obs_penalty": 0.1,
    "kl_penalty": 1.,
    # training args
    "min_eps_len": 50,
    "max_eps_len": 500,
    "train_ratio": 0.7,
    "batch_size": 100,
    "epochs": 500,
    "lr": 0.01, # use 0.005 for nn models
    "lr_flow": 0.001,
    "lr_post": 0.005,
    "decay": 3e-5,
    "grad_clip": 20,
    "decay_steps": 100,
    "decay_rate": 0.9,
    "cp_every": 50,
    "seed": 0,
    "save": True
}

In [15]:
model, df_history = main(arglist)

feature set: ['lv_s_rel', 'lv_ds_rel', 'lv_inv_tau']
action set: ['dds_smooth']
train size: 615, test size: 263
action model loaded
num parameters: 7640
RecurrentBehaviorCloning(bptt_steps=500, pred_steps=5, bc_penalty=1.0, obs_penalty=1.0, pred_penalty=0.2, reg_penalty=0.1, lr=0.01, lr_flow=0.001, decay=3e-05, grad_clip=20, decay_steps=100, decay_rate=0.9,
agent=VINAgent(
  (rnn): QMDPLayer(state_dim=20, act_dim=15, rank=0, horizon=30, detach=False)
  (obs_model): ConditionalFlow(x_dim=3, z_dim=20, hidden_dim=30, cov=tied, batch_norm=True)
  (ctl_model): ConditionalGaussian(x_dim=1, z_dim=15, cov=diag, batch_norm=False)
))
e: 1/500, loss_u: 2.1402/2.1425, loss_o: 4.2226/4.0426, t: 7.24
e: 2/500, loss_u: 2.1342/2.1346, loss_o: 4.0650/3.8912, t: 14.68
e: 3/500, loss_u: 2.1256/2.1232, loss_o: 3.9288/3.7480, t: 22.01
e: 4/500, loss_u: 2.1128/2.1072, loss_o: 3.7893/3.6078, t: 29.36
e: 5/500, loss_u: 2.0951/2.0858, loss_o: 3.6349/3.4638, t: 36.68
e: 6/500, loss_u: 2.0726/2.0589, loss_o: 3.4

KeyboardInterrupt: ignored