# Shelf AI using Cortex and Meta SegmentAnything
This notebook covers how to use Meta's Segment Anything model to identify various objects on a store shelf and extract that portion for Cortex.

### What has been done:
* Create image segments via masked images
* Run Cortex Complete on each masked image to identify the product

### What is still incomplete (TODO):
* ~Cropping masked images (S)~
* Better prompting to throw out any non-product images (i.e. ceiling, shelves) (S)
* ~Map the masks back to the X-Y coordinates on the shelf (M)~
* Associate prices/shelf labels to the products (L)
* Handling multiple facings (L)
* ~Handling products that are broken across multiple images (L)~ Partially handled

### Questions left to answer:
* Could Cortex complete just do all this without having to segment, mask, and evaluate?
* How to associate prices/shelf tags with the corresponding product

## 1. Package Imports and Model download

1. You will need to set up external access integrations for the following sites `github.com` , `dl.fbaipublicfiles.com`, and PyPi for package installs. Code is in the `setup.sql` file that accompanies this notebook
    
2. You will also need to upload a shelf image to the notebook directory using the file browser on the left side - update any paths in this section below as well.

In [None]:
import os

# Key Variables that need to be modified for demoing
SHELF_IMAGE_PATH = 'images/IMG_0741.JPG'
MASKED_IMAGES_OUTPUT_PATH = os.getcwd() + '/outputs/'
IMAGE_STAGE_LOCATION = '@notebook_demo_db.shelf_image_ai.image_upload_2'

In [None]:
# Install Segment Anything
# -- SegmentAnything v1.0
!pip install 'git+https://github.com/facebookresearch/segment-anything.git'

!pip install opencv-python

import torch
import torchvision
import sys
import matplotlib
import matplotlib.pyplot as plt
import cv2
import sys
import numpy as np
from PIL import Image
from snowflake.cortex import complete
from snowflake.snowpark.context import get_active_session

session = get_active_session()

In [None]:
# Download the Meta SegmentAnything model - will save to local tmp directory in Container
!wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth

In [None]:
# FUNCTION DEFINITIONS

def show_annotations(anns):
    """ Adds the masks back into the shelf image as annotations. Code taken from
    the SegmentAnything v1.0 tutorial    
    """
    if len(anns) == 0:
        return
    sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
    ax = plt.gca()
    ax.set_autoscale_on(False)

    img = np.ones((sorted_anns[0]['segmentation'].shape[0], sorted_anns[0]['segmentation'].shape[1], 4))
    img[:,:,3] = 0
    for ann in sorted_anns:
        m = ann['segmentation']
        color_mask = np.concatenate([np.random.random(3), [0.7]])
        img[m] = color_mask
    ax.imshow(img)

def crop_image_to_relevant(img, mask_array):
    """ Auto crop the masked image without all the surrounding mask"""
    
    # Convert mask to binary if not already
    binary_mask = (mask_array > 0).astype(np.uint8) * 255
    
    # Find contours in the mask
    contours, _ = cv2.findContours(binary_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    
    if len(contours) == 0:
        return img  # Return original if no contours found
    
    # Get the largest contour (main product)
    largest_contour = max(contours, key=cv2.contourArea)
    
    # Get bounding box coordinates
    x, y, w, h = cv2.boundingRect(largest_contour)
    
    # Add small padding to ensure we don't cut off edges
    padding = 10
    x = max(0, x - padding)
    y = max(0, y - padding)
    w = min(img.shape[1] - x, w + 2 * padding)
    h = min(img.shape[0] - y, h + 2 * padding)
    
    # Crop the image
    cropped_image = img[y:y+h, x:x+w]
    
    return cropped_image, (x, y, w, h)

def merge_nearby_masks(masks, distance_threshold=50):
    """Merge masks that are close to each other to capture complete products"""
    
    merged_masks = []
    used_indices = set()
    
    for i, mask1 in enumerate(masks):
        if i in used_indices:
            continue
            
        # Get bounding box for current mask
        mask1_array = mask1['segmentation']
        contours1, _ = cv2.findContours((mask1_array > 0).astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
        
        if len(contours1) == 0:
            continue
            
        x1, y1, w1, h1 = cv2.boundingRect(max(contours1, key=cv2.contourArea))
        center1 = (x1 + w1//2, y1 + h1//2)
        
        # Find nearby masks to merge
        masks_to_merge = [mask1]
        indices_to_merge = [i]
        
        for j, mask2 in enumerate(masks):
            if j <= i or j in used_indices:
                continue
                
            mask2_array = mask2['segmentation']
            contours2, _ = cv2.findContours((mask2_array > 0).astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
            
            if len(contours2) == 0:
                continue
                
            x2, y2, w2, h2 = cv2.boundingRect(max(contours2, key=cv2.contourArea))
            center2 = (x2 + w2//2, y2 + h2//2)
            
            # Calculate distance between centers
            distance = np.sqrt((center1[0] - center2[0])**2 + (center1[1] - center2[1])**2)
            
            # Check if masks should be merged (close proximity and similar vertical position)
            if distance < distance_threshold and abs(center1[1] - center2[1]) < h1//2:
                masks_to_merge.append(mask2)
                indices_to_merge.append(j)
        
        # Mark indices as used
        used_indices.update(indices_to_merge)
        
        # Create merged mask
        if len(masks_to_merge) > 1:
            # Combine all masks
            combined_mask = np.zeros_like(mask1['segmentation'])
            total_area = 0
            for mask in masks_to_merge:
                combined_mask = np.logical_or(combined_mask, mask['segmentation'])
                total_area += mask['area']
            
            merged_mask = {
                'segmentation': combined_mask,
                'area': total_area,
                'bbox': cv2.boundingRect((combined_mask > 0).astype(np.uint8)),
                'predicted_iou': np.mean([m['predicted_iou'] for m in masks_to_merge]),
                'point_coords': masks_to_merge[0]['point_coords'],
                'stability_score': np.mean([m['stability_score'] for m in masks_to_merge]),
                'crop_box': masks_to_merge[0]['crop_box']
            }
            merged_masks.append(merged_mask)
        else:
            merged_masks.append(mask1)
    
    return merged_masks

def get_shelf_position(bbox, image_height, num_shelves=3):
    """Determine shelf position based on bounding box coordinates"""
    x, y, w, h = bbox
    center_y = y + h // 2
    
    # Define shelf regions based on image height
    shelf_height = image_height // num_shelves
    
    if center_y < shelf_height:
        shelf_level = "top"
        shelf_number = 1
    elif center_y < 2 * shelf_height:
        shelf_level = "middle" 
        shelf_number = 2
    else:
        shelf_level = "bottom"
        shelf_number = 3
    
    # Calculate relative position on shelf (left to right)
    relative_position = x + w // 2  # center x coordinate
    
    return {
        'shelf_level': shelf_level,
        'shelf_number': shelf_number,
        'x_position': x,
        'y_position': y,
        'center_x': x + w // 2,
        'center_y': center_y,
        'relative_position': relative_position
    }

## 2. Segment the Image and generate masks

In [None]:
image = cv2.imread(SHELF_IMAGE_PATH)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

In [None]:
# Generate sub-image masks automatically

################################################
# SEGMENT ANYTHING 1.0
################################################
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator
sam_checkpoint = "sam_vit_h_4b8939.pth"
model_type = "vit_h"
device = "cuda"

sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device=device)

# Improved mask generator settings for better product segmentation
mask_generator = SamAutomaticMaskGenerator(
    model=sam,
    points_per_side=32,  # More points for better coverage
    pred_iou_thresh=0.9,  # Higher threshold for better quality masks
    stability_score_thresh=0.98,  # Higher stability score
    crop_n_layers=1,
    crop_n_points_downscale_factor=2,
    min_mask_region_area=1000,  # Filter out very small masks
)

# Generate initial masks
initial_masks = mask_generator.generate(image)
print(f"Generated {len(initial_masks)} initial masks")


In [None]:
# Filter small masks that are likely image artifacts and merge masks that are close by together
# TODO - Improve merging process
# TODO - Improve small mask filtering to keep price tags

# Filter masks by area to remove very small or very large masks (likely noise or background)
filtered_masks = [mask for mask in initial_masks if 1000 < mask['area'] < image.shape[0] * image.shape[1] * 0.3]
print(f"After filtering: {len(filtered_masks)} masks")

# Merge nearby masks to capture complete products
masks = merge_nearby_masks(filtered_masks, distance_threshold=80)
print(f"After merging: {len(masks)} final masks")

## Review the different collections of Masks

In [None]:
# Review the Original Image
plt.figure(figsize=(20,20))
plt.imshow(image)
plt.axis('off')
plt.show()

In [None]:
# Review image with all masks
plt.figure(figsize=(20,20))
plt.imshow(image)
show_annotations(initial_masks)
plt.axis('off')
plt.show() 

In [None]:
# Review image with filtered & optimized masks
plt.figure(figsize=(20,20))
plt.imshow(image)
show_annotations(masks)
plt.axis('off')
plt.show() 

In [None]:
# Review one of the masked images with proper cropping
if len(masks) > 1:
    mymask = masks[37]['segmentation']
    y=np.expand_dims(mymask,axis=2)
    newmask=np.concatenate((y,y,y),axis=2)
    masked_image = image * newmask
    
    # Use the improved cropping function
    cropped_image, bbox = crop_image_to_relevant(masked_image, mymask)
    
    plt.figure(figsize=(15,10))
    plt.subplot(1,2,1)
    plt.imshow(masked_image)
    plt.title('Original Masked Image')
    plt.axis('off')
    
    plt.subplot(1,2,2)
    plt.imshow(cropped_image)
    plt.title('Cropped Image')
    plt.axis('off')
    
    plt.tight_layout()
    plt.show()
    
    print(f"Original dimensions: {masked_image.shape}")
    print(f"Cropped dimensions: {cropped_image.shape}")
    print(f"Bounding box: {bbox}")
else:
    print("No masks available for preview") 

## 3. Use Cortex multi-modal to extract information about the product
Unfortunately Cortex Complete Multi-model is only supported in SQL for now, so for each masked image, we will first save them all to the local notebook folder and then we will call a PUT command to load them into a Snowflake stage.

In [None]:
# Save all masks as cropped files with position information
import json

counter = 0
file_path = MASKED_IMAGES_OUTPUT_PATH
mask_metadata = []  # Store metadata for each mask

# Ensure output directory exists
os.makedirs(file_path, exist_ok=True)

for mask in masks:
    mask_array = mask['segmentation']
    y=np.expand_dims(mask_array,axis=2)
    newmask=np.concatenate((y,y,y),axis=2)
    
    # Apply mask to original image
    masked_image = image * newmask
    
    # Crop the image to remove black borders
    try:
        cropped_image, bbox = crop_image_to_relevant(masked_image, mask_array)
        
        # Get shelf position information
        position_info = get_shelf_position(bbox, image.shape[0])
        
        # Save the cropped image
        filename = f'product_{counter:03d}.jpg'
        cv2.imwrite(file_path + filename, cv2.cvtColor(cropped_image, cv2.COLOR_RGB2BGR))
        
        # Store metadata
        metadata = {
            'filename': filename,
            'mask_id': counter,
            'bbox': bbox,
            'area': mask['area'],
            'stability_score': mask['stability_score'],
            'predicted_iou': mask['predicted_iou'],
            **position_info
        }
        mask_metadata.append(metadata)
        
        counter += 1
        
    except Exception as e:
        print(f"Error processing mask {counter}: {e}")
        continue

# Save metadata to JSON file
with open(file_path + 'mask_metadata.json', 'w') as f:
    json.dump(mask_metadata, f, indent=2)

print(f"Saved {counter} cropped product images")
print(f"Metadata saved to {file_path}mask_metadata.json")

In [None]:
-- Create table to store product analysis results
USE DATABASE NOTEBOOK_DEMO_DB;
USE SCHEMA SHELF_IMAGE_AI;
CREATE OR REPLACE TABLE product_analysis_results (
    id INTEGER AUTOINCREMENT,
    filename VARCHAR(255),
    mask_id INTEGER,
    shelf_level VARCHAR(50),
    shelf_number INTEGER,
    x_position INTEGER,
    y_position INTEGER,
    center_x INTEGER,
    center_y INTEGER,
    relative_position INTEGER,
    bbox_x INTEGER,
    bbox_y INTEGER,
    bbox_width INTEGER,
    bbox_height INTEGER,
    mask_area INTEGER,
    stability_score FLOAT,
    predicted_iou FLOAT,
    brand_name VARCHAR(255),
    subbrand_name VARCHAR(255),
    product_category VARCHAR(255),
    size VARCHAR(255),
    description_of_image TEXT,
    cortex_response VARIANT,
    processed_timestamp TIMESTAMP DEFAULT CURRENT_TIMESTAMP(),
    PRIMARY KEY (id)
);


In [None]:
# Move all images to the stage
images = os.listdir(MASKED_IMAGES_OUTPUT_PATH)

for img in images:
    put_result = session.file.put(
        MASKED_IMAGES_OUTPUT_PATH + img,
        IMAGE_STAGE_LOCATION,
        auto_compress=False,  # Optional: Compress the file during upload (default is True)
        overwrite=True       # Optional: Overwrite if a file with the same name exists (default is False)
    )

In [None]:
# Batch process all images with Cortex Complete and store results in Snowflake

# Load the metadata we saved earlier
with open(file_path + 'mask_metadata.json', 'r') as f:
    metadata_list = json.load(f)

# Get list of all image files in the stage
stage_files_query = f"LIST {IMAGE_STAGE_LOCATION}"
stage_files_result = session.sql(stage_files_query).collect()

# Extract just the filenames from the stage listing
stage_files = [row[0].split('/')[-1] for row in stage_files_result if row[0].endswith('.jpg') and 'product_' in row[0]]

print(f"Found {len(stage_files)} product images in stage")
print(f"Metadata available for {len(metadata_list)} products")

# Process each image with Cortex Complete
successful_inserts = 0
failed_inserts = 0

for metadata in metadata_list:
    filename = metadata['filename']
    
    # Check if file exists in stage
    if filename not in stage_files:
        print(f"Warning: {filename} not found in stage, skipping...")
        failed_inserts += 1
        continue
    
    try:
        # Cortex Complete query for this specific image
        cortex_query = f"""
        SELECT snowflake.cortex.complete('pixtral-large', 
            'Analyze this product image from a retail shelf. The image has been cropped to focus on a single product.
            Provide detailed information about the product you see.
            If the product is not clearly visible or appears to be incomplete/damaged, respond with N/A for unclear fields.
            Respond ONLY with valid JSON containing these exact fields and DO NOT surround the json with markdown:
            {{
                "brand_name": "brand name of the product",
                "subbrand_name": "sub-brand or product line name if visible", 
                "product_category": "category like beverages, snacks, dairy, etc",
                "size": "package size or volume if visible",
                "description_of_image": "brief description of what you see in the image"
            }}',
            TO_FILE('{IMAGE_STAGE_LOCATION}', '{filename}')
        ) as cortex_response
        """
        
        # Execute Cortex Complete query
        cortex_result = session.sql(cortex_query).collect()
        cortex_response = cortex_result[0]['CORTEX_RESPONSE']
        
        # Try to parse the JSON response
        try:
            parsed_response = json.loads(cortex_response)
            brand_name = parsed_response.get('brand_name', 'N/A')
            subbrand_name = parsed_response.get('subbrand_name', 'N/A') 
            product_category = parsed_response.get('product_category', 'N/A')
            size = parsed_response.get('size', 'N/A')
            description = parsed_response.get('description_of_image', 'N/A')
        except json.JSONDecodeError:
            print(f"Warning: Could not parse JSON response for {filename}")
            brand_name = subbrand_name = product_category = size = description = 'N/A'
        
        # Insert into the results table
        insert_query = f"""
        INSERT INTO product_analysis_results (
            filename, mask_id, shelf_level, shelf_number, x_position, y_position,
            center_x, center_y, relative_position, bbox_x, bbox_y, bbox_width, bbox_height,
            mask_area, stability_score, predicted_iou, brand_name, subbrand_name,
            product_category, size, description_of_image, cortex_response
        ) VALUES (
            '{filename}', {metadata['mask_id']}, '{metadata['shelf_level']}', 
            {metadata['shelf_number']}, {metadata['x_position']}, {metadata['y_position']},
            {metadata['center_x']}, {metadata['center_y']}, {metadata['relative_position']},
            {metadata['bbox'][0]}, {metadata['bbox'][1]}, {metadata['bbox'][2]}, {metadata['bbox'][3]},
            {metadata['area']}, {metadata['stability_score']}, {metadata['predicted_iou']},
            '{brand_name}', '{subbrand_name}', '{product_category}', '{size}', 
            '{description}', PARSE_JSON('{cortex_response.replace("'", "''")}')
        )
        """
        
        session.sql(insert_query).collect()
        successful_inserts += 1
        print(f"✓ Processed {filename} - {brand_name} ({product_category})")
        
    except Exception as e:
        print(f"✗ Error processing {filename}: {str(e)}")
        failed_inserts += 1
        continue

print(f"\nBatch processing complete!")
print(f"Successfully processed: {successful_inserts}")
print(f"Failed: {failed_inserts}")
print(f"Total: {len(metadata_list)}")

# Display summary of results
summary_query = """
SELECT 
    shelf_level,
    COUNT(*) as product_count,
    COUNT(DISTINCT brand_name) as unique_brands,
    COUNT(DISTINCT product_category) as unique_categories
FROM product_analysis_results 
WHERE brand_name != 'N/A'
GROUP BY shelf_level
ORDER BY shelf_number
"""

print("\n--- Shelf Analysis Summary ---")
summary_results = session.sql(summary_query).collect()
for row in summary_results:
    print(f"{row['SHELF_LEVEL'].title()} Shelf: {row['PRODUCT_COUNT']} products, {row['UNIQUE_BRANDS']} brands, {row['UNIQUE_CATEGORIES']} categories")

In [None]:
with open(file_path + 'mask_metadata.json', 'r') as f:
    metadata_list = json.load(f)

metadata_list

In [None]:
# Additional analysis queries and visualizations

# Query 1: Products by shelf position (left to right)
position_query = """
SELECT 
    shelf_level,
    filename,
    brand_name,
    product_category,
    relative_position,
    center_x,
    center_y
FROM product_analysis_results 
WHERE brand_name != 'N/A'
ORDER BY shelf_number, relative_position
"""

print("--- Products by Shelf Position (Left to Right) ---")
position_results = session.sql(position_query).collect()
current_shelf = None
for row in position_results:
    if current_shelf != row['SHELF_LEVEL']:
        current_shelf = row['SHELF_LEVEL']
        print(f"\n{current_shelf.title()} Shelf:")
    print(f"  {row['BRAND_NAME']} - {row['PRODUCT_CATEGORY']} (x: {row['CENTER_X']})")

# Query 2: Brand distribution across shelves
brand_distribution_query = """
SELECT 
    brand_name,
    shelf_level,
    COUNT(*) as product_count,
    AVG(relative_position) as avg_position
FROM product_analysis_results 
WHERE brand_name != 'N/A'
GROUP BY brand_name, shelf_level
ORDER BY brand_name, shelf_number
"""

print("\n--- Brand Distribution Across Shelves ---")
brand_results = session.sql(brand_distribution_query).collect()
for row in brand_results:
    print(f"{row['BRAND_NAME']} on {row['SHELF_LEVEL']} shelf: {row['PRODUCT_COUNT']} products (avg position: {row['AVG_POSITION']:.0f})")

# Query 3: Create a simple shelf map visualization
import matplotlib.pyplot as plt
import numpy as np

# Get all products with positions
map_query = """
SELECT 
    center_x, center_y, brand_name, product_category, shelf_level
FROM product_analysis_results 
WHERE brand_name != 'N/A'
"""

map_results = session.sql(map_query).collect()

if map_results:
    # Create shelf map visualization
    fig, ax = plt.subplots(figsize=(15, 8))
    
    # Color map for different categories
    categories = list(set([row['PRODUCT_CATEGORY'] for row in map_results]))
    colors = plt.cm.Set3(np.linspace(0, 1, len(categories)))
    category_colors = dict(zip(categories, colors))
    
    for row in map_results:
        x = row['CENTER_X']
        y = row['CENTER_Y'] 
        category = row['PRODUCT_CATEGORY']
        brand = row['BRAND_NAME']
        
        ax.scatter(x, y, c=[category_colors[category]], s=100, alpha=0.7)
        ax.annotate(f"{brand}\n{category}", (x, y), xytext=(5, 5), 
                   textcoords='offset points', fontsize=8, ha='left')
    
    ax.set_xlabel('Horizontal Position (pixels)')
    ax.set_ylabel('Vertical Position (pixels)')
    ax.set_title('Shelf Product Map')
    ax.invert_yaxis()  # Invert Y axis so top of image is at top
    
    # Add legend
    legend_elements = [plt.scatter([], [], c=[category_colors[cat]], s=100, label=cat) 
                      for cat in categories]
    ax.legend(handles=legend_elements, bbox_to_anchor=(1.05, 1), loc='upper left')
    
    plt.tight_layout()
    plt.show()

# Query 4: Export final results for external analysis
export_query = """
SELECT * FROM product_analysis_results 
WHERE brand_name != 'N/A'
ORDER BY shelf_number, relative_position
"""

print(f"\n--- Final Results Summary ---")
final_results = session.sql(export_query).collect()
print(f"Total products successfully analyzed: {len(final_results)}")

# Show sample results
print("\nSample results:")
for i, row in enumerate(final_results[:5]):
    print(f"{i+1}. {row['BRAND_NAME']} - {row['PRODUCT_CATEGORY']} on {row['SHELF_LEVEL']} shelf")

print(f"\nAll results are stored in the 'product_analysis_results' table in Snowflake.")
