# **Automatic Mask Generation Using Unsupervised Approach with Florence-2, SAM2, and Gemma3**

In this notebook, we build an end-to-end unsupervised pipeline for object detection, segmentation, classification, and tracking—focusing on identifying and following milk pouches without manual labels. This approach leverages cutting-edge vision and language models and concludes with lightweight object tracking based on extracted features from segmentation masks.

Key Components:



1.   **Florence-2 Multimodal Model**<br>
A powerful vision-language model that performs generic object detection by returning bounding boxes around visually significant regions—completely label-free and prompt-driven.
2.   **SAM2 (Segment Anything Model v2)**<br>
Using the bounding boxes from Florence-2, SAM2 generates precise segmentation masks, enabling instance-level understanding and clean extraction of objects.
3.  **Gemma3 12B QAT Model**<br>
Each cropped masked region is passed to an open source Gemma3 quantization-aware large language model to determine whether it contains a milk pouch or not, enabling robust classification without explicit supervised training.
4.  **Object Tracking via Mask Features**<br>
For the final step, we extract distinguishing features from the segmented masks of positively identified milk pouches and use them to track the same objects across frames.



While this colab focuses on the specific requirement of distinguishing milk sachets from other types (such as oil), the general approach could easily be adapted for other objects or use cases.

## Install and upgrade the necessary packages.

In [None]:
!sudo apt-get update
!sudo apt-get install -y pciutils lshw
!pip install ollama

# Install the SAM2 (Segment Anything Model v2) library directly from the official Facebook Research GitHub repository
!pip install 'git+https://github.com/facebookresearch/sam2.git'

In [None]:
# download the sample image from the circularnet project
url = (
    "https://raw.githubusercontent.com/tensorflow/models/master/official/"
    "projects/waste_identification_ml/pre_processing/config/sample_images/"
    "IMG_6509.png"
)

!curl -O {url} > /dev/null 2>&1

## Import the libraries and configure resources.

In [None]:
import torch, torchvision
from transformers import AutoProcessor, AutoModelForCausalLM
import sys
from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor
from PIL import Image
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import numpy as np
import tqdm
import math
import cv2
import tempfile
from google.colab.patches import cv2_imshow
from ollama import chat
from ollama import ChatResponse

print("PyTorch version:", torch.__version__)
print("Torchvision version:", torchvision.__version__)
print("CUDA is available:", torch.cuda.is_available())

# select the device for computation
if torch.cuda.is_available():
    device = torch.device("cuda")
elif torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")
print(f"using device: {device}")

if device.type == "cuda":
    # use bfloat16 for the entire notebook
    torch.autocast("cuda", dtype=torch.bfloat16).__enter__()
    # turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
    if torch.cuda.get_device_properties(0).major >= 8:
        torch.backends.cuda.matmul.allow_tf32 = True
        torch.backends.cudnn.allow_tf32 = True
elif device.type == "mps":
    print(
        "\nSupport for MPS devices is preliminary. SAM 2 is trained with CUDA and might "
        "give numerically different outputs and sometimes degraded performance on MPS. "
        "See e.g. https://github.com/pytorch/pytorch/issues/84936 for a discussion."
    )

In [None]:
#@title Utils

def plot_bbox(image, data):
   # Create a figure and axes
    fig, ax = plt.subplots()

    # Display the image
    ax.imshow(image)

    # Plot each bounding box
    for bbox, label in zip(data['bboxes'], data['labels']):
        # Unpack the bounding box coordinates
        x1, y1, x2, y2 = bbox
        # Create a Rectangle patch
        rect = patches.Rectangle((x1, y1), x2-x1, y2-y1, linewidth=1, edgecolor='r', facecolor='none')
        # Add the rectangle to the Axes
        ax.add_patch(rect)
        # Annotate the label
        plt.text(x1, y1, label, color='white', fontsize=8, bbox=dict(facecolor='red', alpha=0.5))

    # Remove the axis ticks and labels
    ax.axis('off')

    # Show the plot
    plt.show()


def run_example(task_prompt, text_input=None):
    if text_input is None:
        prompt = task_prompt
    else:
        prompt = task_prompt + text_input
    inputs = processor(text=prompt, images=image, return_tensors="pt").to('cuda', torch.float16)
    generated_ids = model.generate(
      input_ids=inputs["input_ids"].cuda(),
      pixel_values=inputs["pixel_values"].cuda(),
      max_new_tokens=1024,
      early_stopping=False,
      do_sample=False,
      num_beams=3,
    )
    generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
    parsed_answer = processor.post_process_generation(
        generated_text,
        task=task_prompt,
        image_size=(image.width, image.height)
    )

    return parsed_answer


def show_mask(
        mask,
        ax,
        random_color=False,
        borders = True
):
    if random_color:
        color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
    else:
        color = np.array([30/255, 144/255, 255/255, 0.6])
    h, w = mask.shape[-2:]
    mask = mask.astype(np.uint8)
    mask_image =  mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    if borders:
        import cv2
        contours, _ = cv2.findContours(mask,cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
        # Try to smooth contours
        contours = [cv2.approxPolyDP(contour, epsilon=0.01, closed=True) for contour in contours]
        mask_image = cv2.drawContours(mask_image, contours, -1, (1, 1, 1, 0.5), thickness=2)
    ax.imshow(mask_image)


def show_points(
        coords,
        labels,
        ax,
        marker_size=375
):
    pos_points = coords[labels==1]
    neg_points = coords[labels==0]
    ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
    ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)


def show_box(box, ax):
    x0, y0 = box[0], box[1]
    w, h = box[2] - box[0], box[3] - box[1]
    ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0, 0, 0, 0), lw=2))


def show_masks(
        image,
        masks,
        scores,
        point_coords=None,
        box_coords=None,
        input_labels=None,
        borders=True
):
    for i, (mask, score) in enumerate(zip(masks, scores)):
        plt.figure(figsize=(10, 10))
        plt.imshow(image)
        show_mask(mask, plt.gca(), borders=borders)
        if point_coords is not None:
            assert input_labels is not None
            show_points(point_coords, input_labels, plt.gca())
        if box_coords is not None:
            # boxes
            show_box(box_coords, plt.gca())
        if len(scores) > 1:
            plt.title(f"Mask {i+1}, Score: {score:.3f}", fontsize=18)
        plt.axis('off')
        plt.show()

## Download Florence-2 model.

In [None]:
model_id = 'microsoft/Florence-2-large'
model = AutoModelForCausalLM.from_pretrained(model_id, trust_remote_code=True, torch_dtype='auto').eval().cuda()
processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)

## Download SAM-2 model.

In [None]:
# Create the 'checkpoints' directory one level up if it doesn't already exist
!mkdir -p checkpoints/

# Download the pre-trained SAM2.1 Hiera Large model checkpoint into the 'checkpoints' directory
!wget -P checkpoints/ https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_large.pt

# Path to the pre-trained SAM2 model checkpoint
sam2_checkpoint = "checkpoints/sam2.1_hiera_large.pt"

# Path to the configuration file for the SAM2 model variant being used
model_cfg = "configs/sam2.1/sam2.1_hiera_l.yaml"

# Build the SAM2 model using the config and checkpoint; `device` should be set to "cuda" or "cpu"
sam2_model = build_sam2(model_cfg, sam2_checkpoint, device=device)

# Create a predictor object using the loaded SAM2 model for image-based mask prediction
sam2_predictor = SAM2ImagePredictor(sam2_model)

Run the following commands in the "xterm" terminal within your colab notebook in the next cell.



```
curl https://ollama.ai/install.sh | sh
ollama serve &
```




In [None]:
!pip install colab-xterm
%load_ext colabxterm

In [None]:
%xterm

In [None]:
# Pull the required open sourced LLM model.
!ollama pull gemma3:12b-it-qat

In [None]:
# Check if the model is downloaded.
!ollama list

In [None]:
# Prompt to analyze an image for milk packet vs others.
prompt = """
Analyze the provided image of a packaging. Was this packaging used to contain milk or a milk-based product?  Answer in yes or no only.
"""

## Read an image perform inference with all models to detect the use case.

In [None]:
path = 'IMG_6509.png'
image = Image.open(path)

In [None]:
# Perform object detection using Florence-2 OD task to detect all bboxes.
task_prompt = '<CAPTION_TO_PHRASE_GROUNDING>'
results = run_example(task_prompt, text_input="packets.")
plot_bbox(image, results['<CAPTION_TO_PHRASE_GROUNDING>'])

In [None]:
# Perform segmentation on bbox cordinates using SAM2 model.
sam2_predictor.set_image(image)

# Use bounding boxes to extract mask for each object, then use the mask to
# send object level images to LLM for classification.
for idx, bbox in tqdm.tqdm(enumerate(results['<CAPTION_TO_PHRASE_GROUNDING>']['bboxes'])):
  x1, y1, x2, y2 = list(map(round, bbox))
  if (x2-x1)*(y2-y1) < 0.25 * math.prod(image.size):
    input_box = np.array([x1, y1, x2, y2])

    masks, scores, _ = sam2_predictor.predict(
        point_coords=None,
        point_labels=None,
        box=input_box[None, :],
        multimask_output=False,
    )
    show_masks(image, masks, scores, box_coords=input_box)

    # Convert the first mask to 0-255 and expand its dimensions to match the image channels.
    # Multiply the mask with the original image (preserves object, sets background to 0).
    # Crop the masked image to the bounding box [y1:y2, x1:x2].
    masked_object = Image.fromarray(
        np.where(
            np.expand_dims(masks[0]*255, -1),
            np.array(image), 0
        )[y1:y2, x1:x2]
    )

    # Save the masked object as a temporary PNG file.
    with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as temp_image:
      masked_object.save(temp_image.name)
      image_path = temp_image.name

    # Run the chat/inference API, sending the temporary masked object image as input
    response: ChatResponse = chat(model='gemma3:12b-it-qat', messages=[
      {
        'role': 'user',
        'content': prompt,
        'images': [image_path]
      },
    ])
    plt.imshow(masked_object)
    plt.axis('off')
    plt.show()

    # Print the model's response content (the generated answer)
    print(f"\n{response.message.content}")