In [13]:
import xml.etree.ElementTree as ET
import re

import matplotlib.pyplot as plt

svg_path = "data/02.svg"

# Load and parse the SVG file
tree = ET.parse(svg_path)
root = tree.getroot()

from IPython.display import SVG, display

# Display the SVG in small window
# display(SVG(filename=svg_path))

In [14]:
def extract_shapes_from_svg(root):
    # Namespace handling
    ns = {'svg': 'http://www.w3.org/2000/svg'}

    # Step 1: Extract styles from <style> tag
    style_text = root.find(".//svg:style", ns).text

    # Parse class styles
    style_dict = {}
    for match in re.finditer(r'\.(\w+)\s*{\s*fill:\s*(#[0-9a-fA-F]+);', style_text):
        class_name, fill_color = match.groups()
        style_dict[class_name] = fill_color

    # Step 2: Extract all shape elements
    shape_data = []

    for elem in root.findall(".//svg:*", ns):
        tag = elem.tag.split('}')[-1]  # strip namespace
        if tag in {"rect", "polygon"}:
            class_name = elem.attrib.get("class")
            style = style_dict.get(class_name, None)
            shape_info = {
                "tag": tag,
                "class": class_name,
                "style": style,
                "attributes": elem.attrib
            }
            shape_data.append(shape_info)

    return shape_data



def extract_styles_from_svg(root):
    """
    Extract style information from an SVG root element.
    Returns a dictionary mapping class names to fill colors.
    """
    # Namespace (required for SVG parsing)
    ns = {'svg': 'http://www.w3.org/2000/svg'}

    # Find the <style> tag
    style_element = root.find(".//svg:style", ns)
    style_text = style_element.text if style_element is not None else ""

    # Clean and extract styles from the CSS block
    style_dict = {}
    matches = re.findall(r'\.(cls-\d+)\s*{\s*fill:\s*(#[0-9a-fA-F]+);', style_text)

    for class_name, fill_color in matches:
        style_dict[class_name] = fill_color

    return style_dict

# Extract styles and shapes from the SVG
style_dict = extract_styles_from_svg(root)
shapes = extract_shapes_from_svg(root)


In [18]:
print("Style dict:")
print(style_dict)

print()

# Print shapes
print("Shapes:")
for shape in shapes:
    print(shape)

Style dict:
{'cls-1': '#60924a', 'cls-2': '#263776', 'cls-3': '#f3e0b6', 'cls-4': '#df804b'}

Shapes:
{'tag': 'rect', 'class': 'cls-2', 'style': None, 'attributes': {'class': 'cls-2', 'y': '562.14', 'width': '283.3', 'height': '171.86'}}
{'tag': 'rect', 'class': 'cls-4', 'style': None, 'attributes': {'class': 'cls-4', 'x': '.57', 'y': '391.44', 'width': '283.3', 'height': '171.86'}}
{'tag': 'rect', 'class': 'cls-3', 'style': None, 'attributes': {'class': 'cls-3', 'y': '171.11', 'width': '283.3', 'height': '222.29'}}
{'tag': 'rect', 'class': 'cls-4', 'style': None, 'attributes': {'class': 'cls-4', 'width': '283.3', 'height': '171.86'}}
{'tag': 'rect', 'class': 'cls-3', 'style': None, 'attributes': {'class': 'cls-3', 'x': '283.06', 'y': '561.66', 'width': '283.3', 'height': '171.44'}}
{'tag': 'rect', 'class': 'cls-1', 'style': None, 'attributes': {'class': 'cls-1', 'x': '283.63', 'y': '391.38', 'width': '283.3', 'height': '171.44'}}
{'tag': 'rect', 'class': 'cls-2', 'style': None, 'attri

In [21]:
import xml.etree.ElementTree as ET

def get_rect_center(attrib):
    x = float(attrib.get("x", 0))
    y = float(attrib.get("y", 0))
    w = float(attrib["width"])
    h = float(attrib["height"])
    return x + w / 2, y + h / 2

def get_polygon_centroid(points_str):
    nums = list(map(float, points_str.strip().split()))
    if len(nums) % 2 != 0:
        raise ValueError("Odd number of coordinates in polygon points.")
    
    points = list(zip(nums[::2], nums[1::2]))  # Pair x, y
    x_vals, y_vals = zip(*points)
    return sum(x_vals) / len(x_vals), sum(y_vals) / len(y_vals)

def add_number_labels_to_svg(root, shapes):
    ns = {'svg': 'http://www.w3.org/2000/svg'}
    ET.register_namespace('', ns['svg'])  # ensure output has correct namespace

    for idx, shape in enumerate(shapes):
        tag = shape['tag']
        attrib = shape['attributes']
        if tag == "rect":
            cx, cy = get_rect_center(attrib)
        elif tag == "polygon":
            cx, cy = get_polygon_centroid(attrib["points"])
        else:
            print(f"Skipping unknown shape type: {tag}")
            continue

        # Create a new <text> element
        text_elem = ET.Element("text", {
            "x": str(cx),
            "y": str(cy),
            "font-size": "24",
            "fill": "black",
            "text-anchor": "middle",
            "dominant-baseline": "central"
        })
        text_elem.text = str(idx + 1)

        root.append(text_elem)  # You could also insert after the original shape

    return root


In [22]:
# Add text labels to the SVG
updated_root = add_number_labels_to_svg(root, shapes)

# Save modified SVG
tree = ET.ElementTree(updated_root)
tree.write("labeled_output.svg", encoding="utf-8", xml_declaration=True)