In [None]:
from torch import tensor
import os
from tqdm import tqdm
import torch
import h5py
import numpy as np
import matplotlib.pyplot as plt
import networkx as nx

## Enter Dataset, transition matrix, and features

In [None]:
folder = '/ceph/hdd/students/schmitj/MA_Diffusion_based_trajectory_prediction/experiments'
dataset = 'pneuma'  # pneuma, tdrive, geolife, munich
model = 'residual'  # residual, GAT, MLP
transition_matrix = 'custom'    # custom, marginal_prior
noising = 'linear'  # cosine, linear
features = 'one_hot_edges_coordinates_pos_encoding_pw_distance_edge_length_edge_angles'
history = 5
future = 2

# Single sample
Metrics

# Multiplte sample
Metrics

## Load Data

In [None]:
sample_list = torch.load(f'{folder}/{dataset}_{model}/{transition_matrix}_{noising}/samples_one_shot_{features}_hist{history}_fut_{future}.pth')
valid_ids = torch.load(f'{folder}/{dataset}_{model}/{transition_matrix}_{noising}/valid_ids_{features}_hist{history}_fut_{future}.pth')
samples_raw = torch.load(f'{folder}/{dataset}_{model}/{transition_matrix}_{noising}/samples_raw_{features}_hist{history}_fut_{future}.pth')
samples_valid = torch.load(f'{folder}/{dataset}_{model}/{transition_matrix}_{noising}/samples_valid_{features}_hist{history}_fut_{future}.pth')
ground_truth_hist = torch.load(f'{folder}/{dataset}_{model}/{transition_matrix}_{noising}/gt_hist_{features}_hist{history}_fut_{future}.pth')
ground_truth_fut = torch.load(f'{folder}/{dataset}_{model}/{transition_matrix}_{noising}/gt_fut_{features}_hist{history}_fut_{future}.pth')

In [None]:
res = {'sample_list': sample_list, 'samples_valid': samples_valid, 'valid_ids': valid_ids, 'samples_raw': samples_raw, 'ground_truth_hist': ground_truth_hist, 'ground_truth_fut': ground_truth_fut}

## Load Graph

In [None]:
def load_new_format(file_path, device):
        paths = []
        with h5py.File(file_path, 'r') as new_hf:
            node_coordinates = torch.tensor(new_hf['graph']['node_coordinates'][:], dtype=torch.float, device=device)
            #edges = torch.tensor(new_hf['graph']['edges'][:], dtype=torch.long, device=device)
            edges = new_hf['graph']['edges'][:]
            edge_coordinates = node_coordinates[edges]
            nodes = [(i, {'pos': torch.tensor(pos, device=device)}) for i, pos in enumerate(node_coordinates)]
            #edges = [(torch.tensor(edge[0], device=device), torch.tensor(edge[1], device=device)) for edge in edges]
            edges = [tuple(edge) for edge in edges]

            '''nodes = [(i, {'pos': tuple(pos)}) for i, pos in enumerate(node_coordinates)]
            edges = [tuple(edge) for edge in edges]'''

            for i in tqdm(new_hf['trajectories'].keys()):
                path_group = new_hf['trajectories'][i]
                path = {attr: torch.tensor(path_group[attr][()], device=device) for attr in path_group.keys() if attr in ['coordinates', 'edge_idxs', 'edge_orientations']}
                # path = {attr: path_group[attr][()] for attr in path_group.keys()}
                paths.append(path)
            
        return paths, nodes, edges, edge_coordinates
    
paths, nodes, edges, edge_coordinates = load_new_format(f'/ceph/hdd/students/schmitj/MA_Diffusion_based_trajectory_prediction/data/{dataset}_val.h5', 'cpu')
indexed_edges = [((start, end), index) for index, (start, end) in enumerate(edges)]
G = nx.Graph()
G.add_nodes_from(nodes)
for (start, end), index in indexed_edges:
        G.add_edge(start, end, index=index, default_orientation=(start, end))

## Plot Single Sample

In [None]:
def plot_paths_random(ground_truth_hist, ground_truth_fut, samples, edge_coordinates, num_paths_to_plot=4, zoom_in=True, valid=False, valid_ids=None):
    # Setup the subplot layout based on num_paths_to_plot
    fig, axs = plt.subplots(num_paths_to_plot, 1, figsize=(10, 5 * num_paths_to_plot), squeeze=False)
    axs = axs.flatten()  # Flatten in case of a single subplot to standardize indexing

    path_count = 0
    while path_count < num_paths_to_plot:
        batch_idx = torch.randint(0, len(ground_truth_hist), (1,)).item()
        idx = torch.randint(0, len(ground_truth_hist[batch_idx]), (1,)).item()
        if valid:
            if valid_ids[batch_idx][idx] is None:
                continue
        
        ax = axs[path_count]
        if idx >= len(ground_truth_hist[batch_idx]):
            continue
        
        if zoom_in:
            all_edges = torch.cat([ground_truth_hist[batch_idx][idx], ground_truth_fut[batch_idx][idx], samples[batch_idx][idx]])
            if all_edges.numel() > 0:
                all_coords = edge_coordinates[all_edges].view(-1, 2)
                xmin, xmax = all_coords.min(0)[0], all_coords.max(0)[0]
                margin = (xmax - xmin) * 0.1
                xmin -= margin
                xmax += margin
                ax.set_xlim(xmin[0].item(), xmax[0].item())
                ax.set_ylim(xmin[1].item(), xmax[1].item())
        # Plot all edges as background
        for edge in edge_coordinates:
            ax.plot(edge[:, 0], edge[:, 1], color='grey', linewidth=0.5)
        
        # Define the plot_trajectory within this function to ensure it uses 'ax'
        def plot_trajectory(edge_indices, color, label):
            added_label = False  # Flag to add label only once
            if edge_indices.dim() == 2:
                edge_indices = edge_indices.squeeze(0)
            if edge_indices.numel() > 0:
                for edge_idx in edge_indices:
                    edge = edge_coordinates[edge_idx]
                    line_style = ':' if label == 'Predicted Future' else '-'  # Use dotted lines for 'Predicted Future'
                    lw = 4 if label == 'Predicted Future' else 2  # Use thicker lines for 'Predicted Future'
                    if not added_label:
                        ax.plot(edge[:, 0], edge[:, 1], color=color, linewidth=lw, linestyle=line_style, label=label)
                        added_label = True
                    else:
                        ax.plot(edge[:, 0], edge[:, 1], color=color, linewidth=lw, linestyle=line_style)

        plot_trajectory(ground_truth_hist[batch_idx][idx], 'blue', 'History')
        plot_trajectory(ground_truth_fut[batch_idx][idx], 'green', 'Ground Truth Future')
        plot_trajectory(samples[batch_idx][idx], 'red', 'Predicted Future')

        ax.set_title(f'Trajectory {idx+1} of batch {batch_idx+1}')
        ax.legend(loc='upper left', fontsize=16)
        ax.axis('off')

        path_count += 1  # Increment the path counter
    plt.tight_layout()
    plt.show()

plot_paths_random(res['ground_truth_hist'], res['ground_truth_fut'], res['sample_list'], edge_coordinates, num_paths_to_plot=10, zoom_in=True)

## Plot Valid Samples

In [None]:
plot_paths_random(res['ground_truth_hist'], res['ground_truth_fut'], res['samples_valid'], edge_coordinates, num_paths_to_plot=10, zoom_in=True, valid=True, valid_ids=res['valid_ids'])

## Plot Path Density Random

In [None]:
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import torch
import matplotlib.lines as mlines

def plot_path_density_random(ground_truth_hist, ground_truth_fut, samples, edge_coordinates, num_paths_to_plot=4, random=False, start_id=0, zoom_in=True):
    # Setup the subplot layout based on num_paths_to_plot
    fig, axs = plt.subplots(num_paths_to_plot, 1, figsize=(10, 5 * num_paths_to_plot), squeeze=False)
    axs = axs.flatten()  # Flatten in case of a single subplot to standardize indexing

    path_count = 0
    path_count = 0
    while path_count < num_paths_to_plot:
        batch_idx = torch.randint(0, len(ground_truth_hist), (1,)).item()
        idx = torch.randint(0, len(ground_truth_hist[batch_idx]), (1,)).item()
        
        ax = axs[path_count]
        if idx >= len(ground_truth_hist[batch_idx]):
            continue

        # Configure zooming in on relevant trajectories
        all_samples_edges = torch.cat([sample[idx] for sample in samples[batch_idx]])
        all_edges = torch.cat([ground_truth_hist[batch_idx][idx], ground_truth_fut[batch_idx][idx], all_samples_edges])
        if zoom_in:
            # all_edges = torch.cat([ground_truth_hist[batch_idx][idx], ground_truth_fut[batch_idx][idx]] + [sample[idx] for sample in samples[batch_idx]])
            if all_edges.numel() > 0:
                all_coords = edge_coordinates[all_edges].view(-1, 2)
                xmin, xmax = all_coords.min(0)[0], all_coords.max(0)[0]
                margin = (xmax - xmin) * 0.1
                xmin -= margin
                xmax += margin
                ax.set_xlim(xmin[0].item(), xmax[0].item())
                ax.set_ylim(xmin[1].item(), xmax[1].item())

        # Plot background edges
        for edge in edge_coordinates:
            ax.plot(edge[:, 0], edge[:, 1], color='grey', linewidth=0.5)
            
        # Track how often each edge is used in samples
        edge_count = {}
        for sample in samples[batch_idx]:
            for edge_idx in sample[idx]:
                edge_count[edge_idx.item()] = edge_count.get(edge_idx.item(), 0) + 1

        max_count = max(edge_count.values()) if edge_count else 1
                
                
        # Define the plot_trajectory within this function to ensure it uses 'ax'
        def plot_trajectory(edge_indices, color, label):
            if isinstance(edge_indices, dict):
                edge_indices = edge_indices.items()
                added_label = False  # Flag to add label only once
                for edge_idx, count in edge_indices:
                    edge = edge_coordinates[edge_idx]
                    ax.plot(edge[:, 0], edge[:, 1], color='red', linewidth=2 + 3 * (count / max_count), alpha=0.3 + 0.7 * (count / max_count), linestyle='--')
            else:
                added_label = False  # Flag to add label only once
                if edge_indices.dim() == 2:
                    edge_indices = edge_indices.squeeze(0)
                if edge_indices.numel() > 0:
                    for edge_idx in edge_indices:
                        edge = edge_coordinates[edge_idx]
                        line_style = '-'
                        lw = 2
                        ax.plot(edge[:, 0], edge[:, 1], color=color, linewidth=lw, linestyle=line_style, label=label)

        plot_trajectory(ground_truth_hist[batch_idx][idx], 'blue', 'History')
        plot_trajectory(ground_truth_fut[batch_idx][idx], 'green', 'Ground Truth Future')
        plot_trajectory(edge_count, 'red', 'Predicted Future')
        
        custom_handle1 = mlines.Line2D([], [], color='blue', linestyle='-', markersize=15, label='History')
        custom_handle2 = mlines.Line2D([], [], color='green', linestyle='-', markersize=15, label='Ground Truth Future')
        custom_handle3 = mlines.Line2D([], [], color='red', linestyle='--', markersize=15, label='Predicted Future')


        ax.set_title(f'Trajectory {idx+1} of batch {batch_idx+1}')
        ax.legend(handles=[custom_handle1, custom_handle2, custom_handle3], loc='upper left', fontsize=12)
        ax.axis('off')

        path_count += 1  # Increment the path counter

    plt.tight_layout()
    plt.show()

plot_path_density_random(res['ground_truth_hist'], res['ground_truth_fut'], res['samples_raw'], edge_coordinates, num_paths_to_plot=10, zoom_in=True)

## Plot Multiple Samples

In [None]:
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import torch

def plot_multiple_paths_random(ground_truth_hist, ground_truth_fut, samples, edge_coordinates, num_paths_to_plot=4, random=False, start_id=0, zoom_in=True):
    # Setup the subplot layout based on num_paths_to_plot
    fig, axs = plt.subplots(num_paths_to_plot, 1, figsize=(10, 5 * num_paths_to_plot), squeeze=False)
    axs = axs.flatten()  # Flatten in case of a single subplot to standardize indexing

    path_count = 0
    path_count = 0
    while path_count < num_paths_to_plot:
        batch_idx = torch.randint(0, len(ground_truth_hist), (1,)).item()
        idx = torch.randint(0, len(ground_truth_hist[batch_idx]), (1,)).item()
        
        ax = axs[path_count]
        if idx >= len(ground_truth_hist[batch_idx]):
            continue

        # Configure zooming in on relevant trajectories
        all_samples_edges = torch.cat([sample[idx] for sample in samples[batch_idx]])
        all_edges = torch.cat([ground_truth_hist[batch_idx][idx], ground_truth_fut[batch_idx][idx], all_samples_edges])
        if zoom_in:
            # all_edges = torch.cat([ground_truth_hist[batch_idx][idx], ground_truth_fut[batch_idx][idx]] + [sample[idx] for sample in samples[batch_idx]])
            if all_edges.numel() > 0:
                all_coords = edge_coordinates[all_edges].view(-1, 2)
                xmin, xmax = all_coords.min(0)[0], all_coords.max(0)[0]
                margin = (xmax - xmin) * 0.1
                xmin -= margin
                xmax += margin
                ax.set_xlim(xmin[0].item(), xmax[0].item())
                ax.set_ylim(xmin[1].item(), xmax[1].item())

        # Plot background edges
        for edge in edge_coordinates:
            ax.plot(edge[:, 0], edge[:, 1], color='grey', linewidth=0.5)

        # Define the plot_trajectory within this function to ensure it uses 'ax'
        def plot_trajectory(edge_indices, color, label):
            added_label = False  # Flag to add label only once
            if edge_indices.dim() == 2:
                edge_indices = edge_indices.squeeze(0)
            if edge_indices.numel() > 0:
                for edge_idx in edge_indices:
                    edge = edge_coordinates[edge_idx]
                    line_style = '-' if 'Predicted' in label else '-'  # Use dotted lines for 'Predicted Future'
                    lw = 4 if 'Predicted' in label else 2
                    alpha = 0.5 if 'Predicted' in label else 1
                    if not added_label:
                        ax.plot(edge[:, 0], edge[:, 1], color=color, linewidth=lw, linestyle=line_style, label=label)
                        added_label = True
                    else:
                        ax.plot(edge[:, 0], edge[:, 1], color=color, linewidth=lw, linestyle=line_style)

        plot_trajectory(ground_truth_hist[batch_idx][idx], 'blue', 'History')
        plot_trajectory(ground_truth_fut[batch_idx][idx], 'green', 'Ground Truth Future')
        
        # Plot all samples for this data point
        sample_colors = cm.get_cmap('tab10', len(samples[batch_idx]))
        for j, sample in enumerate(samples[batch_idx]):
            plot_trajectory(sample[idx], sample_colors(j), f'Predicted Future {j+1}')

        ax.set_title(f'Trajectory {idx+1} of batch {batch_idx+1}')
        ax.legend(loc='upper left')
        ax.axis('off')

        path_count += 1  # Increment the path counter

    plt.tight_layout()
    plt.show()

plot_multiple_paths_random(res['ground_truth_hist'], res['ground_truth_fut'], res['samples_raw'], edge_coordinates, num_paths_to_plot=10, zoom_in=True)

In [None]:
'''def plot_paths(ground_truth_hist, ground_truth_fut, samples, edge_coordinates, num_paths_to_plot=4, random=False, start_id=0, zoom_in=True):
    # Setup the subplot layout based on num_paths_to_plot
    fig, axs = plt.subplots(num_paths_to_plot, 1, figsize=(10, 5 * num_paths_to_plot), squeeze=False)
    axs = axs.flatten()  # Flatten in case of a single subplot to standardize indexing

    path_count = 0
    for batch_idx in range(len(ground_truth_hist)):
        # Determine the starting index
        start_idx = torch.randint(0, len(ground_truth_hist[batch_idx]) - num_paths_to_plot + 1, (1,)).item() if random else start_id
        for i in range(len(ground_truth_hist[batch_idx])):
            if path_count >= num_paths_to_plot:
                break  # Stop if we have plotted the desired number of paths

            ax = axs[path_count]
            idx = start_idx + i
            if idx >= len(ground_truth_hist[batch_idx]):
                continue  # Prevent indexing beyond the number of samples

            # Configure zooming in on relevant trajectories
            if zoom_in:
                all_edges = torch.cat([ground_truth_hist[batch_idx][idx], ground_truth_fut[batch_idx][idx], samples[batch_idx][idx]])
                if all_edges.numel() > 0:
                    all_coords = edge_coordinates[all_edges].view(-1, 2)
                    xmin, xmax = all_coords.min(0)[0], all_coords.max(0)[0]
                    margin = (xmax - xmin) * 0.1
                    xmin -= margin
                    xmax += margin
                    ax.set_xlim(xmin[0].item(), xmax[0].item())
                    ax.set_ylim(xmin[1].item(), xmax[1].item())

            # Plot all edges as background
            for edge in edge_coordinates:
                ax.plot(edge[:, 0], edge[:, 1], color='grey', linewidth=0.5)

            # Define the plot_trajectory within this function to ensure it uses 'ax'
            def plot_trajectory(edge_indices, color, label):
                added_label = False  # Flag to add label only once
                if edge_indices.dim() == 2:
                    edge_indices = edge_indices.squeeze(0)
                if edge_indices.numel() > 0:
                    for edge_idx in edge_indices:
                        edge = edge_coordinates[edge_idx]
                        line_style = ':' if label == 'Predicted Future' else '-'  # Use dotted lines for 'Predicted Future'
                        lw = 4 if label == 'Predicted Future' else 2  # Use thicker lines for 'Predicted Future'
                        if not added_label:
                            ax.plot(edge[:, 0], edge[:, 1], color=color, linewidth=lw, linestyle=line_style, label=label)
                            added_label = True
                        else:
                            ax.plot(edge[:, 0], edge[:, 1], color=color, linewidth=lw, linestyle=line_style)

            plot_trajectory(ground_truth_hist[batch_idx][idx], 'blue', 'History')
            plot_trajectory(ground_truth_fut[batch_idx][idx], 'green', 'Ground Truth Future')
            plot_trajectory(samples[batch_idx][idx], 'red', 'Predicted Future')

            ax.set_title(f'Trajectory {idx+1} of batch {batch_idx+1}')
            ax.legend(loc='upper left', fontsize=16)
            ax.axis('off')

            path_count += 1  # Increment the path counter

    plt.tight_layout()
    plt.show()

plot_paths(res['ground_truth_hist'], res['ground_truth_fut'], res['sample_list'], edge_coordinates, num_paths_to_plot=70, zoom_in=True)'''

In [None]:
'''import matplotlib.pyplot as plt
import matplotlib.cm as cm
import torch
import matplotlib.lines as mlines

def plot_path_density(ground_truth_hist, ground_truth_fut, samples, edge_coordinates, num_paths_to_plot=4, random=False, start_id=0, zoom_in=True):
    # Setup the subplot layout based on num_paths_to_plot
    fig, axs = plt.subplots(num_paths_to_plot, 1, figsize=(10, 5 * num_paths_to_plot), squeeze=False)
    axs = axs.flatten()  # Flatten in case of a single subplot to standardize indexing

    path_count = 0
    for batch_idx in range(len(ground_truth_hist)):
        # Determine the starting index
        start_idx = torch.randint(0, len(ground_truth_hist[batch_idx]) - num_paths_to_plot + 1, (1,)).item() if random else start_id
        for i in range(len(ground_truth_hist[batch_idx])):
            if path_count >= num_paths_to_plot:
                break  # Stop if we have plotted the desired number of paths

            ax = axs[path_count]
            idx = start_idx + i
            if idx >= len(ground_truth_hist[batch_idx]):
                continue  # Prevent indexing beyond the number of samples

            # Configure zooming in on relevant trajectories
            all_samples_edges = torch.cat([sample[idx] for sample in samples[batch_idx]])
            all_edges = torch.cat([ground_truth_hist[batch_idx][idx], ground_truth_fut[batch_idx][idx], all_samples_edges])
            if zoom_in:
                # all_edges = torch.cat([ground_truth_hist[batch_idx][idx], ground_truth_fut[batch_idx][idx]] + [sample[idx] for sample in samples[batch_idx]])
                if all_edges.numel() > 0:
                    all_coords = edge_coordinates[all_edges].view(-1, 2)
                    xmin, xmax = all_coords.min(0)[0], all_coords.max(0)[0]
                    margin = (xmax - xmin) * 0.1
                    xmin -= margin
                    xmax += margin
                    ax.set_xlim(xmin[0].item(), xmax[0].item())
                    ax.set_ylim(xmin[1].item(), xmax[1].item())

            # Plot background edges
            for edge in edge_coordinates:
                ax.plot(edge[:, 0], edge[:, 1], color='grey', linewidth=0.5)
                
            # Track how often each edge is used in samples
            edge_count = {}
            for sample in samples[batch_idx]:
                for edge_idx in sample[idx]:
                    edge_count[edge_idx.item()] = edge_count.get(edge_idx.item(), 0) + 1

            max_count = max(edge_count.values()) if edge_count else 1
                
                
            # Define the plot_trajectory within this function to ensure it uses 'ax'
            def plot_trajectory(edge_indices, color, label):
                if isinstance(edge_indices, dict):
                    edge_indices = edge_indices.items()
                    added_label = False  # Flag to add label only once
                    for edge_idx, count in edge_indices:
                        edge = edge_coordinates[edge_idx]
                        ax.plot(edge[:, 0], edge[:, 1], color='red', linewidth=2 + 3 * (count / max_count), alpha=0.3 + 0.7 * (count / max_count), linestyle='--')
                else:
                    added_label = False  # Flag to add label only once
                    if edge_indices.dim() == 2:
                        edge_indices = edge_indices.squeeze(0)
                    if edge_indices.numel() > 0:
                        for edge_idx in edge_indices:
                            edge = edge_coordinates[edge_idx]
                            line_style = '-'
                            lw = 2
                            ax.plot(edge[:, 0], edge[:, 1], color=color, linewidth=lw, linestyle=line_style, label=label)

            plot_trajectory(ground_truth_hist[batch_idx][idx], 'blue', 'History')
            plot_trajectory(ground_truth_fut[batch_idx][idx], 'green', 'Ground Truth Future')
            plot_trajectory(edge_count, 'red', 'Predicted Future')
            
            custom_handle1 = mlines.Line2D([], [], color='blue', linestyle='-', markersize=15, label='History')
            custom_handle2 = mlines.Line2D([], [], color='green', linestyle='-', markersize=15, label='Ground Truth Future')
            custom_handle3 = mlines.Line2D([], [], color='red', linestyle='--', markersize=15, label='Predicted Future')


            ax.set_title(f'Trajectory {idx+1} of batch {batch_idx+1}')
            ax.legend(handles=[custom_handle1, custom_handle2, custom_handle3], loc='upper left', fontsize=12)
            ax.axis('off')

            path_count += 1  # Increment the path counter

    plt.tight_layout()
    plt.show()

plot_path_density(res['ground_truth_hist'], res['ground_truth_fut'], res['samples_raw'], edge_coordinates, num_paths_to_plot=50, zoom_in=True)'''

In [None]:
'''import matplotlib.pyplot as plt
import matplotlib.cm as cm
import torch

def plot_multiple_paths(ground_truth_hist, ground_truth_fut, samples, edge_coordinates, num_paths_to_plot=4, random=False, start_id=0, zoom_in=True):
    # Setup the subplot layout based on num_paths_to_plot
    fig, axs = plt.subplots(num_paths_to_plot, 1, figsize=(10, 5 * num_paths_to_plot), squeeze=False)
    axs = axs.flatten()  # Flatten in case of a single subplot to standardize indexing

    path_count = 0
    for batch_idx in range(len(ground_truth_hist)):
        # Determine the starting index
        start_idx = torch.randint(0, len(ground_truth_hist[batch_idx]) - num_paths_to_plot + 1, (1,)).item() if random else start_id
        for i in range(len(ground_truth_hist[batch_idx])):
            if path_count >= num_paths_to_plot:
                break  # Stop if we have plotted the desired number of paths

            ax = axs[path_count]
            idx = start_idx + i
            if idx >= len(ground_truth_hist[batch_idx]):
                continue  # Prevent indexing beyond the number of samples

            # Configure zooming in on relevant trajectories
            if zoom_in:
                all_edges = torch.cat([ground_truth_hist[batch_idx][idx], ground_truth_fut[batch_idx][idx]] + [sample[idx] for sample in samples[batch_idx]])
                if all_edges.numel() > 0:
                    all_coords = edge_coordinates[all_edges].view(-1, 2)
                    xmin, xmax = all_coords.min(0)[0], all_coords.max(0)[0]
                    margin = (xmax - xmin) * 0.1
                    xmin -= margin
                    xmax += margin
                    ax.set_xlim(xmin[0].item(), xmax[0].item())
                    ax.set_ylim(xmin[1].item(), xmax[1].item())

            # Plot all edges as background
            for edge in edge_coordinates:
                ax.plot(edge[:, 0], edge[:, 1], color='grey', linewidth=0.5)

            # Define the plot_trajectory within this function to ensure it uses 'ax'
            def plot_trajectory(edge_indices, color, label):
                added_label = False  # Flag to add label only once
                if edge_indices.dim() == 2:
                    edge_indices = edge_indices.squeeze(0)
                if edge_indices.numel() > 0:
                    for edge_idx in edge_indices:
                        edge = edge_coordinates[edge_idx]
                        line_style = '-' if 'Predicted' in label else '-'  # Use dotted lines for 'Predicted Future'
                        lw = 4 if 'Predicted' in label else 2
                        alpha = 0.5 if 'Predicted' in label else 1
                        if not added_label:
                            ax.plot(edge[:, 0], edge[:, 1], color=color, linewidth=lw, linestyle=line_style, label=label)
                            added_label = True
                        else:
                            ax.plot(edge[:, 0], edge[:, 1], color=color, linewidth=lw, linestyle=line_style)

            plot_trajectory(ground_truth_hist[batch_idx][idx], 'blue', 'History')
            plot_trajectory(ground_truth_fut[batch_idx][idx], 'green', 'Ground Truth Future')
            
            # Plot all samples for this data point
            sample_colors = cm.get_cmap('tab10', len(samples[batch_idx]))
            for j, sample in enumerate(samples[batch_idx]):
                plot_trajectory(sample[idx], sample_colors(j), f'Predicted Future {j+1}')

            ax.set_title(f'Trajectory {idx+1} of batch {batch_idx+1}')
            ax.legend(loc='upper left')
            ax.axis('off')

            path_count += 1  # Increment the path counter

    plt.tight_layout()
    plt.show()

plot_multiple_paths(res['ground_truth_hist'], res['ground_truth_fut'], res['samples_raw'], edge_coordinates, num_paths_to_plot=50, zoom_in=True)'''