This notebook is an interactive application that allows you to view/explore a price attribution dataset. Once you specify the directory where the dataset is located, you can view each graph in the dataset (annotated with the ground truth product-price connections).

Two views are displayed:

1. A graph view, where each node is a product or price tag detection, and the depicted edges connect product and price nodes that have been labeled as part of the same pricing group.
2. A visual view, where we show the display image the detections were computed from, with colors indicating which products/prices are in the same labeled pricing group. Note that in this view, any individual product / price bboxes that are not in a pricing group are colored in white. Product bboxes are plotted with solid lines, while price bboxes are plotted with dashed lines.

In [None]:
import os
import sys
from pathlib import Path

sys.path.append(str(Path(os.getcwd()).parent))

import pandas as pd
import torch
from ipywidgets import widgets
from matplotlib import pyplot as plt
from seaborn import color_palette

from gpa.common.helpers import plot_bboxes, parse_into_subgraphs
from gpa.datasets.detection_graphs import DetectionGraph
from gpa.datasets.detection_graphs import PriceAttributionDataset

In [None]:
while True:
    dataset_dir = Path(input("Input the dataset directory: "))
    if dataset_dir.exists() and dataset_dir.is_dir():
        break
    print("Invalid dataset directory. Please try again.")

In [None]:
dataset = PriceAttributionDataset(root=dataset_dir)
products_df = pd.read_csv(dataset_dir / "raw"/ "product_boxes.csv", index_col="attributionset_id")

In [None]:
def plot_scene_graph(idx: int, dataset: PriceAttributionDataset):
    plt.close()
    fig, axs = plt.subplots(1, 2, figsize=(10, 4))
    graph: DetectionGraph = dataset[idx]
    
    graph.edge_index = graph.gt_prod_price_edge_index
    graph.edge_attr = torch.ones(graph.edge_index.shape[1])
    graph.plot(ax=axs[0], prod_price_only=True, mark_wrong_edges=True)
    axs[0].set_title(graph.graph_id, fontsize=10)

    image_path = dataset_dir.parent / products_df.loc[graph.graph_id]["local_path"].values[0]
    image = plt.imread(image_path)

    height, width = image.shape[:2]

    subgraph_indices = parse_into_subgraphs(graph.edge_index, graph.num_nodes)
    subgraph_ids = torch.unique(subgraph_indices)
    colors = color_palette(n_colors=len(subgraph_ids))
    for i, color in zip(subgraph_ids, colors):
        node_indices = torch.argwhere(subgraph_indices == i).flatten()
        if len(node_indices) == 1:
            color = (1., 1., 1.)
        product_indices = node_indices[torch.isin(node_indices, graph.product_indices)]
        price_indices = node_indices[torch.isin(node_indices, graph.price_indices)]
        plot_bboxes(graph.x[product_indices, :4], ax=axs[1], color=color, linestyle="solid", width=width, height=height)
        plot_bboxes(graph.x[price_indices, :4], ax=axs[1], color=color, linestyle="dashed", width=width, height=height)
    
    axs[1].imshow(image)
    axs[1].axis("off")
    
    fig.tight_layout()
    fig.set_dpi(100)
    plt.show()

In [None]:
display_func = lambda idx: plot_scene_graph(idx=idx, dataset=dataset)
idx_slider = widgets.IntSlider(value=0, min=0, max=len(dataset)-1, step=1, description="Graph Index")
display(widgets.interact(display_func, idx=idx_slider))