In [2]:
import os
import sys
from pathlib import Path

sys.path.append(str(Path(os.getcwd()).parent))

from copy import deepcopy
import pandas as pd
import yaml
import torch
from ipywidgets import widgets
from matplotlib import pyplot as plt
from seaborn import color_palette

from gpa.common.helpers import plot_bboxes, parse_into_subgraphs
from gpa.datasets.attribution import DetectionGraph
from gpa.datasets.attribution import PriceAttributionDataset
from gpa.configs import TrainingConfig
from gpa.models.attributors import LightningPriceAttributor
from gpa.common.helpers import connect_products_with_nearest_price_tag

In [3]:
while True:
    chkp_path = Path(input("Input a path to the model checkpoint: "))
    if not chkp_path.exists() or not chkp_path.is_file():
        print("Invalid checkpoint path. Please try again.")
        continue
    break

while True:
    config_path = Path(input("Input a path to the config file: "))
    if not config_path.exists() or not config_path.is_file():
        print("Invalid config path. Please try again.")
        continue
    break

KeyboardInterrupt: Interrupted by user

In [4]:
with open(config_path, "r") as f:
    config = TrainingConfig(**yaml.safe_load(f))
dataset_dir = Path("..") / config.dataset_dir / "val"
dataset = PriceAttributionDataset(root=dataset_dir)
products_df = pd.read_csv(
    dataset_dir / "raw" / "product_boxes.csv", index_col="attributionset_id"
)
model = LightningPriceAttributor.load_from_checkpoint(chkp_path)

In [5]:
@torch.inference_mode()
def plot_scene_graph(
    idx: int,
    dataset: PriceAttributionDataset,
    model: LightningPriceAttributor,
    threshold: float = 0.5,
):
    plt.close()
    fig, axs = plt.subplots(1, 4, figsize=(15, 4), width_ratios=[1, 1, 1, 2])
    graph: DetectionGraph = dataset[idx]

    model.eval()
    model.to(device="cpu")
    graph.to(device="cpu")

    actual_graph = deepcopy(graph)
    actual_graph.edge_index = graph.gt_prod_price_edge_index
    actual_graph.edge_attr = torch.ones(len(graph.gt_prod_price_edge_index[0]), 1)

    heuristic_graph = deepcopy(graph)
    src, dst = connect_products_with_nearest_price_tag(
        centroids=graph.x[:, :2],
        product_indices=graph.product_indices,
        price_indices=graph.price_indices,
    )
    heuristic_graph.edge_index = torch.stack([src, dst], dim=0)
    heuristic_graph.edge_attr = torch.ones(len(src), 1)
    heuristic_graph.plot(ax=axs[0], prod_price_only=True, mark_wrong_edges=True)
    axs[0].set_title("Heuristic", fontsize=10)

    edge_probs = model.forward(
        x=graph.x,
        edge_index=graph.edge_index,
        edge_attr=graph.edge_attr,
        src=src,
        dst=dst,
    ).sigmoid()
    keep = edge_probs > threshold
    graph.edge_index = torch.stack([src[keep], dst[keep]], dim=0)
    graph.edge_attr = edge_probs[keep].view(-1, 1)
    graph.plot(ax=axs[1], prod_price_only=True, mark_wrong_edges=True)
    axs[1].set_title("Pruned", fontsize=10)

    actual_graph.plot(ax=axs[2], prod_price_only=True)
    axs[2].set_title("Actual", fontsize=10)

    image_path = (
        dataset_dir.parent / products_df.loc[graph.graph_id]["local_path"].values[0]
    )
    image = plt.imread(image_path)

    height, width = image.shape[:2]

    subgraph_indices = parse_into_subgraphs(
        graph.gt_prod_price_edge_index, graph.num_nodes
    )
    subgraph_ids = torch.unique(subgraph_indices)
    colors = color_palette(n_colors=len(subgraph_ids))
    for i, color in zip(subgraph_ids, colors):
        node_indices = torch.argwhere(subgraph_indices == i).flatten()
        if len(node_indices) == 1:
            color = (1.0, 1.0, 1.0)
        product_indices = node_indices[torch.isin(node_indices, graph.product_indices)]
        price_indices = node_indices[torch.isin(node_indices, graph.price_indices)]
        plot_bboxes(
            graph.x[product_indices, :4],
            ax=axs[3],
            color=color,
            linestyle="solid",
            width=width,
            height=height,
        )
        plot_bboxes(
            graph.x[price_indices, :4],
            ax=axs[3],
            color=color,
            linestyle="dashed",
            width=width,
            height=height,
        )

    axs[3].imshow(image)
    axs[3].axis("off")

    fig.tight_layout()
    fig.set_dpi(100)
    plt.show()

In [6]:
display_func = lambda idx, threshold: plot_scene_graph(
    idx=idx, dataset=dataset, model=model, threshold=threshold
)
idx_slider = widgets.IntSlider(
    value=0, min=0, max=len(dataset) - 1, step=1, description="Graph Index"
)
threshold_slider = widgets.FloatSlider(
    value=0.5, min=0, max=1, step=0.01, description="Threshold"
)
display(widgets.interact(display_func, idx=idx_slider, threshold=threshold_slider))

interactive(children=(IntSlider(value=0, description='Graph Index', max=188), FloatSlider(value=0.5, descripti…

<function __main__.<lambda>(idx, threshold)>