In [None]:
import torch
import yaml
from glob import glob
from pathlib import Path
import torch.nn.functional as F
import matplotlib.pyplot as plt
from postprocess import process_image, load_model, dict_to_namespace, plot_results
import numpy as np
import os

#### Load model

In [None]:
root = Path(os.getcwd())
stage2 = False
if stage2:
    exp_folder = root / "exp/res50_stage2_circles/checkpoints"
else:
    exp_folder = root / "exp/res50_stage1_circles/checkpoints"


In [None]:
cfg_file_name = "config_primitives_stage2.yaml" if stage2 else "config_primitives.yaml" 
with open(cfg_file_name, "r") as f:
    config_dict = yaml.safe_load(f)
cfg = dict_to_namespace(config_dict)
model = load_model(cfg, exp_folder)
model.eval()

#### Load image

In [None]:
diagram_name = "diagram13"
im_path = root / f"data/diagrams/images/{diagram_name}.png"
raw_img = plt.imread(im_path)[:,:,:3]

#### Predict

In [None]:
inputs, orig_size = process_image(raw_img)
outputs = model(inputs)
outputs = outputs["shapes"]
out_logits, out_line = outputs['pred_logits'], outputs['pred_shapes']
prob = F.softmax(out_logits, -1)[0,:,:-1]
threshold = 0.05


In [None]:
prob = F.softmax(out_logits, -1)[0,:,:-1]
keep = prob.max(-1).values > threshold
prob = prob[keep]
out_line = out_line[0, keep]

img_h, img_w = orig_size.unbind(0)
scale_fct = torch.unsqueeze(torch.stack([img_w, img_h, img_w, img_h], dim=0), dim=0)

lines = out_line * scale_fct[:, None, :]
lines = lines.view(len(out_line), 2, 2)
lines = lines.flip([-1])# this is yxyx format
lines = lines.reshape(lines.shape[0], -1)


In [None]:
def plot_results(ax, prob, boxes, thresh_line, thresh_circle):
    c = "green"
    for p, line in zip(prob, boxes):
        ymin, xmin, ymax, xmax = line.detach().numpy()
        cl = p.argmax()
        label = "line" if cl == 0 else "circle"
        if label == "line" and p[cl] > thresh_line:
            ax.plot([xmin, xmax], [ymin, ymax], c=c, linewidth=1)

        elif label == "circle" and p[cl] > thresh_circle:
            r1 = (xmax - xmin) / 2
            r2 = (ymax - ymin) / 2
            center = (xmin + r1, ymin + r2)
            ax.add_patch(plt.Circle(center, r2, color=c, fill=False, linewidth=1)) 

In [None]:
def show_results(img, prob, boxes, thresh_line, thresh_circle, relative=False, dpi=300, show_in_console=False, savedir=None, img_name=None, plot_img=True):
    plt.figure(dpi=dpi)
    plt.rcParams["font.size"] = "5"
    ax = plt.gca()
    if plot_img:
        ax.imshow(img)
    else: 
        ax.set_xlim([0, img.shape[1]])
        ax.set_ylim([img.shape[0], 0])
        ax.set_aspect('equal', adjustable='box')
    

    plot_results(ax, prob, boxes, thresh_line, thresh_circle)
    if show_in_console:
        plt.show()
    if savedir is not None: 
        savename = f"{savedir}/{img_name}"
        os.makedirs(os.path.dirname(savename), exist_ok=True)
        plt.axis("off")
        plt.savefig(savename, bbox_inches="tight", pad_inches=0)

In [None]:
show_results(raw_img.copy(), prob, lines, 0.7, 0.7, show_in_console=False, savedir = "real_predictions", img_name = os.path.basename(im_path))