In [None]:
import os
import torch
from pathlib import Path
import matplotlib.pyplot as plt
from matplotlib.patches import Ellipse, Arc
import numpy as np
import json
import glob
from PIL import Image
from evaluate import get_l2_distance, get_l2_distance_circles, get_l2_distance_arcs

In [None]:
root_dir = Path('../../')
gt_dir = root_dir / "data/eida_dataset"
anno_file = os.path.join(gt_dir, "valid.json")

with open(anno_file, "r") as f:
    dataset = json.load(f)

# im_name = 'ms20_0099_297,2209,2186,1746' # arcs are bad and evaluation is good
# im_name = 'ms9_0517_165,1954,429,166' # good evaluation for arcs 
# im_name = 'ms20_0049_755,200,2753,2833' # good evaluation for circles
im_name = 'ms154_0026_618,1636,1288,1284' # good evaluation for circles
# im_name = 'ms40_0029_124,598,391,384'  # one bad evaluation for circles
# im_name = 'ms3_0021_411,270,178,218' # bad evaluation for arcs
# im_name = 'ms35_0067_1335,4176,879,836' # good example for lines
# im_name = 'ms102_0330_1569,3402,763,757'
# im_name = 'ms138_0019_832,493,864,848'
# im_name = 'ms9_0507_29,681,1169,1002'
# im_name = 'ms13_0045_1339,578,913,783'
# im_name = 'ms155_0056_178,878,1368,1376'
# im_name = 'ms150_0282_1176,365,607,622'
# im_name = "ms2_0232_202,1199,964,915" # ARC EVAL EXAMPLE
# im_name = 'ms102_0151_1482,1176,484,484' # LINE EVAL EXAMPLE
# im_name = 'ms150_0282_1176,365,607,622' # CIRCLE EVAL EXAMPLE
# im_name = 'ms35_0036_2647,1393,1046,1181'
im_name = 'ms156_0124_217,1233,769,356'
for data in dataset:
    if im_name in data["filename"]:
        break
gt_path = gt_dir / f"valid_labels/{im_name}.npz"


In [None]:

def scale_positions(lines, heatmap_scale=(128, 128), im_shape=None):
    if len(lines) == 0:
        return []
    fx, fy = heatmap_scale[0] / im_shape[0], heatmap_scale[1] / im_shape[1]

    lines[:, :, 0] = np.clip(lines[:, :, 0] * fx, 0, heatmap_scale[0] - 1e-4)
    lines[:, :, 1] = np.clip(lines[:, :, 1] * fy, 0, heatmap_scale[1] - 1e-4)

    return lines


In [None]:
def get_pred_gt_match(primitive_pred, primitive_gt, primitive_k=0): 
    if primitive_k==0:
        diff = get_l2_distance(primitive_pred, primitive_gt)
    elif primitive_k==1:
        diff = get_l2_distance_circles(primitive_pred, primitive_gt)
    else:
        diff = get_l2_distance_arcs(primitive_pred, primitive_gt)

    choice = np.argmin(diff, 1)
    dist = np.min(diff, 1)
    return dist, choice

In [None]:
# gt_circles = get_bbox_from_center_radii(data["circle_centers"], data["circle_radii"])
# im = Image.open(os.path.join(gt_dir, "valid_labels", f'{im_name}.png')).convert("RGB")
im = Image.open(os.path.join(gt_dir, "images", f'{im_name}.jpg')).convert("RGB")
prefix = data["filename"].split(".")[0]
im_shape = im.size
aspect_ratio = im_shape[0] / im_shape[1]

if im_shape[0] > im_shape[1]:
    im_rescale = (512, int(512 / aspect_ratio))
else:
    im_rescale = (int(512 * aspect_ratio), 512)
# im_rescale = (512, 512)
heatmap_scale = (128, 128)
im = im.resize(im_rescale)



In [None]:

fgt = np.load(gt_path)
gt_lines, gt_circles, gt_arcs = fgt['lines'], fgt['circles'], fgt['arcs']

model_folder = root_dir / "logs/main_model"
epoch = '0036'
fpred = np.load(model_folder / f"npz_preds{epoch}/{im_name}.npz")

lines, line_scores, circles, circle_scores = fpred["lines"], fpred["line_scores"], fpred["circles"], fpred["circle_scores"]
arcs, arc_scores = fpred["arcs"], fpred["arc_scores"]
threshold = 0.3
mask = line_scores > threshold
lines, line_scores = lines[mask], line_scores[mask]
mask = circle_scores > threshold
circles, circle_scores = circles[mask], circle_scores[mask]
mask = arc_scores > threshold
arcs, arc_scores = arcs[mask], arc_scores[mask]
preds = [lines, circles, arcs]
gts = [gt_lines, gt_circles, gt_arcs]


In [None]:
gts

In [None]:

def get_4_cardinal_pts_circle(pred_circles): 
    pred_circles_centers = (pred_circles[:, 0, :] + pred_circles[:, 1, :] ) / 2 
    pred_circles_min_x = np.minimum(pred_circles[:, 0, 0], pred_circles[:, 1, 0])
    pred_circles_max_x = np.maximum(pred_circles[:, 0, 0], pred_circles[:, 1, 0])
    pred_circles_min_y = np.minimum(pred_circles[:, 0, 1], pred_circles[:, 1, 1])
    pred_circles_max_y = np.maximum(pred_circles[:, 0, 1], pred_circles[:, 1, 1])
    pt1 = pred_circles_centers.copy()
    pt1[:, 0] = pred_circles_min_x
    pt2 = pred_circles_centers.copy()
    pt2[:, 0] = pred_circles_max_x
    pt3 = pred_circles_centers.copy()
    pt3[:, 1] = pred_circles_min_y
    pt4 = pred_circles_centers.copy()
    pt4[:, 1] = pred_circles_max_y
    return np.stack([pt1, pt2, pt3, pt4], axis=1)

In [None]:
primitive_dict = {0: "line", 1: "circle", 2: "arc"}
show_primitives = [True, True, True]
dist_thresh = 4


In [None]:
image = im
gt_lines, gt_circles, gt_arcs = gts
# return cl_circles
gt_lines = scale_positions(gt_lines.copy(), im_rescale, heatmap_scale)
gt_circles = scale_positions(gt_circles.copy(), im_rescale, heatmap_scale)
gt_arcs = scale_positions(gt_arcs.copy(), im_rescale, heatmap_scale)
dists, choices = [], []

show_lines, show_circles, show_arcs = show_primitives
def is_large_arc(rad_angle):
    if rad_angle[0] <= np.pi:
        return not (rad_angle[0] < rad_angle[1] < (np.pi + rad_angle[0]))
    return (rad_angle[0] - np.pi) < rad_angle[1] < rad_angle[0]
fig, ax = plt.subplots(figsize=(8, 8))  # Create a new figure for each subplot
ax.imshow(image, aspect='equal')
for (p0, p1) in gt_lines:
    for point, color in [((p0, p1), 'g')]:
        ax.scatter(point[0][0], point[0][1], color=color, s=80)
        ax.scatter(point[1][0], point[1][1], color=color, s=80)
        ax.plot([point[0][0], point[1][0]], [point[0][1], point[1][1]], linewidth=3, c=color)

        # # Save the subplot
        # if num_line in num_line_to_save:
        #     plt.savefig(f'eval_distance_visu/{im_name}_{num_line}_line_{dist:.2f}.pdf', bbox_inches='tight', format='pdf')
        # plt.savefig(f'subplot_{num_line}.pdf', bbox_inches='tight', format='pdf')
        # plt.close(fig)  # Close the figure

In [None]:
image = im
lines, circles, arcs = preds
gt_lines, gt_circles, gt_arcs = gts
cl_circles = get_4_cardinal_pts_circle(circles)
# return cl_circles
lpos_cl = scale_positions(cl_circles.copy(), im_rescale, heatmap_scale)
lpos_l = scale_positions(lines.copy(), im_rescale, heatmap_scale)
lpos_c = scale_positions(circles.copy(), im_rescale, heatmap_scale)
lpos_a = scale_positions(arcs.copy(), im_rescale, heatmap_scale)
gt_lines = scale_positions(gt_lines.copy(), im_rescale, heatmap_scale)
gt_circles = scale_positions(gt_circles.copy(), im_rescale, heatmap_scale)
gt_arcs = scale_positions(gt_arcs.copy(), im_rescale, heatmap_scale)
dists, choices = [], []
for k, (preds_primitive, gts_primitive) in enumerate(zip(preds, gts)):
    if len(preds_primitive) == 0:
        print(f"no {primitive_dict[k]} detected")
        show_primitives[k] = False
        dists.append(0)
        choices.append(0)
        continue
    if len(gts_primitive) == 0:
        print(f"no {primitive_dict[k]} in ground truth")
        show_primitives[k] = False
        dists.append(0)
        choices.append(0)
        continue
    dist, choice = get_pred_gt_match(preds_primitive, gts_primitive, primitive_k = k)
    
    dists.append(dist)
    choices.append(choice)
show_lines, show_circles, show_arcs = show_primitives
def is_large_arc(rad_angle):
    if rad_angle[0] <= np.pi:
        return not (rad_angle[0] < rad_angle[1] < (np.pi + rad_angle[0]))
    return (rad_angle[0] - np.pi) < rad_angle[1] < rad_angle[0]


In [None]:
num_line_to_save = []
if not show_lines:
    print("No lines detected")
else:
    counter_ = [0,0]
    for num_line, (p0, p1) in enumerate(lpos_l):
        
        dist = dists[0][num_line]
        if dist > 10:
            continue
        result = "TP" if dist < dist_thresh else "FP"
        print('num_line and dist', num_line, dist)
        fig, ax = plt.subplots(figsize=(8, 8))  # Create a new figure for each subplot
        ax.imshow(image, aspect='equal')

        points = [((p0, p1), 'g'), (gt_lines[choices[0][num_line]], 'firebrick')]
        for point, color in points:
            ax.scatter(point[0][0], point[0][1], color=color, s=80)
            ax.scatter(point[1][0], point[1][1], color=color, s=80)
            ax.plot([point[0][0], point[1][0]], [point[0][1], point[1][1]], linewidth=3, c=color)
        ax.set_title(f" $\delta =$ {dist:.2f}", fontsize=40)
        ax.set_xticklabels([])
        ax.set_yticklabels([])
        ax.tick_params(bottom=False, left=False)  # Remove ticks
        ax.spines['top'].set_visible(False)
        ax.spines['right'].set_visible(False)
        ax.spines['bottom'].set_visible(False)
        ax.spines['left'].set_visible(False)

        # Save the subplot
        if num_line in num_line_to_save:
            plt.savefig(f'eval_distance_visu/{im_name}_{num_line}_line_{dist:.2f}.pdf', bbox_inches='tight', format='pdf')
        # plt.savefig(f'subplot_{num_line}.pdf', bbox_inches='tight', format='pdf')
        # plt.close(fig)  # Close the figure

In [None]:
num_line_to_save = [4]

if not show_circles:
    print("No circles detected")
else:

    counter_ = [0,0]
    for num_circle, (p0_c, p1_c) in enumerate(lpos_c):

        dist = dists[1][num_circle]
        if dist > 10:
            continue
        result = "TP" if dist < dist_thresh else "FP"
        print('num_circle and dist', num_circle, dist)
        fig, ax = plt.subplots(figsize=(8, 8))  # Create a new figure for each subplot        
        ax.imshow(image, aspect='equal')
        ax.set_title(f" $\delta =$ {dist:.2f}", fontsize=40)

        data = [([p0_c, p1_c], 'g'), (gt_circles[choices[1][num_circle]], 'firebrick')]

        for (p0_c, p1_c), color in data:
            e = get_4_cardinal_pts_circle(np.array([[p0_c, p1_c]]))
            (p0_cl, p1_cl, p2_cl, p3_cl)  = e[0]
            center = (p0_c[0] + p1_c[0]) / 2, (p0_c[1] + p1_c[1]) / 2
            d1 = np.abs(p1_c[0] - p0_c[0]) 
            d2 = np.abs(p1_c[1] - p0_c[1]) 
            ax.add_patch(Ellipse(center, d1, d2, color=color, fill=False, linewidth=3))
            ax.scatter(p0_cl[0], p0_cl[1], color=color, s=80)
            ax.scatter(p1_cl[0], p1_cl[1], color=color, s=80)
            ax.scatter(p2_cl[0], p2_cl[1], color=color, s=80)
            ax.scatter(p3_cl[0], p3_cl[1], color=color, s=80)

        ax.set_xticklabels([])
        ax.set_yticklabels([])
        ax.tick_params(bottom=False, left=False)  # Remove ticks
        ax.spines['top'].set_visible(False)
        ax.spines['right'].set_visible(False)
        ax.spines['bottom'].set_visible(False)
        ax.spines['left'].set_visible(False)
        if num_circle in num_line_to_save:
            plt.savefig(f'eval_distance_visu/{im_name}_{num_circle}_circle_{dist:.2f}.pdf', bbox_inches='tight', format='pdf')
# plt.savefig('circle_evaluation.pdf', bbox_inches='tight', format='pdf')

In [None]:
def find_circle_center(p1, p2, p3):
    """Circle center from 3 points"""
    # print(p1, p2, p3)
    temp = p2[0] * p2[0] + p2[1] * p2[1]
    bc = (p1[0] * p1[0] + p1[1] * p1[1] - temp) / 2
    cd = (temp - p3[0] * p3[0] - p3[1] * p3[1]) / 2
    det = (p1[0] - p2[0]) * (p2[1] - p3[1]) - (p2[0] - p3[0]) * (p1[1] - p2[1])
    if abs(det) < 1.0e-10:
        return (None, None)

    cx = (bc * (p2[1] - p3[1]) - cd * (p1[1] - p2[1])) / det
    cy = ((p1[0] - p2[0]) * cd - (p2[0] - p3[0]) * bc) / det
    return np.array([cx, cy])


def get_angles_from_arc_points(p0, p_mid, p1):
    arc_center = find_circle_center(p0, p_mid, p1)
    arc_center = (arc_center[0], arc_center[1])
    start_angle = np.arctan2(p0[1] - arc_center[1], p0[0] - arc_center[0])
    end_angle = np.arctan2(p1[1] - arc_center[1], p1[0] - arc_center[0])
    mid_angle = np.arctan2(p_mid[1] - arc_center[1], p_mid[0] - arc_center[0])
    return start_angle, mid_angle, end_angle, arc_center


def get_arc_plot_params(arc):
    start_angle, mid_angle, end_angle, arc_center = get_angles_from_arc_points(
        arc[:2],
        arc[4:],
        arc[2:4],
    )
    # print(start_angle, mid_angle, end_angle)
    diameter = 2 * np.linalg.norm(arc[:2] - arc_center)
    to_deg = lambda x: (x * 180 / np.pi) % 360
    start_angle, mid_angle, end_angle = (
        to_deg(start_angle),
        to_deg(mid_angle),
        to_deg(end_angle),
    )
    # print("angles", start_angle, mid_angle, end_angle)
    return start_angle, mid_angle, end_angle, arc_center, diameter


In [None]:
num_arcs_to_save = [1,3]
counter_ = [0,0]
if not show_arcs:
    print("no arcs detected")
else:

    for num_arc, (p0, p1, p2) in enumerate(lpos_a): 
        dist = dists[2][num_arc]
        if dist > 10:
            continue
        result = "TP" if dist < dist_thresh else "FP"
        fig, ax = plt.subplots( figsize=(8, 8))
        # if result == "TP":
        #     counter_[0] += 1
        #     k = 0
        #     if counter_[0] > 1:
        #         continue
        # if result == "FP":
        #     if dist > 10:
        #         continue
        #     counter_[1] += 1
        #     k = 1
        #     if counter_[1] > 1:
        #         continue
        ax.imshow(image, aspect='equal')

        # if dist > 10:
        #     continue 
        print('num_arc and dist', num_arc, dist)

        ax.set_title(f" $\delta =$ {dist:.2f}", fontsize=40)
        ax.set_xticklabels([])
        ax.set_yticklabels([])
        ax.tick_params(bottom=False, left=False)  # Remove ticks
        ax.spines['top'].set_visible(False)
        ax.spines['right'].set_visible(False)
        ax.spines['bottom'].set_visible(False)
        ax.spines['left'].set_visible(False)
        arc_1 = (p0, p1, p2)
        arc_2 = gt_arcs[choices[2][num_arc]]
        linewidth = 3

        for color, (p0, p1, p2) in zip(['g', 'firebrick'], [arc_1, arc_2]):
            c = color
            for p in [p0, p1, p2]:
                ax.scatter(p[0], p[1], c=color, s = 80)

            arc = np.concatenate([p0, p1, p2])

            theta1, theta_mid, theta2, c_xy, diameter = get_arc_plot_params(arc)

            if theta_mid < theta1 and theta_mid > theta2:
                theta1, theta2 = theta2, theta1
            to_rad = lambda x: (x * np.pi / 180) % (2 * np.pi)
            if not is_large_arc([to_rad(theta1), to_rad(theta_mid)]):
                arc_patch_1 = Arc(
                    c_xy,
                    diameter,
                    diameter,
                    angle=0.0,
                    theta1=theta1,
                    theta2=theta_mid,
                    fill=None,
                    color=c,
                    linewidth=linewidth,
                )
            else:
                arc_patch_1 = Arc(
                    c_xy,
                    diameter,
                    diameter,
                    angle=0.0,
                    theta1=theta_mid,
                    theta2=theta1,
                    fill=None,
                    color=c,
                    # color="black",
                    linewidth=linewidth,
                )
            ax.add_patch(arc_patch_1)

            if not is_large_arc([to_rad(theta_mid), to_rad(theta2)]):
                arc_patch_2 = Arc(
                    c_xy,
                    diameter,
                    diameter,
                    angle=0.0,
                    theta1=theta_mid,
                    theta2=theta2,
                    fill=None,
                    color=c,
                    linewidth=linewidth,
                )

            else:
                arc_patch_2 = Arc(
                    c_xy,
                    diameter,
                    diameter,
                    angle=0.0,
                    theta1=theta2,
                    theta2=theta_mid,
                    fill=None,
                    color=c,
                    # color="black",
                    linewidth=linewidth,
                )
            ax.add_patch(arc_patch_2)
        if num_arc in num_arcs_to_save:
            plt.savefig(f'eval_distance_visu/{im_name}_{num_arc}_arc_{dist:.2f}.pdf', bbox_inches='tight', format='pdf')

# plt.savefig('arc_evaluation.pdf', bbox_inches='tight', format='pdf')

            

In [None]:
def plot_image_and_primitives(image, preds, gts, im_rescale, heatmap_scale):
    lines, circles, arcs = preds
    gt_lines, gt_circles, gt_arcs = gts
    cl_circles = get_4_cardinal_pts_circle(circles)
    # return cl_circles
    lpos_cl = scale_positions(cl_circles.copy(), im_rescale, heatmap_scale)
    lpos_l = scale_positions(lines.copy(), (im_rescale), heatmap_scale)
    lpos_c = scale_positions(circles.copy(), im_rescale, heatmap_scale)
    lpos_a = scale_positions(arcs.copy(), im_rescale, heatmap_scale)
    gt_lines = scale_positions(gt_lines.copy(), im_rescale, heatmap_scale)
    gt_circles = scale_positions(gt_circles.copy(), im_rescale, heatmap_scale)
    gt_arcs = scale_positions(gt_arcs.copy(), im_rescale, heatmap_scale)
    dists, choices = [], []
    for k, (preds_primitive, gts_primitive) in enumerate(zip(preds, gts)):
        if len(preds_primitive) == 0:
            print(f"no {primitive_dict[k]} detected")
            show_primitives[k] = False
            dists.append(0)
            choices.append(0)
            continue
        if len(gts_primitive) == 0:
            print(f"no {primitive_dict[k]} in ground truth")
            show_primitives[k] = False
            dists.append(0)
            choices.append(0)
            continue
        dist, choice = get_pred_gt_match(preds_primitive, gts_primitive, primitive_k = k)
        
        dists.append(dist)
        choices.append(choice)

    print(len(choices[0]))
    print(len(lines))
    show_lines, show_circles, show_arcs = show_primitives

    if show_lines: 
        fig, ax = plt.subplots(1, 2, figsize=(12, 6))
        counter_ = [0,0]
        for num_line, (p0, p1) in enumerate(lpos_l):

            dist = dists[0][num_line]
            result = "TP" if dist < dist_thresh else "FP"
            if result == "TP":
                counter_[0] += 1
                k = 0
                if counter_[0] > 1:
                    continue
            if result == "FP":
                print('in FP')
                counter_[1] += 1
                k = 1
                if counter_[1] > 1:
                    continue
            ax[k].imshow(image, aspect='equal')

            points = [((p0, p1), 'g'), (gt_lines[choices[0][num_line]], 'firebrick')]
            for point, color in points:
                ax[k].scatter(point[0][0], point[0][1], color=color, s=80)
                ax[k].scatter(point[1][0], point[1][1], color=color, s=80)
                ax[k].plot([point[0][0], point[1][0]], [point[0][1], point[1][1]], linewidth=3, c=color)
            ax[k].set_title(f"Distance: {dist:.2f}, {result}", fontsize=20)
            ax[k].set_xticklabels([])
            ax[k].set_yticklabels([])
            ax[k].tick_params(bottom=False, left=False)  # Remove ticks
            ax[k].spines['top'].set_visible(False)
            ax[k].spines['right'].set_visible(False)
            ax[k].spines['bottom'].set_visible(False)
            ax[k].spines['left'].set_visible(False)
        
            # if num_line > 3:
            #     break


        
    if show_circles: 
        fig, ax = plt.subplots(1, 2, figsize=(12, 6))
        counter_ = [0,0]
        for num_circle, (p0_c, p1_c) in enumerate(lpos_c):

            dist = dists[1][k]
            result = "TP" if dist < dist_thresh else "FP"
            if result == "TP":
                counter_[0] += 1
                if counter_[0] > 2:
                    continue
            if result == "FP":
                counter_[1] += 1
                if counter_[1] > 2:
                    continue
            ax[k].imshow(image, aspect='equal')
            ax[k].set_title(f"Distance: {dist:.2f}, {result}", fontsize=20)

            data = [([p0_c, p1_c], 'g'), (gt_circles[choices[1][k]], 'firebrick')]

            for (p0_c, p1_c), color in data:
                e = get_4_cardinal_pts_circle(np.array([[p0_c, p1_c]]))
                (p0_cl, p1_cl, p2_cl, p3_cl)  = e[0]
                center = (p0_c[0] + p1_c[0]) / 2, (p0_c[1] + p1_c[1]) / 2
                d1 = np.abs(p1_c[0] - p0_c[0]) 
                d2 = np.abs(p1_c[1] - p0_c[1]) 
                ax[k].add_patch(Ellipse(center, d1, d2, color=color, fill=False, linewidth=3))
                ax[k].scatter(p0_cl[0], p0_cl[1], color=color)
                ax[k].scatter(p1_cl[0], p1_cl[1], color=color)
                ax[k].scatter(p2_cl[0], p2_cl[1], color=color)
                ax[k].scatter(p3_cl[0], p3_cl[1], color=color)

            ax[k].spines['top'].set_visible(False)
            ax[k].spines['right'].set_visible(False)
            ax[k].spines['bottom'].set_visible(False)
            ax[k].spines['left'].set_visible(False)
            if k == 3:
                break




    # if show_circles: 
    #     fig, ax = plt.subplots(1, 4, figsize=(20, 6))
    #     for k, (p0, p1) in enumerate(lpos_c): 
    #         ax[k].imshow(image, aspect='equal')
    #         ax[k].set_xlim(0, 512)
    #         ax[k].set_ylim(0, 512)
    #         dist = dists[1][k]

    #         center = (p0[0] + p1[0]) / 2, (p0[1] + p1[1]) / 2
    #         d1 = np.abs(p1[0] - p0[0]) 
    #         d2 = np.abs(p1[1] - p0[1]) 
    #         print(d1,d2)
    #         ax[k].add_patch(Ellipse(center, d1, d2, color='g', fill=False, linewidth=3))
    #         ax[k].text(p0[0], p0[1], f"{dist:.2f}",bbox={"facecolor": 'red', "alpha": 0.6, "pad": 1}, fontsize=16, color='black')

    #         p0, p1 = gt_circles[choices[1][k]]
    #         center = (p0[0] + p1[0]) / 2, (p0[1] + p1[1]) / 2
    #         d1 = np.abs(p1[0] - p0[0]) 
    #         d2 = np.abs(p1[1] - p0[1]) 
    #         ax[k].add_patch(Ellipse(center, d1, d2, color='firebrick', fill=False, linewidth=3))
    #         if k == 3:
    #             break
    #     for k, (p0, p1, p2, p3) in enumerate(lpos_cl):
    #         # result = "TP" if dist < dist_thresh else "FP"
    #         # if result == "TP":
    #         #     counter_[0] += 1
    #         #     if counter_[0] > 2:
    #         #         continue
    #         # if result == "FP":
    #         #     counter_[1] += 1
    #         #     if counter_[1] > 2:
    #         #         continue
    #         ax[k].imshow(image, aspect='equal')
    #         # ax[k].set_xlim(0, 512)
    #         # ax[k].set_ylim(0, 512)
    #         ax[k].scatter(p0[0], p0[1])
    #         ax[k].scatter(p1[0], p1[1])
    #         ax[k].scatter(p2[0], p2[1])
    #         ax[k].scatter(p3[0], p3[1])
    #         ax[k].plot([p0[0], p1[0]], [p0[1], p1[1]], linewidth=1, c= 'g')
    #         ax[k].plot([p2[0], p3[0]], [p2[1], p3[1]], linewidth=1, c= 'g')
    #         ax[k].spines['top'].set_visible(False)
    #         ax[k].spines['right'].set_visible(False)
    #         ax[k].spines['bottom'].set_visible(False)
    #         ax[k].spines['left'].set_visible(False)
    #         if k==3: 
    #             break


    

    if show_arcs: 
        fig, ax = plt.subplots(1, 4, figsize=(20, 6))
        # print(gt_arcs)
        print(len(lpos_a))
        for k, (p0, p1, p2) in enumerate(lpos_a): 
            ax[k].imshow(image, aspect='equal')
            ax[k].set_xlim(0, 512)
            ax[k].set_ylim(0, 512)
            dist = dists[2][k]
            # if dist > 10:
            #     continue 
            ax[k].scatter(p0[0], p0[1], c='g')
            ax[k].scatter(p1[0], p1[1], c='g')
            ax[k].scatter(p2[0], p2[1], c='g')
            ax[k].text(p0[0], p0[1], f"{dist:.2f}",bbox={"facecolor": 'red', "alpha": 0.6, "pad": 1}, fontsize=16, color='black')

            p0, p1, p2 = gt_arcs[choices[2][k]]
            
            ax[k].scatter(p0[0], p0[1], c='firebrick')
            ax[k].scatter(p1[0], p1[1], c='firebrick')
            ax[k].scatter(p2[0], p2[1], c='firebrick')
            
            print(f'arc {k} dist: {dist:.2f}')
            if k == 3:
                break
    # plt.show()
    return lpos_l, lpos_c, lpos_a


In [None]:
results_gt = plot_image_and_primitives(im, preds, gts, im_rescale, heatmap_scale)

In [None]:
results_gt.shape

In [None]:
primitive_dict = {0: "line", 1: "circle", 2: "arc"}
def plot_image_and_primitives(image, preds, gts, im_rescale, heatmap_scale):
    lines, circles, arcs = preds
    gt_lines, gt_circles, gt_arcs = gts
    cl_circles = get_4_cardinal_pts_circle(circles)
    # return cl_circles
    lpos_cl = scale_positions(cl_circles.copy(), im_rescale, heatmap_scale)
    lpos_l = scale_positions(lines.copy(), (im_rescale), heatmap_scale)
    lpos_c = scale_positions(circles.copy(), im_rescale, heatmap_scale)
    lpos_a = scale_positions(arcs.copy(), im_rescale, heatmap_scale)
    gt_lines = scale_positions(gt_lines.copy(), im_rescale, heatmap_scale)
    gt_circles = scale_positions(gt_circles.copy(), im_rescale, heatmap_scale)
    gt_arcs = scale_positions(gt_arcs.copy(), im_rescale, heatmap_scale)
    dists, choices = [], []
    show_primitives = [True, True, True]
    for k, (preds_primitive, gts_primitive) in enumerate(zip(preds, gts)):
        if len(preds_primitive) == 0:
            print(f"no {primitive_dict[k]} detected")
            show_primitives[k] = False
            dists.append(0)
            choices.append(0)
            continue
        if len(gts_primitive) == 0:
            print(f"no {primitive_dict[k]} in ground truth")
            show_primitives[k] = False
            dists.append(0)
            choices.append(0)
            continue
        dist, choice = get_pred_gt_match(preds_primitive, gts_primitive, primitive_k = k)
        dists.append(dist)
        choices.append(choice)

    # # plt.subplot(1, 3, 1)
    # fig, ax = plt.subplots(figsize=(8, 6))
    # plt.imshow(image, aspect='equal')
    # ax.imshow(np.transpose(image, (1,0,2)), aspect='equal')

    show_lines, show_circles, show_arcs = show_primitives

    if show_lines: 
        fig, ax = plt.subplots(1, 4, figsize=(20, 6))

        for k, (p0, p1) in enumerate(lpos_l):
            ax[k].imshow(image, aspect='equal')
            ax[k].set_xlim(0, 512)
            ax[k].set_ylim(0, 512)
            dist = dists[0][k]

            ax[k].scatter(p0[0], p0[1])
            ax[k].scatter(p1[0], p1[1])
            ax[k].plot([p0[0], p1[0]], [p0[1], p1[1]], linewidth=1, c= 'g')
            p0, p1 = gt_lines[choices[0][k]]
            ax[k].text(p0[0], p0[1], f"{dist:.2f}",bbox={"facecolor": 'red', "alpha": 0.6, "pad": 1}, fontsize=16, color='black')

            ax[k].scatter(p0[0], p0[1])
            ax[k].scatter(p1[0], p1[1])
            ax[k].plot([p0[0], p1[0]], [p0[1], p1[1]], linewidth=1, c='firebrick')
            if k == 3:
                break
    

    if show_circles: 
        fig, ax = plt.subplots(1, 4, figsize=(20, 6))
        for k, (p0, p1, p2, p3) in enumerate(lpos_cl):
            ax[k].imshow(image, aspect='equal')
            ax[k].set_xlim(0, 512)
            ax[k].set_ylim(0, 512)
            ax[k].scatter(p0[0], p0[1])
            ax[k].scatter(p1[0], p1[1])
            ax[k].scatter(p2[0], p2[1])
            ax[k].scatter(p3[0], p3[1])
            ax[k].plot([p0[0], p1[0]], [p0[1], p1[1]], linewidth=1, c= 'g')
            ax[k].plot([p2[0], p3[0]], [p2[1], p3[1]], linewidth=1, c= 'g')
            if k==3: 
                break

        for k, (p0, p1) in enumerate(lpos_c): 
            ax[k].imshow(image, aspect='equal')
            ax[k].set_xlim(0, 512)
            ax[k].set_ylim(0, 512)
            dist = dists[1][k]

            center = (p0[0] + p1[0]) / 2, (p0[1] + p1[1]) / 2
            d1 = np.abs(p1[0] - p0[0]) 
            d2 = np.abs(p1[1] - p0[1]) 
            print(d1,d2)
            ax[k].add_patch(Ellipse(center, d1, d2, color='g', fill=False, linewidth=1))
            ax[k].text(p0[0], p0[1], f"{dist:.2f}",bbox={"facecolor": 'red', "alpha": 0.6, "pad": 1}, fontsize=16, color='black')

            p0, p1 = gt_circles[choices[1][k]]
            center = (p0[0] + p1[0]) / 2, (p0[1] + p1[1]) / 2
            d1 = np.abs(p1[0] - p0[0]) 
            d2 = np.abs(p1[1] - p0[1]) 
            ax[k].add_patch(Ellipse(center, d1, d2, color='firebrick', fill=False, linewidth=1))
            if k == 3:
                break
    

    if show_arcs: 
        fig, ax = plt.subplots(1, 4, figsize=(20, 6))
        # print(gt_arcs)
        print(len(lpos_a))
        for k, (p0, p1, p2) in enumerate(lpos_a): 
            ax[k].imshow(image, aspect='equal')
            ax[k].set_xlim(0, 512)
            ax[k].set_ylim(0, 512)
            dist = dists[2][k]
            # if dist > 10:
            #     continue 
            ax[k].scatter(p0[0], p0[1], c='g')
            ax[k].scatter(p1[0], p1[1], c='g')
            ax[k].scatter(p2[0], p2[1], c='g')
            ax[k].text(p0[0], p0[1], f"{dist:.2f}",bbox={"facecolor": 'red', "alpha": 0.6, "pad": 1}, fontsize=16, color='black')

            p0, p1, p2 = gt_arcs[choices[2][k]]
            
            ax[k].scatter(p0[0], p0[1], c='firebrick')
            ax[k].scatter(p1[0], p1[1], c='firebrick')
            ax[k].scatter(p2[0], p2[1], c='firebrick')
            
            print(f'arc {k} dist: {dist:.2f}')
            if k == 3:
                break
    plt.show()
    return lpos_l, lpos_c, lpos_a


### Predictions heatmap

In [None]:
import torch
from pathlib import Path
import torch.utils.data
import sys
sys.path.append("../")
from LETR.data.coco import CocoDetection, make_coco_transforms
from process import process_image, load_model, dict_to_namespace, plot_results
import yaml
from LETR.data import build_dataset

import numpy as np

In [None]:
root = Path("/home/kallelis/PrimitiveExtraction/PrimitiveExtraction/Detection/my_letr_circle")
exp_folder = root / "exp/res50_stage1_annos_1000_correct_data_aug/checkpoints"

In [None]:
with open("/home/kallelis/PrimitiveExtraction/PrimitiveExtraction/Detection/my_letr_circle/config_primitives.yaml", "r") as f:
    config_dict = yaml.safe_load(f)
cfg = dict_to_namespace(config_dict)

model = load_model(cfg, exp_folder)
# ckpt_path = "/home/kallelis/PrimitiveExtraction/PrimitiveExtraction/Detection/my_letr_circle/exp/model-epoch=99.ckpt"
# model = load_model(cfg, ckpt_path=ckpt_path)

model.eval()

In [None]:
DATADIR = Path("/home/kallelis/PrimitiveExtraction/PrimitiveExtraction/Detection/my_letr_circle/data")
# dataset_name = "wireframe"
dataset_name = "synthetic"
root = DATADIR / f"{dataset_name}_processed"
mode = 'primitives'

In [None]:
root

In [None]:
if dataset_name == "wireframe":
    PATHS = {
                "train": (root / "train2017", root / "annotations" / f'{mode}_train2017.json'),
                "val": (root / "val2017", root / "annotations" / f'{mode}_val2017.json'),
            }    
else:
    PATHS = {
                "train": (root / "train", root / "annotations" / f'{mode}_train.json'),
                "val": (root / "val", root / "annotations" / f'{mode}_val.json'),
            }    
image_set = "train"
img_folder, ann_file = PATHS[image_set]

In [None]:
cfg.data.coco_path = root

In [None]:
from helper.misc import collate_fn
from torch.utils.data import DataLoader
data_loader_val = DataLoader(
    build_dataset("val", cfg.data),
    batch_size=1,
    drop_last=False,
    collate_fn=collate_fn,
    num_workers=cfg.trainer.num_workers,
    shuffle=False,
)

In [None]:
for batch in data_loader_val:
    break

In [None]:
sample, target = batch
primitive = "annos"
outputs = model(sample)  # forward pass
orig_target_sizes = torch.stack([t["orig_size"] for t in target], dim=0)
postprocessor = model.postprocessors["line"]  # FIXME: only line for now
results = postprocessor(outputs, orig_target_sizes, "prediction")
pred_logits = outputs["pred_logits"]

bz = pred_logits.shape[0]
assert bz == 1, "only support batch size 1"

query = pred_logits.shape[1]
rst = results[0][primitive]
pred_lines = rst.view(query, 2, 2)
pred_lines = pred_lines.flip([-1])  # this is yxyx format

h, w = target[0]["orig_size"].tolist()
pred_lines[:, :, 0] = pred_lines[:, :, 0] * (128)
pred_lines[:, :, 0] = pred_lines[:, :, 0] / h
pred_lines[:, :, 1] = pred_lines[:, :, 1] * (128)
pred_lines[:, :, 1] = pred_lines[:, :, 1] / w

score = results[0]["scores"].cpu().numpy()
label = results[0]["labels"].cpu().numpy()
pred_lines = pred_lines.flip([-1])  # this is yxyx format

line = pred_lines.cpu().numpy()
score_idx = np.argsort(-score)

line, score, label = line[score_idx], score[score_idx], label[score_idx]
img_id = int(target[0]["image_id"].cpu().numpy())


In [None]:
OUTPUT_DIR = "/home/kallelis/PrimitiveExtraction/PrimitiveExtraction/Detection/my_letr_circle/exp/res50_stage1_annos_1000_correct_data_aug/benchmark"
def save_npz(img_id, line, score, label, id_to_img, output_dir=OUTPUT_DIR):
    os.makedirs(output_dir, exist_ok=True)
    checkpoint_path = output_dir + f"/{int(img_id):08d}.npz"
    np.savez(
        checkpoint_path,
        **{"annos": line, "score": score, "label": label},
    ) # TODO: add circles
    

In [None]:
label

In [None]:
import json 
def get_id_to_img(args):
    id_to_img = {}
    mode = args.mode
    extention = "2017" if "wireframe" in args.coco_path else ""
    path_train = os.path.join(
        args.coco_path, "annotations", f"{mode}_train{extention}.json"
    )
    path_val = os.path.join(
        args.coco_path, "annotations", f"{mode}_val{extention}.json"
    )
    with open(path_train) as f_train, open(path_val) as f_val:
        for f in [f_val, f_train]:
            data = json.load(f)
            for d in data["images"]:
                id_to_img[d["id"]] = d["file_name"].split(".")[0]
    # with open(os.path.join(args.coco_path, "id_to_img.json"), "w") as f:
    #     json.dump(id_to_img, f)

    return id_to_img
cfg.data.coco_path = str(cfg.data.coco_path)    
id_to_img = get_id_to_img(cfg.data)


In [None]:
id_to_img[img_id]

In [None]:
save_npz(img_id, line, score, label, id_to_img)

In [None]:
fpred = np.load(f"/home/kallelis/PrimitiveExtraction/PrimitiveExtraction/Detection/my_letr_circle/exp/res50_stage1_annos_1000_correct_data_aug/benchmark_val/synthetic_diagram_{im_number}.npz")
primitives = fpred["annos"][:, :, :2]
mask = fpred["label"]

lcnn_line = primitives[~mask.astype(bool)]
scores_lines = fpred["score"][~mask.astype(bool)]
indices_lines = np.argsort(-scores_lines)
lcnn_circle = primitives[mask.astype(bool)]
scores_circle = fpred["score"][mask.astype(bool)]
indices_circle = np.argsort(-scores_circle)

In [None]:
gt_line = results_gt["lines"][0][:,:,:2]
gt_circle = results_gt["circles"][0][:,:,:2]

In [None]:
from evaluate import msTPFP, ap
for i in range(len(lcnn_line)):
    if i > 0 and (lcnn_line[i] == lcnn_line[0]).all():
        print("inside")
        lcnn_line = lcnn_line[:i]
tp, fp = msTPFP(lcnn_line, gt_line, 10)


In [None]:
n_gt = len(gt_line)

In [None]:
ap_res= ap(np.cumsum(tp[indices_lines]) / n_gt, np.cumsum(fp[indices_lines]) / n_gt)

In [None]:
circles = get_bbox_from_center_radii(data["circle_centers"], data["circle_radii"])
im = cv2.imread(os.path.join(data_root, "images", data["filename"]))
prefix = data["filename"].split(".")[0]
lines = np.array(data["lines"]).reshape(-1, 2, 2)
# circles = np.array(data["circles"]).reshape(-1, 2, 2)
im_shape = im.shape
im_rescale = (512, 512)

heatmap_scale = (128, 128) # Since each image can have a different aspect ratio, we rescale the image to a fixed size (512, 512) and then resize the heatmap to (128, 128) for comparison.




In [None]:


image = cv2.resize(im, im_rescale)


# plt.subplot(1, 3, 1)
fig, ax = plt.subplots(figsize=(8, 6))

plt.imshow(image, aspect='equal')

# plt.subplot(1, 3, 2)
fig, ax = plt.subplots(figsize=(8, 6))

plt.imshow(np.transpose(image, (1,0,2)), aspect='equal')

plt.xlim(0, 512)
plt.ylim(0, 512)
for i0, i1 in Lpos:
    p0 = (junc[i0][1] * 4, junc[i0][0] * 4)
    p1 = (junc[i1][1] * 4, junc[i1][0] * 4)
    plt.scatter(p0[0], p0[1])
    plt.scatter(p1[0], p1[1])
    plt.plot([p0[0], p1[0]], [p0[1], p1[1]])
fig, ax = plt.subplots(figsize=(8, 6))

# plt.subplot(1, 3, 3)
plt.imshow(np.transpose(image, (1,0,2)), aspect='equal')
plt.xlim(0, 512)
plt.ylim(0, 512)
for i0, i1 in Lpos_c:
    p0 = (junc_c[i0][1] * 4, junc_c[i0][0] * 4)
    p1 = (junc_c[i1][1] * 4, junc_c[i1][0] * 4)
    print(p0, p1)
    # plt.scatter(p0[0], p0[1], c = "red")
    # plt.scatter(p1[0], p1[1], c = "red")
    # plt.plot([p0[0], p1[0]], [p0[1], p1[1]])

    center = (p0[0] + p1[0]) / 2, (p0[1] + p1[1]) / 2
    d1 = np.abs(p1[0] - p0[0]) 
    d2 = np.abs(p1[1] - p0[1]) 
    print(d1, d2)
    ax.add_patch(Ellipse(center, d1, d2, color='g', fill=False, linewidth=1))

plt.show()