# D4RL data parsing
* Compare the outputs of d4rl built-in data parsing method, a modification based on recorded next observations, and a custom fast parsing method
* The main finding is that d4rl built-in parsing method incorrectly takes the initial obs of the next episode as the final next_obs of the current episode (have not verified if this is true for all episodes or just early terminated episodes). 

In [1]:
import pickle
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

In [57]:
# load data
# filepath = "../data/d4rl/halfcheetah-medium-expert-v2.p"
filepath = "../data/d4rl/hopper-medium-expert-v2.p"
# filepath = "../data/d4rl/walker2d-medium-expert-v2.p"
with open(filepath, "rb") as f:
    dataset = pickle.load(f)

In [58]:
# get first episode
first_terminal = np.where((dataset["terminals"] == 1) | (dataset["timeouts"] == 1))[0][0]

obs_eps = dataset["observations"][:first_terminal]
act_eps = dataset["actions"][:first_terminal]
rwd_eps = dataset["rewards"][:first_terminal]
next_obs_eps = dataset["next_observations"][:first_terminal]

print(
    "obs equivalent to next_obs by shift:", 
    np.all((obs_eps[1:] - next_obs_eps[:-1]) == 0)
)

obs equivalent to next_obs by shift: True


In [59]:
def parse_data_d4rl(dataset, use_next_obs=False, terminate_on_end=False):
    """ Parse dataset d4rl way """
    N = dataset['rewards'].shape[0]
    obs_ = []
    next_obs_ = []
    action_ = []
    reward_ = []
    done_ = []
    eps_id_ = []

    # The newer version of the dataset adds an explicit
    # timeouts field. Keep old method for backwards compatability.
    use_timeouts = False
    if 'timeouts' in dataset:
        use_timeouts = True
    
    eps_id = 0
    episode_step = 0
    for i in range(N-1):
        obs = dataset['observations'][i].astype(np.float32)
        if use_next_obs:
            new_obs = dataset['next_observations'][i].astype(np.float32)
        else:
            new_obs = dataset['observations'][i+1].astype(np.float32)
        action = dataset['actions'][i].astype(np.float32)
        reward = dataset['rewards'][i].astype(np.float32)
        done_bool = bool(dataset['terminals'][i])

        if use_timeouts:
            final_timestep = dataset['timeouts'][i]
        else:
            final_timestep = (episode_step == 1000 - 1)
        if (not terminate_on_end) and final_timestep:
            # Skip this transition and don't apply terminals on the last step of an episode
            episode_step = 0
            continue

        obs_.append(obs)
        next_obs_.append(new_obs)
        action_.append(action)
        reward_.append(reward)
        done_.append(done_bool)
        eps_id_.append(eps_id)
        episode_step += 1

        if done_bool or final_timestep:
            episode_step = 0
            eps_id += 1

    return {
        "obs": np.array(obs_),
        "act": np.array(action_),
        "rwd": np.array(reward_).reshape(-1, 1),
        "next_obs": np.array(next_obs_),
        "done": np.array(done_).reshape(-1, 1),
        "eps_id": np.array(eps_id_).reshape(-1, 1),
    }

In [5]:
def parse_data_fast(dataset, skip_timeout=True):
    """ Parse data fast and minimc d4rl parser when setting use_next_obs=True """
    # unpack dataset
    obs = dataset["observations"]
    act = dataset["actions"]
    rwd = dataset["rewards"].reshape(-1, 1)
    next_obs = dataset["next_observations"]
    terminated = dataset["terminals"].reshape(-1, 1)
    
    # follow d4rl qlearning_dataset
    if skip_timeout:
        obs = obs[dataset["timeouts"] == False]
        act = act[dataset["timeouts"] == False]
        rwd = rwd[dataset["timeouts"] == False]
        next_obs = next_obs[dataset["timeouts"] == False]
        terminated = terminated[dataset["timeouts"] == False]
    
    return {
        "obs": obs,
        "act": act,
        "rwd": rwd,
        "next_obs": next_obs,
        "done": terminated,
    }

In [60]:
parsed_dataset_1 = parse_data_d4rl(dataset, use_next_obs=False)
parsed_dataset_2 = parse_data_d4rl(dataset, use_next_obs=True)

In [61]:
parsed_dataset_3 = parse_data_fast(dataset, skip_timeout=True)

In [62]:
# d4rl parsing does not handle next_obs properly
print("parsed_dataset_1 shape", parsed_dataset_1["obs"].shape)
print("parsed_dataset_2 shape", parsed_dataset_2["obs"].shape)

print("obs equal",
    len(np.unique(parsed_dataset_1["obs"] - parsed_dataset_2["obs"])) == 1
)
print("next obs equal",
    len(np.unique(parsed_dataset_1["next_obs"] - parsed_dataset_2["next_obs"])) == 1
)
print("done equal",
    len(np.unique(parsed_dataset_1["done"] == parsed_dataset_2["done"])) == 1
)

parsed_dataset_1 shape (1998966, 11)
parsed_dataset_2 shape (1998966, 11)
obs equal True
next obs equal False
done equal True


In [44]:
# fast parsing is equal to parsing with use_next_obs=True
print("parsed_dataset_3 shape", parsed_dataset_3["obs"].shape)

print("obs equal",
    len(np.unique(parsed_dataset_2["obs"] - parsed_dataset_3["obs"][:-1])) == 1
)
print("next obs equal",
    len(np.unique(parsed_dataset_2["next_obs"] - parsed_dataset_3["next_obs"][:-1])) == 1
)
print("done equal",
    len(np.unique(parsed_dataset_2["done"] == parsed_dataset_3["done"][:-1])) == 1
)

parsed_dataset_3 shape (1998319, 17)
obs equal True
next obs equal True
done equal True


In [46]:
# check why next obs is not equal
next_obs_not_equal = np.any(
    (parsed_dataset_1["next_obs"] - parsed_dataset_2["next_obs"]) != 0,
    axis=1
)

In [47]:
# compare two parsing methods around time step where they are different
t_not_equal = np.where(next_obs_not_equal)[0][0]

obs_window_1 = np.hstack([
    parsed_dataset_1["eps_id"][t_not_equal-5:t_not_equal+5].reshape(-1, 1),
    parsed_dataset_1["obs"][t_not_equal-5:t_not_equal+5, :4],
    parsed_dataset_1["next_obs"][t_not_equal-5:t_not_equal+5, :4],
    parsed_dataset_1["done"][t_not_equal-5:t_not_equal+5],
    next_obs_not_equal[t_not_equal-5:t_not_equal+5].reshape(-1, 1),
]).round(2)

obs_window_2 = np.hstack([
    parsed_dataset_2["eps_id"][t_not_equal-5:t_not_equal+5].reshape(-1, 1),
    parsed_dataset_2["obs"][t_not_equal-5:t_not_equal+5, :4],
    parsed_dataset_2["next_obs"][t_not_equal-5:t_not_equal+5, :4],
    parsed_dataset_2["done"][t_not_equal-5:t_not_equal+5],
    next_obs_not_equal[t_not_equal-5:t_not_equal+5].reshape(-1, 1),
]).round(2)

print("use_next_obs=False")
print(obs_window_1)
print()
print("use_next_obs=True")
print(obs_window_2)

# big jump in final step if not use_next_obs (d4rl implementation is wrong?)

use_next_obs=False
[[ 0.    0.89  0.59 -0.05 -0.77  0.88  0.6  -0.04 -0.74  0.    0.  ]
 [ 0.    0.88  0.6  -0.04 -0.74  0.86  0.6  -0.05 -0.69  0.    0.  ]
 [ 0.    0.86  0.6  -0.05 -0.69  0.84  0.62 -0.02 -0.68  0.    0.  ]
 [ 0.    0.84  0.62 -0.02 -0.68  0.82  0.62 -0.01 -0.73  0.    0.  ]
 [ 0.    0.82  0.62 -0.01 -0.73  0.8   0.63  0.02 -0.81  0.    0.  ]
 [ 0.    0.8   0.63  0.02 -0.81  1.25 -0.    0.   -0.    1.    1.  ]
 [ 1.    1.25 -0.    0.   -0.    1.25 -0.03 -0.03  0.    0.    0.  ]
 [ 1.    1.25 -0.03 -0.03  0.    1.24 -0.08 -0.09  0.01  0.    0.  ]
 [ 1.    1.24 -0.08 -0.09  0.01  1.24 -0.12 -0.12 -0.03  0.    0.  ]
 [ 1.    1.24 -0.12 -0.12 -0.03  1.23 -0.15 -0.13 -0.05  0.    0.  ]]

use_next_obs=True
[[ 0.    0.89  0.59 -0.05 -0.77  0.88  0.6  -0.04 -0.74  0.    0.  ]
 [ 0.    0.88  0.6  -0.04 -0.74  0.86  0.6  -0.05 -0.69  0.    0.  ]
 [ 0.    0.86  0.6  -0.05 -0.69  0.84  0.62 -0.02 -0.68  0.    0.  ]
 [ 0.    0.84  0.62 -0.02 -0.68  0.82  0.62 -0.01 -0.73  0.    0

In [48]:
# compare two parsing methods around time step where they are different
t_not_equal = np.where(next_obs_not_equal)[0][1]

obs_window_1 = np.hstack([
    parsed_dataset_1["eps_id"][t_not_equal-5:t_not_equal+5].reshape(-1, 1),
    parsed_dataset_1["obs"][t_not_equal-5:t_not_equal+5, :4],
    parsed_dataset_1["next_obs"][t_not_equal-5:t_not_equal+5, :4],
    parsed_dataset_1["done"][t_not_equal-5:t_not_equal+5],
    next_obs_not_equal[t_not_equal-5:t_not_equal+5].reshape(-1, 1),
]).round(2)

obs_window_2 = np.hstack([
    parsed_dataset_2["eps_id"][t_not_equal-5:t_not_equal+5].reshape(-1, 1),
    parsed_dataset_2["obs"][t_not_equal-5:t_not_equal+5, :4],
    parsed_dataset_2["next_obs"][t_not_equal-5:t_not_equal+5, :4],
    parsed_dataset_2["done"][t_not_equal-5:t_not_equal+5],
    next_obs_not_equal[t_not_equal-5:t_not_equal+5].reshape(-1, 1),
]).round(2)

print("use_next_obs=False")
print(obs_window_1)
print()
print("use_next_obs=True")
print(obs_window_2)

# big jump in final step if not use_next_obs (d4rl implementation is wrong?)

use_next_obs=False
[[ 1.    0.89  0.8  -0.1  -0.7   0.87  0.79 -0.11 -0.76  0.    0.  ]
 [ 1.    0.87  0.79 -0.11 -0.76  0.86  0.78 -0.11 -0.84  0.    0.  ]
 [ 1.    0.86  0.78 -0.11 -0.84  0.84  0.78 -0.09 -0.93  0.    0.  ]
 [ 1.    0.84  0.78 -0.09 -0.93  0.83  0.78 -0.07 -1.03  0.    0.  ]
 [ 1.    0.83  0.78 -0.07 -1.03  0.81  0.75 -0.09 -1.12  0.    0.  ]
 [ 1.    0.81  0.75 -0.09 -1.12  1.25  0.    0.    0.    1.    1.  ]
 [ 2.    1.25  0.    0.    0.    1.25 -0.03 -0.04  0.    0.    0.  ]
 [ 2.    1.25 -0.03 -0.04  0.    1.25 -0.08 -0.09  0.01  0.    0.  ]
 [ 2.    1.25 -0.08 -0.09  0.01  1.24 -0.1  -0.11 -0.02  0.    0.  ]
 [ 2.    1.24 -0.1  -0.11 -0.02  1.23 -0.12 -0.1  -0.08  0.    0.  ]]

use_next_obs=True
[[ 1.    0.89  0.8  -0.1  -0.7   0.87  0.79 -0.11 -0.76  0.    0.  ]
 [ 1.    0.87  0.79 -0.11 -0.76  0.86  0.78 -0.11 -0.84  0.    0.  ]
 [ 1.    0.86  0.78 -0.11 -0.84  0.84  0.78 -0.09 -0.93  0.    0.  ]
 [ 1.    0.84  0.78 -0.09 -0.93  0.83  0.78 -0.07 -1.03  0.    0

In [49]:
# check between step observation difference
obs_mean = parsed_dataset_1["obs"].mean(0)
obs_std = parsed_dataset_1["obs"].std(0)

obs_diff_1 = np.abs(
    (parsed_dataset_1["next_obs"] - obs_mean) / obs_std - \
    (parsed_dataset_1["obs"] - obs_mean) / obs_std
)

obs_diff_2 = np.abs(
    (parsed_dataset_2["next_obs"] - obs_mean) / obs_std - \
    (parsed_dataset_2["obs"] - obs_mean) / obs_std
)

obs_diff_1_75_percent = pd.DataFrame(obs_diff_1).describe().iloc[6].values
obs_diff_1_max = pd.DataFrame(obs_diff_1).describe().iloc[7].values

obs_diff_2_75_percent = pd.DataFrame(obs_diff_2).describe().iloc[6].values
obs_diff_2_max = pd.DataFrame(obs_diff_2).describe().iloc[7].values

In [50]:
print("compare 75 percentile")
print(np.stack([obs_diff_1_75_percent, obs_diff_2_75_percent]).round(3))

print("compare max")
print(np.stack([obs_diff_1_max, obs_diff_2_max]).round(3))

# d4rl parsing methods incurr significant error 

compare 75 percentile
[[0.073 0.041 0.095 0.104 0.062 0.018 0.096 0.104 0.126 0.191 0.401 0.431
  0.582 0.224 0.248 0.446 0.317]
 [0.073 0.041 0.095 0.104 0.062 0.018 0.096 0.104 0.126 0.191 0.401 0.431
  0.582 0.224 0.248 0.446 0.316]]
compare max
[[ 4.565  3.858  7.849 11.142  2.052 11.057  7.137  2.266  5.334  6.22
   9.789  9.181  6.103  4.196 17.067  4.046  3.395]
 [ 0.391  1.06   2.325  2.173  1.002  2.431  1.081  1.072  1.691  2.676
   9.789  9.181  6.103  4.196 17.067  4.046  3.395]]
