# Purpose

To carry out testing of the proposed pipeline to detect the tongue using OpenCV

Pipeline: Image > CLAHE > Image thresholding > detect contours of image > use contours to detect centre of tongue

Thereafter, we apply the SAM2 model on the image and see the results

In [None]:
# pip install if running on colab
# !pip install ultralytics

In [26]:
import cv2

from PIL import Image
import numpy as np

from ultralytics import SAM

import torch

In [22]:
def find_tongue_center(image_path):
    '''
    Finds the center of the tongue. Process:
    1. read image, convert into grayscale
    2. Apply CLAHE to enhance image features
    3. Apply binary thresholding
    4. Look for contours in the thresholded image
    5. Get the largest contour, then find its centroid by computing moments

    Args:
        image_path (str): path to the image

    Returns:
        list of ints: (x, y) coordinates of center of tongue. If unable to find center of tongue, entries will be
        [-1, -1] instead
    '''
    # Load the image
    image = cv2.imread(image_path)
    gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)

    # Enhance contrast
    clahe = cv2.createCLAHE(clipLimit=3.0, tileGridSize=(8, 8))
    enhanced = clahe.apply(gray)

    # Threshold the image
    _, binary = cv2.threshold(enhanced, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)

    # Find contours
    contours, _ = cv2.findContours(binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)

    if contours:
        # Find the largest contour
        tongue_contour = max(contours, key=cv2.contourArea)

        # Compute the center of the tongue
        M = cv2.moments(tongue_contour)
        if M["m00"] != 0:
            cx = int(M["m10"] / M["m00"])
            cy = int(M["m01"] / M["m00"])
        else:
            cx, cy = 0, 0

        # Visualize the result
        # result = image.copy()
        # cv2.drawContours(result, [tongue_contour], -1, (0, 255, 0), 2)
        # cv2.circle(result, (cx, cy), 5, (0, 0, 255), -1)
        # cv2.imwrite(f"./processed/{image_path[10:]}", result)
        # cv2.imshow("Detected Tongue Center", result)
        # cv2.waitKey(0)
        # cv2.destroyAllWindows()

        return [cx, cy]
    else:
        print(f"No contours found for image path: {image_path}")
        # return [-1, -1] to invalidate return value, and throw error if we access the values
        return [-1, -1]

In [23]:
coords = []

for i in range(1, 10):
    cx, cy = find_tongue_center(f"./Samples/sample_{i}.jpg")
    coords.append((cx, cy))

In [24]:
coords

[(318, 363),
 (320, 300),
 (380, 239),
 (301, 308),
 (334, 266),
 (254, 303),
 (324, 349),
 (312, 278),
 (300, 423)]

In [30]:
# Load the image
image_path = "./Samples/sample_7.jpg"
image = cv2.imread(image_path)
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)

# Enhance contrast
clahe = cv2.createCLAHE(clipLimit=3.0, tileGridSize=(8, 8))
enhanced = clahe.apply(gray)

# Threshold the image
_, binary = cv2.threshold(enhanced, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)

# Find contours
contours, _ = cv2.findContours(binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)

contour_lst = []

for contour in contours:
  M = cv2.moments(contour)
  if M["m00"] != 0:
      cx = int(M["m10"] / M["m00"])
      cy = int(M["m01"] / M["m00"])
      contour_lst.append([cx, cy])

for lst in contour_lst:
  cx = lst[0]
  cy = lst[1]
  cv2.circle(image, (cx, cy), 5, (0, 0, 255), -1)

cv2.imshow("image", image)
cv2.waitKey(0)
cv2.destroyAllWindows()

## Using SAM2 to segment the tongues

In [None]:
model = SAM("sam2.1_l.pt")

device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)

# print out device we are working on
device

In [None]:
def segment(image, center):
    '''
    Uses the SAM2 model to segment the tongue out

    Args:
        image (PIL.Image): image as read from PIL image library
        center (list): list of coordinates for center of the tongue

    Returns:
        segmented_image (np.array): image after performing segmentation and masking
    '''
    image_np = np.array(image)

    # Segmentation: arbitrarily take 0, 0 as a point that does not contain the tongue
    results_pil = model(image_np, points=[center, [0, 0]], labels=[1, 0])

    # Get the mask from the results
    mask_pil = results_pil[0].masks.data[0].cpu().numpy()

    # Masking
    binary_mask = mask_pil > 0.5
    rgb_mask = np.repeat(binary_mask[:, :, np.newaxis], 3, axis=2)
    segmented_image = image_np * rgb_mask

    return segmented_image

In [27]:
# read images into a list of images
im_lst = []

for i in range(1, 10):
    im = Image.open(f"./Samples/sample_{i}.jpg")
    im_lst.append(im)

In [28]:
im_lst

[<PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=640x640>,
 <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=640x640>,
 <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=640x640>,
 <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=640x640>,
 <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=640x640>,
 <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=640x640>,
 <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=640x640>,
 <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=640x640>,
 <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=640x640>]