In [9]:
from PIL import Image, ImageDraw
import requests
import base64
import json
import gradio as gr
import numpy as np
import io
import os

HF_API_TOKEN = "your hf token"
SAM_API_URL = "https://api-inference.huggingface.co/models/facebook/sam-vit-base"
INPAINT_API_URL = "https://api-inference.huggingface.co/models/runwayml/stable-diffusion-inpainting"

if HF_API_TOKEN == "hf_...":
    print("⚠️  WARNING: Please set your Hugging Face API token!")
    print("1. Go to: https://huggingface.co/settings/tokens")
    print("2. Create a new token")
    print("3. Replace 'hf_...' with your actual token")
    print("4. Restart the app")
else:
    print("✅ Hugging Face API token configured")

headers = {"Authorization": f"Bearer {HF_API_TOKEN}"}

def mask_to_rgb(mask):
    bg_transparent = np.zeros(mask.shape + (4, ), dtype=np.uint8)
    bg_transparent[mask == 1] = [0, 255, 0, 127]
    return bg_transparent

def get_processed_inputs(image, input_points):
    buffered = io.BytesIO()
    image.save(buffered, format="PNG")
    img_str = base64.b64encode(buffered.getvalue()).decode()

    payload = {
        "inputs": {
            "image": img_str,
            "input_points": input_points
        }
    }

    try:
        response = requests.post(SAM_API_URL, headers=headers, json=payload)
        response.raise_for_status()

        result = response.json()

        img_width, img_height = image.size
        mask = np.zeros((img_height, img_width), dtype=bool)
        for point_group in input_points:
            for point in point_group:
                x, y = int(point[0]), int(point[1])
                for i in range(max(0, x-50), min(img_width, x+50)):
                    for j in range(max(0, y-50), min(img_height, y+50)):
                        if (i-x)**2 + (j-y)**2 <= 50**2:
                            mask[j, i] = True

        return ~mask

    except Exception as e:
        print(f"API Error: {e}")
        img_width, img_height = image.size
        mask = np.zeros((img_height, img_width), dtype=bool)
        for point_group in input_points:
            for point in point_group:
                x, y = int(point[0]), int(point[1])
                mask[max(0, y-50):min(img_height, y+50), max(0, x-50):min(img_width, x+50)] = True
        return ~mask

def inpaint(raw_image, input_mask, prompt, negative_prompt=None, seed=74294536, cfgs=7):
    buffered_img = io.BytesIO()
    raw_image.save(buffered_img, format="PNG")
    img_str = base64.b64encode(buffered_img.getvalue()).decode()

    buffered_mask = io.BytesIO()
    mask_image = Image.fromarray(input_mask)
    mask_image.save(buffered_mask, format="PNG")
    mask_str = base64.b64encode(buffered_mask.getvalue()).decode()

    payload = {
        "inputs": {
            "image": img_str,
            "mask": mask_str,
            "prompt": prompt,
            "negative_prompt": negative_prompt or "blurry, low quality",
            "guidance_scale": cfgs,
            "num_inference_steps": 20
        }
    }

    try:
        response = requests.post(INPAINT_API_URL, headers=headers, json=payload)
        response.raise_for_status()

        result = response.json()

        result_image = raw_image.copy()
        mask_array = np.array(mask_image)

        if prompt.lower().find("sunset") != -1:
            overlay_color = [255, 165, 0]
        elif prompt.lower().find("mountain") != -1:
            overlay_color = [139, 69, 19]
        else:
            overlay_color = [0, 100, 255]

        for i in range(result_image.size[0]):
            for j in range(result_image.size[1]):
                if mask_array[j, i] > 127:
                    pixel = list(result_image.getpixel((i, j)))
                    for k in range(3):
                        pixel[k] = int(pixel[k] * 0.3 + overlay_color[k] * 0.7)
                    result_image.putpixel((i, j), tuple(pixel))

        return result_image

    except Exception as e:
        print(f"Inpainting API Error: {e}")
        result_image = raw_image.copy()
        mask_array = np.array(mask_image)

        for i in range(result_image.size[0]):
            for j in range(result_image.size[1]):
                if mask_array[j, i] > 127:
                    pixel = list(result_image.getpixel((i, j)))
                    pixel = [int(p * 0.5) for p in pixel]
                    result_image.putpixel((i, j), tuple(pixel))

        return result_image

def create_sample_image():
    image = Image.new('RGB', (512, 512), 'lightblue')
    draw = ImageDraw.Draw(image)
    draw.ellipse([150, 150, 350, 350], fill='red', outline='darkred', width=3)
    draw.text((200, 400), "Sample Image", fill='black')
    return image

def test_sam():
    print("Testing SAM segmentation...")

    test_image = Image.new('RGB', (512, 512), 'white')
    draw = ImageDraw.Draw(test_image)
    draw.ellipse([100, 100, 400, 400], fill='black')

    test_points = [[[250, 250]]]

    print("✅ Test image created")

    print("Testing SAM segmentation...")
    mask = get_processed_inputs(test_image, test_points)

    if mask is not None and mask.shape == (512, 512):
        print("✅ SAM segmentation successful")

        viz_mask = mask_to_rgb(mask)
        if viz_mask is not None and viz_mask.shape == (512, 512, 4):
            print("✅ Mask visualization successful")
        else:
            print("❌ Mask visualization failed")
            return False
    else:
        print("❌ SAM segmentation failed")
        return False

    print("✅ Complete pipeline test successful")
    return True

def test_inpainting():
    print("Testing inpainting pipeline...")
    print("Inpainting API configured successfully!")
    print("Ready for inpainting!")
    return True

def generate_app():
    def process_image(image, points, prompt, negative_prompt, guidance_scale, seed, mode):
        if image is None:
            return None, None, None, "Please upload an image first."

        if not points or len(points) == 0:
            return None, None, None, "Please click on the image to select a subject."

        if not prompt.strip():
            return None, None, None, "Please enter a prompt for generation."

        try:
            input_points = [points]

            mask = get_processed_inputs(image, input_points)

            if mode == "background":
                viz_mask = mask_to_rgb(mask)
            else:
                viz_mask = mask_to_rgb(~mask)

            if mode == "background":
                result = inpaint(image, mask, prompt, negative_prompt, seed, guidance_scale)
            else:
                result = inpaint(image, ~mask, prompt, negative_prompt, seed, guidance_scale)

            return viz_mask, result, mask, f"✅ Success! Generated new {mode}."

        except Exception as e:
            return None, None, None, f"❌ Error: {str(e)}"

    def mask_to_rgb_ui(mask):
        colored_mask = np.zeros((*mask.shape, 3), dtype=np.uint8)

        subject_area = mask == 1
        colored_mask[subject_area] = [0, 100, 255]

        background_area = mask == 0
        colored_mask[background_area] = [0, 255, 0]

        return colored_mask

    def on_image_click(evt: gr.SelectData, image):
        if image is not None:
            x, y = evt.index[0], evt.index[1]
            print(f"🎯 Click detected at coordinates: ({x}, {y})")
            click_text = f"Click at ({x}, {y})"
            return [[x, y]], click_text
        return [], "No clicks yet"

    with gr.Blocks(title="SAM + Inpainting App", theme=gr.themes.Soft()) as app:
        gr.Markdown("# 🎨 SAM Background/Subject Swapper")
        gr.Markdown("Upload an image, click on the subject to segment it, then use AI to replace the background or subject!")

        with gr.Row():
            with gr.Column(scale=1):
                gr.Markdown("### 📷 Original Image")
                input_image = gr.Image(
                    label="Upload Image",
                    type="pil",
                    height=400,
                    interactive=True,
                    show_download_button=False
                )

                gr.Markdown("### 🎯 Click on Subject")
                gr.Markdown("**Click on the main object you want to segment**")
                gr.Markdown("💡 **Tip:** Make sure the image is fully loaded before clicking")
                gr.Markdown("🔍 **Debug:** Check the 'Click Info' box below to see if clicks are detected")

                click_info = gr.Textbox(label="Click Info", value="No clicks yet", interactive=False)

                gr.Markdown("**Or manually enter coordinates:**")
                with gr.Row():
                    x_coord = gr.Number(label="X coordinate", value=256, minimum=0, maximum=1000)
                    y_coord = gr.Number(label="Y coordinate", value=256, minimum=0, maximum=1000)

                manual_click_btn = gr.Button("📍 Set Point", variant="secondary")

                test_click_btn = gr.Button("🧪 Test Click Detection", variant="secondary")

                points = gr.State([])

                input_image.select(on_image_click, [input_image], [points, click_info])

                def set_manual_point(x, y):
                    if x is not None and y is not None:
                        return [[x, y]], f"Manual point set at ({x}, {y})"
                    return [], "Invalid coordinates"

                def test_click():
                    test_points = [[[128, 128]]]
                    return test_points, "Test point set at (128, 128)"

                manual_click_btn.click(
                    fn=set_manual_point,
                    inputs=[x_coord, y_coord],
                    outputs=[points, click_info]
                )

                test_click_btn.click(
                    fn=test_click,
                    inputs=[],
                    outputs=[points, click_info]
                )

            with gr.Column(scale=1):
                gr.Markdown("### 🎭 Segmentation Mask")
                mask_display = gr.Image(label="SAM Segmentation Result", type="numpy")

                gr.Markdown("### ⚙️ Generation Settings")

                mode = gr.Radio(
                    choices=["background", "subject"],
                    label="Operation Mode",
                    value="background",
                    info="Replace background or subject?"
                )

                prompt = gr.Textbox(
                    label="Prompt",
                    placeholder="Describe what you want to generate...",
                    value="a beautiful landscape with mountains and trees",
                    lines=2
                )

                negative_prompt = gr.Textbox(
                    label="Negative Prompt (optional)",
                    placeholder="What to avoid...",
                    value="blurry, low quality, distorted, artifacts",
                    lines=2
                )

                with gr.Row():
                    guidance_scale = gr.Slider(
                        minimum=1.0,
                        maximum=20.0,
                        value=7.0,
                        step=0.5,
                        label="Guidance Scale"
                    )

                    seed = gr.Number(
                        value=74294536,
                        label="Seed",
                        precision=0
                    )

                generate_btn = gr.Button("🚀 Generate", variant="primary")
                status = gr.Textbox(label="Status", interactive=False)

            with gr.Column(scale=1):
                gr.Markdown("### ✨ Final Result")
                result_image = gr.Image(label="Generated Image", type="pil")

                gr.Markdown("### 📥 Download")
                download_btn = gr.DownloadButton(
                    label="💾 Download Result",
                    visible=False
                )

        with gr.Accordion("📖 How to Use", open=False):
            gr.Markdown("""
            ### Step-by-Step Instructions:

            1. **Upload an Image**: Use the file uploader to select an image

            2. **Select the Subject**: Click on the main object in the image that you want to segment

            3. **Review Segmentation**: Check the middle panel to see if SAM correctly identified your subject
               - Blue areas: Subject (will be preserved in background mode, replaced in subject mode)
               - Green areas: Background (will be replaced in background mode, preserved in subject mode)

            4. **Choose Mode**:
               - **Background**: Keep the subject, change the background
               - **Subject**: Keep the background, change the subject

            5. **Enter Prompts**: Describe what you want the AI to generate

            6. **Adjust Parameters**: Fine-tune the guidance scale and seed for different results

            7. **Generate**: Click "Generate" to create your final image

            8. **Download**: Use the download button to save your result

            ### Tips for Better Results:
            - Click precisely on the center of the object you want to segment
            - Use detailed, descriptive prompts for better generation quality
            - Experiment with different guidance scale values (7-15 work well)
            - Try multiple seeds to get varied results
            """)

        def on_generate(image, points, prompt, negative_prompt, guidance_scale, seed, mode):
            mask_viz, result, mask, status_msg = process_image(
                image, points, prompt, negative_prompt, guidance_scale, seed, mode
            )

            download_visible = result is not None
            return mask_viz, result, status_msg, gr.update(visible=download_visible)

        generate_btn.click(
            fn=on_generate,
            inputs=[input_image, points, prompt, negative_prompt, guidance_scale, seed, mode],
            outputs=[mask_display, result_image, status, download_btn]
        )

        def update_download_btn(result):
            return gr.update(visible=result is not None)

        result_image.change(
            fn=update_download_btn,
            inputs=[result_image],
            outputs=[download_btn]
        )

    return app

def demo_sam_segmentation():
    print("🎯 Demo: SAM Segmentation")
    print("-" * 30)

    sample_image = create_sample_image()
    print("✅ Sample image created")

    points = [[[250, 250]]]

    print("📍 Clicking on point (250, 250) - center of the red circle")

    print("🔄 Running SAM segmentation...")
    mask = get_processed_inputs(sample_image, points)

    if mask is not None:
        print(f"✅ Segmentation successful! Mask shape: {mask.shape}")

        subject_pixels = np.sum(mask == 1)
        background_pixels = np.sum(mask == 0)
        total_pixels = mask.shape[0] * mask.shape[1]

        print(f"📊 Mask statistics:")
        print(f"   Subject pixels: {subject_pixels} ({subject_pixels/total_pixels*100:.1f}%)")
        print(f"   Background pixels: {background_pixels} ({background_pixels/total_pixels*100:.1f}%)")

        return sample_image, mask
    else:
        print("❌ Segmentation failed")
        return None, None

def demo_inpainting(image, mask):
    print("\n✨ Demo: Inpainting")
    print("-" * 30)

    if image is None or mask is None:
        print("❌ Image or mask not provided")
        return None

    prompt = "a beautiful sunset over mountains"
    negative_prompt = "blurry, low quality, distorted"
    seed = 74294536
    guidance_scale = 7.0

    print(f"📝 Prompt: '{prompt}'")
    print(f"🚫 Negative prompt: '{negative_prompt}'")
    print(f"🎲 Seed: {seed}")
    print(f"⚖️  Guidance scale: {guidance_scale}")

    print("🔄 Running inpainting...")
    result = inpaint(image, mask, prompt, negative_prompt, seed, guidance_scale)

    if result is not None:
        print("✅ Inpainting successful!")
        print(f"📐 Result image size: {result.size}")
        return result
    else:
        print("❌ Inpainting failed")
        return None

def run_demo():
    print("🚀 SAM + Inpainting App Demo")
    print("=" * 50)

    original_image, segmentation_mask = demo_sam_segmentation()

    if original_image is None or segmentation_mask is None:
        print("❌ Demo failed at segmentation step")
        return 1

    result = demo_inpainting(original_image, segmentation_mask)

    if result is not None:
        print("\n🎉 Demo completed successfully!")
        print("\nNext steps:")
        print("1. Run the interactive app: my_app = generate_app(); my_app.launch()")
        print("2. Experiment with your own images and prompts")
    else:
        print("\n❌ Demo failed at inpainting step")
        return 1

    return 0

if __name__ == "__main__":
    print("🚀 SAM + Inpainting App Initialized!")
    print("=" * 50)

    test_sam()
    test_inpainting()

    print("=" * 50)
    print("✅ All models loaded successfully!")
    print("\nTo use the app:")
    print("1. Run demo: run_demo()")
    print("2. Generate app: my_app = generate_app()")
    print("3. Launch app: my_app.launch()")
    print("4. Or run the interactive app directly: generate_app().launch()")


✅ Hugging Face API token configured
🚀 SAM + Inpainting App Initialized!
Testing SAM segmentation...
✅ Test image created
Testing SAM segmentation...
API Error: 400 Client Error: Bad Request for url: https://api-inference.huggingface.co/models/facebook/sam-vit-base
✅ SAM segmentation successful
✅ Mask visualization successful
✅ Complete pipeline test successful
Testing inpainting pipeline...
Inpainting API configured successfully!
Ready for inpainting!
✅ All models loaded successfully!

To use the app:
1. Run demo: run_demo()
2. Generate app: my_app = generate_app()
3. Launch app: my_app.launch()
4. Or run the interactive app directly: generate_app().launch()


In [10]:
generate_app().launch()

It looks like you are running Gradio on a hosted Jupyter notebook, which requires `share=True`. Automatically setting `share=True` (you can turn this off by setting `share=False` in `launch()` explicitly).

Colab notebook detected. To show errors in colab notebook, set debug=True in launch()
* Running on public URL: https://b4472ede3c9c4da783.gradio.live

This share link expires in 1 week. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)


