In [None]:
import matplotlib.pyplot as plt
import numpy as np
import torch

def plot_graphs(edge_coordinates_list, save_path='graphs_grid.png'):
    """
    Plots multiple graphs with specified layout and annotations.

    Parameters:
    - edge_coordinates_list: List of tensors of edge coordinates for each graph.
    - save_path: File path to save the figure.
    """
    # Set the plotting style and parameters
    plt.style.use('seaborn-v0_8-paper')
    plt.rcParams.update({
        'axes.labelsize': 14,
        'axes.titlesize': 18,
        'xtick.labelsize': 12,
        'ytick.labelsize': 12,
        'axes.grid': False,
        'grid.alpha': 0.4,
        'lines.linewidth': 1,
    })

    # Titles for each graph
    titles = ["Beijing (T-Drive/Geolife)", "Munich (MoLe)", "Athens (pNEUMA)"]

    # Setup the subplot layout
    fig = plt.figure(figsize=(16, 10))
    import matplotlib.gridspec as gridspec
    gs = gridspec.GridSpec(1, 3)

    # Create subplots
    ax1 = plt.subplot(gs[0])
    ax2 = plt.subplot(gs[1])
    ax3 = plt.subplot(gs[2])

    axes = [ax1, ax2, ax3]

    # Plot each graph
    for i, ax in enumerate(axes):
        edge_coordinates = edge_coordinates_list[i]

        # Plot all edges
        for edge in edge_coordinates:
            ax.plot(edge[:, 0], edge[:, 1], color='grey', linewidth=0.5, zorder=1)

        # Remove axes ticks and labels
        ax.set_xticks([])
        ax.set_yticks([])
        ax.axis('off')

        # Adjust aspect ratio
        ax.set_aspect('equal')

        # Adjust limits with consistent margins
        xmin, xmax = edge_coordinates[:, :, 0].min(), edge_coordinates[:, :, 0].max()
        ymin, ymax = edge_coordinates[:, :, 1].min(), edge_coordinates[:, :, 1].max()
        x_margin = (xmax - xmin) * 0.05
        y_margin = (ymax - ymin) * 0.05
        ax.set_xlim(xmin - x_margin, xmax + x_margin)
        ax.set_ylim(ymin - y_margin, ymax + y_margin)

    # Adjust layout
    plt.subplots_adjust(wspace=0.05, hspace=0.03)

    # Place titles at the same vertical position
    positions = [ax.get_position() for ax in axes]
    max_y1 = max(pos.y1 for pos in positions)
    title_y = max_y1 + 0.02  # Adjust the offset as needed

    for i, ax in enumerate(axes):
        pos = ax.get_position()
        x = pos.x0 + pos.width / 2
        fig.text(x, title_y, titles[i], ha='center', fontsize=18)

    # Save and show the figure
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.show()

In [None]:
import sys
sys.path.append('/ceph/hdd/students/schmitj/MA_Diffusion_based_trajectory_prediction')

from dataset.trajectory_dataset_geometric import TrajectoryGeoDataset

paths_tdrive, nodes_tdrive, edges_tdrive, tdrive_edge_coordinates = TrajectoryGeoDataset.load_new_format(f'/ceph/hdd/students/schmitj/MA_Diffusion_based_trajectory_prediction/data/tdrive_val.h5', 'cpu')
paths_munich, nodes_munich, edges_munich, munich_edge_coordinates = TrajectoryGeoDataset.load_new_format(f'/ceph/hdd/students/schmitj/MA_Diffusion_based_trajectory_prediction/data/munich_val.h5', 'cpu')
paths_pneuma, nodes_pneuma, edges_pneuma, pneuma_edge_coordinates = TrajectoryGeoDataset.load_new_format(f'/ceph/hdd/students/schmitj/MA_Diffusion_based_trajectory_prediction/data/pneuma_val.h5', 'cpu')
edge_coordinates_list = [tdrive_edge_coordinates, munich_edge_coordinates, pneuma_edge_coordinates]

In [None]:
plot_graphs(edge_coordinates_list, save_path='graphs_grid.png')