### commons.py

In [337]:
import numpy as np
import cv2

def is_number(n):
    is_number = True
    try:
        num = float(n)
        # check for "nan" floats
        is_number = num == num   # or use `math.isnan(num)`
    except ValueError:
        is_number = False
    return is_number

def minmaxToContours(xyxy):
    x1, y1, x2, y2 = xyxy
    return np.array([[x1, y1], [x2, y1], [x2, y2], [x1, y2]], dtype=np.int32)  # Convert to contour format

def fill_contours(preprocessed_img, contours):
    """Fills the contours in the preprocessed image."""
    img_filled = preprocessed_img.copy()
    
    cv2.drawContours(img_filled, contours, -1, (255), thickness=cv2.FILLED)
    return img_filled   

def remove_contours(preprocessed_img, contours):
    """Removes the detected node shapes (circles and rectangles) from the image."""

    contours_mask = np.zeros_like(preprocessed_img)
    contours_mask = fill_contours(contours_mask, contours)
    contours_mask = cv2.bitwise_not(contours_mask)

    img_no_contours = cv2.bitwise_and(preprocessed_img, contours_mask)
    return img_no_contours

def filter_enclosed_contours(
    contours_to_filter: list[np.ndarray],
    enclosing_contours_input: list[np.ndarray],
    include_border: bool = True
) -> list[np.ndarray]:
    """
    Filters contours from contours_to_filter that are entirely enclosed by ANY contour
    in enclosing_contours_input.

    Args:
        contours_to_filter: List of contours (np.ndarray) to be filtered.
                            Each contour is typically an array of shape (N, 1, 2).
        enclosing_contours_input: List of contours (np.ndarray) that act as enclosing shapes.
                                  Each contour is typically an array of shape (M, 1, 2).
        include_border: Whether points on the border of an enclosing_contour
                        count as inside.

    Returns:
        A list of contours from contours_to_filter that are fully enclosed.
        Each contour in the returned list is a reference to an object in 
        the input contours_to_filter list.
    """
    if not contours_to_filter or not enclosing_contours_input:
        return []

    # Filter out empty enclosing contours and precompute their bounding rects
    enclosing_contours = []
    brects_enclosing = []
    for c_enclosing in enclosing_contours_input:
        # Ensure contour has points (shape[0] is the number of points)
        if c_enclosing.shape[0] > 0:
            enclosing_contours.append(c_enclosing)
            brects_enclosing.append(cv2.boundingRect(c_enclosing))

    if not enclosing_contours: # No valid enclosing contours
        return []

    result_contours = []
    # Keep track of indices of contours from contours_to_filter that have been added
    added_contour_indices = set()

    test_threshold = 0 if include_border else 1

    # Precompute bounding rects for contours_to_filter
    # (x, y, w, h)
    brects_to_filter = []
    for c in contours_to_filter:
        if c.shape[0] > 0:
            brects_to_filter.append(cv2.boundingRect(c))
        else:
            # Placeholder for empty contours, they will be skipped later
            brects_to_filter.append((0,0,0,0))


    for idx1, contour1 in enumerate(contours_to_filter):
        if idx1 in added_contour_indices:
            continue

        # An empty contour (no points) cannot be considered enclosed
        if contour1.shape[0] == 0:
            continue

        x1, y1, w1, h1 = brects_to_filter[idx1]
        is_contour1_enclosed_by_any = False

        for idx2, contour2 in enumerate(enclosing_contours):
            x2, y2, w2, h2 = brects_enclosing[idx2]

            # AABB Pruning: For contour1 to be enclosed by contour2,
            # contour1's bounding box must be within contour2's bounding box.
            if not (x1 >= x2 and \
                    y1 >= y2 and \
                    (x1 + w1) <= (x2 + w2) and \
                    (y1 + h1) <= (y2 + h2)):
                continue # Bounding box of contour1 is not contained in bounding box of contour2

            # Precise point-in-polygon test:
            # All points of contour1 must be inside (or on border of) contour2
            all_points_inside = True
            # Reshape contour1 from (N,1,2) to (N,2) for easier iteration
            points_contour1 = contour1.reshape(-1, 2)

            for pt_x, pt_y in points_contour1:
                # cv2.pointPolygonTest expects point as (float, float)
                dist = cv2.pointPolygonTest(contour2, (float(pt_x), float(pt_y)), False)
                if dist < test_threshold:
                    all_points_inside = False
                    break # This point of contour1 is outside contour2

            if all_points_inside:
                is_contour1_enclosed_by_any = True
                break # contour1 is enclosed by contour2; no need to check other enclosing_contours

        if is_contour1_enclosed_by_any:
            result_contours.append(contour1)
            added_contour_indices.add(idx1)

    return result_contours

def dilate_contour(contour, image_shape, config):
    if contour is None or len(contour) == 0:
        raise ValueError("Invalid contour provided for dilation.")
    
    dilation_kernel_size = config.get('shape_detection', {}).get('remove_nodes_dilation_kernel_size', [3, 3])
    dilation_iterations = config.get('shape_detection', {}).get('remove_nodes_dilation_iterations', 3)


    height, width = image_shape
    mask = np.zeros((height, width), dtype=np.uint8)

    cv2.drawContours(mask, [contour], -1, (255), thickness=cv2.FILLED)

    kernel = np.ones((dilation_kernel_size[0], dilation_kernel_size[1]), np.uint8)
    dilated_mask = cv2.dilate(mask, kernel, iterations=dilation_iterations)
    dilated_contours, _ = cv2.findContours(dilated_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)

    if dilated_contours:
        largest_dilated_contour = max(dilated_contours, key=cv2.contourArea)
        return largest_dilated_contour
    else:
        raise ValueError("No contours found after dilation.")
    
def find_closest_distance_to_contour(point_obj: Point, contour_np: np.ndarray) -> float:
    """
    Calculates the shortest distance from a Point object to a contour.
    The contour is a numpy array of points.
    - If contour has 1 point, it's point-to-point distance.
    - If contour has 2 points, it's distance to the line segment defined by these points.
    - If contour has >2 points, it's distance to the boundary of the polygon defined by these points.

    Args:
        point_obj: The Point object from which to measure the distance.
        contour_np: A NumPy array representing the contour, shape (N, 1, 2) or (N, 2).
                    Coordinates are typically integers.

    Returns:
        The shortest distance as a float.
    """

    # Validate and standardize contour_np shape for point extraction
    if contour_np.ndim == 3 and contour_np.shape[1] == 1 and contour_np.shape[2] == 2:
        # Shape (N, 1, 2), reshape to (N, 2) for easier iteration
        processed_contour_points = contour_np.reshape(-1, 2)
    elif contour_np.ndim == 2 and contour_np.shape[1] == 2:
        # Shape (N, 2), use as is
        processed_contour_points = contour_np
    else:
        raise ValueError(f"Contour numpy array has an unsupported shape: {contour_np.shape}. "
                         "Expected (N, 1, 2) or (N, 2).")

    num_contour_points = processed_contour_points.shape[0]

    if num_contour_points == 0:
        return float('inf') # No points in contour, distance is infinite

    if num_contour_points == 1:
        contour_pt_coords = processed_contour_points[0]
        contour_pt_obj = Point(contour_pt_coords[0], contour_pt_coords[1])
        return point_obj.get_distance_between_points(contour_pt_obj)

    if num_contour_points == 2:
        pt_a_coords = processed_contour_points[0]
        pt_b_coords = processed_contour_points[1]
        
        # Create Point objects for the segment endpoints
        segment_pt_a = Point(pt_a_coords[0], pt_a_coords[1])
        segment_pt_b = Point(pt_b_coords[0], pt_b_coords[1])
        
        line_segment = Line(segment_pt_a, segment_pt_b)
        return line_segment.distance_point_to_segment(point_obj)
    
    # num_contour_points > 2 (Polygon case)
    else:
        # cv2.pointPolygonTest requires contour in (N, 1, 2) format and float32 type.
        # We use the original contour_np for this, as it might already be (N,1,2).
        if contour_np.ndim == 2: # Original was (N,2)
            contour_for_cv2 = contour_np.reshape((-1, 1, 2)).astype(np.float32)
        else: # Original was (N,1,2)
            contour_for_cv2 = contour_np.astype(np.float32)
            
        # The query point for pointPolygonTest needs to be a float tuple
        query_point_tuple = (float(point_obj.x), float(point_obj.y))
        
        # measureDist=True returns signed distance:
        # The absolute value is the shortest distance to any edge of the contour.
        distance = cv2.pointPolygonTest(contour_for_cv2, query_point_tuple, True)
        return abs(distance)


### renderer.py

### recognize_text.py

In [338]:
import cv2
import numpy as np
from doctr.models import ocr_predictor
from models import Text # Import the Text class

def _geometry_to_absolute_coords(relative_geom: tuple[tuple[float, float], tuple[float, float]], 
                                 img_width: int, img_height: int) -> tuple[tuple[int, int], tuple[int, int]]:
    """Converts doctr's relative coordinates to absolute integer coordinates."""
    (xmin_rel, ymin_rel), (xmax_rel, ymax_rel) = relative_geom
    
    xmin_abs = int(xmin_rel * img_width)
    ymin_abs = int(ymin_rel * img_height)
    xmax_abs = int(xmax_rel * img_width)
    ymax_abs = int(ymax_rel * img_height)
    
    return ((xmin_abs, ymin_abs), (xmax_abs, ymax_abs))

def detect_text(img_color_resized: np.ndarray, config: dict) -> list[Text]:
    """
    Detects text using doctr and returns a list of Text objects with absolute coordinates.
    (Implementation remains the same)
    """
    predictor_params = config.get('text_detection', {})
    predictor = ocr_predictor(
        det_arch='db_resnet50', 
        reco_arch='crnn_vgg16_bn', 
        pretrained=True,
    )
    predictor.det_predictor.model.postprocessor.bin_thresh = predictor_params.get('bin_thresh', 0.3)
    predictor.det_predictor.model.postprocessor.box_thresh = predictor_params.get('box_thresh', 0.1)

    out = predictor([img_color_resized])

    detected_texts: list[Text] = []
    img_height, img_width = img_color_resized.shape[:2]

    if out.pages:
        for block in out.pages[0].blocks:
            for line in block.lines:
                for word in line.words:
                    abs_geom = _geometry_to_absolute_coords(word.geometry, img_width, img_height)
                    text_obj = Text(value=word.value, 
                                      geometry_abs=abs_geom, 
                                      confidence=word.confidence)
                    detected_texts.append(text_obj)
    
    return detected_texts

def get_img_no_text(preprocessed_img: np.ndarray, detected_texts: list[Text]) -> np.ndarray:
    """
    Removes text from the preprocessed image by finding contours within text bounding boxes
    and applying a mask using bitwise_and (similar to original notebook).

    Args:
        preprocessed_img: The thresholded image (e.g., from Otsu).
        detected_texts: List of Text objects with absolute coordinates.

    Returns:
        The image with text contours removed (blacked out).
    """
    if not detected_texts:
        ### throw an error
        raise ValueError("No detected texts to process.")

    img_contours_list, _ = cv2.findContours(preprocessed_img, cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE)

    bbox_to_contours = [minmaxToContours([text.pt1.x, text.pt1.y, text.pt2.x, text.pt2.y])  for text in detected_texts]

    text_contours = filter_enclosed_contours(img_contours_list, bbox_to_contours, include_border=True)

    img_no_text = remove_contours(preprocessed_img, text_contours)

    return img_no_text 

# Helper function to get the center of a Text object's bounding box
def get_text_center(text_obj: Text) -> Point:
    """Calculates the center point of a Text object's bounding box."""
    center_x = (text_obj.pt1.x + text_obj.pt2.x) / 2.0
    center_y = (text_obj.pt1.y + text_obj.pt2.y) / 2.0
    return Point(center_x, center_y)


def link_text_to_elements(
    detected_text_list: list[Text],
    places_list: list[Place],
    transitions_list: list[Transition],
    arcs_list: list[Arc],
    distance_threshold: float
):
    """
    Associates Text objects with the closest Place, Transition, or Arc
    if the distance from the text's center to the element is within a given threshold.
    The association is done by appending the Text object to the `text` list
    of the corresponding Place, Transition, or Arc. 
    """

    # Clear any previous text associations from elements
    for element_list in [places_list, transitions_list, arcs_list]:
        for element in element_list:
            element.text = [] 

    for text_obj in detected_text_list:
        text_center = get_text_center(text_obj)
        
        min_overall_distance = float('inf')
        closest_element_overall = None

        # 1. Check Places
        for place in places_list:
            dist_to_place_center = text_center.get_distance_between_points(place.center)
            distance = max(0, dist_to_place_center - place.radius) # Distance to circumference
            
            if distance < min_overall_distance:
                min_overall_distance = distance
                closest_element_overall = place

        # 2. Check Transitions
        for transition in transitions_list:
            contour_to_use = None
            if transition.original_detection_data is not None and \
               isinstance(transition.original_detection_data, np.ndarray) and \
               transition.original_detection_data.shape[0] > 0:
                contour_to_use = transition.original_detection_data
            elif transition.points and len(transition.points) > 0: # Fallback to box_points
                contour_to_use = np.array([p.get_numpy_array() for p in transition.points], dtype=np.int32).reshape((-1, 1, 2))

            if contour_to_use is None or contour_to_use.shape[0] == 0:
                continue 
                
            distance = find_closest_distance_to_contour(text_center, contour_to_use)
            if distance < min_overall_distance:
                min_overall_distance = distance
                closest_element_overall = transition
        
        # 3. Check Arcs
        for arc in arcs_list:
            arc_contour_for_dist_calc = None
            # Prioritize arc.points if available, as it represents the path
            if arc.points and len(arc.points) >= 1:
                arc_contour_for_dist_calc = np.array([p.get_numpy_array() for p in arc.points], dtype=np.int32).reshape((-1, 1, 2))
            # Fallback if arc.points is empty but start/end points are defined (simple line arc)
            elif arc.start_point and arc.end_point:
                arc_contour_for_dist_calc = np.array([
                    arc.start_point.get_numpy_array(), 
                    arc.end_point.get_numpy_array()
                ], dtype=np.int32).reshape((-1, 1, 2))
            # If arc is defined by arc.lines (more complex, potentially disjoint segments)
            # This path is less common if arc.points is expected to be canonical.
            elif arc.lines:
                current_arc_min_dist_lines = float('inf')
                for line_segment in arc.lines:
                    dist_to_segment = line_segment.distance_point_to_segment(text_center)
                    current_arc_min_dist_lines = min(current_arc_min_dist_lines, dist_to_segment)
                
                if current_arc_min_dist_lines < min_overall_distance:
                    min_overall_distance = current_arc_min_dist_lines
                    closest_element_overall = arc
                continue # Skip contour-based distance if lines were processed

            if arc_contour_for_dist_calc is not None and arc_contour_for_dist_calc.shape[0] > 0:
                distance = find_closest_distance_to_contour(text_center, arc_contour_for_dist_calc)
                if distance < min_overall_distance:
                    min_overall_distance = distance
                    closest_element_overall = arc
        
        # Associate text with the overall closest element if within threshold
        if closest_element_overall is not None and min_overall_distance <= distance_threshold:
            closest_element_overall.text.append(text_obj)
            # print(f"Associated '{text_obj.value}' (center: {text_center}) with {closest_element_overall.__class__.__name__} id={id(closest_element_overall)} (dist: {min_overall_distance:.2f})")
        # else:
            # print(f"Text '{text_obj.value}' (center: {text_center}) not associated, min_dist {min_overall_distance:.2f} > threshold {distance_threshold}")


### recognize_node.py

In [339]:
import cv2
import numpy as np
import largestinteriorrectangle as lir
from constants import EPS

def get_circle_overlap(contour):
    """Checks if a contour is roughly circular based on the ratio of its area to its minimum enclosing circle."""
    (x, y), radius = cv2.minEnclosingCircle(contour)
    enclosing_area = np.pi * (radius ** 2) + EPS
    contour_area = cv2.contourArea(contour)
    
    return contour_area / enclosing_area 

def get_rectangle_overlap(contour):
    """Checks if a contour is roughly rectangular based on the ratio of its area to its minimum area bounding box."""
    rect = cv2.minAreaRect(contour)
    box_area = rect[1][0] * rect[1][1] + EPS
    contour_area = cv2.contourArea(contour)
        
    return contour_area / box_area 

def detect_shapes(preprocessed_img, config):

    circle_threshold = config.get('shape_detection', {}).get('fill_circle_enclosing_threshold', 0.8)
    rect_threshold = config.get('shape_detection', {}).get('fill_rect_enclosing_threshold', 0.95)
    
    contours_list, hierarchy = cv2.findContours(preprocessed_img, cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE) 

    circles = []
    rectangles = []
    for contour in contours_list:
        circle_overlap_percentage = get_circle_overlap(contour)
        rectangle_overlap_percentage = get_rectangle_overlap(contour)

        if circle_overlap_percentage > rectangle_overlap_percentage and circle_overlap_percentage > circle_threshold:
            circles.append(contour)
            
        elif rectangle_overlap_percentage > circle_overlap_percentage and rectangle_overlap_percentage > rect_threshold:
            # print(f"Rectangle detected: rectangle_overlap_percentage: {rectangle_overlap_percentage} circle_overlap_percentage: {circle_overlap_percentage}")
            rectangles.append(contour)

    return circles, rectangles

def get_nodes_mask(img_empty_nodes_filled, config):
    """
    Isolates node structures using an iterative erosion/dilation heuristic based on contour count stability.
    """
    # Default values, will be overridden by config if available
    erosion_kernel_size = tuple(config.get('shape_detection', {}).get('erosion_kernel_size', [3, 3]))
    min_stable_length = config.get('shape_detection', {}).get('min_stable_length', 3)
    max_erosion_iterations = config.get('shape_detection', {}).get('max_erosion_iterations', 30)
    verbose = config.get('shape_detection', {}).get('verbose', True) # Control printing

    erosion_kernel = np.ones(erosion_kernel_size, np.uint8)
    contour_counts_history = []
    optimal_erosion_iterations = 0  # Default if loop doesn't run or no erosions found necessary
    optimal_condition_found = False # Flag to indicate if stability or zero contours was met

    # This image is progressively eroded to find the optimal number of iterations
    image_for_iterative_erosion = img_empty_nodes_filled.copy()

    # Loop to determine the optimal number of erosion iterations
    # If max_erosion_iterations is 0, this loop won't execute, and optimal_erosion_iterations will remain 0.
    for current_iteration in range(1, max_erosion_iterations + 1):
        # Perform one erosion step
        eroded_this_step = cv2.erode(image_for_iterative_erosion, erosion_kernel, iterations=1)
        contours, _ = cv2.findContours(eroded_this_step, cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE)
        num_contours_at_step = len(contours)
        contour_counts_history.append(num_contours_at_step)
        
        image_for_iterative_erosion = eroded_this_step # Update for the next iteration

        # Check for stability in contour count
        if len(contour_counts_history) >= min_stable_length:
            last_n_counts = contour_counts_history[-min_stable_length:]
            if all(c == last_n_counts[0] for c in last_n_counts):
                # Optimal iterations = iteration count at the start of the stable sequence
                optimal_erosion_iterations = current_iteration - min_stable_length + 1
                if verbose:
                    print(f"Stability detected: Contour count {last_n_counts[0]} stable for {min_stable_length} iterations.")
                    print(f"Optimal number of erosions determined as: {optimal_erosion_iterations}.")
                optimal_condition_found = True
                break 

        # Check if all contours have disappeared
        if num_contours_at_step == 0:
            if not optimal_condition_found: # Only set if stability wasn't the primary reason
                optimal_erosion_iterations = current_iteration # All contours gone after this many erosions
                if verbose:
                    print(f"All contours disappeared after {current_iteration} erosions.")
                    print(f"Optimal number of erosions determined as: {optimal_erosion_iterations}.")
            optimal_condition_found = True # This is a definitive condition to stop
            break
    # Loop ends

    # If the loop completed fully (max_erosion_iterations reached) without finding stability or zero contours
    if not optimal_condition_found and max_erosion_iterations > 0:
        optimal_erosion_iterations = max_erosion_iterations
        if verbose:
            print(f"Max erosions ({max_erosion_iterations}) reached without specific stability/zero-contour condition. "
                  f"Using {optimal_erosion_iterations} erosions.")

    # Obtain the node mask by applying the optimal number of erosions to the original filled image
    if optimal_erosion_iterations > 0:
        node_mask_eroded = cv2.erode(img_empty_nodes_filled, erosion_kernel, iterations=optimal_erosion_iterations)
        if verbose:
            print(f"Applied {optimal_erosion_iterations} erosions to input image to get the node mask.")
    else:
        # If no erosions are optimal, return a copy of the input to maintain consistency (always a new image object)
        node_mask_eroded = img_empty_nodes_filled.copy()
        if verbose:
            print("No erosions applied for node mask (0 optimal erosions). Using a copy of input.")

    # Dilate the eroded node mask to recover node sizes
    if optimal_erosion_iterations > 0:
        # Dilate by the same number of steps and with the same kernel
        dilated_node_mask = cv2.dilate(node_mask_eroded, erosion_kernel, iterations=optimal_erosion_iterations)
        if verbose:
            print(f"Applied {optimal_erosion_iterations} dilations to recover node sizes.")
    else:
        # If no erosions were done, no dilations are needed either.
        # node_mask_eroded is already a copy of the original (or the optimally eroded one if erosions > 0).
        dilated_node_mask = node_mask_eroded 
        if verbose:
            print("No dilations applied (0 optimal erosions).")

    if verbose: # Print history only if verbose mode is on
        print(f"Contour counts per erosion iteration: {contour_counts_history}")

    return dilated_node_mask

### recognize_arrow.py

In [340]:
import base64
import requests
import cv2
import numpy as np
import supervision as sv

# Function to load configuration (can be moved to a central utils.py if used elsewhere)
# For now, keep it simple. The main caller (notebook) will load and pass the config.

def detect_arrowheads(
    image: np.ndarray,
    config: dict # Expects the full loaded YAML config
    # image_path: str = None # This was unused and can be removed if image is always passed as np.ndarray
) -> dict:
    """
    Detects objects (arrowheads) in an image using the Roboflow API.
    Configuration for the API (project_id, version, api_key, confidence)
    is expected to be in the passed config dictionary.
    """

    api_config = config.get('connection_processing', {}).get('arrowhead_api', {})
    project_id = api_config.get('project_id')
    version = api_config.get('version')
    api_key = api_config.get('api_key')
    # Roboflow API expects confidence as a percentage (0-100)
    confidence = api_config.get('confidence_threshold_percent', 10.0) # Default if not found

    if not all([project_id, version, api_key]):
        raise ValueError("Missing Roboflow API configuration (project_id, version, or api_key) in config.")
    
    if api_key == "YOUR_API_KEY":
        # It's good practice to warn or error if the placeholder API key is still there.
        # For now, let's raise an error to prevent accidental calls with a dummy key.
        raise ValueError("Roboflow API key is set to 'YOUR_API_KEY'. Please update it in your config.yaml.")

    # Encode the image
    # The original code had a commented-out section for reading from image_path.
    # Sticking to encoding the provided numpy array.
    success, encoded_image_bytes = cv2.imencode(".png", image) # Using .png as it's lossless; .jpg was also an option
    if not success:
        raise ValueError("Could not encode image to PNG format.")
    
    # Base64-encode the image bytes
    b64_encoded_image = base64.b64encode(encoded_image_bytes.tobytes()).decode("utf-8")

    # Build the request URL with query parameters
    # Note: The confidence parameter in the URL is the threshold.
    url = (
        f"https://detect.roboflow.com/{project_id}/{version}"
        f"?api_key={api_key}"
        f"&confidence={confidence}"  # This should be the percentage value
        "&format=json"
        # Consider adding other parameters like overlap, stroke, labels if needed,
        # and managing them via config.
    )

    # Send the POST request with the base64-encoded image
    headers = {"Content-Type": "application/x-www-form-urlencoded"} # Roboflow expects this for base64 data
    response = requests.post(url, data=b64_encoded_image, headers=headers)
    response.raise_for_status() # Raises an HTTPError for bad responses (4xx or 5xx)

    return response.json()


def show_arrows(img, arrowhead_result):
    img_drawn = img.copy()
    detections = sv.Detections.from_inference(arrowhead_result)

    # create supervision annotators
    bounding_box_annotator = sv.BoxAnnotator()
    label_annotator = sv.LabelAnnotator()

    # annotate the image with our inference results
    annotated_image = bounding_box_annotator.annotate(
        scene=img_drawn, detections=detections)

    sv.plot_image(annotated_image)

### recognize_arc.py

In [341]:
import math 
from skimage.morphology import skeletonize
import copy 

def get_hough_lines(img, config):
    """
    Detects lines in the image using Hough Transform.
    Returns the detected lines.
    """
    # Default values, will be overridden by config if available
    rho = config.get('connection_processing', {}).get('hough_rho', 1)
    theta = config.get('connection_processing', {}).get('hough_theta', np.pi / 180)
    threshold = config.get('connection_processing', {}).get('hough_threshold', 10)
    min_line_length = config.get('connection_processing', {}).get('min_line_length', 10)
    max_line_gap = config.get('connection_processing', {}).get('max_line_gap', 20)
    min_line_length = max(min_line_length, 1)  # Ensure it's at least 1
    

        # Skeletonize the image
    skeleton = skeletonize(img / 255).astype(np.uint8)*255
    hough_lines = cv2.HoughLinesP(skeleton, rho, np.pi/180, threshold, minLineLength=min_line_length, maxLineGap=max_line_gap)
    return hough_lines

class HoughBundler:     
    def __init__(self,min_distance=5,min_angle=2):
        self.min_distance = min_distance
        self.min_angle = min_angle
    
    def get_orientation(self, line):
        orientation = math.atan2(abs((line[3] - line[1])), abs((line[2] - line[0])))
        return math.degrees(orientation)

    def check_is_line_different(self, line_1, groups, min_distance_to_merge, min_angle_to_merge):
        for group in groups:
            for line_2 in group:
                if self.get_distance(line_2, line_1) < min_distance_to_merge:
                    orientation_1 = self.get_orientation(line_1)
                    orientation_2 = self.get_orientation(line_2)
                    if abs(orientation_1 - orientation_2) < min_angle_to_merge:
                        group.append(line_1)
                        return False
        return True

    def distance_point_to_line(self, point, line):
        px, py = point
        x1, y1, x2, y2 = line

        def line_magnitude(x1, y1, x2, y2):
            line_magnitude = math.sqrt(math.pow((x2 - x1), 2) + math.pow((y2 - y1), 2))
            return line_magnitude

        lmag = line_magnitude(x1, y1, x2, y2)
        if lmag < 0.00000001:
            distance_point_to_line = 9999
            return distance_point_to_line

        u1 = (((px - x1) * (x2 - x1)) + ((py - y1) * (y2 - y1)))
        u = u1 / (lmag * lmag)

        if (u < 0.00001) or (u > 1):
            #// closest point does not fall within the line segment, take the shorter distance
            #// to an endpoint
            ix = line_magnitude(px, py, x1, y1)
            iy = line_magnitude(px, py, x2, y2)
            if ix > iy:
                distance_point_to_line = iy
            else:
                distance_point_to_line = ix
        else:
            # Intersecting point is on the line, use the formula
            ix = x1 + u * (x2 - x1)
            iy = y1 + u * (y2 - y1)
            distance_point_to_line = line_magnitude(px, py, ix, iy)

        return distance_point_to_line

    def get_distance(self, a_line, b_line):
        dist1 = self.distance_point_to_line(a_line[:2], b_line)
        dist2 = self.distance_point_to_line(a_line[2:], b_line)
        dist3 = self.distance_point_to_line(b_line[:2], a_line)
        dist4 = self.distance_point_to_line(b_line[2:], a_line)

        return min(dist1, dist2, dist3, dist4)

    def merge_lines_into_groups(self, lines):
        groups = []  # all lines groups are here
        # first line will create new group every time
        groups.append([lines[0]])
        # if line is different from existing gropus, create a new group
        for line_new in lines[1:]:
            if self.check_is_line_different(line_new, groups, self.min_distance, self.min_angle):
                groups.append([line_new])

        return groups

    def merge_line_segments(self, lines):
        orientation = self.get_orientation(lines[0])
      
        if(len(lines) == 1):
            return np.block([[lines[0][:2], lines[0][2:]]])

        points = []
        for line in lines:
            points.append(line[:2])
            points.append(line[2:])
        if 45 < orientation <= 90:
            #sort by y
            points = sorted(points, key=lambda point: point[1])
        else:
            #sort by x
            points = sorted(points, key=lambda point: point[0])

        return np.block([[points[0],points[-1]]])

    def process_lines(self, lines):
        lines_horizontal  = []
        lines_vertical  = []
  
        for line_i in [l[0] for l in lines]:
            orientation = self.get_orientation(line_i)
            # if vertical
            if 45 < orientation <= 90:
                lines_vertical.append(line_i)
            else:
                lines_horizontal.append(line_i)

        lines_vertical  = sorted(lines_vertical , key=lambda line: line[1])
        lines_horizontal  = sorted(lines_horizontal , key=lambda line: line[0])
        merged_lines_all = []

        # for each cluster in vertical and horizantal lines leave only one line
        for i in [lines_horizontal, lines_vertical]:
            if len(i) > 0:
                groups = self.merge_lines_into_groups(i)
                merged_lines = []
                for group in groups:
                    merged_lines.append(self.merge_line_segments(group))
                merged_lines_all.extend(merged_lines)
                    
        return np.asarray(merged_lines_all)
    
def assign_proximity_nodes(
    original_lines: list[Line], 
    original_places: list[Place], 
    original_transitions: list[Transition], 
    config: dict,
) -> tuple[list[Line], list[Place], list[Transition]]:
    """
    Assigns a 'proximity_node' attribute to points of lines if they are close
    to a place or transition. Operates on deep copies of the input objects.

    Args:
        original_lines: A list of Line objects.
        original_places: A list of Place objects.
        original_transitions: A list of Transition objects.
        proximity_threshold: A factor to expand node boundaries for proximity checks.
                           For Places, it scales the radius.
                           For Transitions, it scales the height and width.

    Returns:
        A tuple containing:
        - lines_copy: Copied lines, where points may have a 'proximity_node' attribute.
        - places_copy: Deep copies of the original places.
        - transitions_copy: Deep copies of the original transitions.
        The 'proximity_node' attributes will refer to objects within places_copy or transitions_copy.
    """
    proximity_thres_place = config.get('connection_processing', {}).get('proximity_thres_place', 1.5)
    proximity_thres_trans_width = config.get('connection_processing', {}).get('proximity_thres_trans_width', 3)
    proximity_thres_trans_height = config.get('connection_processing', {}).get('proximity_thres_trans_height', 1.2)

    # 1. Create deep copies of all input object lists
    lines_copy = copy.deepcopy(original_lines)
    places_copy = copy.deepcopy(original_places)
    transitions_copy = copy.deepcopy(original_transitions)

    all_copied_node_centers = [node.center for node in places_copy] + \
                              [node.center for node in transitions_copy]

    # 3. Iterate through copied lines and their points
    for line in lines_copy:
        for line_point in [line.point1, line.point2]:
            for node_center_copy in all_copied_node_centers:
                node_copy = node_center_copy.part_of 
                # print(f"Checking proximity for line point {line_point} to node {node_copy}")

                if isinstance(node_copy, Place):
                    # print(f"Node {node_copy} is a place")
                    distance = line_point.get_distance_between_points(node_center_copy)
                    if distance < proximity_thres_place * node_copy.radius:
                        line_point.proximity_node = node_copy
                
                elif isinstance(node_copy, Transition):
                    # print(f"Node {node_copy} is a transition")
                    expanded_height = node_copy.height * proximity_thres_trans_height
                    expanded_width = node_copy.width * proximity_thres_trans_width
                    
                    
                    expanded_bbox_contour = cv2.boxPoints(((float(node_center_copy.x), float(node_center_copy.y)),
                                                            (expanded_height, expanded_width), node_copy.angle))
                    current_line_point_coords = (float(line_point.x), float(line_point.y))
                    
                    if cv2.pointPolygonTest(expanded_bbox_contour, current_line_point_coords, False) >= 0:
                        line_point.proximity_node = node_copy

                else:
                    print(f"Node {node_copy} is not a recognized type")
                    # This case should ideally not be reached if inputs are as expected.
                    raise ValueError(f"Unknown node type encountered: {type(node_copy)}")
    
    return lines_copy, places_copy, transitions_copy

def get_entry_points_from_lines(lines_list):
    """
    Original function provided by user, slightly adapted to use a local list.
    Extracts all unique points marked as 'is_entry' from a list of lines.
    """
    entry_points_set = set()
    for line in lines_list:
        if hasattr(line.point1, "proximity_node") and line.point1.proximity_node:
            entry_points_set.add(line.point1)
        if hasattr(line.point2, "proximity_node") and line.point2.proximity_node:
            entry_points_set.add(line.point2)
    return list(entry_points_set)


def cosine_similarity(vec1_norm: np.ndarray, vec2_norm: np.ndarray) -> float:
    """Computes the cosine similarity (dot product of normalized vectors)."""
    return np.dot(vec1_norm, vec2_norm)

def find_line_paths(
    initial_lines_list: list[Line],
    proximity_threshold: float = 30.0,
    dot_product_weight: float = 0.6,
    distance_to_line_weight: float = 0.2,
    endpoint_distance_weight: float = 0.2
) -> list[dict]:
    """
    Connects lines from a pool into paths, starting from an entry point
    and ending at another entry point.

    Args:
        initial_lines_list: A list of Line objects.
        proximity_threshold: Maximum distance to search for next point.
        dot_product_weight: Weight for vector alignment score.
        distance_to_line_weight: Weight for point-to-line distance score.
        endpoint_distance_weight: Weight for endpoint-to-endpoint distance score.

    Returns:
        A list of paths. Each path is a dictionary with 'lines' (list of Line)
        and 'points' (ordered list of Point forming the path).
    """
    lines_pool = set(initial_lines_list) # Use a set for efficient removal (O(1) on average)
    all_paths_found = []
    
    # Keep track of entry points that have successfully started a path to avoid re-processing
    # or entry points that have been used as an end of a path.
    consumed_entry_points = set()

    while True:
        current_start_line = None
        current_start_entry_point = None

        # Find a new starting line with an available entry point
        # Iterate over a temporary list as lines_pool can be modified
        for line in list(lines_pool):
            potential_start_points = []
            if hasattr(line.point1, "proximity_node") and line.point1.proximity_node and line.point1 not in consumed_entry_points:
                potential_start_points.append(line.point1)
            if hasattr(line.point2, "proximity_node") and line.point2.proximity_node and line.point2 not in consumed_entry_points:
                potential_start_points.append(line.point2)
            
            if potential_start_points:
                current_start_line = line
                # Prefer point1 if both are entries and available, or just take the first one.
                current_start_entry_point = potential_start_points[0]
                break
        
        if not current_start_line:
            break # No more available entry points or lines to start a path

        current_path_lines = [current_start_line]
        current_path_points = [current_start_entry_point]
        
        lines_pool.remove(current_start_line)
        consumed_entry_points.add(current_start_entry_point) # Mark this entry point as used for path initiation

        last_line_in_path = current_start_line
        # The current tip of the path is the other point of the start_line
        current_tip_of_path = last_line_in_path.get_other_point(current_start_entry_point)
        current_path_points.append(current_tip_of_path)

        # Inner loop to extend the current path
        while True:
            # Check if the current_tip_of_path is a destination entry point
            if hasattr(current_tip_of_path, "proximity_node") and current_tip_of_path.proximity_node:
                all_paths_found.append({"lines": list(current_path_lines), "points": list(current_path_points)})
                consumed_entry_points.add(current_tip_of_path) # Mark end entry point
                break # Path successfully found, break from inner loop

            candidate_extensions = []
            # Vector of the last segment, oriented towards the current tip
            vec_last_segment_norm = last_line_in_path.get_normalized_vector(
                start_point=last_line_in_path.get_other_point(current_tip_of_path),
                end_point=current_tip_of_path
            )

            for candidate_line in list(lines_pool): # Iterate over a copy of the pool for safe removal
                for point_on_candidate in [candidate_line.point1, candidate_line.point2]:
                    # Must not connect via an intermediate entry point
                    if hasattr(point_on_candidate, "proximity_node") and point_on_candidate.proximity_node:
                        continue

                    endpoint_dist = current_tip_of_path.get_distance_between_points(point_on_candidate)

                    if endpoint_dist <= proximity_threshold:
                        # Scoring Criterion 1: Dot product of normalized vectors
                        # Vector of candidate_line, oriented away from point_on_candidate
                        vec_candidate_norm = candidate_line.get_normalized_vector(
                            start_point=point_on_candidate,
                            end_point=candidate_line.get_other_point(point_on_candidate)
                        )
                        dot_prod_score = (cosine_similarity(vec_last_segment_norm, vec_candidate_norm) + 1) / 2 # Scale to [0,1]

                        # Scoring Criterion 2: Start point of "to be merged" line is close to the infinite line
                        # formed by our last merged line.
                        dist_to_prev_line_inf = last_line_in_path.distance_point_to_infinite_line(point_on_candidate)
                        # Score: higher is better (closer to 0 distance)
                        # Avoid division by zero; add 1. Max possible distance could normalize this.
                        # For simplicity, using 1 / (1 + dist).
                        dist_line_score = 1.0 / (1.0 + dist_to_prev_line_inf) if proximity_threshold > 0 else 1.0


                        # Bonus: endpoint_distance score (closer is better)
                        endpoint_dist_score = (proximity_threshold - endpoint_dist) / proximity_threshold \
                                              if proximity_threshold > 0 else 1.0
                        
                        # Combined score
                        total_score = (dot_product_weight * dot_prod_score +
                                       distance_to_line_weight * dist_line_score +
                                       endpoint_distance_weight * endpoint_dist_score)
                        
                        candidate_extensions.append({
                            "line": candidate_line,
                            "connection_point_on_candidate": point_on_candidate,
                            "score": total_score
                        })
            
            if not candidate_extensions:
                # No suitable extension found, path terminates here (not at an entry point).
                # This path is considered "noise" or incomplete.
                break # Break from inner loop

            # Select the best candidate extension
            candidate_extensions.sort(key=lambda x: x["score"], reverse=True)
            best_extension = candidate_extensions[0]

            # Add best extension to the current path
            lines_pool.remove(best_extension["line"]) # Remove from available lines
            current_path_lines.append(best_extension["line"])
            
            last_line_in_path = best_extension["line"]
            # The connection point on the candidate becomes part of the path
            current_path_points.append(best_extension["connection_point_on_candidate"])
            # The new tip is the other end of the newly added line
            current_tip_of_path = last_line_in_path.get_other_point(best_extension["connection_point_on_candidate"])
            current_path_points.append(current_tip_of_path)
            # Continue extending this path

    return all_paths_found

def assign_arrowheads(found_paths_original: list[Line], arrowhead_result: dict, config) -> tuple[list[Line], int]:
    arrowhead_proximity_thres = config.get('connection_processing', {}).get('arrowhead_proximity_threshold', 30)

    found_paths_copy = copy.deepcopy(found_paths_original)
    path_endpoints = []
    for path in found_paths_copy:
        if path["points"]: 
            path_endpoints.extend([path["points"][0], path["points"][-1]])
        else:
            raise ValueError("Path points list is empty. Cannot assign arrowheads.")
        
    rejected_arrowhead_count = 0

    for arrowhead in arrowhead_result["predictions"]:
        arrowhead_center = Point(arrowhead["x"], arrowhead["y"])
        
        closest_point = None
        closest_distance = float("inf")
        for endpoint in path_endpoints:
            distance = arrowhead_center.get_distance_between_points(endpoint)
            if distance < closest_distance and distance < arrowhead_proximity_thres:
                closest_distance = distance
                closest_point = endpoint
                ### remove endpoint from endpoints to avoid reusing
                path_endpoints.remove(endpoint)
        
        ### check if the closest point is None and throw error
        if closest_point is None:
            print("No closest point found for the arrowhead center.")
            rejected_arrowhead_count += 1
        else: 
            closest_point.is_arrow = True

    return found_paths_copy, rejected_arrowhead_count

def get_arcs(paths):
    """
    Links the nodes of the paths based on the proximity_node attribute of the points.
    This function assumes that the paths are already processed and contain points with proximity_node.
    """

    arcs = []

    for path in paths:
        if not path["points"][0].proximity_node or not path["points"][-1].proximity_node:
            raise ValueError("Path must start and end with a proximity node.")
        if len(path["points"]) < 2:
            raise ValueError("Path must contain at least two points.")
        if len(path["lines"]) < 1:
            raise ValueError("Path must contain at least one line.")
        # Assuming a path of N points connected sequentially has N-1 lines.
        if len(path["points"]) != len(path["lines"]) * 2:
             raise ValueError("Path points and lines are inconsistent.")

        start_point = path["points"][0]
        end_point = path["points"][-1]

        # Add arc from start to end unless start is an arrow and end is not
        if not (start_point.is_arrow and not end_point.is_arrow):
            arcs.append(Arc(
                source=start_point.proximity_node,
                target=end_point.proximity_node,
                start_point=start_point,
                end_point=end_point,
                points=path["points"],
                lines=path["lines"]
            ))

        # Add arc from end to start if start is an arrow
        if start_point.is_arrow:
             arcs.append(Arc(
                source=end_point.proximity_node,
                target=start_point.proximity_node,
                start_point=end_point,
                end_point=start_point,
                points=path["points"],
                lines=path["lines"]
            ))

    return arcs


### data_loading.py

In [342]:
import cv2
import numpy as np
import os

def preprocess(img):
    """Applies Otsu's thresholding to the input image."""
    # Ensure input is grayscale if it's not already
    gray_img = img
    if len(img.shape) == 3 and img.shape[2] == 3:
        gray_img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
    
    _, thresh_otsu = cv2.threshold(gray_img, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
    return thresh_otsu

def load_and_preprocess_image(image_path: str, config: dict):
    if not os.path.exists(image_path):
        raise FileNotFoundError(f"Image file not found at: {image_path}")

    img_color = cv2.imread(image_path)
    if img_color is None:
        raise ValueError(f"Could not read image file: {image_path}")

    img_gray = cv2.cvtColor(img_color, cv2.COLOR_BGR2GRAY)

    # --- Upscaling Heuristic ---
    cfg_proc = config.get('image_processing', {})
    min_dimension_threshold = cfg_proc.get('min_dimension_threshold', 800)
    upscale_factor = cfg_proc.get('upscale_factor', 2)

    h, w = img_gray.shape
    img_color_resized = img_color
    img_gray_resized = img_gray

    if h < min_dimension_threshold or w < min_dimension_threshold:
        print(f"Image dimensions ({w}x{h}) below threshold ({min_dimension_threshold}px). Upscaling by {upscale_factor}x.")
        new_w, new_h = w * upscale_factor, h * upscale_factor
        img_gray_resized = cv2.resize(img_gray, (new_w, new_h), interpolation=cv2.INTER_LANCZOS4)
        img_color_resized = cv2.resize(img_color, (new_w, new_h), interpolation=cv2.INTER_LANCZOS4)

    # --- Initial Preprocessing (Inversion + Thresholding) ---
    img_inverted = cv2.bitwise_not(img_gray_resized)
    preprocessed_img = preprocess(img_inverted) 

    return preprocessed_img, img_color_resized, img_gray_resized 

### models.py

In [343]:
import math
import numpy as np
import largestinteriorrectangle as lir
import cv2

# Epsilon for floating point comparisons if needed, though not used in current definitions
EPS = 1e-6 

class Point:
    def __init__(self, x, y):
        self.x = int(x) # Ensure integer coordinates if they represent pixels
        self.y = int(y)

        self.proximity_node = None # Placeholder for proximity node assignment
        self.is_arrow = False # Placeholder for entry point assignment

    def get_distance_between_points(self, other_point):
        """Calculate Euclidean distance between this point and another point."""
        return math.sqrt((self.x - other_point.x) ** 2 + (self.y - other_point.y) ** 2)
    
    def is_inside_contour(self, contour):
        """Check if this point is inside a given contour using cv2.pointPolygonTest"""
        # Note: This requires cv2, which might be better placed in a different module
        point_tuple = (float(self.x), float(self.y)) # pointPolygonTest needs float tuple
        # Ensure contour is in the correct format (e.g., Nx1x2 or Nx2)
        try:
            # >= 0 means inside or on the boundary
            return cv2.pointPolygonTest(contour, point_tuple, False) >= 0 
        except NameError:
            print("Warning: cv2 not imported. is_inside_contour cannot function.")
            return False
        except Exception as e:
            print(f"Error during pointPolygonTest: {e}")
            return False
        
    def get_numpy_array(self):
        """Returns the point as a numpy array."""
        return np.array([self.x, self.y], dtype=np.int32)

    def __repr__(self):
        return f"Point({self.x}, {self.y})"

    def __eq__(self, other):
        if not isinstance(other, Point):
            return NotImplemented
        return self.x == other.x and self.y == other.y

    def __hash__(self):
        """Allows Point objects to be added to sets or used as dictionary keys."""
        return hash((self.x, self.y))

class Line:
    def __init__(self, start_point: Point, end_point: Point, angle=None, length=None):
        """
        Initializes a Line object.
        If angle and length are not provided, they are calculated.
        """
        self.point1 = start_point
        self.point2 = end_point

        # Assign self to the points for back-reference if needed later
        # self.point1.part_of = self 
        # self.point2.part_of = self

        if angle is None or length is None:
            dx = self.point2.x - self.point1.x
            dy = self.point2.y - self.point1.y
            # Calculate angle in degrees
            self.angle = math.degrees(math.atan2(dy, dx)) if not (dx == 0 and dy == 0) else 0.0
            # Calculate length
            self.length = self.point1.get_distance_between_points(self.point2)
        else:
            self.angle = angle
            self.length = length

    def get_other_point(self, point: Point) -> Point:
        """Given one point of the line, returns the other point."""
        if point == self.point1:
            return self.point2
        elif point == self.point2:
            return self.point1
        else:
            # This case should ideally not be reached if logic is correct
            raise ValueError("Point is not part of this line.")

    def get_vector(self, start_point: Point = None, end_point: Point = None) -> np.ndarray:
        """
        Returns the vector of the line.
        If start_point and end_point are provided, computes vector from start to end.
        Otherwise, defaults to point1 -> point2.
        """
        if start_point and end_point:
            return np.array([end_point.x - start_point.x, end_point.y - start_point.y])
        return np.array([self.point2.x - self.point1.x, self.point2.y - self.point1.y])

    def get_normalized_vector(self, start_point: Point = None, end_point: Point = None) -> np.ndarray:
        """Returns the normalized (unit) vector of the line."""
        vec = self.get_vector(start_point, end_point)
        norm = np.linalg.norm(vec)
        if norm == 0:
            return np.array([0, 0]) # Represents a zero-length line segment
        return vec / norm

    def distance_point_to_infinite_line(self, point: Point) -> float:
        """
        Calculates the perpendicular distance from a point to the infinite line
        defined by this line segment.
        """
        p1_np = np.array([self.point1.x, self.point1.y])
        p2_np = np.array([self.point2.x, self.point2.y])
        p3_np = np.array([point.x, point.y])

        if np.array_equal(p1_np, p2_np): # If the line is just a point
            return np.linalg.norm(p3_np - p1_np)

        numerator = np.abs(np.cross(p2_np - p1_np, p1_np - p3_np))
        denominator = np.linalg.norm(p2_np - p1_np)
        if denominator == 0:
            return np.linalg.norm(p3_np - p1_np) # Distance to the single point
        return numerator / denominator
    
    def distance_point_to_segment(self, point: Point) -> float:
        """
        Calculates the shortest distance from a query point to this line segment.
        """
        # Convert query point and segment endpoints to numpy arrays
        p_np = point.get_numpy_array().astype(float)
        a_np = self.point1.get_numpy_array().astype(float) # Segment start (self.point1)
        b_np = self.point2.get_numpy_array().astype(float) # Segment end (self.point2)

        # If the segment is essentially a point (point1 and point2 are the same)
        if self.point1 == self.point2: # Relies on Point.__eq__
            return point.get_distance_between_points(self.point1)

        # Vector from A to B (segment vector)
        vec_ab = b_np - a_np
        # Vector from A to P (point relative to segment start)
        vec_ap = p_np - a_np

        t = np.dot(vec_ap, vec_ab) / np.dot(vec_ab, vec_ab)

        if 0.0 <= t <= 1.0:
            # The projection falls on the segment AB.
            # The shortest distance is the perpendicular distance from P to the line AB.
            # This can be calculated by self.distance_point_to_infinite_line(point).
            return self.distance_point_to_infinite_line(point)
        elif t < 0.0:
            # The projection falls outside the segment, on the side of A.
            # The closest point on the segment to P is A (self.point1).
            return point.get_distance_between_points(self.point1)
        else: # t > 1.0
            return point.get_distance_between_points(self.point2)

    def __repr__(self):
        return f"Line(start={self.point1}, end={self.point2}, angle={self.angle:.2f}, length={self.length:.2f})"

    def __eq__(self, other):
        if not isinstance(other, Line):
            return NotImplemented
        # A line is considered equal if its endpoints are the same, regardless of order.
        return (self.point1 == other.point1 and self.point2 == other.point2) or \
               (self.point1 == other.point2 and self.point2 == other.point1)

    def __hash__(self):
        """Allows Line objects to be added to sets. The hash is order-invariant for points."""
        # Hash the tuple of sorted point hashes
        return hash(tuple(sorted((hash(self.point1), hash(self.point2)))))


#####################################################################
#####################################################################
class Place:
    def __init__(
        self,
        circle: tuple[int, int, int], # (x, y, radius)
        original_detection_data=None, # Placeholder for any original detection data
    ):
        self.center = Point(circle[0], circle[1])
        self.radius = circle[2]
        self.center.part_of = self # Link back to the Place object

        self.text = [] # Placeholder for any text associated with this place
        self.original_detection_data = original_detection_data 

        self.markers = 0 # Placeholder for markers associated with this place

    @classmethod
    def from_contour(cls, contour: np.ndarray):
        (x, y), radius = cv2.minEnclosingCircle(contour)
        return cls((x, y, radius), original_detection_data= contour)
    
    def update_markers_from_text(self):
        """
        Recalculates and updates self.markers by summing numeric values
        from associated Text objects in self.text.
        Only text values that consist purely of digits after stripping whitespace
        are considered numeric.
        """
        current_sum_of_markers = 0
        for text_obj in self.text: # self.text is a list of Text objects
            value_str = text_obj.value.strip()
            if is_number(value_str):
                try:
                    num_val = float(value_str)
                    # Check for infinity, as int(inf) raises OverflowError
                    if num_val != float('inf') and num_val != float('-inf'):
                        current_sum_of_markers += int(num_val)
                    # else: print(f"Info: Skipped infinite value '{value_str}' for markers.") # Optional logging
                except ValueError:
                    pass 
        self.markers = current_sum_of_markers
        

    def __repr__(self):
        return f"Place(center={self.center}, radius={self.radius})"

class Transition:
    def __init__(
        self,
        center_coords: tuple[int, int], # (x, y)
        height: int,
        width: int,
        angle: float = 0.0, # Default angle
        original_detection_data=None, 
    ):
        self.center = Point(center_coords[0], center_coords[1])
        self.center.part_of = self

        self.height = height
        self.width = width
        self.angle = angle # Angle in degrees

        self.box_points = cv2.boxPoints(((self.center.x, self.center.y), (self.height, self.width), angle))

        self.points = [Point(int(pt[0]), int(pt[1])) for pt in self.box_points]
        for point in self.points:
            point.part_of = self

        self.text = [] 

        self.original_detection_data = original_detection_data 

    @classmethod
    def from_contour(cls, contour: np.ndarray):
        min_area_rect = cv2.minAreaRect(contour)
        return cls(min_area_rect[0], min_area_rect[1][0], min_area_rect[1][1], min_area_rect[2], original_detection_data=contour)
    
    def __repr__(self):
        return f"Transition(center={self.center}, height={self.height}, width={self.width}, angle={self.angle})"

### Potentially add an Arc class later if needed to represent the final connections
class Arc:
    def __init__(self, source, target, start_point, end_point, points=None, lines=None):
        self.source = source # Place or Transition object
        self.target = target # Place or Transition object
        self.start_point = start_point # Point object
        self.end_point = end_point # Point object
        self.points = points # Optional: Ordered list of points forming the arc geometry
        self.lines = lines   # Optional: List of Line segments forming the arc geometry

        self.text = [] # Placeholder for any text associated with this place

    def __repr__(self):
        return f"Arc(source={self.source}, target={self.target})"
    
    def __eq__(self, other):
        if not isinstance(other, Arc):
            return NotImplemented
        return (self.source == other.source and self.target == other.target)

class Text:
    """Represents a detected text element with its content and bounding box."""
    # Store geometry as absolute integer coordinates
    def __init__(self, value: str, geometry_abs: tuple[tuple[int, int], tuple[int, int]], confidence: float):
        """
        Args:
            value: The recognized text string.
            geometry_abs: Bounding box absolute coordinates ((xmin, ymin), (xmax, ymax)).
            confidence: The recognition confidence score.
        """
        self.value = value
        self.pt1 = Point(geometry_abs[0][0], geometry_abs[0][1])
        self.pt2 = Point(geometry_abs[1][0], geometry_abs[1][1])
        self.center = Point(
            (self.pt1.x + self.pt2.x) // 2,
            (self.pt1.y + self.pt2.y) // 2
        )
        self.confidence = confidence

    def __repr__(self):
        return f"Text(value='{self.value}', box=({self.pt1.x},{self.pt1.y})-({self.pt2.x},{self.pt2.y}), conf={self.confidence:.2f})"



### workflow.py

In [344]:
import yaml
from PIL import Image

# ## 2. Configuration Loading
CONFIG_PATH = 'config.yaml'
config = {}

try:
    with open(CONFIG_PATH, 'r') as f:
        config = yaml.safe_load(f)
    print(f"Configuration loaded from {CONFIG_PATH}") # Kept simple confirmation
except FileNotFoundError:
    print(f"Error: {CONFIG_PATH} not found. Using empty config.")
except Exception as e:
    print(f"Error loading or parsing {CONFIG_PATH}: {e}. Using empty config.")


Configuration loaded from config.yaml


In [345]:
INPUT_IMAGE_PATH = '../data/local/mid_petri_2.png' # Example relative path
# INPUT_IMAGE_PATH = '../data/internet/petri_net_19.jpg' # rect thresh 0.85
# INPUT_IMAGE_PATH = '../data/internet/petri_net_12.jpg' # rect thresh 0.85

img_steps = []

preprocessed_img, img_color_resized, img_gray_resized = load_and_preprocess_image(INPUT_IMAGE_PATH, config)
print(f"Image loaded and preprocessed from: {INPUT_IMAGE_PATH}") # Kept simple confirmation

Image dimensions (1057x619) below threshold (800px). Upscaling by 2x.
Image loaded and preprocessed from: ../data/local/mid_petri_2.png


In [346]:
img_steps.append(Image.fromarray(img_color_resized))
Image.fromarray(img_color_resized).show(title="Image without shapes")

In [347]:
detected_text_list = detect_text(img_color_resized, config)
img_no_text = get_img_no_text(preprocessed_img, detected_text_list)

In [348]:
Image.fromarray(img_no_text).show(title="Image without text")

In [349]:
circles, rectangles = detect_shapes(img_no_text, config)
img_empty_nodes_filled = fill_contours(img_no_text, circles + rectangles)

nodes_mask = get_nodes_mask(img_empty_nodes_filled, config) 
# Image.fromarray(nodes_mask).show(title="Isolated Nodes Mask")

detected_circles, detected_rectangles = detect_shapes(nodes_mask, config)

Stability detected: Contour count 31 stable for 3 iterations.
Optimal number of erosions determined as: 6.
Applied 6 erosions to input image to get the node mask.
Applied 6 dilations to recover node sizes.
Contour counts per erosion iteration: [936, 72, 69, 69, 42, 31, 31, 31]


In [350]:
dialated_circles = [dilate_contour(c, img_no_text.shape, config) for c in detected_circles]
dialated_rectangles = [dilate_contour(r, img_no_text.shape, config) for r in detected_rectangles]
img_no_shapes = remove_contours(img_empty_nodes_filled, dialated_circles + dialated_rectangles)

In [351]:
# img_drawn = img_color_resized.copy() # Convert to BGR for drawing
img_drawn = cv2.cvtColor(img_no_shapes, cv2.COLOR_GRAY2BGR) # Convert to BGR for drawing
for contour in detected_circles + detected_rectangles:
    cv2.drawContours(img_drawn, [contour], -1, (0, 255, 0), 2) # Draw contours in green
Image.fromarray(img_drawn).show(title="Detected Shapes")

In [352]:
Image.fromarray(img_no_shapes).show(title="Image without shapes")

In [353]:
type(detected_circles[0])

numpy.ndarray

In [354]:
places = [Place.from_contour(circle) for circle in detected_circles]
transitions = [Transition.from_contour(rect) for rect in detected_rectangles]
   
len(transitions), len(detected_rectangles)

(14, 14)

In [355]:
### draw places and transitions
img_drawn = cv2.cvtColor(img_no_shapes, cv2.COLOR_GRAY2BGR) # Convert to BGR for drawing
for place in places:
    cv2.circle(img_drawn, (place.center.x, place.center.y), int(place.radius), (255, 0, 0), 2) # Draw places in blue
for transition in transitions:
    cv2.drawContours(img_drawn, [transition.box_points.astype(np.int32)], -1, (0, 255, 0), 2) # Draw transitions in green

Image.fromarray(img_drawn).show(title="Detected Places and Transitions")

In [356]:
np.unique(skeleton.flatten())

array([  0, 255], dtype=uint8)

In [357]:
# skeleton = skeletonize(img_no_shapes / 255).astype(np.uint8)*255
# img_draw = cv2.cvtColor(skeleton.copy(), cv2.COLOR_GRAY2BGR)
# hough_lines = cv2.HoughLinesP(skeleton, 1, np.pi/180, 15, minLineLength=10, maxLineGap=25)
# for line in hough_lines:
#     cv2.line(img_draw, (line[0][0], line[0][1]), (line[0][2], line[0][3]), (np.random.randint(0, 256), np.random.randint(0, 256), np.random.randint(0, 256)), 2)
# Image.fromarray(img_draw).show()


In [358]:
hough_lines = get_hough_lines(img_no_shapes, config)
img_draw = cv2.cvtColor(skeletonize(img_no_shapes / 255).astype(np.uint8)*255, cv2.COLOR_GRAY2BGR)
for line in hough_lines:
    cv2.line(img_draw, (line[0][0], line[0][1]), (line[0][2], line[0][3]), (np.random.randint(0, 256), np.random.randint(0, 256), np.random.randint(0, 256)), 2)
Image.fromarray(img_draw).show()

In [359]:
hough_bundler_min_distance = config.get('connection_processing', {}).get('hough_bundler_min_distance', 10)
hough_bundler_min_angle = config.get('connection_processing', {}).get('hough_bundler_min_angle', 5) 

bundler = HoughBundler(min_distance = hough_bundler_min_distance ,min_angle=hough_bundler_min_angle)
merged_hough_lines = bundler.process_lines(hough_lines)

lines = [Line(Point(line[0][0], line[0][1]), Point(line[0][2], line[0][3])) for line in merged_hough_lines]
len(hough_lines), len(merged_hough_lines), len(lines)

(132, 59, 59)

In [360]:
img_drawn = cv2.cvtColor(img_no_shapes, cv2.COLOR_GRAY2BGR) # Convert to BGR for drawing
for line in lines:
    cv2.line(img_drawn, (line.point1.x, line.point1.y), (line.point2.x, line.point2.y), (255, 0, 0), 2) # Draw lines in blue

Image.fromarray(img_drawn).show(title="Detected Places and Transitions")

In [361]:
points_with_proximity = list(filter(lambda x: x.proximity_node, [point for line in lines for point in [line.point1, line.point2]]))
len(points_with_proximity)

0

In [362]:
isinstance(places[0], Place)

True

In [363]:
processed_lines, processed_places, processed_transitions = assign_proximity_nodes(
    lines, 
    places, 
    transitions, 
    config
)
entry_points = get_entry_points_from_lines(processed_lines)

points_with_proximity = list(filter(lambda x: x.proximity_node, [point for line in processed_lines for point in [line.point1, line.point2]]))
len(processed_lines), len(points_with_proximity)

(59, 90)

In [364]:
img_drawn = cv2.cvtColor(img_no_shapes, cv2.COLOR_GRAY2BGR) 
for line in processed_lines:

    color = (np.random.randint(0, 256), np.random.randint(0, 256), np.random.randint(0, 256))
    cv2.line(img_drawn, (line.point1.x, line.point1.y), (line.point2.x, line.point2.y), color , 2) # Draw lines in blue

for point in entry_points:
    cv2.circle(img_drawn, (point.x, point.y), 5, (0, 255, 0), -1) # Draw proximity points in green

Image.fromarray(img_drawn).show(title="Detected Places and Transitions")

In [365]:
### filtered lines still includes lines that are not connected to any proximity node
filtered_lines = []
for line in processed_lines:
    if line.point1.proximity_node == line.point2.proximity_node != None:
        continue
    else:
        filtered_lines.append(line)
        
len(filtered_lines)
img_draw = cv2.cvtColor(np.zeros_like(img_no_shapes), cv2.COLOR_GRAY2BGR)
for line in filtered_lines:
    color = (np.random.randint(0, 256), np.random.randint(0, 256), np.random.randint(0, 256))
    
    cv2.line(img_draw, (line.point1.x, line.point1.y), (line.point2.x, line.point2.y), color, 2)
# for point in entry_points:
#     cv2.circle(img_draw, (point.x, point.y), 5, (0, 255, 0), -1) # Draw proximity points in green

Image.fromarray(img_draw).show()

In [366]:
found_paths_result = find_line_paths(
    filtered_lines,
    proximity_threshold=100.0, # Max distance between points to consider connecting
    dot_product_weight=0.5,
    distance_to_line_weight=0.25,
    endpoint_distance_weight=0.25
)
img_draw = cv2.cvtColor(np.zeros_like(img_no_shapes), cv2.COLOR_GRAY2BGR)
for path in found_paths_result:
    color = (np.random.randint(0, 256), np.random.randint(0, 256), np.random.randint(0, 256))
    for line in path["lines"]:
        cv2.line(img_draw, (line.point1.x, line.point1.y), (line.point2.x, line.point2.y), color, 2)

Image.fromarray(img_draw).show()




In [367]:
arrowhead_result = detect_arrowheads(
    image=img_color_resized,
    config=config
)
detections = sv.Detections.from_inference(arrowhead_result)
arrow_contours = [minmaxToContours(xyxy) for xyxy in detections.xyxy]
# img_no_arrows = remove_contours(img_empty_nodes_filled, arrow_contours)
# Image.fromarray(img_no_arrows).show(title="Image with Arrowheads Removed")

In [368]:
len(arrowhead_result["predictions"]), len(arrow_contours)

(40, 40)

In [369]:
paths_with_arrows, rejected_arrows_count = assign_arrowheads(found_paths_result, arrowhead_result, config)

No closest point found for the arrowhead center.
No closest point found for the arrowhead center.
No closest point found for the arrowhead center.
No closest point found for the arrowhead center.
No closest point found for the arrowhead center.
No closest point found for the arrowhead center.


In [370]:
points_with_arrows_before = list(filter(lambda x: x.is_arrow, [point for path in found_paths_result for line in path["lines"] for point in [line.point1, line.point2]]))
points_with_arrows_after = list(filter(lambda x: x.is_arrow, [point for path in paths_with_arrows for line in path["lines"] for point in [line.point1, line.point2]]))
len(points_with_arrows_before), len(points_with_arrows_after), rejected_arrows_count

(0, 34, 6)

In [371]:
img_draw = cv2.cvtColor(np.zeros_like(img_no_shapes), cv2.COLOR_GRAY2BGR)
for path in paths_with_arrows:
    color = (np.random.randint(0, 256), np.random.randint(0, 256), np.random.randint(0, 256))
    for line in path["lines"]:
        cv2.line(img_draw, (line.point1.x, line.point1.y), (line.point2.x, line.point2.y), color, 2)

Image.fromarray(img_draw).show()


In [372]:
arcs = get_arcs(paths_with_arrows)
arcs[:5], len(arcs)

([Arc(source=Place(center=Point(1896, 194), radius=42.63063049316406), target=Transition(center=Point(2046, 97), height=106.0, width=18.0, angle=90.0)),
  Arc(source=Transition(center=Point(1233, 307), height=106.0, width=28.0, angle=90.0), target=Place(center=Point(1505, 559), radius=43.1812858581543)),
  Arc(source=Place(center=Point(2047, 454), radius=42.58375930786133), target=Transition(center=Point(1233, 307), height=106.0, width=28.0, angle=90.0)),
  Arc(source=Place(center=Point(396, 643), radius=43.05750274658203), target=Transition(center=Point(600, 483), height=106.0, width=20.0, angle=90.0)),
  Arc(source=Transition(center=Point(2046, 97), height=106.0, width=18.0, angle=90.0), target=Place(center=Point(2047, 454), radius=42.58375930786133))],
 41)

In [373]:
### check arcs for errors
### filter arcs to remove cycles from the same source and target
arcs_filtered = []
for arc in arcs:
    if arc.source != arc.target and type(arc.source) != type(arc.target):
        arcs_filtered.append(arc)

len(arcs_filtered)

39

In [374]:
### visualize arcs
img_draw = cv2.cvtColor(np.zeros_like(img_no_shapes), cv2.COLOR_GRAY2RGB)
for arc in arcs_filtered:
    src_color = (0, 0, 255) # Red for source
    tgt_color = (255, 0, 0) # Blue for target
    
    cv2.circle(img_draw, (arc.start_point.x, arc.start_point.y), 5, src_color, -1) # Draw source point
    cv2.circle(img_draw, (arc.end_point.x, arc.end_point.y), 5, tgt_color, -1) # Draw target point

    cv2.line(img_draw, (arc.start_point.x, arc.start_point.y), (arc.end_point.x, arc.end_point.y), (0, 255, 0), 2) # Draw line in green

Image.fromarray(img_draw).show(title="Detected Arcs")

In [375]:
arc_contour = np.array(([point.get_numpy_array() for point in arcs[0].points])).reshape((-1, 1, 2))
place_contour = places[0].original_detection_data
transition_contour = transitions[0].original_detection_data

arc_contour.shape, place_contour.shape, transition_contour.shape

# Define a sample point
my_point = Point(50, 50)

img_draw = cv2.cvtColor(np.zeros_like(img_no_shapes), cv2.COLOR_GRAY2BGR)
cv2.circle(img_draw, (my_point.x, my_point.y), 5, (0, 255, 0), -1) # Draw point in green
cv2.polylines(img_draw, [arc_contour], isClosed=False, color=(0, 255, 0), thickness=2) # Draw polyline in green
cv2.polylines(img_draw, [place_contour], isClosed=True, color=(255, 0, 0), thickness=2) # Draw place in blue
cv2.polylines(img_draw, [transition_contour], isClosed=True, color=(0, 0, 255), thickness=2) # Draw transition in red
Image.fromarray(img_draw).show()

In [376]:

dist_to_arc = find_closest_distance_to_contour(my_point, arc_contour)
dist_to_place = find_closest_distance_to_contour(my_point, place_contour)
dist_to_transition = find_closest_distance_to_contour(my_point, transition_contour)
print(f"Distance from {my_point} to arc contour: {dist_to_arc}")
print(f"Distance from {my_point} to place contour: {dist_to_place}")
print(f"Distance from {my_point} to transition contour: {dist_to_transition}")

Distance from Point(50, 50) to arc contour: 1890.6237066111278
Distance from Point(50, 50) to place contour: 1643.8795576318844
Distance from Point(50, 50) to transition contour: 1344.5463919106696


In [377]:
# Set a threshold
text_linking_threshold = config.get('connection_processing', {}).get('text_linking_threshold',  25.0 )

link_text_to_elements(
    detected_text_list,
    places,
    transitions,
    arcs_filtered,
    text_linking_threshold
)

# Check associations
print("--- Associated Text ---")
for p in places:
    if p.text: print(f"{p} has text: {[t.value for t in p.text]}")
for t in transitions:
    if t.text: print(f"{t} has text: {[txt.value for txt in t.text]}") # changed t.text to txt.value
for a in arcs_filtered:
    if a.text: print(f"{a} has text: {[t.value for t in a.text]}")

--- Associated Text ---
Place(center=Point(1303, 1178), radius=42.83219909667969) has text: ['0']
Place(center=Point(1928, 987), radius=42.33136749267578) has text: ['deadend', '0']
Place(center=Point(849, 772), radius=42.84130096435547) has text: ['0']
Place(center=Point(1506, 764), radius=42.801334381103516) has text: ['orders', '0']
Place(center=Point(2052, 744), radius=42.6140251159668) has text: ['P1', '0']
Place(center=Point(396, 643), radius=43.05750274658203) has text: ['customers', '0']
Place(center=Point(86, 643), radius=42.58401870727539) has text: ['P1', '1']
Place(center=Point(1505, 559), radius=43.1812858581543) has text: ['orders', '0']
Place(center=Point(795, 471), radius=42.31279373168945) has text: ['U']
Place(center=Point(2047, 454), radius=42.58375930786133) has text: ['0']
Place(center=Point(812, 296), radius=43.51975631713867) has text: ['size', '72']
Place(center=Point(1896, 194), radius=42.63063049316406) has text: ['P1', '1']
Place(center=Point(425, 111), radiu



In [378]:
for place in places:
    place.update_markers_from_text()
    print(f"Place {place} has {place.markers} markers.")

Place Place(center=Point(1303, 1178), radius=42.83219909667969) has 0 markers.
Place Place(center=Point(1298, 999), radius=42.35411071777344) has 0 markers.
Place Place(center=Point(1928, 987), radius=42.33136749267578) has 0 markers.
Place Place(center=Point(849, 772), radius=42.84130096435547) has 0 markers.
Place Place(center=Point(1506, 764), radius=42.801334381103516) has 0 markers.
Place Place(center=Point(2052, 744), radius=42.6140251159668) has 0 markers.
Place Place(center=Point(396, 643), radius=43.05750274658203) has 0 markers.
Place Place(center=Point(86, 643), radius=42.58401870727539) has 1 markers.
Place Place(center=Point(1505, 559), radius=43.1812858581543) has 0 markers.
Place Place(center=Point(795, 471), radius=42.31279373168945) has 0 markers.
Place Place(center=Point(2047, 454), radius=42.58375930786133) has 0 markers.
Place Place(center=Point(812, 296), radius=43.51975631713867) has 72 markers.
Place Place(center=Point(1896, 194), radius=42.63063049316406) has 1 