# Rectify map-matching errors

In [None]:
import os
import copy
import dill
import pickle
from tqdm import tqdm, trange

import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib.transforms as mtransforms

import osmnx as ox
import geopandas as gpd
from shapely.geometry import Point, LineString

from rectify_shortestpath import rectify_shortest_path_avg_dist

In [None]:
proj_dir = "<YOUR_PROJECT_DIRECTORY>"

# Plot theme
sns.set_theme()
sns.set_context("paper")

## Load datasets

In [None]:
# Load road network graph
graph_name = "aa_road_graph_drive_service_bbox_time_speed_bearing.graphml"  # Ann Arbor
# graph_name = "la_road_graph_drive_service_bbox_time_speed_bearing.graphml"  # Los Angeles
roadnet_graph = ox.load_graphml(os.path.join(proj_dir, graph_name))
print("Road graph loaded! CRS:", aa_driveservice_graph.graph["crs"])

Project to UTM. Note that the zone should be consistent with the study area.

In [None]:
# Project to UTM 17N
aa_graph_utm17n = ox.project_graph(roadnet_graph, to_crs="EPSG:32617")  # WGS 84 / UTM zone 17N
aa_graph_utm17n.graph["crs"]

In [None]:
# Load pickle
pk_name = "samp_pts_ls_stride2_labeled184.pkl"
with open(os.path.join(proj_dir, "Data", pk_name), 'rb') as my_file_obj:
    samp_pts_ls = pickle.load(my_file_obj)

print("Number of samples:", len(samp_pts_ls))

In [None]:
# Load pickle
pk_name = "labeled184_pts_traj_ptid.pkl"
with open(os.path.join(proj_dir, "Data", pk_name), 'rb') as my_file_obj:
    pts_gdf, traj_ptid_ls = pickle.load(my_file_obj)

print("Number of trajectories:", len(traj_ptid_ls))
print("Number of points:", len(pts_gdf))
pts_gdf.head()

In [None]:
# Load pickle
pk_name = "probs_pred_labels_ls_allsampleset_stride2_labeled184_aa.pkl"
with open(
    os.path.join(proj_dir, "Data", pk_name), 'rb'
) as my_file_obj:
    probs_ls, pred_labels_ls = pickle.load(my_file_obj)

print("Number of samples:", len(pred_labels_ls))

## Error rectification

In [None]:
rect_traj_gdf_ls = []

for sample_id in trange(len(err_seg_ls)):
    rect_traj_gdf = rectify_shortest_path_avg_dist(sample_id, err_seg_ls, samp_trajs_gdf, subgraph_ls)
    rect_traj_gdf_ls.append(rect_traj_gdf)

rect_trajs_gdf = gpd.GeoDataFrame(
    pd.concat(rect_traj_gdf_ls), crs=rect_traj_gdf_ls[0].crs
)
rect_trajs_gdf.head()

In [None]:
# Save to pickle
pk_name = f"rect_trajs_gdf_stride2_labeled184_aa.pkl"
with open(
    os.path.join(proj_dir, "Data", pk_name), 'wb'
) as my_file_obj:
    pickle.dump(rect_trajs_gdf, my_file_obj)

In [None]:
# Load pickle
pk_name = f"rect_trajs_gdf_stride2_labeled184_aa.pkl"
with open(
    os.path.join(proj_dir, "Data", pk_name), 'rb'
) as my_file_obj:
    rect_trajs_gdf = pickle.load(my_file_obj)

print("Number of rectified trajectories:", len(rect_trajs_gdf))

## Visualization

### Sample subgraph

In [None]:
from utils import sample_subgraph

In [None]:
subgraph_ls = sample_subgraph(aa_graph_utm17n, samp_trajs_gdf, buffer_dist=500)

In [None]:
# Save to pickle
pk_name = f"samp_trajs_gdf_subgraph_ls_stride2_labeled184_aa.dpk"
with open(
    os.path.join(proj_dir, "Data", pk_name), 'wb'
) as my_file_obj:
    dill.dump([samp_trajs_gdf, subgraph_ls], my_file_obj)

In [None]:
# Load pickle
pk_name = f"samp_trajs_gdf_subgraph_ls_stride2_labeled184_aa.dpk"
with open(
    os.path.join(proj_dir, "Data", pk_name), 'rb'
) as my_file_obj:
    samp_trajs_gdf, subgraph_ls = dill.load(my_file_obj)

print("Number of sample trajectories:", len(samp_trajs_gdf))
print("Number of sample subgraphs:", len(subgraph_ls))

### Plot rectified trajectory on road network graph

In [None]:
def plot_rect_traj(
        sample_id, samp_trajs_gdf, pred_labels_ls, rect_trajs_gdf, fig=None, ax=None, figsize=None,
        s=3, show_legend=False, legend_loc="best", save_figure=False, save_dir=None, fig_name=f"Sample.png"
):
    if fig is None or ax is None:
        if figsize is None:
            fig, ax = plt.subplots()
        else:
            fig, ax = plt.subplots(figsize=figsize)

    # Predicted labels of the sample
    err_labels_arr = pred_labels_ls[sample_id].astype(bool)
    # Coordinates of the original map-matched trajectory
    samp_traj_gdf = samp_trajs_gdf.take([sample_id])
    coords_df = samp_traj_gdf.get_coordinates()
    # Error points
    x_error = coords_df.loc[err_labels_arr, "x"]
    y_error = coords_df.loc[err_labels_arr, "y"]
    # Correct points
    x_correct = coords_df.loc[~err_labels_arr, "x"]
    y_correct = coords_df.loc[~err_labels_arr, "y"]

    # Rectified coordinates
    rect_traj_gdf = rect_trajs_gdf.take([sample_id])
    rect_coords_df = rect_traj_gdf.get_coordinates()
    # Rectified points
    x_rect = rect_coords_df.loc[err_labels_arr, "x"]
    y_rect = rect_coords_df.loc[err_labels_arr, "y"]

    # Plot the points
    ax.scatter(x_correct, y_correct, s=s, color='lime', marker='o', zorder=5, label="Correct point")
    ax.scatter(x_error, y_error, s=s, color='red', marker='o', zorder=5, label="Errorneous point")
    ax.scatter(x_rect, y_rect, s=s, color='dodgerblue', marker='o', zorder=5, label="Rectified point")

    # Original map-matched trajectory
    samp_traj_gdf.plot(ax=ax, linewidth=1, color="navy", linestyle='dashed', label="Error-containing trajectory")
    # Rectified trajectory
    rect_traj_gdf.plot(ax=ax, linewidth=2, color="hotpink", label="Rectified trajectory")

    # Road network graph
    ox.plot_graph(subgraph_ls[sample_id], ax=ax, node_size=0) # Skip plotting the nodes

    if show_legend:
        ax.legend(loc=legend_loc)

    if save_figure:
        fig.savefig(os.path.join(proj_dir, save_dir, fig_name), dpi=800, bbox_inches='tight')

    plt.show()

In [None]:
sample_id = 8000
plot_rect_traj(
    sample_id, samp_trajs_gdf, pred_labels_ls, rect_trajs_gdf, figsize=(10, 11),
    show_legend=True, legend_loc="lower right", save_figure=True,
    save_dir="<YOUR_SAVE_DIRECTION>>", fig_name=f"Rectify_Sample{sample_id}.png"
)