In [1]:
import xml.etree.ElementTree as ET
import re

import matplotlib.pyplot as plt

svg_path = "data/01.svg"

from IPython.display import SVG, display

# Display the SVG in small window
# display(SVG(filename=svg_path))

In [13]:
def extract_shapes_from_svg(root):
    # Namespace handling
    ns = {'svg': 'http://www.w3.org/2000/svg'}

    # Step 1: Extract styles from <style> tag
    style_text = root.find(".//svg:style", ns).text

    # Parse class styles
    style_dict = {}
    for match in re.finditer(r'\.(\w+)\s*{\s*fill:\s*(#[0-9a-fA-F]+);', style_text):
        class_name, fill_color = match.groups()
        style_dict[class_name] = fill_color

    # Step 2: Extract all shape elements
    shape_data = []

    for elem in root.findall(".//svg:*", ns):
        tag = elem.tag.split('}')[-1]  # strip namespace
        if tag in {"rect", "polygon", "path", "circle"}:
            class_name = elem.attrib.get("class")
            style = style_dict.get(class_name, None)
            shape_info = {
                "tag": tag,
                "class": class_name,
                "style": style,
                "attributes": elem.attrib,
                "element": elem
                
            }
            shape_data.append(shape_info)

        elif tag in {"style", "defs"}:
            continue
        else:
            print(f"Skipping unknown shape type: {tag}")
            continue

    return shape_data



def extract_styles_from_svg(root):
    """
    Extract style information from an SVG root element.
    Returns a dictionary mapping class names to fill colors.
    """
    # Namespace (required for SVG parsing)
    ns = {'svg': 'http://www.w3.org/2000/svg'}

    # Find the <style> tag
    style_element = root.find(".//svg:style", ns)
    style_text = style_element.text if style_element is not None else ""

    # Clean and extract styles from the CSS block
    style_dict = {}
    matches = re.findall(r'\.(cls-\d+)\s*{\s*fill:\s*(#[0-9a-fA-F]+);', style_text)

    for class_name, fill_color in matches:
        style_dict[class_name] = fill_color

    return style_dict


In [68]:
import xml.etree.ElementTree as ET
from svgpathtools import parse_path, Path
import numpy as np
from svgpathtools import parse_path
from shapely.geometry import Polygon, Point, MultiPoint
import numpy as np

def get_path_interior_point(d_attr, samples=200):
    try:
        path = parse_path(d_attr)

        # Sample many points along the path
        sampled_points = [segment.point(t) for segment in path for t in np.linspace(0, 1, 10)]
        coords = [(p.real, p.imag) for p in sampled_points]

        if len(coords) < 3:
            raise ValueError("Not enough points to form a polygon.")

        # Ensure it's closed
        if coords[0] != coords[-1]:
            coords.append(coords[0])

        polygon = Polygon(coords)

        # Fix invalid polygons
        if not polygon.is_valid:
            polygon = polygon.buffer(0)

        if not polygon.is_valid or polygon.area == 0:
            raise ValueError("Polygon still invalid or zero-area after fix.")

        # Safe point inside
        pt = polygon.representative_point()
        return pt.x, pt.y, True

    except Exception as e:
        print(f"Error creating polygon from path: {e}")
        path = parse_path(d_attr)
        xmin, xmax, ymin, ymax = path.bbox()
        return (xmin + xmax) / 2, (ymin + ymax) / 2, False


def get_rect_center(attrib):
    x = float(attrib.get("x", 0))
    y = float(attrib.get("y", 0))
    w = float(attrib["width"])
    h = float(attrib["height"])
    return x + w / 2, y + h / 2

def get_polygon_centroid(points_str):
    nums = list(map(float, points_str.strip().split()))
    if len(nums) % 2 != 0:
        raise ValueError("Odd number of coordinates in polygon points.")
    
    points = list(zip(nums[::2], nums[1::2]))  # Pair x, y
    x_vals, y_vals = zip(*points)
    return sum(x_vals) / len(x_vals), sum(y_vals) / len(y_vals)

def get_circle_center(attrib):
    cx = float(attrib["cx"])
    cy = float(attrib["cy"])
    return cx, cy

def add_number_labels_to_svg(root, shapes, style_dict):
    ns = {'svg': 'http://www.w3.org/2000/svg'}
    ET.register_namespace('', ns['svg'])  # ensure output has correct namespace

    # Create style mapping from values to numbers
    style_mapping = {key: str(i) for i, key in enumerate(style_dict.keys())}

    shape_type_color_mapping = {
        "rect": "black",
        "polygon": "orange",
        "circle": "red",
        "path": "blue"
    }

    for idx, shape in enumerate(shapes):
        tag = shape['tag']
        attrib = shape['attributes']

        if tag == "rect":
            cx, cy = get_rect_center(attrib)
        elif tag == "polygon":
            cx, cy = get_polygon_centroid(attrib["points"])
        elif tag == "circle":
            cx, cy = get_circle_center(attrib)
        elif tag == "path":
            cx, cy, is_valid = get_path_interior_point(attrib["d"])
        else:
            print(f"Skipping unknown shape type: {tag}")
            continue

        text_color = shape_type_color_mapping[tag]
        if tag == "path" and not is_valid:
            text_color = "yellow"

        text_color = "black"

        # Create a new <text> element
        text_elem = ET.Element("text", {
            "x": str(cx),
            "y": str(cy),
            "font-size": "3",
            "fill": text_color,
            "text-anchor": "middle",
            "dominant-baseline": "central"
        })
        text_elem.text = style_mapping[shape['attributes']['class']]

        root.append(text_elem)  # You could also insert after the original shape

    return root

def outline_shapes_for_paint_by_numbers(shapes):
    for shape in shapes:
        elem = shape['element']

        # Remove 'class' attribute
        elem.attrib.pop('class', None)

        # Set stroke and remove fill
        elem.attrib['fill'] = 'white'
        elem.attrib['stroke'] = 'lightgray'
        elem.attrib['stroke-width'] = '0.5'

In [69]:
# Load and parse the SVG file
tree = ET.parse(svg_path)
root = tree.getroot()

# Extract styles and shapes from the SVG
style_dict = extract_styles_from_svg(root)
shapes = extract_shapes_from_svg(root)

# Modify SVG in-place
add_number_labels_to_svg(root, shapes, style_dict)
outline_shapes_for_paint_by_numbers(shapes)  # no assignment!

# Save modified SVG
tree.write("labeled_output.svg", encoding="utf-8", xml_declaration=True)

Error creating polygon from path: Polygon still invalid or zero-area after fix.
Error creating polygon from path: Polygon still invalid or zero-area after fix.
Error creating polygon from path: Polygon still invalid or zero-area after fix.
Error creating polygon from path: Polygon still invalid or zero-area after fix.
Error creating polygon from path: Polygon still invalid or zero-area after fix.


In [16]:
style_dict

{'cls-1': '#48392a',
 'cls-2': '#60924a',
 'cls-3': '#336430',
 'cls-4': '#263776',
 'cls-5': '#28488d',
 'cls-6': '#f3e0b6',
 'cls-7': '#7f7660',
 'cls-8': '#2b3c66',
 'cls-9': '#b4a795',
 'cls-10': '#191b44',
 'cls-11': '#dc9e8d',
 'cls-12': '#c54c32',
 'cls-13': '#df804b',
 'cls-14': '#e1e3e8',
 'cls-15': '#f7efde',
 'cls-16': '#74955c'}

In [17]:
for shape in shapes:
    print(shape)

{'tag': 'path', 'class': 'cls-15', 'style': None, 'attributes': {'class': 'cls-15', 'd': 'M736,0v734H0V0h736ZM278,73c-.16,1.64.13,3.35-.01,5h-.98c.45-1.77-.54-2.74-1.01-4-.84-2.21-1.95-7.01-6.01-4.51-2.51,1.54,2.22,6.01,3.01,7.51,1.23,2.33,2.42,4.03,2,7-1.82-1.07-1.55-3.04-3.46-3.09-1.53-.04-7.32,1.85-8.08,3.07-1.54,2.47.67,4.56,1.32,6.78.98,3.35-1.39,3.85,2.22,4.24-.33.65-1.5,1.33-2,2-.89-.38-2.92.04-4-1-.36-1.08-2.07-3.23-3-4-.32-.26-.55-.87-1-1-.28-.08-.66.05-1,0-6-1.85-3.12-6.33-5-11l1.99-11-4.99,8c-2.06-2.03-5.55-5.15-8-6-4.12-1.42-6.16-1.71-6.82,3.78-.72,5.94,3.1,4.05,4.82,5.22.17.12.06,2.56,2.51,3,.55.1.93-7.42,8.48-1.9.05.6-.99.81-.99.9-2.34,2.49-3.64,5.69-4,9-.97.11-2.08-.15-3,0-4.04.68-1.8,2.94-3.06,3.96-.37.3-6.81-2.89-4.94,5.04-.66.02-1.33,0-2,0l.99-3.99c-.58-.53-3.59,1.35-4,1.99.05-2.04-.62-2.59-2.53-3.02-4.45-.98-3.26,1.88-3.47,2.02-.22.14-4.43.02-5,0-1.33-.05-2.31-.91-3-1,1.86-6.03,4.26-11.2,4.88-17.63.11-1.11.92-3.59-.86-3.36-1.44,7.19-10.18,18.47-11.01,24.99-.08.65.06,

In [14]:
root.

<Element '{http://www.w3.org/2000/svg}svg' at 0x0000018CA0457510>