In [1]:
import cv2
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from tqdm import tqdm
from datasets import load_dataset

from ultralytics import SAM
from skimage import io
import os
import csv

In [12]:
dataset = load_dataset("e1010101/tongue-images-384")
dataset

DatasetDict({
    train: Dataset({
        features: ['image', 'labels', 'pixel_values'],
        num_rows: 746
    })
    validation: Dataset({
        features: ['image', 'labels', 'pixel_values'],
        num_rows: 214
    })
    test: Dataset({
        features: ['image', 'labels', 'pixel_values'],
        num_rows: 106
    })
})

In [14]:
type(dataset['train'][0]['image'])

PIL.JpegImagePlugin.JpegImageFile

In [3]:
# Please check the documentation at https://docs.ultralytics.com/models/sam-2
# to get the latest models
model = SAM("sam2.1_b.pt")
model.info()

Model summary: 566 layers, 80,850,178 parameters, 80,850,178 gradients


(566, 80850178, 80850178, 0.0)

In [17]:
def segment(examples):
    image = examples['image']    
    image_np = np.array(image)
    
    # Segmentation
    results_pil = model(image_np, points=[[350, 320], [0, 0]], labels=[1, 0])
    
    # Get the mask from the results
    mask_pil = results_pil[0].masks.data[0].cpu().numpy()
    
    # Masking
    binary_mask = mask_pil > 0.5
    rgb_mask = np.repeat(binary_mask[:, :, np.newaxis], 3, axis=2)
    segmented_image = image_np * rgb_mask
    
    return { 'image': segmented_image }

In [None]:
dataset = dataset.map(segment)

Map:   0%|          | 0/746 [00:00<?, ? examples/s]


0: 1024x1024 1 0, 1 1, 1387.5ms
Speed: 34.9ms preprocess, 1387.5ms inference, 2.9ms postprocess per image at shape (1, 3, 1024, 1024)

0: 1024x1024 1 0, 1 1, 1209.0ms
Speed: 10.5ms preprocess, 1209.0ms inference, 0.9ms postprocess per image at shape (1, 3, 1024, 1024)

0: 1024x1024 1 0, 1 1, 1209.0ms
Speed: 10.6ms preprocess, 1209.0ms inference, 0.0ms postprocess per image at shape (1, 3, 1024, 1024)

0: 1024x1024 1 0, 1 1, 1204.0ms
Speed: 11.8ms preprocess, 1204.0ms inference, 2.5ms postprocess per image at shape (1, 3, 1024, 1024)

0: 1024x1024 1 0, 1 1, 1197.9ms
Speed: 9.3ms preprocess, 1197.9ms inference, 0.0ms postprocess per image at shape (1, 3, 1024, 1024)

0: 1024x1024 1 0, 1 1, 1201.0ms
Speed: 8.5ms preprocess, 1201.0ms inference, 0.0ms postprocess per image at shape (1, 3, 1024, 1024)

0: 1024x1024 1 0, 1 1, 1211.5ms
Speed: 8.5ms preprocess, 1211.5ms inference, 0.0ms postprocess per image at shape (1, 3, 1024, 1024)

0: 1024x1024 1 0, 1 1, 1215.2ms
Speed: 15.0ms preprocess,