<a href="https://colab.research.google.com/github/thiensean/Grounded-Segment-Anything/blob/main/groundedSAM_FastAPI.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import os
import sys
from fastapi import FastAPI, File, UploadFile, HTTPException, Response
from pydantic import BaseModel
from PIL import Image
import torch
from io import BytesIO

# Add GroundingDINO to the path (update the path as per your directory structure)
sys.path.append(os.path.join(os.getcwd(), "GroundingDINO"))

# Importing necessary modules from GroundingDINO and Segment Anything
from GroundingDINO.groundingdino.models import build_model
from GroundingDINO.groundingdino.util.slconfig import SLConfig
from GroundingDINO.groundingdino.util.utils import clean_state_dict
from segment_anything import build_sam, SamPredictor
from diffusers import StableDiffusionInpaintPipeline
from huggingface_hub import hf_hub_download

# Initialize FastAPI app
app = FastAPI()

# Load models
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def load_model_hf(repo_id, filename, ckpt_config_filename, device='cpu'):
    cache_config_file = hf_hub_download(repo_id=repo_id, filename=ckpt_config_filename)
    args = SLConfig.fromfile(cache_config_file)
    args.device = device
    model = build_model(args)
    cache_file = hf_hub_download(repo_id=repo_id, filename=filename)
    checkpoint = torch.load(cache_file, map_location=device)
    model.load_state_dict(clean_state_dict(checkpoint['model']), strict=False)
    model.eval()
    return model

# Replace with actual repository IDs and filenames
groundingdino_model = load_model_hf('repo_id', 'filename', 'config_filename', device)
sam_predictor = SamPredictor(build_sam(device))
sd_pipe = StableDiffusionInpaintPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", use_auth_token=True)
sd_pipe.to(device)

# Define request model for generation endpoint
class GenerateRequest(BaseModel):
    mask_id: int
    prompt: str

# Endpoint for uploading image and getting mask IDs
@app.post("/upload")
async def upload_image(file: UploadFile = File(...)):
    image_data = await file.read()
    image = Image.open(BytesIO(image_data)).convert("RGB")

    # Implement preprocessing and detection logic (placeholders)
    mask_ids = [0, 1, 2]  # Replace with actual logic

    return {"mask_ids": mask_ids}

# Endpoint for generating image
@app.post("/generate")
def generate_image(request: GenerateRequest):
    mask_id = request.mask_id
    prompt = request.prompt

    # Implement image generation logic (placeholder)
    generated_image = Image.new("RGB", (512, 512))  # Replace with actual logic

    img_byte_arr = BytesIO()
    generated_image.save(img_byte_arr, format='PNG')
    img_byte_arr = img_byte_arr.getvalue()

    return Response(content=img_byte_arr, media_type="image/png")
