In [1]:
import numpy as np
from PIL import Image
from tqdm import tqdm
from datasets import load_dataset

from ultralytics import SAM
import os

In [None]:
dataset = load_dataset("e1010101/tongue-images-384")
dataset

In [None]:
# Please check the documentation at https://docs.ultralytics.com/models/sam-2
# to get the latest models
model = SAM("sam2.1_l.pt")
model.info()

In [4]:
def segment(image):
    image_np = np.array(image)
    
    # Segmentation
    results_pil = model(image_np, points=[[350, 320], [0, 0]], labels=[1, 0])
    
    # Get the mask from the results
    mask_pil = results_pil[0].masks.data[0].cpu().numpy()
    
    # Masking
    binary_mask = mask_pil > 0.5
    rgb_mask = np.repeat(binary_mask[:, :, np.newaxis], 3, axis=2)
    segmented_image = image_np * rgb_mask
    
    return segmented_image

In [None]:
dataset['train']

In [None]:
splits = ['train', 'validation', 'test']

for split in splits:
    split_dir = os.path.join("output", split)
    os.makedirs(split_dir, exist_ok=True)
    
    ds = dataset[split]
    
    for idx, item in tqdm(enumerate(ds), total=len(ds), desc=f"Processing {split}"):
        # Perform segmentation
        result = segment(item['image'])
        
        # Convert the numpy array to PIL Image
        segmented_image = Image.fromarray(result.astype(np.uint8))
        
        # Save the image
        output_path = os.path.join(split_dir, f"image_{idx}.png")
        segmented_image.save(output_path)
        
print("Saved all segmented images!")

It is recommended to manually browse through the output and remove any poorly-segmented images.