In [None]:
from torch import tensor
import os
from tqdm import tqdm
import torch
import h5py
import numpy as np
import matplotlib.pyplot as plt

# Single sample
Metrics

# Multiplte sample
Metrics

## Load Data

In [None]:
sample_list = torch.load('/ceph/hdd/students/schmitj/MA_Diffusion_based_trajectory_prediction/experiments/pneuma_residual/marginal_prior_cosine/samples_one_hot_edges_coordinates_pos_encoding_pw_distance_hist5_fut_2.pth')
samples_raw = torch.load('/ceph/hdd/students/schmitj/MA_Diffusion_based_trajectory_prediction/experiments/pneuma_residual/marginal_prior_cosine/samples_raw_one_hot_edges_coordinates_pos_encoding_pw_distance_hist5_fut_2.pth')
ground_truth_hist = torch.load('/ceph/hdd/students/schmitj/MA_Diffusion_based_trajectory_prediction/experiments/pneuma_residual/marginal_prior_cosine/gt_hist_one_hot_edges_coordinates_pos_encoding_pw_distance_hist5_fut_2.pth')
ground_truth_fut = torch.load('/ceph/hdd/students/schmitj/MA_Diffusion_based_trajectory_prediction/experiments/pneuma_residual/marginal_prior_cosine/gt_fut_one_hot_edges_coordinates_pos_encoding_pw_distance_hist5_fut_2.pth')

In [None]:
res = {'sample_list': sample_list, 'samples_raw': samples_raw, 'ground_truth_hist': ground_truth_hist, 'ground_truth_fut': ground_truth_fut}

## Load Graph

In [None]:
def load_new_format(file_path, device):
        paths = []
        with h5py.File(file_path, 'r') as new_hf:
            node_coordinates = torch.tensor(new_hf['graph']['node_coordinates'][:], dtype=torch.float, device=device)
            #edges = torch.tensor(new_hf['graph']['edges'][:], dtype=torch.long, device=device)
            edges = new_hf['graph']['edges'][:]
            edge_coordinates = node_coordinates[edges]
            nodes = [(i, {'pos': torch.tensor(pos, device=device)}) for i, pos in enumerate(node_coordinates)]
            #edges = [(torch.tensor(edge[0], device=device), torch.tensor(edge[1], device=device)) for edge in edges]
            edges = [tuple(edge) for edge in edges]

            '''nodes = [(i, {'pos': tuple(pos)}) for i, pos in enumerate(node_coordinates)]
            edges = [tuple(edge) for edge in edges]'''

            for i in tqdm(new_hf['trajectories'].keys()):
                path_group = new_hf['trajectories'][i]
                path = {attr: torch.tensor(path_group[attr][()], device=device) for attr in path_group.keys() if attr in ['coordinates', 'edge_idxs', 'edge_orientations']}
                # path = {attr: path_group[attr][()] for attr in path_group.keys()}
                paths.append(path)
            
        return paths, nodes, edges, edge_coordinates
    
paths, nodes, edges, edge_coordinates = load_new_format('/ceph/hdd/students/schmitj/MA_Diffusion_based_trajectory_prediction/data/pneuma_val.h5', 'cpu')

## Calculate best Metric from Multiple Raw Samples

In [None]:
def find_trajectory_endpoints(edge_sequence, edge_coordinates):
        """
        Find the start and end points of a trajectory based on a sequence of edge indices,
        accounting for the direction and connection of edges.
        
        Args:
            edge_coordinates (torch.Tensor): Coordinates of all edges in the graph, shape (num_edges, 2, 2).
                                            Each edge is represented by two points [point1, point2].
            edge_sequence (torch.Tensor): Indices of edges forming the trajectory, shape (sequence_length).
        
        Returns:
            tuple: Start point and end point of the trajectory.
        """
        # Get the coordinates of edges in the sequence
        trajectory_edges = edge_coordinates[edge_sequence]
        
        # Determine the start point by checking the connection of the first edge with the second
        if torch.norm(trajectory_edges[0, 0] - trajectory_edges[1, 0]) < torch.norm(trajectory_edges[0, 1] - trajectory_edges[1, 0]):
            start_point_coord = trajectory_edges[0, 1]  # Closer to the second edge's start
            start_point = edges[edge_sequence[0]][1]
        else:
            start_point_coord = trajectory_edges[0, 0]
            start_point = edges[edge_sequence[0]][0]
        
        # Determine the end point by checking the connection of the last edge with the second to last
        if torch.norm(trajectory_edges[-1, 1] - trajectory_edges[-2, 1]) < torch.norm(trajectory_edges[-1, 0] - trajectory_edges[-2, 1]):
            end_point_coord = trajectory_edges[-1, 0]  # Closer to the second to last edge's end
            end_point = edges[edge_sequence[-1]][0]
        else:
            end_point_coord = trajectory_edges[-1, 1]
            end_point = edges[edge_sequence[-1]][1]
        
        return start_point_coord, end_point_coord, start_point, end_point

In [None]:
def build_node_sequence(edge_sequence, edges, start_point):
    # Convert list of tuples to a tensor of shape [num_edges, 2]
    edge_tensor = torch.tensor(edges, dtype=torch.long)[edge_sequence]
    
    # Initialize the list of nodes with the start point
    node_sequence = [start_point]
    
    # Current node to find the next connected edge
    current_node = start_point
    
    if current_node not in edge_tensor:
        return []
    # Continue until we have traversed all edges
    while len(node_sequence) < len(edge_sequence) + 1:
        for i in range(edge_tensor.size(0)):
            # Check if the current edge connects to the current node
            if edge_tensor[i, 0] == current_node:
                # Append the connected node and update the current node
                node_sequence.append(edge_tensor[i, 1].item())
                current_node = edge_tensor[i, 1].item()
            elif edge_tensor[i, 1] == current_node:
                # Append the connected node and update the current node
                node_sequence.append(edge_tensor[i, 0].item())
                current_node = edge_tensor[i, 0].item()
            else:
                return []

    return node_sequence

In [None]:
def best_ade_and_fde(batched_preds, batched_gt_futs, batched_gt_hists, edge_coordinates, edges):
    eps = 1e-8
    ade_list = []
    fde_list = []
    best_ade_samples = []
    gt_futs = []
    gt_hists = []
    best_fde_samples = []
    for batch_idx in tqdm(range(len(batched_gt_futs))):
        best_ade_sample_list = []
        best_fde_sample_list = []
        gt_fut_list = []
        gt_hist_list = []
        for idx in range(len(batched_gt_futs[batch_idx])):
            preds = batched_preds[batch_idx]
            gt_hist = batched_gt_hists[batch_idx][idx]
            start_point_coords, end_point_coords, start_point, end_point = find_trajectory_endpoints(gt_hist, edge_coordinates)
            gt_fut_ = batched_gt_futs[batch_idx][idx]
            gt_fut = gt_fut_[gt_fut_ != -1]
            #print("\nGT Fut", gt_fut)
            gt_fut_nodes = []
            gt_fut_nodes = build_node_sequence(gt_fut, edges, end_point)
            if len(gt_fut_nodes) == 0:
                continue
            
            best_ade = 10
            best_fde = 10
            best_ade_sample = torch.tensor([])
            best_fde_sample = torch.tensor([])
            for pred in preds:
                ade = 0
                fde = 0
                current_node = end_point
                tol = 1
                traj = []
                sample = pred[idx]
                #print("\nSample", sample)
                connection_found = True
                
                if len(pred[idx]) == 0 or len(gt_fut) == 0:
                    # No ground truth edges and no predicted edges
                    if len(gt_fut) == 0 and len(pred[idx]) == 0:
                        ade = 0
                        fde = 0
                    # No predicted edges
                    else:
                        for i in range(len(gt_fut)):
                            ade += torch.norm(nodes[gt_fut_nodes[i]][1]['pos'] - nodes[gt_fut_nodes[i+1]][1]['pos'])
                        ade /= len(gt_fut)
                        fde = torch.norm(nodes[end_point][1]['pos'] - nodes[gt_fut_nodes[-1]][1]['pos'])
                else: 
                    while len(sample) > 0 and connection_found:
                        connection_found = False
                        for i, e in enumerate(sample):
                            if current_node in edges[e]:
                                traj.append(e.item())
                                mask = sample != e
                                sample = sample[mask]
                                current_node = edges[e][0] if edges[e][1] == current_node else edges[e][1]
                                connection_found = True
                                if len(traj) <= len(gt_fut):
                                    ade += torch.norm(nodes[current_node][1]['pos'] - nodes[gt_fut_nodes[len(traj)]][1]['pos'])
                                else:
                                    ade += torch.norm(nodes[current_node][1]['pos'] - nodes[gt_fut_nodes[-1]][1]['pos'])
                                break
                            else:
                                continue
                        if not connection_found and len(sample) > 0 and len(traj) == 0:
                            ade = 1
                            # Optional: Check if adding one edge could connect the sample
                            # Implement logic here if needed
                            pass
                    if len(gt_fut) > len(traj):
                        for i in range(len(traj), len(gt_fut)):
                            ade += torch.norm(nodes[gt_fut_nodes[i]][1]['pos'] - nodes[gt_fut_nodes[i+1]][1]['pos'])
                        ade /= len(gt_fut)
                    if connection_found or len(traj) > 0:
                        ade /= len(traj)
                        
                    if len(traj) == 0:
                        fde = torch.norm(nodes[end_point][1]['pos'] - nodes[gt_fut_nodes[-1]][1]['pos'])
                    else:
                        fde = torch.norm(nodes[current_node][1]['pos'] - nodes[gt_fut_nodes[-1]][1]['pos'])
                if len(traj) == 0:
                    ade = 1
                    fde = torch.norm(nodes[end_point][1]['pos'] - nodes[gt_fut_nodes[-1]][1]['pos'])
                if ade < best_ade:
                    best_ade = ade
                    best_ade_sample = torch.tensor(traj)
                if fde < best_fde:
                    best_fde = fde
                    best_fde_sample = torch.tensor(traj)
                    
            ade_list.append(best_ade)
            fde_list.append(best_fde)
            best_ade_sample_list.append(best_ade_sample)
            best_fde_sample_list.append(best_fde_sample)
            gt_fut_list.append(gt_fut_)
            gt_hist_list.append(gt_hist)
            
        best_ade_samples.append(best_ade_sample_list)
        best_fde_samples.append(best_fde_sample_list)
        gt_futs.append(gt_fut_list)
        gt_hists.append(gt_hist_list)
            

    return gt_futs, gt_hists, best_ade_samples, torch.mean(torch.tensor(ade_list)), best_fde_samples, torch.mean(torch.tensor(fde_list))

In [None]:
def f1_and_tpr(batched_preds, batched_gt_futs, edge_coordinates, threshold=1.0):
    f1_scores = []
    tprs = []
    best_samples = []
    for batch_idx in tqdm(range(len(batched_gt_futs))):
        best_sample_list = []
        for idx in range(len(batched_gt_futs[batch_idx])):
            preds = batched_preds[batch_idx]
            gt_fut = batched_gt_futs[batch_idx][idx]
            gt_fut_bin = torch.zeros(edge_coordinates.shape[0])
            gt_fut_bin[gt_fut] = 1
            best_f1 = 0
            best_tpr = 0
            best_sample = torch.tensor([])
            for pred in preds:
                sample = pred[idx]
                sample_bin = torch.zeros(edge_coordinates.shape[0])
                sample_bin[sample] = 1
                tp = torch.sum(gt_fut_bin * sample_bin)
                fp = torch.sum(sample_bin) - tp
                fn = torch.sum(gt_fut_bin) - tp
                
                if tp + fp == 0:
                    precision = 0
                else:
                    precision = tp / (tp + fp)
                    
                if tp + fn == 0:
                    recall = 0
                else:
                    recall = tp / (tp + fn)
                    
                if precision + recall == 0:
                    f1 = 0
                else:
                    f1 = 2 * precision * recall / (precision + recall)
                
                if recall > best_tpr:
                    best_tpr = recall
                    
                if f1 > best_f1:
                    best_f1 = f1
                    best_sample = sample
            best_sample_list.append(best_sample.to(torch.int16))
        f1_scores.append(best_f1)
        tprs.append(best_tpr)
        best_samples.append(best_sample_list)
    return torch.mean(torch.tensor(f1_scores)), torch.mean(torch.tensor(tprs)), best_samples

In [None]:
gt_futs, gt_hists, best_ade_samples, ade, best_fde_samples, fde = best_ade_and_fde(res['samples_raw'], res['ground_truth_fut'], res['ground_truth_hist'], edge_coordinates, edges)
print("ADE", format(ade, '12f'))
print("FDE", format(fde, '12f'))

In [None]:
f1, tpr, best_samples = f1_and_tpr(res['samples_raw'], res['ground_truth_fut'], edge_coordinates)
print("Best F1 Score: ", f1.item())
print("Best TPR: ", tpr.item())
total_len = 0
total = 0
for i in range(len(best_samples)):
    for j in range(len(best_samples[i])):
        total_len += len(best_samples[i][j])
        total += 1
print("Average sample length: ", round(total_len / total, 3))

## Plot Single Sample

In [None]:
def plot_paths(ground_truth_hist, ground_truth_fut, samples, edge_coordinates, num_paths_to_plot=4, random=False, start_id=0, zoom_in=True):
    # Setup the subplot layout based on num_paths_to_plot
    fig, axs = plt.subplots(num_paths_to_plot, 1, figsize=(10, 5 * num_paths_to_plot), squeeze=False)
    axs = axs.flatten()  # Flatten in case of a single subplot to standardize indexing

    path_count = 0
    for batch_idx in range(len(ground_truth_hist)):
        # Determine the starting index
        start_idx = torch.randint(0, len(ground_truth_hist[batch_idx]) - num_paths_to_plot + 1, (1,)).item() if random else start_id
        for i in range(len(ground_truth_hist[batch_idx])):
            if path_count >= num_paths_to_plot:
                break  # Stop if we have plotted the desired number of paths

            ax = axs[path_count]
            idx = start_idx + i
            if idx >= len(ground_truth_hist[batch_idx]):
                continue  # Prevent indexing beyond the number of samples

            # Configure zooming in on relevant trajectories
            if zoom_in:
                all_edges = torch.cat([ground_truth_hist[batch_idx][idx], ground_truth_fut[batch_idx][idx], samples[batch_idx][idx]])
                if all_edges.numel() > 0:
                    all_coords = edge_coordinates[all_edges].view(-1, 2)
                    xmin, xmax = all_coords.min(0)[0], all_coords.max(0)[0]
                    margin = (xmax - xmin) * 0.1
                    xmin -= margin
                    xmax += margin
                    ax.set_xlim(xmin[0].item(), xmax[0].item())
                    ax.set_ylim(xmin[1].item(), xmax[1].item())

            # Plot all edges as background
            for edge in edge_coordinates:
                ax.plot(edge[:, 0], edge[:, 1], color='grey', linewidth=0.5)

            # Define the plot_trajectory within this function to ensure it uses 'ax'
            def plot_trajectory(edge_indices, color, label):
                added_label = False  # Flag to add label only once
                if edge_indices.dim() == 2:
                    edge_indices = edge_indices.squeeze(0)
                if edge_indices.numel() > 0:
                    for edge_idx in edge_indices:
                        edge = edge_coordinates[edge_idx]
                        line_style = ':' if label == 'Predicted Future' else '-'  # Use dotted lines for 'Predicted Future'
                        lw = 4 if label == 'Predicted Future' else 2  # Use thicker lines for 'Predicted Future'
                        if not added_label:
                            ax.plot(edge[:, 0], edge[:, 1], color=color, linewidth=lw, linestyle=line_style, label=label)
                            added_label = True
                        else:
                            ax.plot(edge[:, 0], edge[:, 1], color=color, linewidth=lw, linestyle=line_style)

            plot_trajectory(ground_truth_hist[batch_idx][idx], 'blue', 'History')
            plot_trajectory(ground_truth_fut[batch_idx][idx], 'green', 'Ground Truth Future')
            plot_trajectory(samples[batch_idx][idx], 'red', 'Predicted Future')

            ax.set_title(f'Trajectory {idx+1} of batch {batch_idx+1}')
            ax.legend(loc='upper left', fontsize=16)
            ax.axis('off')

            path_count += 1  # Increment the path counter

    plt.tight_layout()
    plt.show()

plot_paths(res['ground_truth_hist'], res['ground_truth_fut'], best_samples, edge_coordinates, num_paths_to_plot=70, zoom_in=True)

## Density Plot of Multiple Samples

In [None]:
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import torch
import matplotlib.lines as mlines

def plot_path_density(ground_truth_hist, ground_truth_fut, samples, edge_coordinates, num_paths_to_plot=4, random=False, start_id=0, zoom_in=True):
    # Setup the subplot layout based on num_paths_to_plot
    fig, axs = plt.subplots(num_paths_to_plot, 1, figsize=(10, 5 * num_paths_to_plot), squeeze=False)
    axs = axs.flatten()  # Flatten in case of a single subplot to standardize indexing

    path_count = 0
    for batch_idx in range(len(ground_truth_hist)):
        # Determine the starting index
        start_idx = torch.randint(0, len(ground_truth_hist[batch_idx]) - num_paths_to_plot + 1, (1,)).item() if random else start_id
        for i in range(len(ground_truth_hist[batch_idx])):
            if path_count >= num_paths_to_plot:
                break  # Stop if we have plotted the desired number of paths

            ax = axs[path_count]
            idx = start_idx + i
            if idx >= len(ground_truth_hist[batch_idx]):
                continue  # Prevent indexing beyond the number of samples

            # Configure zooming in on relevant trajectories
            all_samples_edges = torch.cat([sample[idx] for sample in samples[batch_idx]])
            all_edges = torch.cat([ground_truth_hist[batch_idx][idx], ground_truth_fut[batch_idx][idx], all_samples_edges])
            if zoom_in:
                # all_edges = torch.cat([ground_truth_hist[batch_idx][idx], ground_truth_fut[batch_idx][idx]] + [sample[idx] for sample in samples[batch_idx]])
                if all_edges.numel() > 0:
                    all_coords = edge_coordinates[all_edges].view(-1, 2)
                    xmin, xmax = all_coords.min(0)[0], all_coords.max(0)[0]
                    margin = (xmax - xmin) * 0.1
                    xmin -= margin
                    xmax += margin
                    ax.set_xlim(xmin[0].item(), xmax[0].item())
                    ax.set_ylim(xmin[1].item(), xmax[1].item())

            # Plot background edges
            for edge in edge_coordinates:
                ax.plot(edge[:, 0], edge[:, 1], color='grey', linewidth=0.5)
                
            # Track how often each edge is used in samples
            edge_count = {}
            for sample in samples[batch_idx]:
                for edge_idx in sample[idx]:
                    edge_count[edge_idx.item()] = edge_count.get(edge_idx.item(), 0) + 1

            max_count = max(edge_count.values()) if edge_count else 1
                
                
            # Define the plot_trajectory within this function to ensure it uses 'ax'
            def plot_trajectory(edge_indices, color, label):
                if isinstance(edge_indices, dict):
                    edge_indices = edge_indices.items()
                    added_label = False  # Flag to add label only once
                    for edge_idx, count in edge_indices:
                        edge = edge_coordinates[edge_idx]
                        ax.plot(edge[:, 0], edge[:, 1], color='red', linewidth=2 + 3 * (count / max_count), alpha=0.3 + 0.7 * (count / max_count), linestyle='--')
                else:
                    added_label = False  # Flag to add label only once
                    if edge_indices.dim() == 2:
                        edge_indices = edge_indices.squeeze(0)
                    if edge_indices.numel() > 0:
                        for edge_idx in edge_indices:
                            edge = edge_coordinates[edge_idx]
                            line_style = '-'
                            lw = 2
                            ax.plot(edge[:, 0], edge[:, 1], color=color, linewidth=lw, linestyle=line_style, label=label)

            plot_trajectory(ground_truth_hist[batch_idx][idx], 'blue', 'History')
            plot_trajectory(ground_truth_fut[batch_idx][idx], 'green', 'Ground Truth Future')
            plot_trajectory(edge_count, 'red', 'Predicted Future')
            
            custom_handle1 = mlines.Line2D([], [], color='blue', linestyle='-', markersize=15, label='History')
            custom_handle2 = mlines.Line2D([], [], color='green', linestyle='-', markersize=15, label='Ground Truth Future')
            custom_handle3 = mlines.Line2D([], [], color='red', linestyle='--', markersize=15, label='Predicted Future')


            ax.set_title(f'Trajectory {idx+1} of batch {batch_idx+1}')
            ax.legend(handles=[custom_handle1, custom_handle2, custom_handle3], loc='upper left', fontsize=12)
            ax.axis('off')

            path_count += 1  # Increment the path counter

    plt.tight_layout()
    plt.show()

plot_path_density(res['ground_truth_hist'], res['ground_truth_fut'], samples_raw, edge_coordinates, num_paths_to_plot=50, zoom_in=True)

## Plot of individual Multiple Samples

In [None]:
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import torch

def plot_multiple_paths(ground_truth_hist, ground_truth_fut, samples, edge_coordinates, num_paths_to_plot=4, random=False, start_id=0, zoom_in=True):
    # Setup the subplot layout based on num_paths_to_plot
    fig, axs = plt.subplots(num_paths_to_plot, 1, figsize=(10, 5 * num_paths_to_plot), squeeze=False)
    axs = axs.flatten()  # Flatten in case of a single subplot to standardize indexing

    path_count = 0
    for batch_idx in range(len(ground_truth_hist)):
        # Determine the starting index
        start_idx = torch.randint(0, len(ground_truth_hist[batch_idx]) - num_paths_to_plot + 1, (1,)).item() if random else start_id
        for i in range(len(ground_truth_hist[batch_idx])):
            if path_count >= num_paths_to_plot:
                break  # Stop if we have plotted the desired number of paths

            ax = axs[path_count]
            idx = start_idx + i
            if idx >= len(ground_truth_hist[batch_idx]):
                continue  # Prevent indexing beyond the number of samples

            # Configure zooming in on relevant trajectories
            if zoom_in:
                all_edges = torch.cat([ground_truth_hist[batch_idx][idx], ground_truth_fut[batch_idx][idx]] + [sample[idx] for sample in samples[batch_idx]])
                if all_edges.numel() > 0:
                    all_coords = edge_coordinates[all_edges].view(-1, 2)
                    xmin, xmax = all_coords.min(0)[0], all_coords.max(0)[0]
                    margin = (xmax - xmin) * 0.1
                    xmin -= margin
                    xmax += margin
                    ax.set_xlim(xmin[0].item(), xmax[0].item())
                    ax.set_ylim(xmin[1].item(), xmax[1].item())

            # Plot all edges as background
            for edge in edge_coordinates:
                ax.plot(edge[:, 0], edge[:, 1], color='grey', linewidth=0.5)

            # Define the plot_trajectory within this function to ensure it uses 'ax'
            def plot_trajectory(edge_indices, color, label):
                added_label = False  # Flag to add label only once
                if edge_indices.dim() == 2:
                    edge_indices = edge_indices.squeeze(0)
                if edge_indices.numel() > 0:
                    for edge_idx in edge_indices:
                        edge = edge_coordinates[edge_idx]
                        line_style = '-' if 'Predicted' in label else '-'  # Use dotted lines for 'Predicted Future'
                        lw = 4 if 'Predicted' in label else 2
                        alpha = 0.5 if 'Predicted' in label else 1
                        if not added_label:
                            ax.plot(edge[:, 0], edge[:, 1], color=color, linewidth=lw, linestyle=line_style, label=label)
                            added_label = True
                        else:
                            ax.plot(edge[:, 0], edge[:, 1], color=color, linewidth=lw, linestyle=line_style)

            plot_trajectory(ground_truth_hist[batch_idx][idx], 'blue', 'History')
            plot_trajectory(ground_truth_fut[batch_idx][idx], 'green', 'Ground Truth Future')
            
            # Plot all samples for this data point
            sample_colors = cm.get_cmap('tab10', len(samples[batch_idx]))
            for j, sample in enumerate(samples[batch_idx]):
                plot_trajectory(sample[idx], sample_colors(j), f'Predicted Future {j+1}')

            ax.set_title(f'Trajectory {idx+1} of batch {batch_idx+1}')
            ax.legend(loc='upper left')
            ax.axis('off')

            path_count += 1  # Increment the path counter

    plt.tight_layout()
    plt.show()

plot_multiple_paths(res['ground_truth_hist'], res['ground_truth_fut'], samples_raw, edge_coordinates, num_paths_to_plot=50, zoom_in=True)

In [None]:
def plot_paths_random(ground_truth_hist, ground_truth_fut, samples, edge_coordinates, num_paths_to_plot=4, zoom_in=True):
    # Setup the subplot layout based on num_paths_to_plot
    fig, axs = plt.subplots(num_paths_to_plot, 1, figsize=(10, 5 * num_paths_to_plot), squeeze=False)
    axs = axs.flatten()  # Flatten in case of a single subplot to standardize indexing

    path_count = 0
    while path_count < num_paths_to_plot:
        batch_idx = torch.randint(0, len(ground_truth_hist), (1,)).item()
        idx = torch.randint(0, len(ground_truth_hist[batch_idx]), (1,)).item()
        
        ax = axs[path_count]
        if idx >= len(ground_truth_hist[batch_idx]):
            continue
        
        if zoom_in:
            all_edges = torch.cat([ground_truth_hist[batch_idx][idx], ground_truth_fut[batch_idx][idx], samples[batch_idx][idx]])
            if all_edges.numel() > 0:
                all_coords = edge_coordinates[all_edges].view(-1, 2)
                xmin, xmax = all_coords.min(0)[0], all_coords.max(0)[0]
                margin = (xmax - xmin) * 0.1
                xmin -= margin
                xmax += margin
                ax.set_xlim(xmin[0].item(), xmax[0].item())
                ax.set_ylim(xmin[1].item(), xmax[1].item())
        # Plot all edges as background
        for edge in edge_coordinates:
            ax.plot(edge[:, 0], edge[:, 1], color='grey', linewidth=0.5)
        
        # Define the plot_trajectory within this function to ensure it uses 'ax'
        def plot_trajectory(edge_indices, color, label):
            added_label = False  # Flag to add label only once
            if edge_indices.dim() == 2:
                edge_indices = edge_indices.squeeze(0)
            if edge_indices.numel() > 0:
                for edge_idx in edge_indices:
                    edge = edge_coordinates[edge_idx]
                    line_style = ':' if label == 'Predicted Future' else '-'  # Use dotted lines for 'Predicted Future'
                    lw = 4 if label == 'Predicted Future' else 2  # Use thicker lines for 'Predicted Future'
                    if not added_label:
                        ax.plot(edge[:, 0], edge[:, 1], color=color, linewidth=lw, linestyle=line_style, label=label)
                        added_label = True
                    else:
                        ax.plot(edge[:, 0], edge[:, 1], color=color, linewidth=lw, linestyle=line_style)

        plot_trajectory(ground_truth_hist[batch_idx][idx], 'blue', 'History')
        plot_trajectory(ground_truth_fut[batch_idx][idx], 'green', 'Ground Truth Future')
        plot_trajectory(samples[batch_idx][idx], 'red', 'Predicted Future')

        ax.set_title(f'Trajectory {idx+1} of batch {batch_idx+1}')
        ax.legend(loc='upper left', fontsize=16)
        ax.axis('off')

        path_count += 1  # Increment the path counter
    plt.tight_layout()
    plt.show()

plot_paths_random(res['ground_truth_hist'], res['ground_truth_fut'], best_samples, edge_coordinates, num_paths_to_plot=10, zoom_in=True)

## Plot Path Density Random

In [None]:
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import torch
import matplotlib.lines as mlines

def plot_path_density_random(ground_truth_hist, ground_truth_fut, samples, edge_coordinates, num_paths_to_plot=4, random=False, start_id=0, zoom_in=True):
    # Setup the subplot layout based on num_paths_to_plot
    fig, axs = plt.subplots(num_paths_to_plot, 1, figsize=(10, 5 * num_paths_to_plot), squeeze=False)
    axs = axs.flatten()  # Flatten in case of a single subplot to standardize indexing

    path_count = 0
    path_count = 0
    while path_count < num_paths_to_plot:
        batch_idx = torch.randint(0, len(ground_truth_hist), (1,)).item()
        idx = torch.randint(0, len(ground_truth_hist[batch_idx]), (1,)).item()
        
        ax = axs[path_count]
        if idx >= len(ground_truth_hist[batch_idx]):
            continue

        # Configure zooming in on relevant trajectories
        all_samples_edges = torch.cat([sample[idx] for sample in samples[batch_idx]])
        all_edges = torch.cat([ground_truth_hist[batch_idx][idx], ground_truth_fut[batch_idx][idx], all_samples_edges])
        if zoom_in:
            # all_edges = torch.cat([ground_truth_hist[batch_idx][idx], ground_truth_fut[batch_idx][idx]] + [sample[idx] for sample in samples[batch_idx]])
            if all_edges.numel() > 0:
                all_coords = edge_coordinates[all_edges].view(-1, 2)
                xmin, xmax = all_coords.min(0)[0], all_coords.max(0)[0]
                margin = (xmax - xmin) * 0.1
                xmin -= margin
                xmax += margin
                ax.set_xlim(xmin[0].item(), xmax[0].item())
                ax.set_ylim(xmin[1].item(), xmax[1].item())

        # Plot background edges
        for edge in edge_coordinates:
            ax.plot(edge[:, 0], edge[:, 1], color='grey', linewidth=0.5)
            
        # Track how often each edge is used in samples
        edge_count = {}
        for sample in samples[batch_idx]:
            for edge_idx in sample[idx]:
                edge_count[edge_idx.item()] = edge_count.get(edge_idx.item(), 0) + 1

        max_count = max(edge_count.values()) if edge_count else 1
                
                
        # Define the plot_trajectory within this function to ensure it uses 'ax'
        def plot_trajectory(edge_indices, color, label):
            if isinstance(edge_indices, dict):
                edge_indices = edge_indices.items()
                added_label = False  # Flag to add label only once
                for edge_idx, count in edge_indices:
                    edge = edge_coordinates[edge_idx]
                    ax.plot(edge[:, 0], edge[:, 1], color='red', linewidth=2 + 3 * (count / max_count), alpha=0.3 + 0.7 * (count / max_count), linestyle='--')
            else:
                added_label = False  # Flag to add label only once
                if edge_indices.dim() == 2:
                    edge_indices = edge_indices.squeeze(0)
                if edge_indices.numel() > 0:
                    for edge_idx in edge_indices:
                        edge = edge_coordinates[edge_idx]
                        line_style = '-'
                        lw = 2
                        ax.plot(edge[:, 0], edge[:, 1], color=color, linewidth=lw, linestyle=line_style, label=label)

        plot_trajectory(ground_truth_hist[batch_idx][idx], 'blue', 'History')
        plot_trajectory(ground_truth_fut[batch_idx][idx], 'green', 'Ground Truth Future')
        plot_trajectory(edge_count, 'red', 'Predicted Future')
        
        custom_handle1 = mlines.Line2D([], [], color='blue', linestyle='-', markersize=15, label='History')
        custom_handle2 = mlines.Line2D([], [], color='green', linestyle='-', markersize=15, label='Ground Truth Future')
        custom_handle3 = mlines.Line2D([], [], color='red', linestyle='--', markersize=15, label='Predicted Future')


        ax.set_title(f'Trajectory {idx+1} of batch {batch_idx+1}')
        ax.legend(handles=[custom_handle1, custom_handle2, custom_handle3], loc='upper left', fontsize=12)
        ax.axis('off')

        path_count += 1  # Increment the path counter

    plt.tight_layout()
    plt.show()

plot_path_density_random(res['ground_truth_hist'], res['ground_truth_fut'], samples_raw, edge_coordinates, num_paths_to_plot=10, zoom_in=True)