In [None]:
from pathlib import Path

from copy import deepcopy
import pandas as pd
import yaml
import torch
from ipywidgets import widgets
from matplotlib import pyplot as plt
from seaborn import color_palette
from IPython.display import display

from gpa.common.helpers import plot_bboxes, parse_into_subgraphs
from gpa.datasets.attribution import DetectionGraph
from gpa.datasets.attribution import PriceAttributionDataset
from gpa.configs import TrainingConfig
from gpa.models.attributors import LightningPriceAttributor

In [None]:
chkp_root = Path("../chkp")
candidate_chkps = list(chkp_root.rglob("*.ckpt"))

if not candidate_chkps:
    raise FileNotFoundError("No checkpoint files found.")

dropdown = widgets.Dropdown(
    options=[(str(p.relative_to(chkp_root)), p) for p in candidate_chkps],
    description='Checkpoint:',
    style={'description_width': 'initial'},
    layout=widgets.Layout(width='60%')
)

output = widgets.Output()
display(dropdown, output)

def on_dropdown_change(change):
    if change["name"] == "value" and change["type"] == "change":
        chkp_path = change["new"]
        config_path = Path(str(chkp_path).replace("chkp", "logs")).with_name("config.yaml")
        with open(config_path, "r") as f:
            config = TrainingConfig(**yaml.safe_load(f))
        global model
        global dataset
        global dataset_dir
        global products_df
        dataset_dir = Path("..") / config.dataset_dir / "val"
        dataset = PriceAttributionDataset(dataset_dir)
        products_df = pd.read_csv(dataset_dir / "raw" / "product_boxes.csv", index_col="attributionset_id")
        model = LightningPriceAttributor.load_from_checkpoint(chkp_path).eval()

dropdown.observe(on_dropdown_change)

In [None]:
@torch.inference_mode()
def plot_scene_graph(idx: int, dataset: PriceAttributionDataset, model: LightningPriceAttributor, threshold: float = 0.5):
    plt.close()
    fig, axs = plt.subplots(1, 3, figsize=(12, 4), width_ratios=[1, 1, 2])
    graph: DetectionGraph = dataset[idx]
    
    model.eval()
    model.to(device="cpu")
    graph.to(device="cpu")
    
    actual_graph = deepcopy(graph)
    actual_graph.edge_index = graph.gt_prod_price_edge_index
    actual_graph.edge_attr = torch.ones(len(graph.gt_prod_price_edge_index[0]), 1)

    src, dst = torch.cartesian_prod(graph.product_indices, graph.price_indices).T
    if model.use_visual_info:
        edge_probs = model.forward(
            x=graph.x,
            edge_index=graph.edge_index,
            z=graph.global_embedding,
            src=src,
            dst=dst,
            cluster_assignment=graph.upc_clusters,
        ).sigmoid()
    else:
        edge_probs = model.forward(
            x=torch.cat([graph.x[:, :4], graph.x[:, -1].view(-1, 1)], dim=1),
            edge_index=graph.edge_index,
            z=graph.global_embedding,
            src=src,
            dst=dst,
            cluster_assignment=graph.upc_clusters,
        ).sigmoid()
    keep = edge_probs > threshold
    graph.edge_index = torch.stack([src[keep], dst[keep]], dim=0)
    graph.edge_attr = edge_probs[keep].view(-1, 1)
    graph.plot(ax=axs[0], prod_price_only=True, mark_wrong_edges=True)
    axs[0].set_title("Predicted", fontsize=10)

    actual_graph.plot(ax=axs[1], prod_price_only=True)
    axs[1].set_title("Actual", fontsize=10)

    image_path = dataset_dir.parent / products_df.loc[graph.graph_id]["local_path"].values[0]
    image = plt.imread(image_path)

    height, width = image.shape[:2]

    subgraph_indices = parse_into_subgraphs(graph.gt_prod_price_edge_index, graph.num_nodes)
    subgraph_ids = torch.unique(subgraph_indices)
    colors = color_palette(n_colors=len(subgraph_ids))
    for i, color in zip(subgraph_ids, colors):
        node_indices = torch.argwhere(subgraph_indices == i).flatten()
        if len(node_indices) == 1:
            color = (1., 1., 1.)
        product_indices = node_indices[torch.isin(node_indices, graph.product_indices)]
        price_indices = node_indices[torch.isin(node_indices, graph.price_indices)]
        plot_bboxes(graph.x[product_indices, :4], ax=axs[2], color=color, linestyle="solid", width=width, height=height)
        plot_bboxes(graph.x[price_indices, :4], ax=axs[2], color=color, linestyle="dashed", width=width, height=height)
    
    axs[2].imshow(image)
    axs[2].axis("off")
    
    fig.tight_layout()
    fig.set_dpi(100)
    plt.show()

In [None]:
display_func = lambda idx, threshold: plot_scene_graph(idx=idx, dataset=dataset, model=model, threshold=threshold)
idx_slider = widgets.IntSlider(value=0, min=0, max=len(dataset)-1, step=1, description="Graph Index")
threshold_slider = widgets.FloatSlider(value=0.5, min=0, max=1, step=0.01, description="Threshold")
widgets.interact(display_func, idx=idx_slider, threshold=threshold_slider)