In [None]:
from pathlib import Path

from copy import deepcopy
import pandas as pd
import yaml
import torch
from ipywidgets import widgets
from matplotlib import pyplot as plt
from IPython.display import display

from gpa.datasets.attribution import DetectionGraph, PriceAttributionDataset
from gpa.datamodules.attribution import PriceAttributionDataModule
from gpa.configs import TrainingConfig
from gpa.models.attributors import LightningPriceAttributor

In [None]:
chkp_root = Path("../chkp")
candidate_chkps = list(chkp_root.rglob("*.ckpt"))

if not candidate_chkps:
    raise FileNotFoundError("No checkpoint files found.")

dropdown = widgets.Dropdown(
    options=[(str(p.relative_to(chkp_root)), p) for p in candidate_chkps],
    description='Checkpoint:',
    style={'description_width': 'initial'},
    layout=widgets.Layout(width='60%')
)

output = widgets.Output()
display(dropdown, output)

def on_dropdown_change(change):
    if change["name"] == "value" and change["type"] == "change":
        chkp_path = change["new"]
        config_path = Path(str(chkp_path).replace("chkp", "logs")).with_name("config.yaml")
        with open(config_path, "r") as f:
            config = TrainingConfig(**yaml.safe_load(f))
        
        global model
        global serving_dataset
        global viewing_dataset
        global products_df
        global data_dir
        data_dir = Path("..") / config.dataset_dir
        datamodule = PriceAttributionDataModule(
            data_dir=data_dir, 
            use_visual_info=config.model.use_visual_info,
            aggregate_by_upc=config.model.aggregate_by_upc,
            use_spatially_invariant_coords=config.model.use_spatially_invariant_coords,
            initial_connection_strategy=config.model.initial_connection_strategy,
        )
        datamodule.setup("")
        serving_dataset = datamodule.test
        viewing_dataset = serving_dataset.copy()
        viewing_dataset.transform = None
        products_df = pd.read_csv(data_dir / "test" / "raw" / "product_boxes.csv", index_col="attributionset_id")
        model = LightningPriceAttributor.load_from_checkpoint(chkp_path).eval()

dropdown.observe(on_dropdown_change)

In [None]:
@torch.inference_mode()
def plot_scene_graph(
    idx: int,
    serving_dataset: PriceAttributionDataset,
    viewing_dataset: PriceAttributionDataset,
    model: LightningPriceAttributor,
    threshold: float = 0.5,
):
    plt.close()
    fig, axs = plt.subplots(1, 3, figsize=(12, 4), width_ratios=[1, 1, 2])
    graph: DetectionGraph = serving_dataset[idx]
    viewing_graph: DetectionGraph = viewing_dataset[idx]
    
    model.eval()
    model.to(device="cpu")
    graph.to(device="cpu")
    viewing_graph.to(device="cpu")
    
    actual_graph = deepcopy(viewing_graph)
    actual_graph.edge_index = graph.gt_prod_price_edge_index
    actual_graph.edge_attr = torch.ones(len(graph.gt_prod_price_edge_index[0]), 1)

    src, dst = torch.cartesian_prod(graph.product_indices, graph.price_indices).T
    edge_probs = model.forward(
        x=graph.x,
        edge_index=graph.edge_index,
        src=src,
        dst=dst,
        cluster_assignment=graph.get("upc_clusters"),
    ).sigmoid()
    keep = edge_probs > threshold

    viewing_graph.edge_index = torch.stack([src[keep], dst[keep]], dim=0)
    viewing_graph.edge_attr = edge_probs[keep].view(-1, 1)
    viewing_graph.plot(ax=axs[0], prod_price_only=True, mark_wrong_edges=True)
    axs[0].set_title("Predicted", fontsize=10)

    actual_graph.plot(ax=axs[1], prod_price_only=True)
    axs[1].set_title("Actual", fontsize=10)

    image_path = data_dir / products_df.loc[graph.graph_id]["local_path"].values[0]
    image = plt.imread(image_path)
    
    axs[2].imshow(image)
    axs[2].axis("off")
    
    fig.tight_layout()
    fig.set_dpi(100)
    plt.show()

In [None]:
display_func = lambda idx, threshold: plot_scene_graph(idx=idx, serving_dataset=serving_dataset, viewing_dataset=viewing_dataset, model=model, threshold=threshold)
idx_slider = widgets.IntSlider(value=0, min=0, max=len(serving_dataset)-1, step=1, description="Graph Index")
threshold_slider = widgets.FloatSlider(value=0.5, min=0, max=1, step=0.01, description="Threshold")
widgets.interact(display_func, idx=idx_slider, threshold=threshold_slider)