This notebook visualizes the following graph sequencing strategy: For each price tag with 1+ product connections, add all connected products one at a time in order of the Euclidean distance between their centroids.

For the sake of visual clarity, edges between product nodes with the same UPC are not shown (we only visualize prod-price edges).

In [1]:
from typing import Iterator
import ipywidgets as widgets
from IPython.display import display, clear_output
from pathlib import Path
from tempfile import TemporaryDirectory
from tqdm.notebook import tqdm

import imageio
import matplotlib.pyplot as plt
import pandas as pd
import torch

from gpa.datasets.attribution import PriceAttributionDataset, DetectionGraph

In [2]:
while True:
    dataset_dir = Path(input("Input the dataset directory: "))
    if dataset_dir.exists() and dataset_dir.is_dir():
        break
    print("Invalid dataset directory. Please try again.")

dataset = PriceAttributionDataset(root=dataset_dir)
products_df = pd.read_csv(
    dataset_dir / "raw" / "product_boxes.csv", index_col="attributionset_id"
)

In [3]:
def order_prod_nodes_per_price(graph: DetectionGraph) -> dict[int, list[int]]:
    """For each price tag node in the graph, get an ordering of its associated product nodes based on their distance to the price tag node.

    Args:
        graph (DetectionGraph): The graph to process.

    Returns:
        dict[int, list[int]]: A mapping from price tag node indices to lists of product node indices, ordered by their distance to the price tag node.
    """
    gt_edges = set(map(tuple, graph.gt_prod_price_edge_index.T.tolist()))
    price_prod_dists = torch.cdist(
        graph.x[graph.price_indices, :2], graph.x[graph.product_indices, :2], p=2
    )
    prod_indices_sorted_by_dist = graph.product_indices[
        torch.argsort(price_prod_dists, dim=-1)
    ]
    node_sequences = {}
    for i in range(len(graph.price_indices)):
        price_idx = graph.price_indices[i].item()
        node_sequences[price_idx] = [
            prod_idx
            for prod_idx in prod_indices_sorted_by_dist[i].tolist()
            if (price_idx, prod_idx) in gt_edges or (prod_idx, price_idx) in gt_edges
        ]
    return node_sequences


def graph_to_seq(graph: DetectionGraph) -> Iterator[DetectionGraph]:
    """Convert a graph into a sequence of subgraphs by incrementally adding edges
    between price tags and their associated products.

    Args:
        graph (DetectionGraph): Original detection graph.

    Yields:
        Iterator[DetectionGraph]: Sequence of subgraphs with progressively more edges.
    """
    prod_sequences = order_prod_nodes_per_price(graph)
    max_seq_len = max(map(len, prod_sequences.values()))
    curr_graph_edge_index = graph.shared_upc_edge_index

    for t in range(max_seq_len + 1):
        if t > 0:  # At t=0, we just take the curr_graph_edge_index as is.
            new_edges = []
            for price_idx, prod_seq in prod_sequences.items():
                if t - 1 < len(prod_seq):
                    prod_idx = prod_seq[t - 1]
                    new_edges.extend([(price_idx, prod_idx), (prod_idx, prod_idx)])
            if new_edges:
                new_edges = torch.tensor(new_edges).T
                curr_graph_edge_index = torch.cat(
                    [curr_graph_edge_index, new_edges], dim=1
                )

        yield DetectionGraph(
            x=graph.x,
            edge_index=curr_graph_edge_index.clone(),
            edge_attr=torch.ones(curr_graph_edge_index.shape[1]),
            graph_id=graph.graph_id,
            bbox_ids=graph.bbox_ids,
            product_indices=graph.product_indices,
            price_indices=graph.price_indices,
            shared_upc_edge_index=graph.shared_upc_edge_index,
            gt_prod_price_edge_index=graph.gt_prod_price_edge_index,
        )

In [4]:
def interactive_graph_viewer(graph_sequence: Iterator[DetectionGraph]):
    class Viewer:
        def __init__(self):
            self.graphs: list[DetectionGraph] = []
            self.idx = 0
            self.output = widgets.Output()
            self.back_button = widgets.Button(description="<")
            self.forward_button = widgets.Button(description=">")

            self.back_button.on_click(self.on_back)
            self.forward_button.on_click(self.on_forward)

            self.controls = widgets.HBox([self.back_button, self.forward_button])
            self.load_graph(0)
            self.show_graph()
            display(widgets.VBox([self.controls, self.output]))

        def load_graph(self, target_idx: int):
            while target_idx >= len(self.graphs):
                try:
                    self.graphs.append(next(graph_sequence))
                except StopIteration:
                    break

        def show_graph(self):
            with self.output:
                clear_output(wait=True)
                graph = self.graphs[self.idx]
                image_path = (
                    dataset_dir.parent
                    / products_df.loc[graph.graph_id]["local_path"].values[0]
                )
                image = plt.imread(image_path)
                fig, axs = plt.subplots(1, 2, figsize=(10, 4))
                graph.plot(ax=axs[0], prod_price_only=True)
                axs[0].set_title(graph.graph_id, fontsize=10)
                axs[1].imshow(image)
                axs[1].axis("off")
                fig.tight_layout()
                fig.set_dpi(100)
                plt.show()

        def on_forward(self, _):
            self.load_graph(self.idx + 1)
            if self.idx + 1 < len(self.graphs):
                self.idx += 1
                self.show_graph()

        def on_back(self, _):
            if self.idx > 0:
                self.idx -= 1
                self.show_graph()

    Viewer()


def interactive_graph_selector():
    index_slider = widgets.IntSlider(
        value=0,
        min=0,
        max=len(dataset) - 1,
        step=1,
        description="Graph Index:",
        continuous_update=False,
    )
    output = widgets.Output()

    def on_index_change(change):
        if change["name"] == "value":  # only respond to actual value changes
            output.clear_output()
            with output:
                graph_seq = graph_to_seq(dataset[change["new"]])
                interactive_graph_viewer(graph_seq)

    index_slider.observe(on_index_change, names="value")
    on_index_change({"name": "value", "new": index_slider.value})

    display(widgets.VBox([index_slider, output]))

In [None]:
interactive_graph_selector()

In [6]:
def generate_graph_sequence_video(
    graph_sequence: Iterator[DetectionGraph], output_path: str, fps: int = 2
):
    with TemporaryDirectory() as tmpdir:
        tmpdir = Path(tmpdir)
        with imageio.get_writer(output_path, fps=fps) as writer:
            for i, graph in enumerate(graph_sequence):
                image_path = (
                    dataset_dir.parent
                    / products_df.loc[graph.graph_id]["local_path"].values[0]
                )
                image = plt.imread(image_path)

                fig, axs = plt.subplots(1, 2, figsize=(9, 4))
                graph.plot(ax=axs[0], prod_price_only=True)
                axs[0].set_title(graph.graph_id, fontsize=10)
                axs[1].imshow(image)
                axs[1].axis("off")
                fig.tight_layout()
                fig.set_dpi(100)

                frame_file = tmpdir / f"frame_{i:04d}.png"
                fig.savefig(frame_file)
                plt.close(fig)
                writer.append_data(imageio.v2.imread(frame_file))

    print(f"Video saved to {output_path}")

In [None]:
indices_of_interest = [0, 1, 4, 12, 15]
for idx in tqdm(indices_of_interest):
    generate_graph_sequence_video(
        graph_to_seq(dataset[idx]), f"euclidean_{idx}.mp4", fps=10
    )