In [None]:
import matplotlib.pyplot as plt
import numpy as np
import torch

# Import make_axes_locatable
from mpl_toolkits.axes_grid1 import make_axes_locatable

def plot_graphs_with_heatmap(edge_coordinates_list, trajectories_list, save_path='graphs_grid.png'):
    """
    Plots multiple graphs with heatmaps representing edge usage frequency and returns the percentage of used edges.

    Parameters:
    - edge_coordinates_list: List of tensors of edge coordinates for each graph.
                             Each element is a tensor of shape [num_edges, 2, 2].
    - trajectories_list: List of lists of trajectories for each graph.
                         Each trajectory is a dictionary containing 'edge_idxs' as a list or tensor of edge indices.
    - save_path: File path to save the figure.

    Returns:
    - percentages_used_edges: List of percentages of used edges for each graph.
    """
    # Set the plotting style and parameters
    plt.style.use('seaborn-v0_8-paper')
    plt.rcParams.update({
        'axes.labelsize': 14,
        'axes.titlesize': 18,
        'xtick.labelsize': 12,
        'ytick.labelsize': 12,
        'axes.grid': False,
        'grid.alpha': 0.4,
        'lines.linewidth': 1,
    })

    # Titles for each graph
    titles = ["T-Drive", "Geolife", "MoLe", "pNEUMA"]

    num_graphs = len(edge_coordinates_list)

    # Setup the subplot layout
    fig, axes = plt.subplots(2, num_graphs // 2, figsize=(16, 10))
    axes = axes.flatten()

    percentages_used_edges = []  # To store the percentages of used edges

    # Adjust layout to reduce space between subplots
    plt.subplots_adjust(wspace=0.05, hspace=0.1, right=0.85)  # Reduced spaces and adjusted right for colorbar

    # For each graph
    for i, ax in enumerate(axes):
        edge_coordinates = edge_coordinates_list[i]
        trajectories = trajectories_list[i]  # List of trajectories for the current graph

        num_edges = edge_coordinates.shape[0]

        # Flatten the list of trajectories to get all used edge indices
        all_used_edges = []
        for traj in trajectories:
            # Convert trajectory to numpy array if it's a tensor
            if isinstance(traj['edge_idxs'], torch.Tensor):
                traj_indices = traj['edge_idxs'].numpy()
            elif isinstance(traj['edge_idxs'], list):
                traj_indices = np.array(traj['edge_idxs'])
            all_used_edges.append(traj_indices)
        all_used_edges = np.concatenate(all_used_edges)

        # Count the usage frequency of each edge
        edge_usage_counts = np.zeros(num_edges, dtype=int)
        unique, counts = np.unique(all_used_edges, return_counts=True)
        edge_usage_counts[unique] = counts

        # Calculate the percentage of used edges
        num_used_edges = np.count_nonzero(edge_usage_counts)
        percentage_used = (num_used_edges / num_edges) * 100
        percentages_used_edges.append(percentage_used)

        # Normalize edge usage counts for coloring
        max_count = edge_usage_counts.max()
        if max_count == 0:
            max_count = 1  # To avoid division by zero
        edge_usage_norm = edge_usage_counts / max_count  # Normalize between 0 and 1

        # Create a colormap
        cmap = plt.cm.coolwarm  # You can choose any colormap you like
        norm = plt.Normalize(vmin=0, vmax=1)

        # Define linewidths based on usage counts
        lw_min = 1.0
        lw_max = 4.0
        edge_lw = lw_min + (edge_usage_norm * (lw_max - lw_min))

        # Plot all edges
        for idx, edge in enumerate(edge_coordinates):
            if edge_usage_counts[idx] == 0:
                # Edge never used
                color = 'lightgrey'
                lw = 0.5
            else:
                # Edge used
                color = cmap(norm(edge_usage_norm[idx]))
                lw = edge_lw[idx]
            ax.plot(edge[:, 0], edge[:, 1], color=color, linewidth=lw, zorder=1)

        # Remove axes ticks and labels
        ax.set_xticks([])
        ax.set_yticks([])
        ax.axis('off')

        # Adjust aspect ratio
        ax.set_aspect('equal')

        # Adjust limits with consistent margins
        xmin, xmax = edge_coordinates[:, :, 0].min(), edge_coordinates[:, :, 0].max()
        ymin, ymax = edge_coordinates[:, :, 1].min(), edge_coordinates[:, :, 1].max()
        x_margin = (xmax - xmin) * 0.05
        y_margin = (ymax - ymin) * 0.05
        ax.set_xlim(xmin - x_margin, xmax + x_margin)
        ax.set_ylim(ymin - y_margin, ymax + y_margin)

        # Set title above each subplot with appropriate padding
        ax.set_title(titles[i], fontsize=18, pad=5)

    # Add a single colorbar for all subplots
    cbar_ax = fig.add_axes([0.88, 0.15, 0.02, 0.7])  # Adjust position as needed
    sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
    sm.set_array([])
    cbar = plt.colorbar(sm, cax=cbar_ax)
    cbar.ax.tick_params(labelsize=14)
    cbar.ax.set_ylabel('Usage Frequency', fontsize=16, rotation=270, labelpad=15)

    # Save and show the figure
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.show()

    return percentages_used_edges


In [None]:
import sys
sys.path.append('/ceph/hdd/students/schmitj/MA_Diffusion_based_trajectory_prediction')

from dataset.trajectory_dataset_geometric import TrajectoryGeoDataset

paths_tdrive_0, nodes_tdrive, edges_tdrive, tdrive_edge_coordinates = TrajectoryGeoDataset.load_new_format(f'/ceph/hdd/students/schmitj/MA_Diffusion_based_trajectory_prediction/data/tdrive_train.h5', 'cpu')
paths_tdrive_1, nodes_tdrive, edges_tdrive, tdrive_edge_coordinates = TrajectoryGeoDataset.load_new_format(f'/ceph/hdd/students/schmitj/MA_Diffusion_based_trajectory_prediction/data/tdrive_val.h5', 'cpu')
paths_tdrive_2, nodes_tdrive, edges_tdrive, tdrive_edge_coordinates = TrajectoryGeoDataset.load_new_format(f'/ceph/hdd/students/schmitj/MA_Diffusion_based_trajectory_prediction/data/tdrive_test.h5', 'cpu')
paths_geolife_0, nodes_geolife, edges_geolife, geolife_edge_coordinates = TrajectoryGeoDataset.load_new_format(f'/ceph/hdd/students/schmitj/MA_Diffusion_based_trajectory_prediction/data/geolife_train.h5', 'cpu')
paths_geolife_1, nodes_geolife, edges_geolife, geolife_edge_coordinates = TrajectoryGeoDataset.load_new_format(f'/ceph/hdd/students/schmitj/MA_Diffusion_based_trajectory_prediction/data/geolife_val.h5', 'cpu')
paths_geolife_2, nodes_geolife, edges_geolife, geolife_edge_coordinates = TrajectoryGeoDataset.load_new_format(f'/ceph/hdd/students/schmitj/MA_Diffusion_based_trajectory_prediction/data/geolife_test.h5', 'cpu')
paths_munich_0, nodes_munich, edges_munich, munich_edge_coordinates = TrajectoryGeoDataset.load_new_format(f'/ceph/hdd/students/schmitj/MA_Diffusion_based_trajectory_prediction/data/munich_train.h5', 'cpu')
paths_munich_1, nodes_munich, edges_munich, munich_edge_coordinates = TrajectoryGeoDataset.load_new_format(f'/ceph/hdd/students/schmitj/MA_Diffusion_based_trajectory_prediction/data/munich_val.h5', 'cpu')
paths_munich_2, nodes_munich, edges_munich, munich_edge_coordinates = TrajectoryGeoDataset.load_new_format(f'/ceph/hdd/students/schmitj/MA_Diffusion_based_trajectory_prediction/data/munich_test.h5', 'cpu')
paths_pneuma_0, nodes_pneuma, edges_pneuma, pneuma_edge_coordinates = TrajectoryGeoDataset.load_new_format(f'/ceph/hdd/students/schmitj/MA_Diffusion_based_trajectory_prediction/data/pneuma_train.h5', 'cpu')
paths_pneuma_1, nodes_pneuma, edges_pneuma, pneuma_edge_coordinates = TrajectoryGeoDataset.load_new_format(f'/ceph/hdd/students/schmitj/MA_Diffusion_based_trajectory_prediction/data/pneuma_val.h5', 'cpu')
paths_pneuma_2, nodes_pneuma, edges_pneuma, pneuma_edge_coordinates = TrajectoryGeoDataset.load_new_format(f'/ceph/hdd/students/schmitj/MA_Diffusion_based_trajectory_prediction/data/pneuma_test.h5', 'cpu')

In [None]:
tdrive_all = paths_tdrive_0 + paths_tdrive_1 + paths_tdrive_2
geolife_all = paths_geolife_0 + paths_geolife_1 + paths_geolife_2
munich_all = paths_munich_0 + paths_munich_1 + paths_munich_2
pneuma_all = paths_pneuma_0 + paths_pneuma_1 + paths_pneuma_2
trajectories_list = [tdrive_all, geolife_all, munich_all, pneuma_all]
coordinates_list = [tdrive_edge_coordinates, geolife_edge_coordinates, munich_edge_coordinates, pneuma_edge_coordinates]

In [None]:
precentages = plot_graphs_with_heatmap(coordinates_list, trajectories_list, save_path='heatmap.png')
print(precentages)