This notebook visualizes the following graph subsampling strategy: For each input graph, we randomly sample a subset of the (product-price) edges and return the subgraph induced by those edges.

For the sake of visual clarity, edges between product nodes with the same UPC are not shown (we only visualize prod-price edges).

In [58]:
import ipywidgets as widgets
from IPython.display import display, clear_output
from pathlib import Path
from tqdm.notebook import tqdm

import matplotlib.pyplot as plt
import pandas as pd
import torch

from gpa.datasets.attribution import PriceAttributionDataset, DetectionGraph

In [43]:
while True:
    dataset_dir = Path(input("Input the dataset directory: "))
    if dataset_dir.exists() and dataset_dir.is_dir():
        break
    print("Invalid dataset directory. Please try again.")

dataset = PriceAttributionDataset(root=dataset_dir)
products_df = pd.read_csv(
    dataset_dir / "raw" / "product_boxes.csv", index_col="attributionset_id"
)

In [44]:
def get_random_subgraph(graph: DetectionGraph, p: float = 0.5) -> DetectionGraph:
    """Randomly sample a subgraph from the given graph.

    Subgraphs are sampled by applying independent Bernoulli(p) random variables to each edge in the known product-price graph.

    Args:
        graph (DetectionGraph): Original detection graph.
        p (float): Probability of keeping any given edge.

    Returns:
        DetectionGraph: Subgraph sampled from the original graph.
    """
    unique_edges = torch.tensor(
        list(
            set(
                map(
                    lambda x: tuple(sorted(x)),
                    graph.gt_prod_price_edge_index.T.tolist(),
                )
            )
        )
    ).T
    if unique_edges.numel() == 0:
        return graph
    keep_edges = torch.rand(unique_edges.shape[1]) < p
    unique_edges = unique_edges[:, keep_edges]
    sampled_edge_index = torch.cat([unique_edges, unique_edges.flip(0)], dim=1)

    return DetectionGraph(
        x=graph.x,
        edge_index=sampled_edge_index,
        edge_attr=torch.ones(sampled_edge_index.shape[1]),
        graph_id=graph.graph_id,
        bbox_ids=graph.bbox_ids,
        product_indices=graph.product_indices,
        price_indices=graph.price_indices,
        shared_upc_edge_index=graph.shared_upc_edge_index,
        gt_prod_price_edge_index=graph.gt_prod_price_edge_index,
    )

In [45]:
idx_selector = widgets.IntSlider(
    value=0, min=0, max=len(dataset) - 1, description="Graph Index:"
)
p_selector = widgets.FloatSlider(
    value=0.5, min=0, max=1.0, step=0.01, description="Sampling p"
)
sample_button = widgets.Button(description="Sample Subgraph")
output = widgets.Output()


def on_sample_button_clicked(b):
    with output:
        plt.close()
        clear_output(wait=True)

        fig, axs = plt.subplots(1, 3, figsize=(12, 8))

        graph = dataset[idx_selector.value]
        p = p_selector.value  # Get current sampling probability
        subgraph = get_random_subgraph(graph, p=p)

        graph.edge_index = graph.gt_prod_price_edge_index
        graph.edge_attr = torch.ones(graph.gt_prod_price_edge_index.shape[1])
        graph.plot(ax=axs[0], prod_price_only=True)
        subgraph.plot(ax=axs[1], prod_price_only=True)

        image_path = (
            dataset_dir.parent / products_df.loc[graph.graph_id]["local_path"].values[0]
        )
        image = plt.imread(image_path)
        axs[2].imshow(image)
        axs[2].set_title("Image")
        axs[2].axis("off")

        axs[0].set_title("Full Graph")
        axs[1].set_title("Subgraph")

        fig.tight_layout()
        plt.show()


def on_index_change(change):
    if change["name"] == "value":
        output.clear_output()


idx_selector.observe(on_index_change, names="value")
sample_button.on_click(on_sample_button_clicked)

display(widgets.VBox([idx_selector, p_selector, sample_button, output]))

VBox(children=(IntSlider(value=0, description='Graph Index:', max=188), FloatSlider(value=0.5, description='Sa…

In [57]:
from tempfile import TemporaryDirectory
import imageio


def generate_sampling_video(graph_idx: int, output_path: str | Path):
    p_vals = torch.linspace(0.1, 0.9, 9)
    fps = 2
    with imageio.get_writer(output_path, fps=fps) as writer:
        with TemporaryDirectory() as tmpdir:
            tmpdir = Path(tmpdir)
            i = 0
            for p in p_vals:
                for _ in range(5):
                    graph = dataset[graph_idx]
                    subgraph = get_random_subgraph(graph, p=p)
                    graph.edge_index = graph.gt_prod_price_edge_index
                    graph.edge_attr = torch.ones(
                        graph.gt_prod_price_edge_index.shape[1]
                    )

                    fig, axs = plt.subplots(1, 3, figsize=(12, 4))
                    graph.plot(ax=axs[0], prod_price_only=True)
                    subgraph.plot(ax=axs[1], prod_price_only=True)
                    image_path = (
                        dataset_dir.parent
                        / products_df.loc[graph.graph_id]["local_path"].values[0]
                    )
                    image = plt.imread(image_path)
                    axs[0].set_title("Graph")
                    axs[1].set_title(f"Subgraph (p={p:.1f})")
                    axs[2].imshow(image)
                    axs[2].axis("off")
                    fname = tmpdir / f"frame_{i:04d}.png"
                    fig.savefig(fname)
                    plt.close(fig)
                    i += 1
                    writer.append_data(imageio.v2.imread(fname))

In [59]:
indices_of_interest = [0, 1, 4, 12, 15]
for idx in tqdm(indices_of_interest):
    generate_sampling_video(idx, f"random_subgraph_sampling_{idx}.mp4")

  0%|          | 0/5 [00:00<?, ?it/s]