# Background removal
This notebook is used to remove the background of the images in the dataset.

In [1]:
import cv2
import torch
import numpy as np
from torchvision import transforms
from PIL import Image
from tqdm import tqdm
import os

: 

In [None]:
# Load the pretrained DeepLabV3 model for semantic segmentation
model = torch.hub.load('pytorch/vision:v0.6.0', 'deeplabv3_resnet101', pretrained=True)
# Set the model to evaluation mode
model.eval()

In [None]:
def make_transparent_foreground(pic, mask):
    """
    This function takes in a PIL image and a mask and returns a transparent foreground image.

    Parameters
    ----------
    pic : PIL image
        The image to be made transparent.

    mask : numpy array
        The mask to be applied to the image.

    Returns
    -------
    foreground : numpy array
        The transparent foreground image.
    """
    b, g, r = cv2.split(np.array(pic).astype('uint8'))
    a = np.ones(mask.shape, dtype='uint8') * 255
    alpha_im = cv2.merge([b, g, r, a], 4)
    bg = np.zeros(alpha_im.shape)
    new_mask = np.stack([mask, mask, mask, mask], axis=2)
    foreground = np.where(new_mask, alpha_im, bg).astype(np.uint8)
    return foreground

In [None]:
def remove_background(input_image):
    """
    This function takes in an image and returns a transparent foreground image and a binary mask.

    Parameters
    ----------
    input_image : PIL image
        The image to be made transparent.

    Returns
    -------
    foreground : numpy array
        The transparent foreground image.

    bin_mask : numpy array
        The binary mask.
    """
    preprocess = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

    input_tensor = preprocess(input_image)
    input_batch = input_tensor.unsqueeze(0)

    if torch.cuda.is_available():
        input_batch = input_batch.to('cuda')
        model.to('cuda')

    with torch.no_grad():
        output = model(input_batch)['out'][0]
    output_predictions = output.argmax(0)

    mask = output_predictions.byte().cpu().numpy()
    background = np.zeros(mask.shape)
    bin_mask = np.where(mask, 255, background).astype(np.uint8)

    foreground = make_transparent_foreground(input_image, bin_mask)

    return foreground, bin_mask

In [None]:
def batch_remove_background(input_folder, output_folder):
    """
    This function takes in a folder of images and returns a folder of transparent foreground images.

    Parameters
    ----------
    input_folder : str
        The path to the folder containing the images to be made transparent.

    output_folder : str
        The path to the folder where the transparent foreground images will be saved.
    """
    # Create output folder if it doesn't exist
    if not os.path.exists(output_folder):
        os.makedirs(output_folder)

    # Get list of all .jpg files in the input folder
    jpg_files = [f for f in os.listdir(input_folder) if f.endswith('.jpg')]
    
    # Loop through all files in the input folder using tqdm for the progress bar
    for filename in tqdm(jpg_files, desc="Removing background"):
        if filename.endswith(".jpg"):
            # Construct the full input path and read the image
            input_path = os.path.join(input_folder, filename)
            input_image_cv = cv2.imread(input_path)

            # Convert the OpenCV image (BGR) to PIL image (RGB)
            input_image_pil = Image.fromarray(cv2.cvtColor(input_image_cv, cv2.COLOR_BGR2RGB))

            # Remove background
            foreground, _ = remove_background(input_image_pil)

            # Convert the foreground back to OpenCV format
            foreground_cv = cv2.cvtColor(np.array(foreground), cv2.COLOR_RGBA2BGRA)

            # Construct the full output path and save the image
            output_path = os.path.join(output_folder, f"foreground_{filename}")
            cv2.imwrite(output_path, foreground_cv)

In [None]:
batch_remove_background("/path/to/input/folder", "/path/to/output/folder")