In [None]:
!rm -rf clone && git clone https://github.com/rusgu-real/ecg_digitisation clone && cp -a clone/. .

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import cv2 as cv

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import cv2
from abc import ABC, abstractmethod
from google.colab.patches import cv2_imshow
from collections import deque

class ImageData:
    def __init__(self, image, metadata=None):
        self.image = image
        self.metadata = metadata if metadata is not None else {}

    def show(self, text):
      print(text)
      cv2_imshow(self.image)

class Processor(ABC):
  @abstractmethod
  def process(self, data:ImageData) -> ImageData:
    pass

class Crop(Processor):
  def process(self, data:ImageData) -> ImageData:
    orig = data.image
    scale = 800 / orig.shape[1]
    image = cv2.resize(orig, None, fx=scale, fy=scale)

    # Grayscale + blur
    gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
    blur = cv2.GaussianBlur(gray, (5, 5), 0)

    clahe = cv2.createCLAHE(clipLimit=1.0, tileGridSize=(8,8))
    blur = clahe.apply(blur)

    # Edge detection
    edges = cv2.Canny(blur, 50, 200)
    cv2_imshow(edges)

    # Find contours
    contours, _ = cv2.findContours(
        edges, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE
    )

    contours, _ = cv2.findContours(
        edges, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE
    )

    if contours:
        # Pick the largest contour regardless of approximation
        paper_contour = max(contours, key=cv2.contourArea)

        # Optional: approximate its polygon for later perspective transform
        peri = cv2.arcLength(paper_contour, True)
        approx = cv2.approxPolyDP(paper_contour, 0.02 * peri, True)

        # Use approx for warping or drawing
        paper_contour = approx
    else:
        paper_contour = None  # No contours found


    # Scale contour back to original size
    paper_contour = (paper_contour / scale).astype(np.int32)

    # Crop original image
    x, y, w, h = cv2.boundingRect(paper_contour)
    cropped = orig[y:y+h, x:x+w]
    data.metadata["cropped"] = cropped
    data.metadata["paper_contour"] = paper_contour
    data.image = cropped # Pass the cropped image to the next processor
    return data


class Unwarp(Processor):
    def process(self, data: ImageData) -> ImageData:
        pass

class GridDetect(Processor):
  def process(self, data:ImageData) -> ImageData:
    img = data.image # This will be the cropped image from the previous step

    if img is None:
        raise IOError("Image not found in data.image")

    h_img, w_img = img.shape[:2]
    data.metadata["image_height"] = h_img # Store image dimensions
    data.metadata["image_width"] = w_img

    # --- Step 1: Detect red grid mask ---
    hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)
    lower_red1 = np.array([0,50,50])
    upper_red1 = np.array([30,255,255])
    lower_red2 = np.array([150,50,50])
    upper_red2 = np.array([180,255,255])
    mask1 = cv2.inRange(hsv, lower_red1, upper_red1)
    mask2 = cv2.inRange(hsv, lower_red2, upper_red2)
    red_mask = cv2.bitwise_or(mask1, mask2)

    # Minimal morphology to clean noise
    kernel = np.ones((2,2),np.uint8)
    red_mask_clean = cv2.morphologyEx(red_mask, cv2.MORPH_OPEN, kernel, iterations=1)
    data.metadata["red_mask_clean"] = red_mask_clean
    print("Step 1: Red grid mask cleaned")
    cv2_imshow(red_mask_clean)

    # --- Step 2: Edge detection ---
    edges = cv2.Canny(red_mask_clean, 20, 60)
    print("Step 2: Edges detected for grid")
    cv2_imshow(edges)

    # --- Step 3: Detect Hough lines ---
    lines = cv2.HoughLinesP(edges, 1, np.pi/180, threshold=20, minLineLength=20, maxLineGap=15)
    horizontal_lines = []
    vertical_lines = []
    if lines is not None:
        for x1, y1, x2, y2 in lines[:,0]:
            if abs(y1 - y2) < 5: # Horizontal line
                horizontal_lines.append((x1, y1, x2, y2))
            elif abs(x1 - x2) < 5: # Vertical line
                vertical_lines.append((x1, y1, x2, y2))

    grid_detected_img = img.copy()
    if lines is not None:
        for x1, y1, x2, y2 in lines[:,0]:
            cv2.line(grid_detected_img, (x1,y1),(x2,y2),(0,255,0),1)
    print(f"Step 3: Grid detected - {len(horizontal_lines)} horizontal, {len(vertical_lines)} vertical lines")
    cv2_imshow(grid_detected_img)

    # --- Step 4: Reconstruct grid ---
    def cluster_positions(positions, tolerance=3):
        if not positions:
            return []
        positions = sorted(positions)
        clusters = []
        current_cluster = [positions[0]]
        for pos in positions[1:]:
            if abs(pos - current_cluster[-1]) <= tolerance:
                current_cluster.append(pos)
            else:
                clusters.append(int(np.mean(current_cluster)))
                current_cluster = [pos]
        clusters.append(int(np.mean(current_cluster))) # Add the last cluster
        return clusters

    h_lines_unique = cluster_positions([line[1] for line in horizontal_lines])
    v_lines_unique = cluster_positions([line[0] for line in vertical_lines])
    data.metadata["h_lines_unique"] = h_lines_unique
    data.metadata["v_lines_unique"] = v_lines_unique

    grid_reconstructed = img.copy()
    for y in h_lines_unique:
        cv2.line(grid_reconstructed, (0,y),(w_img,y),(255,0,0),1)
    for x in v_lines_unique:
        cv2.line(grid_reconstructed, (x,0),(x,h_img),(255,0,0),1)
    print(f"Step 4: Grid reconstructed - {len(h_lines_unique)} horizontal, {len(v_lines_unique)} vertical lines")
    cv2_imshow(grid_reconstructed)

    # --- Store the image with the grid for potential later use ---
    # The example usage expects 'warped' key, so setting it here as the main output of this processor
    data.metadata["warped"] = grid_reconstructed
    data.image = grid_reconstructed # Pass the image with reconstructed grid to the next processor

    # --- Step 5a: Grayscale with red grid removed --- # (Moved to 5a for consistency)
    gray_img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
    gray_no_grid = gray_img.copy()
    gray_no_grid[red_mask_clean > 0] = 255
    data.metadata["gray_no_grid"] = gray_no_grid
    print("Step 5a: Grayscale with red grid removed (for curve detection)")
    cv2_imshow(gray_no_grid)

    # --- Step 5b: Thresholded ECG lines (grid removed) --- # (Added to metadata)
    # Now threshold to detect ECG line (black line -> white)
    # Adaptive threshold is robust to lighting
    thresh_ecg = cv2.adaptiveThreshold(
        gray_no_grid,
        255,
        cv2.ADAPTIVE_THRESH_MEAN_C,
        cv2.THRESH_BINARY_INV,
        blockSize=15,
        C=10
    )

    data.metadata["thresh_ecg"] = thresh_ecg
    print("Step 5b: Thresholded ECG lines (grid removed)")
    cv2_imshow(thresh_ecg)

    return data

class CurveDetectPerfect(Processor):
  def process(self, data:ImageData) -> ImageData:
    img = data.metadata["cropped"].copy()
    hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)
    gray = hsv[:, :, 2]
    mask = gray < 60
    img[mask] = 0
    img[~mask] = 255
    cv2_imshow(img)

    def bfs_component(image, start_x, start_y, n, threshold, visited):
      h, w = image.shape
      half = n // 2

      queue = deque([(start_x, start_y)])
      component = []
      visited[start_y, start_x] = True

      while queue:
          x, y = queue.popleft()
          component.append((x, y))

          for dy in range(-half, half + 1):
              for dx in range(-half, half + 1):
                  nx, ny = x + dx, y + dy

                  if (
                      0 <= nx < w and
                      0 <= ny < h and
                      not visited[ny, nx] and
                      image[ny, nx] < threshold
                  ):
                      visited[ny, nx] = True
                      queue.append((nx, ny))

      return component

    def find_all_dark_components(image, n, threshold=50):
      h, w = image.shape
      visited = np.zeros((h, w), dtype=bool)
      components = []

      for y in range(h):
          for x in range(w):
              if image[y, x] < threshold and not visited[y, x]:
                  component = bfs_component(
                      image=image,
                      start_x=x,
                      start_y=y,
                      n=n,
                      threshold=threshold,
                      visited=visited
                  )
                  components.append(np.array(component))

      return components

    components = find_all_dark_components(gray, 5, 100)

    def draw_components_cv(image_gray, components):
      # Convert grayscale â†’ BGR
      img_color = cv2.cvtColor(image_gray, cv2.COLOR_GRAY2BGR)

      for comp in components:
          # random color per component
          color = tuple(np.random.randint(0, 255, 3).tolist())

          for (x, y) in comp:
              img_color[y, x] = color   # fastest way (single pixel)

      return img_color
    components = find_all_dark_components(
        image=gray,
        n=5,
        threshold=50
    )

    overlay = draw_components_cv(gray, components)

    cv2_imshow(overlay)
    cv2.waitKey(0)
    cv2.destroyAllWindows()



class TreatImage():
  def __init__(self,processors : list[Processor]) -> None:
    self.processors = processors

  def run(self, image) -> ImageData:
    data = ImageData(image, {})
    for processor in self.processors:
      data = processor.process(data)
    return data

'''
treat = TreatImage([Crop(),GridDetect(),CurveDetect()])
treat.run(img)
'''
# Change img_path to match the one used in the previous cell's execution for consistency.
img_path = "data/test/1053922973.png"

test = TreatImage([Crop(), CurveDetectPerfect()])
res = test.run(cv2.imread(img_path))

# Optional: Display final outputs stored in metadata if desired
# cv2_imshow(res.metadata["cropped"])
# print(res.metadata.get("paper_contour"))
# cv2_imshow(res.metadata["warped"])
# cv2_imshow(res.metadata["gray_no_grid"])
# cv2_imshow(res.metadata["thresh_ecg"])
# If you want to see the overlaid image directly from metadata:
# cv2_imshow(res.metadata["ecg_traces_overlaid_image"])
