In [32]:
!nvidia-smi

Fri Nov 24 00:05:10 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.105.17   Driver Version: 525.105.17   CUDA Version: 12.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |
| N/A   58C    P0    30W /  70W |  11883MiB / 15360MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [1]:
!pip install -qq -U diffusers==0.11.1 transformers ftfy gradio accelerate rembg

In [2]:
from rembg import remove
import cv2
import numpy as np
import requests
from io import BytesIO

import inspect
from typing import List, Optional, Union
import torch

import PIL
import gradio as gr
from diffusers import StableDiffusionInpaintPipeline

In [3]:
device = "cuda"
model_path = "runwayml/stable-diffusion-inpainting"

pipe = StableDiffusionInpaintPipeline.from_pretrained(
    model_path,
    torch_dtype=torch.float16,
).to(device)

generator = torch.Generator(device="cuda").manual_seed(0) # change the seed to get different results

unet/diffusion_pytorch_model.safetensors not found


Fetching 24 files:   0%|          | 0/24 [00:00<?, ?it/s]

`text_config_dict` is provided which will be used to initialize `CLIPTextConfig`. The value `text_config["id2label"]` will be overriden.
`text_config_dict` is provided which will be used to initialize `CLIPTextConfig`. The value `text_config["bos_token_id"]` will be overriden.
`text_config_dict` is provided which will be used to initialize `CLIPTextConfig`. The value `text_config["eos_token_id"]` will be overriden.


In [4]:
def image_grid(imgs, rows, cols):
  assert len(imgs) == rows*cols

  w, h = imgs[0].size
  grid = PIL.Image.new('RGB', size=(cols*w, rows*h))
  grid_w, grid_h = grid.size

  for i, img in enumerate(imgs):
      grid.paste(img, box=(i%cols*w, i//cols*h))
  return grid


In [5]:
def remove_background(image):
    return remove(image)

def generate_background(image,prompt,guidance_scale,num_samples):
    # Convert BGR to HSV
    hsv = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)

    # define range of color in HSV
    lower = np.array([0,0,0])
    upper = np.array([5,15,35])

    # Create a mask. Threshold the HSV image to get only colors
    mask_image = cv2.inRange(hsv, lower, upper)

    # define factor to resize the image due to memory issue
    if (hsv.shape[0]*hsv.shape[1]/400000)**0.5 > 1:
      factor =  (hsv.shape[0]*hsv.shape[1]/400000)**0.5
    else:
      factor = 1

    # calculate the height and width to generate image
    height = int(hsv.shape[0]/factor) - int(hsv.shape[0]/factor)%8
    width = int(hsv.shape[1]/factor) - int(hsv.shape[1]/factor)%8

    image2 = PIL.Image.fromarray(image).resize((width, height))
    mask_image2 = PIL.Image.fromarray(mask_image).resize((width, height))


    guidance_scale = float(guidance_scale)
    num_samples = int(num_samples)
    generator = torch.Generator(device="cuda").manual_seed(0) # change the seed to get different results

    images = pipe(
        prompt=prompt,
        image=image2,
        mask_image=mask_image2,
        guidance_scale=guidance_scale,
        generator=generator,
        num_images_per_prompt=num_samples,
        width=width,
        height=height,
    ).images

    return image_grid(images, 1, num_samples)

In [6]:
def remove_generate_background(image,prompt,negative_prompt,guidance_scale,num_samples):
  # Remove current background
  img = remove(image)

  # define range of color
  lower = np.array([0,0,0,0])
  upper = np.array([1,1,1,1])

  # create a mask. Threshold image to get only colors
  mask_image = cv2.inRange(img, lower, upper)

  # define factor to resize the image due to memory issue
  if (img.shape[0]*img.shape[1]/400000)**0.5 > 1:
    factor =  (img.shape[0]*img.shape[1]/400000)**0.5
  else:
    factor = 1

  # calculate the height and width to generate image
  height = int(img.shape[0]/factor) - int(img.shape[0]/factor)%8
  width = int(img.shape[1]/factor) - int(img.shape[1]/factor)%8


  image2 = PIL.Image.fromarray(img).resize((width, height))
  mask_image2 = PIL.Image.fromarray(mask_image).resize((width, height))


  guidance_scale = float(guidance_scale)
  num_samples = int(num_samples)
  generator = torch.Generator(device="cuda").manual_seed(0)

  prompt = prompt

  images = pipe(
      prompt=prompt,
      negative_prompt=negative_prompt,
      image=image2,
      mask_image=mask_image2,
      guidance_scale=guidance_scale,
      generator=generator,
      num_images_per_prompt=num_samples,
      width=width,
      height=height,
  ).images

  return image_grid(images, 1, num_samples)


In [7]:
with gr.Blocks() as demo:
    gr.Markdown("Remove Background and Generate a Background with Prompt")

    with gr.Tab("Remove Background"):
        with gr.Row():
            image_input_remove = gr.Image()
            image_output_remove = gr.Image()
        image_button_remove = gr.Button("Remove")

    with gr.Tab("Generate Background"):
        with gr.Row():
            image_input_generate = gr.Image()
            with gr.Column():
              guidance_scales = gr.Slider(0, 20, step=0.1, value=7, label="Guidance Scale", info="Choose between 0 and 20")
              image_number = gr.Dropdown([1,2,3],value=1,label="Number of Images", info="Select number of images to generate")
            image_output_generate = gr.Image()
        with gr.Column():
          text_input_generate = gr.Textbox("A background with subtle nature elements, like a soft-focus image of flowers or leaves",label="Prompt",lines=3, info="Enter the prompt")

        image_button_generate = gr.Button("Generate")

    with gr.Tab("Remove and Generate New Background"):
        with gr.Row():
            image_input_remove_generate = gr.Image()
            with gr.Column():
              rm_guidance_scales = gr.Slider(0, 20, step=0.1, value=7.5, label="Guidance Scale", info="Choose between 0 and 20")
              rm_image_number = gr.Dropdown([1,2,3],value=1,label="Number of Images", info="Select number of images to generate")
            image_output_remove_generate = gr.Image()
        with gr.Column():
          text_input_remove_generate = gr.Textbox("A background with subtle nature elements, like a soft-focus image of flowers or leaves",label="Prompt",lines=2, info="Enter the prompt")
          text_negative_input_remove_generate = gr.Textbox("",label="Negative Prompt",lines=1, info="Enter the negative prompt")
        image_button_remove_generate = gr.Button("Generate")


    with gr.Accordion("Open for More!"):
        gr.Markdown("Look at me...")

    image_button_remove.click(remove_background, inputs=image_input_remove, outputs=image_output_remove)
    image_button_generate.click(generate_background, inputs=[image_input_generate, text_input_generate, guidance_scales, image_number], outputs=image_output_generate)
    image_button_remove_generate.click(remove_generate_background, inputs=[image_input_remove_generate, text_input_remove_generate,text_negative_input_remove_generate, rm_guidance_scales, rm_image_number], outputs=image_output_remove_generate)

demo.launch()

Setting queue=True in a Colab notebook requires sharing enabled. Setting `share=True` (you can turn this off by setting `share=False` in `launch()` explicitly).

Colab notebook detected. To show errors in colab notebook, set debug=True in launch()
Running on public URL: https://9f067b59ab46d34e82.gradio.live

This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from Terminal to deploy to Spaces (https://huggingface.co/spaces)


