> Notebook for generating svg predictions with dino-detr for semi-automatic annotation of diagrams.

In [None]:
import os, sys
import torch, json
import numpy as np
from PIL import Image
import svgwrite
from tqdm.notebook import tqdm
import cv2
import matplotlib.pyplot as plt
from pathlib import Path

from main import build_model_main
from util.slconfig import SLConfig
from datasets import build_dataset
from util.visualizer import COCOVisualizer, get_angles_from_arc_points
from util import box_ops
# from datasets.transforms import arc_cxcywh2_to_xy3, arc_xy3_to_cxcywh2
import datasets.transforms as T

from util import DATA_DIR, MODEL_DIR

### Loading model folder to get its npz predictions

In [None]:
# VARIABLES TO CHANGE
model_name = "main_model"
epoch = '0036'
dataset_name = "eida_dataset"
extension = "jpg"

model_folder = MODEL_DIR / model_name
id2name = {0: "line", 1: "circle", 2: "arc"}

npz_dir = DATA_DIR / dataset_name / f"npz_preds_{model_name}{epoch}"
output_dir = DATA_DIR / dataset_name / f"svg_preds_{model_name}{epoch}"

os.makedirs(output_dir, exist_ok=True)
diagram_dir = DATA_DIR / dataset_name / "images"

# Helper functions

In [None]:
def read_npz(npz_dir, im_name):
    model_pred = np.load(npz_dir / f"{im_name}.npz")
    lines, line_scores = model_pred["lines"], model_pred["line_scores"]
    circles, circle_scores = model_pred["circles"], model_pred["circle_scores"]
    arcs, arc_scores = model_pred["arcs"], model_pred["arc_scores"]
    return lines, circles, arcs

In [None]:
def box_xyxy_to_cxcyr(x):
    """
    Only valid for circles
    """
    x0, y0, x1, y1 = x.T
    b = np.stack([(x0 + x1) / 2, (y0 + y1) / 2,
                  ((x1 - x0) + (y1 - y0))/4], axis=-1)
    return b

In [None]:
def calculate_angle(p1, p2, p3):
    v1 = p1 - p2
    v2 = p3 - p2
    angle1 = np.arctan2(v1[1], v1[0])
    angle2 = np.arctan2(v2[1], v2[0])
    angle = angle1 - angle2
    angle = (angle + np.pi) % (2 * np.pi) - np.pi
    return angle

def write_svg_dwg(dwg, lines = None, circles=None, arcs=None, show_image = False, image=None):
    # Add the background image to the drawing
    from matplotlib.patches import Polygon, Circle
    
    if show_image:
        dpi = 100
        fig = plt.figure(dpi=dpi)
        plt.rcParams["font.size"] = "5"
        ax = plt.gca()
        ax.imshow(image)

    for circle in circles:
        cx,cy = circle[:2]
        radius = circle[-1]
        if show_image:
            circle_plot = Circle(circle[:2], circle[-1], fill=None, color="red", linewidth=1)
            ax.add_patch(circle_plot)
        dwg.add(dwg.circle(center=[str(cx), str(cy)], r=str(radius), fill="none", stroke='blue', stroke_width=3))

    for line in lines:
        p1x, p1y = line[0]
        p2x, p2y = line[1]
        if show_image:
            line_plot = Polygon(line, fill=None, color="red", linewidth=1)
            ax.add_patch(line_plot)

        dwg.add(dwg.path(d="M " + str(p1x) + " " + str(p1y) + " L " + str(p2x) + " " + str(p2y), stroke="green", stroke_width=3, fill="none"))
    for arc in arcs:  
        p0, p1, pmid =  arc[0],arc[1],arc[2]   
        start_angle, mid_angle, end_angle, arc_center = get_angles_from_arc_points(p0, p1, pmid)
        arc_radius = np.linalg.norm(p0 - arc_center)
        large_arc_flag = np.linalg.norm((p0+p1)/2 - pmid) > arc_radius
        sweep_flag = calculate_angle(p0,p1,pmid) < 0
        # p1x, p1y = p0
        # p2x, p2y = pmid
        # dwg.add(dwg.path(d="M " + str(p1x) + " " + str(p1y) + " L " + str(p2x) + " " + str(p2y), stroke="red", stroke_width=2, fill="none"))
        # p1x, p1y = pmid
        # p2x, p2y = p1
        # dwg.add(dwg.path(d="M " + str(p1x) + " " + str(p1y) + " L " + str(p2x) + " " + str(p2y), stroke="red", stroke_width=2, fill="none"))
        arc_args = {
            "x0": p0[0],
            "y0": p0[1],
            "xradius": arc_radius,
            "yradius": arc_radius,
            "ellipseRotation": 0,  # has no effect for circles
            "x1": p1[0],
            "y1": p1[1],
            "large_arc_flag": int(large_arc_flag),
            "sweep_flag": int(sweep_flag),  # set sweep-flag to 1 for clockwise arc
        }
        dwg.add(dwg.path(
                    d="M %(x0)f,%(y0)f A %(xradius)f,%(yradius)f %(ellipseRotation)f %(large_arc_flag)d,%(sweep_flag)d %(x1)f,%(y1)f"
                    % arc_args,
                    fill="none",
                    stroke="firebrick",
                    stroke_width=3,
                ))
        
    return dwg


In [None]:
def get_arc_param(arc_path, arc_transform=None):
    to_2pi = lambda x: (x + 2 * np.pi) % (2 * np.pi)
    assert len(arc_path) == 1, f"arc path with more than one arc {arc}"
    arc_path = arc_path[0]
        
    assert arc_path.rotation == 0, f"arc path with non-zero rotation {arc_path}"

    p0 = np.array([arc_path.start.real, arc_path.start.imag])
    p1 = np.array([arc_path.end.real, arc_path.end.imag])
    if not arc_path.sweep:
        p0, p1 = p1, p0
    center = np.array([arc_path.center.real, arc_path.center.imag])

    radius = np.linalg.norm(p0 - center)
    start_angle = to_2pi(np.arctan2(p0[1] - center[1], p0[0] - center[0]))
    end_angle = to_2pi(np.arctan2(p1[1] - center[1], p1[0] - center[0]))

    return center, radius, start_angle, end_angle, p0, p1

# Looping over npz files and saving svgs

In [None]:
from calendar import c
import xml.etree.ElementTree as ET
import re
import shutil

from svg.path import parse_path
from svg.path.path import Line, Move, Arc


for image_path in tqdm(diagram_dir.glob(f"*.{extension}")):
    shutil.copy2(image_path, output_dir)
    # if not (("ms2_0208_351,1565,858,758"  in image_path.stem) or  ("ms3_0021_411,270,178,218"  in image_path.stem)):
    #     continue # FIXME: removing some problematic examples
    # if not ("ms3_0021_411,270,178,218"  in image_path.stem):
    #     continue
    
    # print(image_path)
    # if "1728,709,665" in image_path.stem:
    diagram_name = image_path.stem
    image_name = os.path.basename(image_path)
    image = Image.open(image_path).convert("RGB") # load image
    size = image.size
    lines, circles, arcs = read_npz(npz_dir, image_name[:-4])
    lines = lines.reshape(-1, 2, 2)
    arcs = arcs.reshape(-1, 3, 2)

     
    dwg = svgwrite.Drawing(str(output_dir / f"{diagram_name}.svg"), profile="tiny", size=size)
    dwg.add(dwg.image(href=image_name, insert=(0, 0), size=size))
    dwg = write_svg_dwg(dwg, lines, circles, arcs, show_image=False, image=None)
    dwg.save(pretty=True)
    # break

    ET.register_namespace('', "http://www.w3.org/2000/svg")
    ET.register_namespace('xlink', "http://www.w3.org/1999/xlink")
    ET.register_namespace('sodipodi', "http://sodipodi.sourceforge.net/DTD/sodipodi-0.dtd")
    ET.register_namespace('inkscape', "http://www.inkscape.org/namespaces/inkscape")

    input_folder = output_dir
    file_name = output_dir / f"{diagram_name}.svg"
    tree = ET.parse(file_name)
    root = tree.getroot()

    root.set('xmlns:inkscape', 'http://www.inkscape.org/namespaces/inkscape')
    root.set('xmlns:sodipodi', 'http://sodipodi.sourceforge.net/DTD/sodipodi-0.dtd')
    root.set('inkscape:version', '1.3 (0e150ed, 2023-07-21)')

    # Regular expression to match the 'a' or 'A' command in the 'd' attribute
    arc_regex = re.compile(r'[aA]')

    # Iterate over all 'path' elements
    for path in root.findall('{http://www.w3.org/2000/svg}path'):
        # Get the 'd' attribute
        d = path.get('d', '')

        # If the 'd' attribute contains an arc
        if arc_regex.search(d):
            # Add the 'sodipodi:type' and 'sodipodi:arc-type' attributes
            path.set('sodipodi:type', 'arc')
            path.set('sodipodi:arc-type', 'arc')
            path_parsed = parse_path(d)
            
            for e in path_parsed:
                if isinstance(e, Line):
                    continue
                elif isinstance(e, Arc):
                    center, radius, start_angle, end_angle, p0, p1 = get_arc_param([e])
                    path.set('sodipodi:cx', f'{center[0]}')
                    path.set('sodipodi:cy', f'{center[1]}')
                    path.set('sodipodi:rx', f'{radius}')
                    path.set('sodipodi:ry', f'{radius}')
                    path.set('sodipodi:start', f'{start_angle}')
                    path.set('sodipodi:end', f'{end_angle}')

    # Write the changes back to the file
    tree.write(file_name, xml_declaration=True)