# Remove People

This code is a proof of concept of a program that removes people or unwanted objects from an image. Before running follow the installation guide lines specified in the `README.md` file.

In [None]:
import os
import matplotlib.pyplot as plt
import numpy as np
import requests
import torch
import torchvision
from dotenv import load_dotenv
from openai import OpenAI
from PIL import Image

load_dotenv()
client = OpenAI(api_key=os.getenv("OPENAI_API"))

# For best results use images that have a 1:1 aspect ratio. 
# This limitation is given by the OpenAI API.
# If you change this provider to one that accept any aspect ration, ignore this.
image_name = 'example_2.jpg' # Only edit this line to change the image.

# Add known path to the image.
image_path = f"images/original/{image_name}"
# Clean name of image without the extension.
clean_image_name = image_name.split('.')[0]

# Create image folders if they don't exist.
os.makedirs("images/original", exist_ok=True)
os.makedirs("images/crops", exist_ok=True)
os.makedirs("images/results", exist_ok=True)

## Load Mask R-CNN

For this we use a [Mask R-CNN](https://arxiv.org/abs/1703.06870) model pre-trained on the [COCO dataset](https://cocodataset.org/). The model is loaded from [`pytorch`](https://pytorch.org/vision/main/models/generated/torchvision.models.detection.maskrcnn_resnet50_fpn.html) using default weights.

In [None]:
# Load a pre-trained Mask R-CNN model.
model = torchvision.models.detection.maskrcnn_resnet50_fpn(weights="DEFAULT")
# Set the model to evaluation mode.
model.eval()

## Miscellaneous Functions

Utility functions to pre-process the image and apply masks.

In [None]:
def preprocess_image(image_path):
    """
    Preprocess the input image for segmentation.

    This function opens an image from the given path, converts it to RGB, 
    and transforms it into a tensor suitable for input into a neural network.

    Args:
        image_path (str): Path to the input image file.

    Returns:
        torch.Tensor: The preprocessed image tensor with shape (1, C, H, W).
        PIL.Image.Image: The original image.
    """
    image = Image.open(image_path).convert("RGB")
    transform = torchvision.transforms.Compose([
        torchvision.transforms.ToTensor()
    ])
    return transform(image).unsqueeze(0), image

def apply_mask(image, mask):
    """
    Post-process the segmentation mask and apply it to the input image.

    This function resizes the mask to match the input image dimensions,
    applies a threshold to create a binary mask, and sets the unwanted 
    object pixels to black.

    Args:
        image (PIL.Image.Image): The original image.
        mask (torch.Tensor): The segmentation mask with shape (1, H, W).

    Returns:
        PIL.Image.Image: The image with the mask applied, with unwanted object pixels set to black.
    """
    image = np.array(image)
    mask = mask.squeeze().cpu().numpy()
    mask = mask > 0.5  # Apply threshold to create binary mask
    mask = np.resize(mask, (image.shape[0], image.shape[1]))  # Resize mask to match image dimensions
    image[mask] = [0, 0, 0]  # Set unwanted object pixels to black (you can change this to any processing)
    return Image.fromarray(image)

## Apply Model and Masks

Main function that applies the model to the image and returns the masked image.

In [None]:
# Preprocess the input image
input_tensor, original_image = preprocess_image(image_path)

# Perform segmentation
with torch.no_grad(): # Disable gradient calculation for inference
    predictions = model(input_tensor)

# Extract masks, labels, and scores from the model's predictions
masks = predictions[0]['masks']
labels = predictions[0]['labels']
scores = predictions[0]['scores']

# Convert the original image to a numpy array for processing
input_image_np = np.array(original_image)
input_image_masked = np.array(original_image)

# Apply each mask to the original image
for mask in masks:
    input_image_masked = apply_mask(input_image_masked, mask)

# Plot the original image and the object detection masks.
fig, axes = plt.subplots(1, 2, figsize=(12, 6))

axes[0].imshow(input_image_np)
axes[0].set_title('Original Image')
axes[0].axis('off')

axes[1].imshow(input_image_masked)
axes[1].set_title('Object Detection Masks')
axes[1].axis('off')

plt.show()

## Inspect Object Detection

Function to visualize the object detection results, note-down the indices of the main objects (objects might appear more than once).

In [None]:
# Plot each mask individually
num_masks = len(masks)
cols = 5  # Number of columns in the plot grid
rows = (num_masks // cols) + (num_masks % cols > 0)  # Calculate the number of rows

# Parse CSV file with the labels for the COCO dataset.
coco_labels = []
with open('coco_labels.txt', 'r') as file:
    for line in file:
        coco_labels.append(line.strip())

plt.figure(figsize=(20, 5 * rows))
for i, mask in enumerate(masks):
    plt.subplot(rows, cols, i + 1)
    mask_np = mask.squeeze().cpu().numpy()
    mask_resized = np.array(Image.fromarray(mask_np).resize((original_image.width, original_image.height)))
    plt.imshow(mask_resized, cmap='gray')
    plt.title(f'{i} - {coco_labels[labels[i].item()]} ({scores[i]:.2f})')
    plt.axis('off')

plt.tight_layout()
plt.show()

## Protagonist Selection

Input the indices that correspond to the main objects in the image and remove all others.

In [None]:
# Select the indices of the main protagonists in the image, check the labels above.
main_protagonist_indices = [0, 1, 4, 22, 59]

# Filter out masks of main protagonists.
protagonist_masks = [masks[i] for i in main_protagonist_indices]

# Create a combined mask for the main protagonists.
combined_mask_protagonists = np.zeros(input_image_np.shape[:2], dtype=np.uint8)
for mask in protagonist_masks:
    combined_mask_protagonists = np.maximum(combined_mask_protagonists, mask.squeeze().cpu().numpy())

# Create a combined mask for all objects.
combined_mask_all = np.zeros(input_image_np.shape[:2], dtype=np.uint8)
for mask in masks:
    combined_mask_all = np.maximum(combined_mask_all, mask.squeeze().cpu().numpy())

# Subtract the protagonist mask from the combined mask to get the non-protagonist mask.
non_protagonist_mask = combined_mask_all - combined_mask_protagonists
non_protagonist_mask = (non_protagonist_mask > 0).astype(np.uint8)  # Ensure it's binary.

# Plot the masks side by side.
fig, axes = plt.subplots(1, 2, figsize=(12, 6))

axes[0].imshow(combined_mask_protagonists, cmap='gray')
axes[0].set_title('Protagonist Mask')
axes[0].axis('off')

axes[1].imshow(non_protagonist_mask, cmap='gray')
axes[1].set_title('Combined Mask of All Objects')
axes[1].axis('off')

plt.show()

## PNG Creation

Create a PNG file of the protagonist without unwanted objects.

In [None]:
# Invert the mask to keep the protagonists and background
final_mask = 1 - non_protagonist_mask

# Apply the final mask to the image
masked_image = input_image_np.copy()
masked_image[final_mask == 0] = 0  # Set non-used areas to 0 (black)

# Create an alpha channel where 0 is transparent and 255 is opaque
alpha_channel = np.where(final_mask == 1, 255, 0).astype(np.uint8)

# Convert to PIL Image and save as PNG with transparent background
result_image = Image.fromarray(masked_image.astype('uint8'), 'RGB')
result_image.putalpha(Image.fromarray(alpha_channel))
result_image.save(f'images/crops/{clean_image_name}.png', format='PNG')

# Display the resulting image
plt.imshow(result_image)
plt.axis('off')  # Hide axes
plt.show()

## OpenAI Pre-processing

The [OpenAI API being used](https://platform.openai.com/docs/api-reference/images/createEdit) to generate content requires the image to be of 1024x1024 pixels and smaller than 4MB. This function resizes the image to the required dimensions and saves it as a PNG file. If for any reason the image is still larger than 4MB, please compress it further and add it to the `images/crops` folder.

In [None]:
# Crop the result image to be a 1024x1024 square if needed.
def crop_image(image_path, size=(1024, 1024)):
    image = Image.open(image_path)
    image.thumbnail(size)
    image.save(f'images/crops/{clean_image_name}.png', format='PNG')

crop_image(f'images/crops/{clean_image_name}.png')

## OpenAI Execution and Download

Call the OpenAI API for content generation and download the result to the `images/results` folder. If you are not happy with the result you can re-run the code, the API will always generate different results.

In [None]:
# Call the OpenAI API to complete the image.
response = client.images.edit(
  model="dall-e-2", # Only DALL-E 2 was supported at the time of writing.
  image=open(f"images/crops/{clean_image_name}.png", "rb"),
  prompt="A family photo in Petra that needs completion.",
  n=1, # If you want more results, you can increase this number.
  size="1024x1024"
)
image_url = response.data[0].url

# Download image from the URL and save in images/results folder.
image_data = requests.get(image_url).content
with open(f'images/results/{clean_image_name}.jpg', 'wb') as image_file:
    image_file.write(image_data)

# Display the final image.
final_image = Image.open(f'images/results/{clean_image_name}.jpg')

# Plot the original image and the result image.
fig, axes = plt.subplots(1, 2, figsize=(12, 6))

axes[0].imshow(input_image_np)
axes[0].set_title('Original Image')
axes[0].axis('off')

axes[1].imshow(final_image)
axes[1].set_title('Result Image')
axes[1].axis('off')

plt.show()

## Additional Information

The following cell just gets the labels being used by our model's weights. The result of this cell can be found in the [`coco_labels.txt`](coco_labels.txt) file. Run this in case you try a different set of weights, save the copied results into the [`coco_labels.txt`](coco_labels.txt) file. This project uses `MaskRCNN_ResNet50_FPN_Weights.DEFAULT`.

In [None]:
from torchvision.models.detection import MaskRCNN_ResNet50_FPN_Weights

weights_val = MaskRCNN_ResNet50_FPN_Weights.DEFAULT
categories_val = weights_val.meta["categories"]

# Save the labels with index in CSV format.
result = ""
for i, category in enumerate(categories_val):
  result += f"{category}\n"

# Copy the result to clipboard.
os.system(f"echo '{result}' | pbcopy")