In [15]:
from torch import tensor
import os
from tqdm import tqdm
import torch
import h5py
import numpy as np
import matplotlib.pyplot as plt
import networkx as nx

In [16]:
# Configuration for multiple future lengths with different settings
folder = '/ceph/hdd/students/schmitj/MA_Diffusion_based_trajectory_prediction/experiments'
dataset = 'geolife'  # Common dataset for all configurations
test = True
history = 5
conditional = True  # Set to True if using conditional future lengths

# List of configurations, one per future length
configurations = [
    {
        'future_len': 0,
        'conditional_fut_len': None,
        'model': 'residual',
        'features': 'one_hot_edges_coordinates_pos_encoding_pw_distance_edge_length_edge_angles_num_pred_edges_future_len',
        'transition_matrix': 'custom',
        'noising': 'cosine',
    },
    {
        'future_len': 0,
        'conditional_fut_len': None,
        'model': 'residual',
        'features': 'one_hot_edges_coordinates_pos_encoding_pw_distance_edge_length_edge_angles_num_pred_edges_future_len',
        'transition_matrix': 'custom',
        'noising': 'cosine',
    },
]
#one_hot_edges_coordinates_pos_encoding_pw_distance_edge_length_edge_angles_num_pred_edges

# Ensure that only one of 'future_len' or 'conditional_fut_len' is set per configuration
for config in configurations:
    if conditional:
        config['future_len'] = None
    else:
        config['conditional_fut_len'] = None


In [17]:
import torch

res_dict = {}

for idx, config in enumerate(configurations):
    model = config['model']
    features = config['features']
    transition_matrix = config['transition_matrix']
    noising = config['noising']
    future_len = config['future_len']
    conditional_fut_len = config['conditional_fut_len']

    if conditional:
        fut_len = conditional_fut_len
        prefix = 'cond_'
        if test:
            prefix = 'test_' + prefix
    elif test:
        fut_len = future_len
        prefix = 'test_'
    else:
        fut_len = future_len
        prefix = ''

    # Construct the file path using the configuration settings
    base_path = f'{folder}/{dataset}_{model}/{transition_matrix}_{noising}'
    file_suffix = f'{features}_hist{history}_fut_{0 if conditional else fut_len}'

    if conditional:
        sample_list = torch.load(f'{base_path}/{prefix}fut_len_{fut_len}_samples_one_shot_{file_suffix}.pth')
        valid_ids = torch.load(f'{base_path}/{prefix}fut_len_{fut_len}_valid_ids_{file_suffix}.pth')
        samples_raw = torch.load(f'{base_path}/{prefix}fut_len_{fut_len}_samples_raw_{file_suffix}.pth')
        samples_valid = torch.load(f'{base_path}/{prefix}fut_len_{fut_len}_samples_valid_{file_suffix}.pth')
        ground_truth_hist = torch.load(f'{base_path}/{prefix}fut_len_{fut_len}_gt_hist_{file_suffix}.pth')
        ground_truth_fut = torch.load(f'{base_path}/{prefix}fut_len_{fut_len}_gt_fut_{file_suffix}.pth')
    else:
        sample_list = torch.load(f'{base_path}/{prefix}samples_one_shot_{file_suffix}.pth')
        valid_ids = torch.load(f'{base_path}/{prefix}valid_ids_{file_suffix}.pth')
        samples_raw = torch.load(f'{base_path}/{prefix}samples_raw_{file_suffix}.pth')
        samples_valid = torch.load(f'{base_path}/{prefix}samples_valid_{file_suffix}.pth')
        ground_truth_hist = torch.load(f'{base_path}/{prefix}gt_hist_{file_suffix}.pth')
        ground_truth_fut = torch.load(f'{base_path}/{prefix}gt_fut_{file_suffix}.pth')

    res = {
        'sample_list': sample_list,
        'samples_valid': samples_valid,
        'valid_ids': valid_ids,
        'samples_raw': samples_raw,
        'ground_truth_hist': ground_truth_hist,
        'ground_truth_fut': ground_truth_fut,
        'config': config,  # Store the configuration for reference
    }
    # Use a tuple of (future_length, configuration index) as the key
    res_dict[(fut_len, idx)] = res


  sample_list = torch.load(f'{base_path}/{prefix}fut_len_{fut_len}_samples_one_shot_{file_suffix}.pth')
  valid_ids = torch.load(f'{base_path}/{prefix}fut_len_{fut_len}_valid_ids_{file_suffix}.pth')
  samples_raw = torch.load(f'{base_path}/{prefix}fut_len_{fut_len}_samples_raw_{file_suffix}.pth')
  samples_valid = torch.load(f'{base_path}/{prefix}fut_len_{fut_len}_samples_valid_{file_suffix}.pth')
  ground_truth_hist = torch.load(f'{base_path}/{prefix}fut_len_{fut_len}_gt_hist_{file_suffix}.pth')
  ground_truth_fut = torch.load(f'{base_path}/{prefix}fut_len_{fut_len}_gt_fut_{file_suffix}.pth')


In [20]:
'''def load_new_format(file_path, edge_features, device):
        paths = []
        with h5py.File(file_path, 'r') as new_hf:
            node_coordinates = torch.tensor(new_hf['graph']['node_coordinates'][:], dtype=torch.float, device=device)
            # Normalize the coordinates to (0, 1) if any of the coordinates is larger than 1
            if node_coordinates.max() > 1:
                max_values = node_coordinates.max(0)[0]
                min_values = node_coordinates.min(0)[0]
                node_coordinates[:, 0] = (node_coordinates[:, 0] - min_values[0]) / (max_values[0] - min_values[0])
                node_coordinates[:, 1] = (node_coordinates[:, 1] - min_values[1]) / (max_values[1] - min_values[1])
            edges = new_hf['graph']['edges'][:]
            edge_coordinates = node_coordinates[edges]
            nodes = [(i, {'pos': torch.tensor(pos, device=device)}) for i, pos in enumerate(node_coordinates)]
            edges = [tuple(edge) for edge in edges]

            for i in tqdm(new_hf['trajectories'].keys()):
                path_group = new_hf['trajectories'][i]
                path = {attr: torch.tensor(path_group[attr][()], device=device) for attr in path_group.keys() if attr in ['coordinates', 'edge_idxs', 'edge_orientations']}
                paths.append(path)
            if 'road_type' in edge_features:
                onehot_encoded_road_type = new_hf['graph']['road_type'][:]
                return paths, nodes, edges, edge_coordinates, onehot_encoded_road_type
            else:
                return paths, nodes, edges, edge_coordinates
        return paths, nodes, edges, edge_coordinates'''

import sys
sys.path.append('/ceph/hdd/students/schmitj/MA_Diffusion_based_trajectory_prediction')

from dataset.trajectory_dataset_geometric import TrajectoryGeoDataset

if test:
    paths, nodes, edges, edge_coordinates = TrajectoryGeoDataset.load_new_format(f'/ceph/hdd/students/schmitj/MA_Diffusion_based_trajectory_prediction/data/{dataset}_test.h5', features, 'cpu')
else:
    paths, nodes, edges, edge_coordinates = TrajectoryGeoDataset.load_new_format(f'/ceph/hdd/students/schmitj/MA_Diffusion_based_trajectory_prediction/data/{dataset}_val.h5', features, 'cpu')
indexed_edges = [((start, end), index) for index, (start, end) in enumerate(edges)]
G = nx.Graph()
G.add_nodes_from(nodes)
for (start, end), index in indexed_edges:
        G.add_edge(start, end, index=index, default_orientation=(start, end))

  nodes = [(i, {'pos': torch.tensor(pos, device=device)}) for i, pos in enumerate(node_coordinates)]
100%|██████████| 2665/2665 [00:02<00:00, 942.94it/s] 


In [None]:
def plot_paths_random_multiple_configs(res_dict, edge_coordinates, configurations, conditional=True, valid=False,
                                       num_paths_per_config=8, num_cols=4, zoom_in=True,
                                       save_path='trajectory_grid.png'):
    """
    Plots a grid of trajectories for multiple configurations.

    Parameters:
    - res_dict: Dictionary with keys as (future_length, config index), values as data dictionaries.
    - edge_coordinates: Tensor of shape (num_edges, 2, 2) with edge coordinates.
    - configurations: List of configuration dictionaries.
    - conditional: Boolean indicating whether the model is conditional.
    - num_paths_per_config: Number of trajectories to plot per configuration.
    - num_cols: Number of columns in the grid.
    - zoom_in: If True, zooms into the trajectory area.
    - save_path: File path to save the figure.
    """
    import matplotlib.pyplot as plt
    import numpy as np
    from matplotlib.lines import Line2D

    # Set the plotting style and parameters
    plt.style.use('seaborn-v0_8-paper')
    plt.rcParams.update({
        'axes.labelsize': 14,
        'axes.titlesize': 16,
        'xtick.labelsize': 12,
        'ytick.labelsize': 12,
        'legend.fontsize': 12,
        'axes.grid': False,
        'grid.alpha': 0.4,
        'lines.linewidth': 2,
        'mathtext.fontset': 'cm',  # Use Computer Modern fonts for math
        'mathtext.rm': 'serif',
    })

    num_configs = len(configurations)
    num_rows_per_config = (num_paths_per_config + num_cols - 1) // num_cols  # Ceiling division
    total_rows = num_configs * num_rows_per_config

    # Setup the subplot layout with decreased vertical spacing within configurations
    fig, axs = plt.subplots(total_rows, num_cols, figsize=(5 * num_cols, 5 * total_rows),
                            gridspec_kw={'hspace': 0.2})  # Reduce hspace to decrease vertical spacing
    axs = np.array(axs)

    # Adjust layout to make room for the legend, labels, and to accommodate the left label
    plt.subplots_adjust(top=0.88, bottom=0.05, left=0.12, right=0.95, wspace=0.1)

    # Create legend elements manually to ensure all lines are included
    legend_elements = [
        Line2D([0], [0], color='blue', lw=2, label='History'),
        Line2D([0], [0], color='green', lw=2, label='Ground Truth Future'),
        Line2D([0], [0], color='orange', lw=2, linestyle=':', label='Predicted Future'),
    ]

    plot_idx = 0
    for config_idx, config in enumerate(configurations):
        fut_len = config['conditional_fut_len'] if conditional else config['future_len']
        key = (fut_len, config_idx)
        res = res_dict[key]
        sample_list = res['sample_list']
        if valid:
            sample_list = res['samples_valid']
        ground_truth_hist = res['ground_truth_hist']
        ground_truth_fut = res['ground_truth_fut']
        valid_ids = res['valid_ids']

        path_count = 0
        attempts = 0
        max_attempts = 1000

        while path_count < num_paths_per_config and attempts < max_attempts:
            batch_idx = torch.randint(0, len(ground_truth_hist), (1,)).item()
            idx = torch.randint(0, len(ground_truth_hist[batch_idx]), (1,)).item()
            attempts += 1
            if valid and valid_ids[batch_idx][idx] is None:
                continue

            if idx >= len(ground_truth_hist[batch_idx]):
                continue

            row = (config_idx * num_rows_per_config) + (path_count // num_cols)
            col = path_count % num_cols
            ax = axs[row, col] if total_rows > 1 else axs[col]
            ax.cla()

            # Plot all edges as background with lighter color
            for edge in edge_coordinates:
                ax.plot(edge[:, 0], edge[:, 1], color='lightgrey', linewidth=0.5, zorder=1)

            # Plot trajectories
            def plot_trajectory(edge_indices, color, linestyle='-', lw=2):
                edge_indices = edge_indices[edge_indices >= 0]
                if edge_indices.dim() == 2:
                    edge_indices = edge_indices.squeeze(0)
                if edge_indices.numel() > 0:
                    for edge_idx in edge_indices:
                        edge = edge_coordinates[edge_idx]
                        ax.plot(edge[:, 0], edge[:, 1], color=color, linewidth=lw, linestyle=linestyle, zorder=2)

            
            plot_trajectory(ground_truth_fut[batch_idx][idx], color='green', linestyle='-', lw=2)
            plot_trajectory(ground_truth_hist[batch_idx][idx], color='blue', linestyle='-', lw=2)
            plot_trajectory(sample_list[batch_idx][idx], color='orange', linestyle=':', lw=3)

            if zoom_in:
                all_edges = torch.cat([ground_truth_hist[batch_idx][idx], ground_truth_fut[batch_idx][idx], sample_list[batch_idx][idx]])
                all_edges = all_edges[all_edges >= 0]
                if all_edges.numel() > 0:
                    all_coords = edge_coordinates[all_edges].view(-1, 2)
                    xmin, xmax = all_coords[:, 0].min(), all_coords[:, 0].max()
                    ymin, ymax = all_coords[:, 1].min(), all_coords[:, 1].max()
                    x_margin = (xmax - xmin) * 0.2
                    y_margin = (ymax - ymin) * 0.2
                    ax.set_xlim(xmin - x_margin, xmax + x_margin)
                    ax.set_ylim(ymin - y_margin, ymax + y_margin)
            else:
                ax.set_xlim(edge_coordinates[:, :, 0].min(), edge_coordinates[:, :, 0].max())
                ax.set_ylim(edge_coordinates[:, :, 1].min(), edge_coordinates[:, :, 1].max())

            # Remove axes ticks and labels
            ax.set_xticks([])
            ax.set_yticks([])
            ax.axis('off')

            path_count += 1
            plot_idx += 1

    # Remove any unused subplots
    total_plots = total_rows * num_cols
    for i in range(plot_idx, total_plots):
        row = i // num_cols
        col = i % num_cols
        if total_rows > 1:
            fig.delaxes(axs[row, col])
        else:
            fig.delaxes(axs[col])

    # Adjust the figure canvas to update positions
    fig.canvas.draw()

    # Increase the vertical space between the legend and the first row of plots
    legend_y = 0.94  # Adjust this value to move the legend higher or lower
    fig.legend(handles=legend_elements, loc='upper center', ncol=3, bbox_to_anchor=(0.5, legend_y),
               frameon=True, fontsize=22)

    # Add horizontal lines and future length labels
    features = ['$F_1$', '$F_2$', '$F_3$']
    for config_idx in range(num_configs):
        # Determine the vertical positions
        row_start = config_idx * num_rows_per_config
        row_end = row_start + num_rows_per_config - 1

        # Get the positions of the top and bottom of the section
        bbox_top = axs[row_start, 0].get_position()
        bbox_bottom = axs[row_end, 0].get_position()

        section_top = bbox_top.y1
        section_bottom = bbox_bottom.y0

        # Calculate the middle y-position for the label
        section_middle = (section_top + section_bottom) / 2

        # Add the future length label to the left
        fut_len = configurations[config_idx]['conditional_fut_len'] if conditional else configurations[config_idx]['future_len']
        f = 'f'
        # Fix the LaTeX code for fut_len_display
        if conditional:
            fut_len_display = f'$f_{{cond}} = {fut_len}$'
        else:
            fut_len_display = f'$f = {fut_len}$'

        #fut_len_display = features[config_idx]
        # Add text to the left of the plots, centered vertically
        #fig.text(0.05, section_middle, fut_len_display, ha='left', va='center', fontsize=24)

        # Add horizontal line at the bottom of the section, except for the last one
        if config_idx < num_configs - 1:
            # Position for the line is between this section and the next
            next_bbox_top = axs[row_end + 1, 0].get_position()
            line_y = (bbox_bottom.y0 + next_bbox_top.y1) / 2

            # Add horizontal line across the figure
            fig.add_artist(Line2D([0.1, 0.99], [line_y, line_y], transform=fig.transFigure, color='lightgrey', linewidth=1.5))

    # Save and show the figure
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.show()


In [None]:
import os

save_path = f'{folder}/{dataset}_multiple_configs/output'
if not os.path.exists(save_path):
    os.makedirs(save_path)

if conditional:
    future_len_str = "_".join([str(config['conditional_fut_len']) for config in configurations])
    save_filename = f'cond_fut_lens_{future_len_str}_multiple_configs.png'
else:
    future_len_str = "_".join([str(config['future_len']) for config in configurations])
    save_filename = f'fut_lens_{future_len_str}_multiple_configs.png'

save_full_path = os.path.join(save_path, save_filename)

print("Saved at", save_full_path)

plot_paths_random_multiple_configs(res_dict, edge_coordinates, configurations,
                                   conditional=conditional, valid=False, num_paths_per_config=8, num_cols=4,
                                   save_path=save_full_path, zoom_in=True)


# Valid Paths

In [None]:
plot_paths_random_multiple_configs(res_dict, edge_coordinates, configurations,
                                   conditional=conditional, valid=True, num_paths_per_config=8, num_cols=4,
                                   save_path=save_full_path, zoom_in=True)