In [None]:
# Copyright (c) Meta Platforms, Inc. and affiliates.


# SAM 3 Agent (Deployment)


This notebook shows an example of how an MLLM can use SAM 3 as a tool, i.e., "SAM 3 Agent", to segment more complex text queries such as "the leftmost child wearing blue vest".

**This version uses:**
- **Local vLLM** for LLM calls (same as original notebook)
- **Deployed SAM3 service** (Modal endpoint) for SAM3 inference instead of loading the model locally
- **Local agent logic** that orchestrates between local vLLM and remote SAM3


## Env Setup


First install `sam3` in your environment using the [installation instructions](https://github.com/facebookresearch/sam3?tab=readme-ov-file#installation) in the repository.

**Note**: Since SAM3 is deployed remotely, you don't need to load the SAM3 model locally. However, you still need the SAM3 package installed for any helper functions.


In [None]:
import torch
# turn on tfloat32 for Ampere GPUs
# https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True

# use bfloat16 for the entire notebook. If your card doesn't support it, try float16 instead
torch.autocast("cuda", dtype=torch.bfloat16).__enter__()

# inference mode for the whole notebook. Disable if you need gradients
torch.inference_mode().__enter__()


In [None]:
import os
import base64
import requests
from IPython.display import display, Image
from pathlib import Path

SAM3_ROOT = os.path.dirname(os.getcwd())
os.chdir(SAM3_ROOT)

# setup GPU to use - needed for vLLM if running locally
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
_ = os.system("nvidia-smi")


## SAM3 Deployment Configuration

Configure the SAM3 deployment endpoint URL. This should point to your deployed SAM3 service (e.g., Modal deployment).

**Note**: The SAM3 model is loaded and running on the deployment server, not locally. We'll use the `/sam3/infer` endpoint for pure SAM3 inference.


In [None]:
# SAM3 Deployment endpoint URL for pure inference (no LLM/agent)
# Replace with your Modal deployment URL (e.g., https://your-username--sam3-agent-sam3-infer.modal.run)
SAM3_DEPLOYMENT_URL = "https://srinjoy59--sam3-agent-sam3-infer.modal.run"

# Verify the endpoint is accessible (optional check)
print(f"SAM3 Deployment URL: {SAM3_DEPLOYMENT_URL}")


## LLM Setup

Config which MLLM to use, it can either be a model served by vLLM that you launch from your own machine or a model is served via external API. If you want to using a vLLM model, we also provided insturctions below.


In [None]:
LLM_CONFIGS = {
    # vLLM-served models
    "qwen3_vl_8b_thinking": {
        "provider": "vllm",
        "model": "Qwen/Qwen3-VL-8B-Thinking",
    }, 
    # models served via external APIs
    # add your own
}

model = "qwen3_vl_8b_thinking"
LLM_API_KEY = "DUMMY_API_KEY"

llm_config = LLM_CONFIGS[model]
llm_config["api_key"] = LLM_API_KEY
llm_config["name"] = model

# setup API endpoint
if llm_config["provider"] == "vllm":
    LLM_SERVER_URL = "http://0.0.0.0:8001/v1"  # replace this with your vLLM server address as needed
else:
    LLM_SERVER_URL = llm_config["base_url"]


### Setup vLLM server 
This step is only required if you are using a model served by vLLM, skip this step if you are calling LLM using an API like Gemini and GPT.

* Install vLLM (in a separate conda env from SAM 3 to avoid dependency conflicts).
  ```bash
    conda create -n vllm python=3.12
    pip install vllm --extra-index-url https://download.pytorch.org/whl/cu128
  ```
* Start vLLM server on the same machine of this notebook
  ```bash
    # qwen 3 VL 8B thinking
    vllm serve Qwen/Qwen3-VL-8B-Thinking --tensor-parallel-size 4 --allowed-local-media-path / --enforce-eager --port 8001
  ```

**Note**: Since we're running the agent locally, the vLLM server only needs to be accessible from this notebook (localhost), not from Modal.


In [None]:
# Start vLLM server in the background
# Make sure MODEL_ID matches the model in LLM_CONFIGS above
MODEL_ID = llm_config["model"]  # e.g., "Qwen/Qwen3-VL-8B-Thinking"

# Check if vLLM server is already running
import subprocess
check_result = subprocess.run(
    ["pgrep", "-f", "vllm serve"],
    capture_output=True,
    text=True,
    check=False
)

if check_result.returncode == 0 and check_result.stdout.strip():
    print("‚ö† vLLM server appears to be already running.")
    print("   If you want to restart, stop it first or restart the kernel.")
else:
    # The command to run, broken into a list
    command = [
        "nohup",
        "vllm", "serve",
        MODEL_ID,
        "--trust-remote-code",
        "--dtype", "bfloat16", 
        # "--max-model-len", "65536",  # Uncomment if needed
        "--gpu-memory-utilization", "0.9",
        "--port", "8001",  # Match the port in LLM_SERVER_URL
        "--host", "0.0.0.0",  # Allow external connections (needed for Modal deployment)
    ]
    
    # Open a file to redirect stdout and stderr
    vllm_log = open('vllm.log', 'w')
    
    process = subprocess.Popen(
        command,
        stdout=vllm_log,
        stderr=subprocess.STDOUT,
        preexec_fn=os.setpgrp
    )
    
    print(f"‚úì vLLM server started in the background with PID: {process.pid}")
    print("  Logs are being written to vllm.log")
    print(f"  Model: {MODEL_ID}")
    print(f"  Server will be available at: {LLM_SERVER_URL}")


In [None]:
# Wait for vLLM server to become ready
import time
from pathlib import Path

LOG_PATH = Path("vllm.log")
CHECK_CMD = ["pgrep", "-f", "vllm serve"]
TIMEOUT = 1200  # 20 minutes
LOG_TAIL_LINES = 50

def get_vllm_pids():
    """Return list of PIDs for 'vllm serve' or [] if none."""
    result = subprocess.run(
        CHECK_CMD,
        capture_output=True,
        text=True,
        check=False,       
    )
    if result.returncode != 0:
        return []
    out = result.stdout.strip()
    if not out:
        return []
    return [int(p) for p in out.split() if p.strip().isdigit()]

def log_tail(path: Path, n: int = 50) -> str:
    """Return last n lines of a log file, or '' if missing."""
    if not path.exists():
        return ""
    with path.open("r") as f:
        lines = f.readlines()
    return "".join(lines[-n:])

# Wait a bit before we start checking
time.sleep(5)

start = time.time()
print("Waiting for vLLM server to become ready...")

while True:
    diff = time.time() - start
    print(f"Time since start: {diff:.1f} seconds")

    if diff > TIMEOUT:
        print("‚ö† 20 minutes passed, vLLM server not ready. Exiting monitor loop.")
        break

    pids = get_vllm_pids()

    if not pids:
        # No process at all ‚Üí either never started or crashed.
        print("‚ùå vLLM server process not found (stopped / failed to start).")
        tail = log_tail(LOG_PATH, LOG_TAIL_LINES)
        if tail:
            print("\nLast log lines:\n" + "-" * 60)
            print(tail)
            print("-" * 60)
        break

    # Process exists; now inspect logs for startup completion
    if LOG_PATH.exists():
        content_tail = log_tail(LOG_PATH, LOG_TAIL_LINES)

        if "Application startup complete." in content_tail:
            print("‚úÖ vLLM SERVER STARTED and application startup complete.")
            print(f"   Server is ready at: {LLM_SERVER_URL}")
            break

        # Helpful debug: surface error if engine failed, but process is still alive for a bit
        if "Engine core initialisation failed." in content_tail:
            print("‚ùå Detected 'Engine core initialisation failed' in logs.")
            print("\nLast log lines:\n" + "-" * 60)
            print(content_tail)
            print("-" * 60)
            break

    time.sleep(10)


**Note**: If the vLLM server is already running from a previous cell execution, you can skip the startup cells above. Check the logs with `!tail -n 50 vllm.log` if needed.

## Run SAM3 Agent Inference

The agent logic runs locally, using:
- **Local vLLM** for LLM calls
- **Remote SAM3 endpoint** for SAM3 inference


In [None]:
from functools import partial
from IPython.display import display, Image
from sam3.agent.client_llm import send_generate_request as send_generate_request_orig
from sam3.agent.inference import run_single_image_inference

# Create remote SAM3 client that calls Modal endpoint
def call_sam_service_remote(
    sam3_processor,  # Not used, but kept for interface compatibility
    image_path: str,
    text_prompt: str,
    output_folder_path: str = "sam3_output",
    deployment_url: str = SAM3_DEPLOYMENT_URL,
):
    """
    Remote version of call_sam_service that calls Modal SAM3 endpoint.
    Matches the interface of the original call_sam_service function.
    """
    import json
    from sam3.agent.client_sam3 import remove_overlapping_masks, visualize
    
    print(f"üìû Loading image '{image_path}' and sending with prompt '{text_prompt}' to remote SAM3 endpoint...")
    
    text_prompt_for_save_path = (
        text_prompt.replace("/", "_") if "/" in text_prompt else text_prompt
    )
    
    os.makedirs(
        os.path.join(output_folder_path, image_path.replace("/", "-")), exist_ok=True
    )
    output_json_path = os.path.join(
        output_folder_path,
        image_path.replace("/", "-"),
        rf"{text_prompt_for_save_path}.json",
    )
    output_image_path = os.path.join(
        output_folder_path,
        image_path.replace("/", "-"),
        rf"{text_prompt_for_save_path}.png",
    )
    
    try:
        # Encode image to base64
        with open(image_path, "rb") as f:
            image_b64 = base64.b64encode(f.read()).decode("utf-8")
        
        # Call remote SAM3 endpoint
        request_body = {
            "text_prompt": text_prompt,
            "image_b64": image_b64,
        }
        
        response = requests.post(deployment_url, json=request_body, timeout=600)
        response.raise_for_status()
        result = response.json()
        
        if result.get("status") != "success":
            raise Exception(f"SAM3 endpoint error: {result.get('message', 'Unknown error')}")
        
        # Format response to match sam3_inference output
        serialized_response = {
            "orig_img_h": result["orig_img_h"],
            "orig_img_w": result["orig_img_w"],
            "pred_boxes": result["pred_boxes"],
            "pred_masks": result["pred_masks"],
            "pred_scores": result["pred_scores"],
        }
        
        # Apply same post-processing as original call_sam_service
        serialized_response = remove_overlapping_masks(serialized_response)
        serialized_response = {
            "original_image_path": image_path,
            "output_image_path": output_image_path,
            **serialized_response,
        }
        
        # Reorder predictions by scores (highest to lowest)
        if "pred_scores" in serialized_response and serialized_response["pred_scores"]:
            score_indices = sorted(
                range(len(serialized_response["pred_scores"])),
                key=lambda i: serialized_response["pred_scores"][i],
                reverse=True,
            )
            serialized_response["pred_scores"] = [
                serialized_response["pred_scores"][i] for i in score_indices
            ]
            serialized_response["pred_boxes"] = [
                serialized_response["pred_boxes"][i] for i in score_indices
            ]
            serialized_response["pred_masks"] = [
                serialized_response["pred_masks"][i] for i in score_indices
            ]
        
        # Remove invalid RLE masks
        valid_masks = []
        valid_boxes = []
        valid_scores = []
        for i, rle in enumerate(serialized_response["pred_masks"]):
            if len(rle) > 4:
                valid_masks.append(rle)
                valid_boxes.append(serialized_response["pred_boxes"][i])
                valid_scores.append(serialized_response["pred_scores"][i])
        serialized_response["pred_masks"] = valid_masks
        serialized_response["pred_boxes"] = valid_boxes
        serialized_response["pred_scores"] = valid_scores
        
        # Save JSON
        with open(output_json_path, "w") as f:
            json.dump(serialized_response, f, indent=4)
        print(f"‚úÖ Raw JSON response saved to '{output_json_path}'")
        
        # Render and save visualization
        print("üîç Rendering visualizations on the image ...")
        viz_image = visualize(serialized_response)
        os.makedirs(os.path.dirname(output_image_path), exist_ok=True)
        viz_image.save(output_image_path)
        print("‚úÖ Saved visualization at:", output_image_path)
        
    except Exception as e:
        print(f"‚ùå Error calling remote SAM3 service: {e}")
        raise
    
    return output_json_path


In [None]:
# Prepare input args and run single image inference
image = "assets/images/test_image.jpg"
prompt = "the leftmost child wearing blue vest"
image = os.path.abspath(image)

# Setup functions for agent inference
send_generate_request = partial(
    send_generate_request_orig, 
    server_url=LLM_SERVER_URL, 
    model=llm_config["model"], 
    api_key=llm_config["api_key"]
)

# Use remote SAM3 service instead of local processor
call_sam_service = partial(
    call_sam_service_remote,
    sam3_processor=None,  # Not used for remote, but kept for interface
    deployment_url=SAM3_DEPLOYMENT_URL,
)

print(f"Image: {image}")
print(f"Prompt: {prompt}")
print(f"LLM Server: {LLM_SERVER_URL}")
print(f"SAM3 Endpoint: {SAM3_DEPLOYMENT_URL}")
print("\nStarting agent inference...")


In [None]:
# Run agent inference with local vLLM and remote SAM3
output_image_path = run_single_image_inference(
    image, prompt, llm_config, send_generate_request, call_sam_service, 
    debug=True, output_dir="agent_output"
)

# display output
if output_image_path is not None:
    display(Image(filename=output_image_path))
