In [1]:
import io
from pathlib import Path

from copy import deepcopy
import pandas as pd
import yaml
import torch
from torch_geometric.utils import scatter
from ipywidgets import widgets
from matplotlib import pyplot as plt
from seaborn import color_palette
import imageio
from IPython.display import display
from tqdm import tqdm

from gpa.common.helpers import plot_bboxes, parse_into_subgraphs, edge_index_union, edge_index_diff
from gpa.datasets.attribution import DetectionGraph
from gpa.datasets.attribution import PriceAttributionDataset
from gpa.configs import TrainingConfig
from gpa.models.attributors import LightningPriceAttributor
from gpa.common.helpers import connect_products_with_nearest_price_tag

In [2]:
chkp_root = Path("../chkp")
candidate_chkps = list(chkp_root.rglob("*.ckpt"))

if not candidate_chkps:
    raise FileNotFoundError("No checkpoint files found.")

dropdown: widgets.Dropdown = widgets.Dropdown(
    options=[(str(p.relative_to(chkp_root)), p) for p in candidate_chkps],
    description='Checkpoint:',
    style={'description_width': 'initial'},
    layout=widgets.Layout(width='60%')
)

def get_chkp_and_config_paths() -> tuple[Path, Path]:
    display(dropdown)
    chkp_path: Path = dropdown.value
    config_path = Path(str(chkp_path).replace("chkp", "logs")).with_name("config.yaml")
    return chkp_path, config_path

chkp_path, config_path = get_chkp_and_config_paths()
with open(config_path, "r") as f:
    config = TrainingConfig(**yaml.safe_load(f))
dataset_dir = Path("..") / config.dataset_dir / "val"
dataset = PriceAttributionDataset(root=dataset_dir)
products_df = pd.read_csv(dataset_dir / "raw"/ "product_boxes.csv", index_col="attributionset_id")
model = LightningPriceAttributor.load_from_checkpoint(chkp_path).eval()

Dropdown(description='Checkpoint:', layout=Layout(width='60%'), options=(('0.0.1/version_0/last.ckpt', PosixPa…

In [3]:
@torch.inference_mode()
def iterative_inference_v2(graph: DetectionGraph, model: LightningPriceAttributor, max_iter: int = 50, threshold: float = 0.5):
    graphs = [graph.clone()]

    for _ in range(max_iter):
        src, dst = edge_index_diff(torch.cartesian_prod(graph.product_indices, graph.price_indices).T, graph.edge_index)
        edge_probs = model.forward(
            x=graph.x,
            edge_index=graph.edge_index,
            src=src,
            dst=dst,
            cluster_assignment=graph.upc_clusters,
        ).sigmoid()

        max_avg_prob = -torch.inf
        max_cluster_price_idx_pair = None
        for price_idx in dst.unique():
            assoc_src = src[dst == price_idx]
            assoc_upc_clusters = graph.upc_clusters[assoc_src]
            assoc_edge_probs = edge_probs[dst == price_idx]
            scattered = scatter(assoc_edge_probs, assoc_upc_clusters, dim=0, reduce="mean")
            max_cluster = torch.argmax(scattered)
            max_cluster_prob = scattered[max_cluster]
            if max_cluster_prob > max_avg_prob and max_cluster_prob > threshold:
                max_avg_prob = max_cluster_prob
                max_cluster_price_idx_pair = (max_cluster, price_idx)

        if max_cluster_price_idx_pair is None:
            break
        cluster_idx, price_idx = max_cluster_price_idx_pair
        src_in_cluster = torch.where(graph.upc_clusters == cluster_idx)[0]
        new_edges = torch.stack([src_in_cluster, price_idx.repeat(len(src_in_cluster))])
        graph.edge_index = edge_index_union(graph.edge_index, new_edges)
        graphs.append(graph.clone())
    return graphs

In [4]:
@torch.inference_mode()
def iterative_inference(
    graph: DetectionGraph,
    model: LightningPriceAttributor,
    threshold: float = 0.5,
    max_iter: int = 50,
) -> list[DetectionGraph]:
    model.eval()
    graphs = [graph.clone()]

    for _ in range(max_iter):
        src, dst = edge_index_diff(torch.cartesian_prod(graph.product_indices, graph.price_indices).T, graph.edge_index)
        edge_probs = model.forward(
            x=graph.x,
            edge_index=graph.edge_index,
            src=src,
            dst=dst,
            cluster_assignment=graph.upc_clusters,
        ).sigmoid()
        # Choose the UPC group -> price connection with the highest probability.
        
        src_upc_clusters = graph.upc_clusters[src]
        scatter(edge_probs, src_upc_clusters, dim=0, reduce="sum")

        keep = edge_probs > threshold
        if keep.sum() == 0:
            break
        new_edges = torch.stack([src[keep], dst[keep]], dim=0)
        graph.edge_index = edge_index_union(graph.edge_index, new_edges)
        graphs.append(graph.clone())
    return graphs

In [6]:
index_selector: widgets.IntSlider = widgets.IntSlider(
    value=0, min=0, max=len(dataset)-1, step=1,
    description="Index",
    continuous_update=False
)
generate_button: widgets.Button = widgets.Button(description="Generate Video")
output: widgets.Output = widgets.Output()

def generate_video(_):
    idx = index_selector.value
    initial_graph = dataset[idx]
    graphs = iterative_inference_v2(initial_graph, model, threshold=0.5)
    image_path = dataset_dir.parent / products_df.loc[initial_graph.graph_id]["local_path"].values[0]
    image = plt.imread(image_path)
    
    video_path = Path(f"inference-videos/graph_{idx}.mp4")
    video_path.parent.mkdir(parents=True, exist_ok=True)
    
    with output:
        output.clear_output(wait=True)
        with imageio.get_writer(video_path, fps=2) as writer:
            for t, graph in tqdm(enumerate(graphs), desc="Generating video...", total=len(graphs)):
                buffer = io.BytesIO()
                fig, axs = plt.subplots(1, 2, dpi=100)
                graph.plot(ax=axs[0], prod_price_only=True)
                axs[0].set_title(f"t={t}")
                axs[1].imshow(image)
                axs[1].axis("off")
                
                fig.savefig(buffer, format="png")
                plt.close(fig)
                buffer.seek(0)
                writer.append_data(imageio.v2.imread(buffer))
                buffer.close()
    
generate_button.on_click(generate_video)
display(index_selector, generate_button, output)

IntSlider(value=0, continuous_update=False, description='Index', max=188)

Button(description='Generate Video', style=ButtonStyle())

Output()

### Old Code

In [5]:
@torch.inference_mode()
def plot_scene_graph(idx: int, dataset: PriceAttributionDataset, model: LightningPriceAttributor, threshold: float = 0.5):
    plt.close()
    fig, axs = plt.subplots(1, 4, figsize=(15, 4), width_ratios=[1, 1, 1, 2])
    graph: DetectionGraph = dataset[idx]
    
    model.eval()
    model.to(device="cpu")
    graph.to(device="cpu")
    
    actual_graph = deepcopy(graph)
    actual_graph.edge_index = graph.gt_prod_price_edge_index
    actual_graph.edge_attr = torch.ones(len(graph.gt_prod_price_edge_index[0]), 1)
    
    heuristic_graph = deepcopy(graph)
    src, dst = connect_products_with_nearest_price_tag(
        centroids=graph.x[:, :2],
        product_indices=graph.product_indices,
        price_indices=graph.price_indices,
    )
    heuristic_graph.edge_index = torch.stack([src, dst], dim=0)
    heuristic_graph.edge_attr = torch.ones(len(src), 1)
    heuristic_graph.plot(ax=axs[0], prod_price_only=True, mark_wrong_edges=True)
    axs[0].set_title("Heuristic", fontsize=10)

    edge_probs = model.forward(
        x=graph.x,
        edge_index=graph.edge_index,
        edge_attr=graph.edge_attr,
        src=src,
        dst=dst,
    ).sigmoid()
    keep = edge_probs > threshold
    graph.edge_index = torch.stack([src[keep], dst[keep]], dim=0)
    graph.edge_attr = edge_probs[keep].view(-1, 1)
    graph.plot(ax=axs[1], prod_price_only=True, mark_wrong_edges=True)
    axs[1].set_title("Pruned", fontsize=10)

    actual_graph.plot(ax=axs[2], prod_price_only=True)
    axs[2].set_title("Actual", fontsize=10)

    image_path = dataset_dir.parent / products_df.loc[graph.graph_id]["local_path"].values[0]
    image = plt.imread(image_path)

    height, width = image.shape[:2]

    subgraph_indices = parse_into_subgraphs(graph.gt_prod_price_edge_index, graph.num_nodes)
    subgraph_ids = torch.unique(subgraph_indices)
    colors = color_palette(n_colors=len(subgraph_ids))
    for i, color in zip(subgraph_ids, colors):
        node_indices = torch.argwhere(subgraph_indices == i).flatten()
        if len(node_indices) == 1:
            color = (1., 1., 1.)
        product_indices = node_indices[torch.isin(node_indices, graph.product_indices)]
        price_indices = node_indices[torch.isin(node_indices, graph.price_indices)]
        plot_bboxes(graph.x[product_indices, :4], ax=axs[3], color=color, linestyle="solid", width=width, height=height)
        plot_bboxes(graph.x[price_indices, :4], ax=axs[3], color=color, linestyle="dashed", width=width, height=height)
    
    axs[3].imshow(image)
    axs[3].axis("off")
    
    fig.tight_layout()
    fig.set_dpi(100)
    plt.show()

In [6]:
display_func = lambda idx, threshold: plot_scene_graph(idx=idx, dataset=dataset, model=model, threshold=threshold)
idx_slider = widgets.IntSlider(value=0, min=0, max=len(dataset)-1, step=1, description="Graph Index")
threshold_slider = widgets.FloatSlider(value=0.5, min=0, max=1, step=0.01, description="Threshold")
display(widgets.interact(display_func, idx=idx_slider, threshold=threshold_slider))

interactive(children=(IntSlider(value=0, description='Graph Index', max=188), FloatSlider(value=0.5, descripti…

<function __main__.<lambda>(idx, threshold)>