# Notes

Notebook for saving the model preditions over many images in a folder into .npz format


In [None]:
import os, sys
import torch, json
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
from torch import nn
from PIL import Image
import glob

ROOT_DIR = os.path.join(os.getcwd(), "../..")
sys.path.append(os.path.join(ROOT_DIR, "src"))

from src.util import DATA_DIR, MODEL_DIR
import src.datasets.transforms as T
from src.main import build_model_main
from src.util.slconfig import SLConfig
from src.datasets import build_dataset
from src.util.visualizer import COCOVisualizer
from src.util import box_ops

# 0. Finding model and input folder

In [None]:
# VARIABLES TO CHANGE
model_name = "main_model"
epoch = '0036'
dataset_name = "eida_dataset"
extension = "jpg"

model_folder = MODEL_DIR / model_name
model_config_path = f"{model_folder}/config_cfg.py"
model_checkpoint_path = f"{model_folder}/checkpoint{epoch}.pth"

args = SLConfig.fromfile(model_config_path) 
args.device = 'cuda' 
args.num_select = 200

corpus_folder = DATA_DIR / dataset_name
image_paths = glob.glob(f"{corpus_folder}/images/*.{extension}")
npz_dir = DATA_DIR / dataset_name / f"npz_preds_{model_name}{epoch}"

os.makedirs(npz_dir) 
# os.makedirs(npz_dir, exist_ok=True)  # If you are not bothered by overwriting the existing folder, uncomment this and comment the above line

In [None]:
model, criterion, postprocessors = build_model_main(args)
checkpoint = torch.load(model_checkpoint_path, map_location='cuda')
model.load_state_dict(checkpoint['model'])
_ = model.eval()

In [None]:
image_set = "val"
args.dataset_file = 'synthetic'
args.mode = "primitives"
args.relative = False
args.common_queries = True
args.eval = (image_set == "val")
args.coco_path = f"{DATA_DIR}/synthetic_processed" # the path of coco
args.fix_size = False
args.batch_size = 1
args.boxes_only=False
encoder_only = False
vslzr = COCOVisualizer()
id2name = {0: 'line', 1: 'circle', 2: 'arc'}
primitives_to_show = ['line', 'circle', 'arc']

# 1. Helper functions


In [None]:
def line_to_xy(x):
    c_x, c_y, w, h = x
    x0, y0 = c_x - w / 2, c_y - h / 2
    x1, y1 = c_x + w / 2, c_y + h / 2
    
    return x0, y0, x1, y1

def circle_to_xy(x):
    c_x, c_y, w, h = x
    r = (w+h)/4 
    return c_x, c_y, r

def arc_to_xy(x):
    cx, cy, w, h, w_mid, h_mid = x
    x0, y0 = cx - w / 2, cy - h / 2
    x1, y1 = cx + w / 2, cy + h / 2
    x_mid, y_mid = cx + w_mid / 2, cy + h_mid / 2
    return x0, y0, x1, y1, x_mid, y_mid

In [None]:
def get_outputs_per_class(pred_dict):
    mask = pred_dict["labels"] == 0
    lines, line_scores = pred_dict["parameters"][mask][:, :4], pred_dict["scores"][mask]
    mask = pred_dict["labels"] == 1
    circles, circle_scores = pred_dict["parameters"][mask][:, 4:8], pred_dict["scores"][mask],
    mask = pred_dict["labels"] == 2
    arcs, arc_scores = pred_dict["parameters"][mask][:, 8:14], pred_dict["scores"][mask]
    
    lines, line_scores = lines.cpu().numpy(), line_scores.cpu().numpy()
    circles, circle_scores = circles.cpu().numpy(), circle_scores.cpu().numpy()
    arcs, arc_scores = arcs.cpu().numpy(), arc_scores.cpu().numpy()
    
    lines = np.array([line_to_xy(x) for x in lines])
    circles = np.array([circle_to_xy(x) for x in circles])
    arcs = np.array([arc_to_xy(x) for x in arcs])
    
    return lines, line_scores, circles, circle_scores, arcs, arc_scores


In [None]:
from tqdm import tqdm_notebook as tqdm

In [None]:
def is_large_arc(rad_angle):
    if rad_angle[0] <= np.pi:
        return not (rad_angle[0] < rad_angle[1] < (np.pi + rad_angle[0]))
    return (rad_angle[0] - np.pi) < rad_angle[1] < rad_angle[0]

def find_circle_center(p1, p2, p3):
    """Circle center from 3 points"""
    # print(p1, p2, p3)
    temp = p2[0] * p2[0] + p2[1] * p2[1]
    bc = (p1[0] * p1[0] + p1[1] * p1[1] - temp) / 2
    cd = (temp - p3[0] * p3[0] - p3[1] * p3[1]) / 2
    det = (p1[0] - p2[0]) * (p2[1] - p3[1]) - (p2[0] - p3[0]) * (p1[1] - p2[1])
    if abs(det) < 1.0e-10:
        return (None, None)

    cx = (bc * (p2[1] - p3[1]) - cd * (p1[1] - p2[1])) / det
    cy = ((p1[0] - p2[0]) * cd - (p2[0] - p3[0]) * bc) / det
    return np.array([cx, cy])

def find_circle_center_arr(p1, p2, p3):
    """Circle center from 3 points"""
    temp = p2[:, 0] ** 2 + p2[:, 1] ** 2
    bc = (p1[:, 0] ** 2 + p1[:, 1] ** 2 - temp) / 2
    cd = (temp - p3[:, 0] ** 2 - p3[:, 1] ** 2) / 2
    det = (p1[:, 0] - p2[:, 0]) * (p2[:, 1] - p3[:, 1]) - (p2[:, 0] - p3[:, 0]) * (
        p1[:, 1] - p2[:, 1]
    )

    # Handle the case where the determinant is close to zero
    mask = np.abs(det) < 1.0e-10
    det[mask] = 1.0  # Prevent division by zero
    bc[mask] = 0.0  # These arcs will have center at (0, 0)
    cd[mask] = 0.0

    cx = (bc * (p2[:, 1] - p3[:, 1]) - cd * (p1[:, 1] - p2[:, 1])) / det
    cy = ((p1[:, 0] - p2[:, 0]) * cd - (p2[:, 0] - p3[:, 0]) * bc) / det
    return np.stack([cx, cy], axis=-1)


def get_angles_from_arc_points(p0, p_mid, p1):
    arc_center = find_circle_center(p0, p_mid, p1)
    arc_center = (arc_center[0], arc_center[1])
    start_angle = np.arctan2(p0[1] - arc_center[1], p0[0] - arc_center[0])
    end_angle = np.arctan2(p1[1] - arc_center[1], p1[0] - arc_center[0])
    mid_angle = np.arctan2(p_mid[1] - arc_center[1], p_mid[0] - arc_center[0])
    return start_angle, mid_angle, end_angle, arc_center


def get_arc_plot_params(arc):
    start_angle, mid_angle, end_angle, arc_center = get_angles_from_arc_points(
        arc[:2],
        arc[4:],
        arc[2:4],
    )
    # print(start_angle, mid_angle, end_angle)
    diameter = 2 * np.linalg.norm(arc[:2] - arc_center)
    to_deg = lambda x: (x * 180 / np.pi) % 360
    start_angle, mid_angle, end_angle = (
        to_deg(start_angle),
        to_deg(mid_angle),
        to_deg(end_angle),
    )
    # print("angles", start_angle, mid_angle, end_angle)
    return start_angle, mid_angle, end_angle, arc_center, diameter

## Some helpers for removing duplicates
Removing immediate duplicates, very small lines, arcs on top of lines and circles (might need to add merging fragmented lines and arcs)

In [1]:
def remove_duplicate_circles(circle_coords_list, image_size, circle_scores=None):
    circle_coords = np.array(circle_coords_list).reshape(-1, 3) # (n_circles, 3)
    distances = np.linalg.norm(circle_coords[None, :, :] - circle_coords[:, None, :], axis=-1)
    threshold = (image_size[0] + image_size[1]) / 80
    mask = distances < threshold
    indices_to_remove = np.array([np.sum(row[i+1:]) for i, row in enumerate(mask)])
    indices_to_keep = np.where(indices_to_remove == 0)[0]
    if circle_scores is not None:
        circle_scores = np.array(circle_scores)
        return circle_coords_list[indices_to_keep], circle_scores[indices_to_keep]
    return circle_coords[indices_to_keep]

def remove_duplicate_lines(line_coords_list, image_size, line_scores=None):
    line_coords = np.array(line_coords_list).reshape(-1, 4) # (n_lines, 4)
    permuted_lines = np.hstack((line_coords[:, 2:4],line_coords[:, 0:2]))
    distances = np.minimum(np.linalg.norm(line_coords[None, :, :] - line_coords[:, None, :], axis=-1), np.linalg.norm(line_coords[None, :, :] - permuted_lines[:, None, :], axis=-1))
    threshold = (image_size[0] + image_size[1]) / 80
    mask = distances < threshold
    indices_to_remove = np.array([np.sum(row[i+1:]) for i, row in enumerate(mask)])
    indices_to_keep = np.where(indices_to_remove == 0)[0]
    if line_scores is not None:
        line_scores = np.array(line_scores)
        return line_coords_list[indices_to_keep], line_scores[indices_to_keep]
    return line_coords_list[indices_to_keep]

def remove_small_lines(line_coords_list, image_size, line_scores=None):
    line_coords = np.array(line_coords_list).reshape(-1, 4) # (n_lines, 4)
    lengths = np.linalg.norm(line_coords[:, :2] - line_coords[:, 2:], axis=-1)
    threshold = (image_size[0] + image_size[1]) / 50
    # print(lengths)
    mask = lengths > threshold
    indices_to_keep = np.where(mask)[0]
    if line_scores is not None:
        line_scores = np.array(line_scores)
        return line_coords_list[indices_to_keep], line_scores[indices_to_keep]
    return line_coords_list[indices_to_keep]

def remove_duplicate_arcs(line_coords_list, image_size, line_scores=None):
    line_coords = np.array(line_coords_list).reshape(-1, 6) # (n_lines, 6)
    permuted_lines = np.hstack((line_coords[:, 2:4],line_coords[:, 0:2], line_coords[:, 4:6]))
    distances = np.minimum(np.linalg.norm(line_coords[None, :, :] - line_coords[:, None, :], axis=-1), np.linalg.norm(line_coords[None, :, :] - permuted_lines[:, None, :], axis=-1))
    threshold = (image_size[0] + image_size[1]) / 50
    mask = distances < threshold
    indices_to_remove = np.array([np.sum(row[i+1:]) for i, row in enumerate(mask)])
    indices_to_keep = np.where(indices_to_remove == 0)[0]
    if line_scores is not None:
        line_scores = np.array(line_scores)
        return line_coords_list[indices_to_keep], line_scores[indices_to_keep]
    return line_coords_list[indices_to_keep]

def remove_arcs_on_top_of_circles(arc_coords_list, circle_coords_list, image_size, arc_scores=None): 
    arc_coords = np.array(arc_coords_list).reshape(-1, 6) # (n_arcs, 6)
    circle_coords = np.array(circle_coords_list).reshape(-1, 3) # (n_circles, 3)
    arc_centers = find_circle_center_arr(
        arc_coords[:, :2],
        arc_coords[:, 4:],
        arc_coords[:, 2:4],
    )
    radii = np.linalg.norm(arc_coords[:, :2] - arc_centers, axis = 1)
    arc_circle_coords = np.hstack((arc_centers, radii[:, None]))
    distances = np.linalg.norm(circle_coords[None, :, :] - arc_circle_coords[:, None, :], axis=-1)
    threshold = (image_size[0] + image_size[1]) / 80
    # print(distances)
    # print(threshold)
    mask = distances < threshold
    indices_to_remove = np.array([np.sum(row) for i, row in enumerate(mask)])
    indices_to_keep = np.where(indices_to_remove == 0)[0]
    if arc_scores is not None:
        arc_scores = np.array(arc_scores)
        return arc_coords_list[indices_to_keep], arc_scores[indices_to_keep]
    return arc_coords_list[indices_to_keep]

def remove_arcs_on_top_of_lines(arc_coords_list, line_coords_list, image_size, arc_scores=None): 
    arc_coords = np.array(arc_coords_list).reshape(-1, 6) # (n_arcs, 6)
    line_coords = np.array(line_coords_list).reshape(-1, 4) # (n_lines, 4)
    line_coords_w_center = np.hstack((line_coords[:, :2], line_coords[:, 2:], (line_coords[:, :2] + line_coords[:, 2:])/2))
    line_coords_w_center_permuted = np.hstack((line_coords[:, 2:], line_coords[:, :2], (line_coords[:, :2] + line_coords[:, 2:])/2))
    distances = np.minimum(np.linalg.norm(line_coords_w_center[None, :, :] - arc_coords[:, None, :], axis=-1), np.linalg.norm(line_coords_w_center_permuted[None, :, :] - arc_coords[:, None, :], axis=-1))
    threshold = (image_size[0] + image_size[1]) / 50
    mask = distances < threshold
    indices_to_remove = np.array([np.sum(row) for i, row in enumerate(mask)])
    indices_to_keep = np.where(indices_to_remove == 0)[0]
    if arc_scores is not None:
        arc_scores = np.array(arc_scores)
        return arc_coords_list[indices_to_keep], arc_scores[indices_to_keep]
    return arc_coords_list[indices_to_keep]

In [2]:
# arc visualization, very clunky in python but this should work
from matplotlib.patches import Arc
def plot_arc(ax, arc, c='r', linewidth=2):
    arc = arc.reshape(-1)
    theta1, theta_mid, theta2, c_xy, diameter = get_arc_plot_params(arc)
    if theta_mid < theta1 and theta_mid > theta2:
        theta1, theta2 = theta2, theta1
    to_rad = lambda x: (x * np.pi / 180) % (2 * np.pi)
    if not is_large_arc([to_rad(theta1), to_rad(theta_mid)]):
        arc_patch_1 = Arc(
            c_xy,
            diameter,
            diameter,
            angle=0.0,
            theta1=theta1,
            theta2=theta_mid,
            fill=None,
            color=c,
            linewidth=linewidth,
        )
    else:
        arc_patch_1 = Arc(
            c_xy,
            diameter,
            diameter,
            angle=0.0,
            theta1=theta_mid,
            theta2=theta1,
            fill=None,
            color=c,
            # color="black",
            linewidth=linewidth,
        )
    ax.add_patch(arc_patch_1)

    if not is_large_arc([to_rad(theta_mid), to_rad(theta2)]):
        arc_patch_2 = Arc(
            c_xy,
            diameter,
            diameter,
            angle=0.0,
            theta1=theta_mid,
            theta2=theta2,
            fill=None,
            color=c,
            linewidth=linewidth,
        )

    else:
        arc_patch_2 = Arc(
            c_xy,
            diameter,
            diameter,
            angle=0.0,
            theta1=theta2,
            theta2=theta_mid,
            fill=None,
            color=c,
            # color="black",
            linewidth=linewidth,
        )
    ax.add_patch(arc_patch_2)

# 2. Making model predictions and saving them to npz format

In [None]:
# looping over images 
for image_path in tqdm(image_paths): 
    # loading image
    im_name = Path(image_path).stem
    # im_name_0 = 'wit2_img2_0082_230,1314,859,848'
    # if im_name!= im_name_0:
    #     continue
    # print(im_name)
    # if os.path.exists(npz_dir / f"{im_name}.npz") and im_name!= 'wit185_pdf193_26_188,421,557,598':
    #     continue

    # if os.path.exists(npz_dir / f"{im_name}.npz"):
    #     continue
    # if im_name not in ['wit176_man179_0151_1261,1428,471,585', 'wit205_pdf216_099_1030,1839,509,505']: 
    #     continue
    image = Image.open(image_path).convert("RGB") # load image
    orig_img_size = image.size
    transform = T.Compose([
        T.RandomResize([800], max_size=1333),
        T.ToTensor(),
        T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
    tr_img, _ = transform(image, None)
    size = torch.Tensor([tr_img.shape[1], tr_img.shape[2]])
    out_size = torch.Tensor([[orig_img_size[1], orig_img_size[0]]])

    # output
    output = model.cuda()(tr_img[None].cuda())   
    output = postprocessors['param'](output, out_size.cuda(), to_xyxy=False)[0]

    threshold, arc_threshold = 0.3, 0.3
    scores = output['scores']
    labels = output['labels']
    boxes = output['parameters']
    select_mask = ((scores > threshold) & (labels!=2)) | ((scores > arc_threshold) & (labels==2) )
    labels = labels[select_mask]
    boxes = boxes[select_mask]
    scores = scores[select_mask]
    pred_dict = {'parameters': boxes, 'labels': labels, 'scores': scores}
    lines, line_scores, circles, circle_scores, arcs, arc_scores = get_outputs_per_class(pred_dict)

    # some duplicate postprocessing
    lines, line_scores = remove_duplicate_lines(lines, orig_img_size, line_scores)
    lines, line_scores = remove_small_lines(lines, orig_img_size, line_scores)
    circles, circle_scores = remove_duplicate_circles(circles, orig_img_size, circle_scores)
    arcs, arc_scores = remove_duplicate_arcs(arcs, orig_img_size, arc_scores)
    arcs, arc_scores = remove_arcs_on_top_of_circles(arcs, circles, orig_img_size, arc_scores)
    arcs, arc_scores = remove_arcs_on_top_of_lines(arcs, lines, orig_img_size, arc_scores)

    # save to npz
    np.savez(
        npz_dir / f"{im_name}.npz",
        lines=lines,
        line_scores=line_scores,
        circles=circles,
        circle_scores=circle_scores,
        arcs=arcs,
        arc_scores=arc_scores,
    )