<a href="https://colab.research.google.com/github/poomshift/remove-bg-colab/blob/main/alchemist_removeBG_colab.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# Downloading all codes & dependencies
!git clone https://github.com/ZhengPeng7/BiRefNet.git
!pip uninstall -q torchaudio torchdata torchtext -y
!pip install -q -r BiRefNet/requirements.txt
!pip install -q -U gdown gradio

%cd BiRefNet

# Imports
from PIL import Image
import torch
from torchvision import transforms
import gradio as gr

from models.birefnet import BiRefNet

# Load Model
device = 'cuda' if torch.cuda.is_available() else 'cpu'
torch.set_float32_matmul_precision(['high', 'highest'][0])
model = BiRefNet.from_pretrained('zhengpeng7/birefnet')
model.to(device)
model.eval()
print('BiRefNet is ready to use.')

# Input Data
transform_image = transforms.Compose([
    transforms.Resize((1024, 1024)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

def remove_background(image, left, top, right, bottom):
    w, h = image.size
    box = [left, top, right, bottom]
    for idx, value in enumerate(box):
        if value == -1:
            box[idx] = [0, 0, w, h][idx]

    image_crop = image.crop(box)
    input_images = transform_image(image_crop).unsqueeze(0).to(device)

    # Prediction
    with torch.no_grad():
        preds = model(input_images)[-1].sigmoid().cpu()
    pred = preds[0].squeeze()

    canvas = torch.zeros((h, w))
    box_to_canvas = [int(round(value * (w / 1024, h / 1024)[idx % 2])) for idx, value in enumerate(box)]
    pred = torch.nn.functional.interpolate(
        pred.unsqueeze(0).unsqueeze(0),
        size=(box_to_canvas[3] - box_to_canvas[1], box_to_canvas[2] - box_to_canvas[0]),
        mode='bilinear',
        align_corners=True
    ).squeeze()
    canvas[box_to_canvas[1]:box_to_canvas[3], box_to_canvas[0]:box_to_canvas[2]] = pred

    # Create masked image
    pred_pil = transforms.ToPILImage()(canvas)
    image_masked = image.copy()
    image_masked.putalpha(pred_pil)

    return image_masked

# Gradio Interface
iface = gr.Interface(
    fn=remove_background,
    inputs=[
        gr.Image(type="pil"),
        gr.Slider(minimum=-1, maximum=1024, step=1, value=-1, label="Left"),
        gr.Slider(minimum=-1, maximum=1024, step=1, value=-1, label="Top"),
        gr.Slider(minimum=-1, maximum=1024, step=1, value=-1, label="Right"),
        gr.Slider(minimum=-1, maximum=1024, step=1, value=-1, label="Bottom"),
    ],
    outputs=gr.Image(type="pil"),
    title="Background Removal App",
    description="Upload an image and optionally specify a bounding box to remove the background.",
)

# Launch the app
iface.launch(share=True)