In [2]:
import matplotlib.cm as cm
import matplotlib.pyplot as plt
import numpy as np
import torch

def plot_partial_denoising(ground_truth_hist, ground_truth_fut, samples, edge_coordinates,
                      num_paths_to_plot=2, zoom_in=True, save_path='trajectory_grid.png'):
    """
    Plots trajectories with a column for each trajectory displaying 10 samples.

    Parameters:
    - ground_truth_hist: List of lists of ground truth history edge indices.
    - ground_truth_fut: List of lists of ground truth future edge indices.
    - samples: List of lists of lists containing 10 samples of predicted trajectories.
    - edge_coordinates: Tensor of edge coordinates.
    - num_paths_to_plot: Number of trajectories to plot.
    - zoom_in: If True, zooms into the trajectory area.
    - save_path: File path to save the figure.
    """

    # Set the plotting style and parameters
    plt.style.use('seaborn-v0_8-paper')
    plt.rcParams.update({
        'axes.labelsize': 14,
        'axes.titlesize': 14,
        'xtick.labelsize': 12,
        'ytick.labelsize': 12,
        'legend.fontsize': 11,
        'axes.grid': False,
        'grid.alpha': 0.4,
        'lines.linewidth': 2,
    })

    # Calculate the number of subplots needed
    num_cols = 2  # All trajectories in one column
    num_rows = 6  # 10 samples per trajectory

    # Setup the subplot layout
    fig, axs = plt.subplots(num_rows, num_cols, figsize=(8, 2.3 * num_rows))

    # Adjust layout to make room for the legend
    plt.subplots_adjust(top=0.96, bottom=0.0, wspace=0.2, hspace=0.05)

    # Plot each sample for each trajectory
    for path_idx in range(num_paths_to_plot):
        for i, sample_idx in enumerate(range(0, 11, 2)):
            ax = axs[i, path_idx]
            ax.cla()

            # Background edges
            for edge in edge_coordinates:
                ax.plot(edge[:, 0], edge[:, 1], color='lightgrey', linewidth=0.5, zorder=1)

            # Define the plot_trajectory function
            def plot_trajectory(ax, edge_indices, color, label, linestyle='-', lw=2, zorder=2):
                edge_indices = edge_indices[edge_indices >= 0]
                added_label = False  # Ensure label is added only once
                if edge_indices.dim() == 2:
                    edge_indices = edge_indices.squeeze(0)
                if edge_indices.numel() > 0:
                    for edge_idx in edge_indices:
                        edge = edge_coordinates[edge_idx]
                        if not added_label:
                            ax.plot(edge[:, 0], edge[:, 1], color=color, linewidth=lw, linestyle=linestyle,
                                    label=label, zorder=zorder)
                            added_label = True
                        else:
                            ax.plot(edge[:, 0], edge[:, 1], color=color, linewidth=lw, linestyle=linestyle,
                                    zorder=zorder)


            # Plot the history and future
            history = torch.tensor(ground_truth_hist[path_idx])
            future = torch.tensor(ground_truth_fut[path_idx])
            plot_trajectory(ax, history, color='blue', label='History', lw=2)
            plot_trajectory(ax, future, color='green', label='Future', lw=2)
            # Plot the samples
            sample_trajectory = torch.tensor(samples[path_idx][sample_idx])
            plot_trajectory(ax, sample_trajectory, color='orange', label='Sampled Edges', linestyle='-', lw=3)

            # Zoom into the trajectory area
            if zoom_in:
                all_edges = torch.cat([history, future, sample_trajectory])
                all_edges = all_edges[all_edges >= 0]
                if all_edges.numel() > 0:
                    all_coords = edge_coordinates[all_edges].view(-1, 2)
                    xmin, xmax = all_coords[:, 0].min(), all_coords[:, 0].max()
                    ymin, ymax = all_coords[:, 1].min(), all_coords[:, 1].max()
                    x_margin = (xmax - xmin) * 0.2
                    y_margin = (ymax - ymin) * 0.2
                    ax.set_xlim(xmin - x_margin, xmax + x_margin)
                    ax.set_ylim(ymin - y_margin, ymax + y_margin)
                else:
                    ax.set_xlim(edge_coordinates[:, :, 0].min(), edge_coordinates[:, :, 0].max())
                    ax.set_ylim(edge_coordinates[:, :, 1].min(), edge_coordinates[:, :, 1].max())

            # Remove axes ticks and labels
            ax.set_xticks([])
            ax.set_yticks([])
            ax.axis('off')
            
            t = 100 - 10 * sample_idx
            ax.annotate(f"t = {t}", xy=(0.05, 0.95), xycoords='axes fraction', fontsize=12, color='black',
                        verticalalignment='top', bbox=dict(boxstyle="round,pad=0.3", edgecolor='gray', facecolor='white', alpha=0.5))

        
        ax_grid = fig.add_axes([0, 0, 1, 1], frameon=False)
        ax_grid.set_xticks([])
        ax_grid.set_yticks([])
        ax_grid.set_xlim(0, 1)
        ax_grid.set_ylim(0, 1)
        
        # Collect positions for vertical lines (columns)
        x_positions = []
        for col in range(1, num_cols):
            pos = axs[0, col - 1].get_position()
            x_sep = pos.x1
            x_positions.append(x_sep)

        # Collect positions for horizontal lines (rows)
        y_positions = []
        for row in range(1, num_rows):
            pos = axs[row - 1, 0].get_position()
            y_sep = pos.y0
            y_positions.append(y_sep)

        # Draw vertical grid lines
        for x in x_positions:
            ax_grid.axvline(x=x+0.035, color='lightgrey', linewidth=1)

        # Draw horizontal grid lines
        for y in y_positions:
            ax_grid.axhline(y=y, color='lightgrey', linewidth=1)

    # Draw grid lines and add legends
    handles, labels = ax.get_legend_handles_labels()
    unique = dict(zip(labels, handles))
    fig.legend(unique.values(), unique.keys(), loc='upper center', ncol=3, bbox_to_anchor=(0.5, 1.02),
               frameon=True, fontsize=11)

    # Save and show the figure
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.show()

#ground_truth_hist = [hist_1, hist_1]
#ground_truth_fut = [fut_1, fut_1]
#samples = [samples_1_uniform, samples_1_prior]
#plot_partial_denoising(ground_truth_hist, ground_truth_fut, samples, edge_coordinates,
#                      num_paths_to_plot=2, zoom_in=False, save_path='partial_denoising_comb.png')