In [None]:
import matplotlib.pyplot as plt
import numpy as np
import cv2

ASSETS_FOLDER = 'assets/'
EDGE_THRESHOLD = 5

In [None]:
def draw_on_image(image):
  drawing = False
  pt1_x, pt1_y = None, None

  def line_drawing(event, x, y, _, __):
    global pt1_x, pt1_y, drawing

    if event == cv2.EVENT_LBUTTONDOWN:
      drawing = True
      pt1_x, pt1_y = x, y

    elif event == cv2.EVENT_MOUSEMOVE:
      if drawing == True:
        cv2.line(image, (pt1_x, pt1_y), (x, y), color=(0, 0, 0), thickness=3)
        pt1_x, pt1_y = x, y

    elif event == cv2.EVENT_LBUTTONUP:
      drawing = False
      cv2.line(image, (pt1_x,pt1_y), (x, y), color=(0, 0, 0), thickness=3)        


  image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
  cv2.namedWindow('test draw')
  cv2.setMouseCallback('test draw', line_drawing)

  while(1):
    cv2.imshow('test draw',image)
    if cv2.waitKey(1) & 0xFF == 27:
      break

  cv2.destroyAllWindows()

  return cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

In [None]:
def display_images(input, edges, output):
  _, axes = plt.subplots(1, 3, figsize=(20, 20), squeeze=False)
  axes[0, 0].imshow(input, cmap="gray")
  axes[0, 0].set_title("Input")
  axes[0, 0].axis("off")

  axes[0, 1].imshow(edges, cmap="gray")
  axes[0, 1].set_title("Edges")
  axes[0, 1].axis("off")

  axes[0, 2].imshow(output)
  axes[0, 2].set_title("Output")
  axes[0, 2].axis("off")

  plt.tight_layout()

In [None]:
def load_image(path, color=cv2.COLOR_BGR2GRAY):
  image = cv2.imread(path, -1)

  if image is not None:
    image = cv2.cvtColor(image, color)

    return image

In [None]:
def edge_detection(image):
  kernel = np.ones((3, 3), np.uint8)
  edges = cv2.dilate(image, kernel) - cv2.erode(image, kernel)
  edges[edges < EDGE_THRESHOLD] = 0
  return edges

In [None]:
def get_pixel(image, i, j):
  if i < 0 or j < 0 or i >= image.shape[0] or j >= image.shape[1]:
    return -1
  return image[i][j]

In [None]:
def get_neighourhood(image, i, j):
  return np.array([
    get_pixel(image, i - 1, j - 1),
    get_pixel(image, i - 1, j),
    get_pixel(image, i - 1, j + 1),
    get_pixel(image, i, j - 1),
    get_pixel(image, i, j + 1),
    get_pixel(image, i + 1, j - 1),
    get_pixel(image, i + 1, j),
    get_pixel(image, i + 1, j + 1),
  ])

In [None]:
def watershed_basic(image, regions=None):
  new_region = 0
  colors = []
  segmented = np.empty((image.shape[0], image.shape[1], 3))

  if regions == None:
    regions = np.full(image.shape, -1)

  for intensity in range(256):
    for i in range(image.shape[0]):
      for j in range(image.shape[1]):
        if image[i][j] == intensity:
          neighbours = get_neighourhood(regions, i, j)
          mask = neighbours > -1

          if True in mask:
            segment = min(x for x in neighbours if x > -1)
            regions[i][j] = segment
            segmented[i, j, :] = colors[segment]
          else:
            new_color = list(np.random.choice(range(256), size=3))
            regions[i][j] = new_region
            segmented[i, j, :] = new_color
            colors.append(new_color)
            new_region += 1

  segmented = segmented.astype(np.byte)
  segmented[segmented > 255] = 255
  segmented[segmented < 0] = 0
  return segmented

In [None]:
image = load_image(ASSETS_FOLDER + 'primer1.jpg')
edges = edge_detection(image)
segmented = watershed_basic(edges)

display_images(image, edges, segmented)

In [None]:
image = load_image(ASSETS_FOLDER + 'primer2.jpg')
edges = edge_detection(image)
segmented = watershed_basic(edges)

display_images(image, edges, segmented)

In [None]:
image = load_image(ASSETS_FOLDER + 'primer3.jpg')
edges = edge_detection(image)
segmented = watershed_basic(edges)

display_images(image, edges, segmented)

In [None]:
image = load_image(ASSETS_FOLDER + 'primer1.jpg', cv2.COLOR_BGR2RGB)
image = draw_on_image(image)

plt.imshow(image)
plt.axis('off')
plt.show()