# Demo of Context-Aware Image Inpainting for Automatic Object Removal

## Download sample images and model checkpoints:
For simplicity, only the baseline, 0.5L1 + 0.5SSIM joint reconstruction loss, and 0.3L2 + 0.7SSIM models are downloaded.
The others can be manually downloaded from [https://drive.google.com/drive/folders/11yjGAIRzUpQH2IoSuLsKn6Vm9yr_-29h?usp=drive_link]

In [19]:
# First, download the sample images and model checkpoints:
!wget --no-check-certificate -L 'https://docs.google.com/uc?export=download&id=1kpu4lvoFmQffxBkfAMsMloXlxtYT-kf8' -O object_classes.txt
!wget --no-check-certificate -L 'https://docs.google.com/uc?export=download&id=1f24sS0lYczLJ0BJzrPuELVx-_Suj6nEi&format=zip' -O sample_images.zip
!wget --no-check-certificate -L 'https://docs.google.com/uc?export=download&id=1OVX9n70-CPWjPG_6tL_P39KgQQcHI9j_&format=zip' -O sample_images_object_removal.zip
!unzip sample_images.zip
!unzip sample_images_object_removal.zip
!rm sample_images.zip
!rm sample_images_object_removal.zip

--2024-04-30 23:25:21--  https://docs.google.com/uc?export=download&id=1kpu4lvoFmQffxBkfAMsMloXlxtYT-kf8
Resolving docs.google.com (docs.google.com)... 74.125.23.102, 74.125.23.113, 74.125.23.138, ...
Connecting to docs.google.com (docs.google.com)|74.125.23.102|:443... connected.
HTTP request sent, awaiting response... 303 See Other
Location: https://drive.usercontent.google.com/download?id=1kpu4lvoFmQffxBkfAMsMloXlxtYT-kf8&export=download [following]
--2024-04-30 23:25:22--  https://drive.usercontent.google.com/download?id=1kpu4lvoFmQffxBkfAMsMloXlxtYT-kf8&export=download
Resolving drive.usercontent.google.com (drive.usercontent.google.com)... 108.177.97.132, 2404:6800:4008:c00::84
Connecting to drive.usercontent.google.com (drive.usercontent.google.com)|108.177.97.132|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1532 (1.5K) [application/octet-stream]
Saving to: ‘object_classes.txt’


2024-04-30 23:25:23 (73.6 MB/s) - ‘object_classes.txt’ saved [1532/1532

In [2]:
# Get model checkpoints:
!pip install gdown
!gdown 1mmDswcNnt6vu3jbO7hgyhV_-ts8Rtm6i # Baseline
# !gdown 1SbsNVqOEcajMQoJq0-7_Y1jQa1i4uVfT # L1 + SSIM Joint Reconstruction
!gdown 1e1NRuB2urJKHQS8pQLizTJ-t8gymnj4R # 0.5L1 + 0.5SSIM Joint Reconstruction
!gdown 1z8S2KSPWHrqrexGD1DdHkD1g2cdlx_aM # L2 + SSIM Joint Reconstruction
!mkdir Checkpoints
!mv /content/model_baseline_epoch_39.pth /content/Checkpoints/model_baseline_epoch_39.pth
# !mv /content/model_L1_SSIM_epoch_39.pth /content/Checkpoints/model_L1_SSIM_epoch_39.pth
!mv /content/model_0.5L1_0.5SSIM_epoch_39.pth /content/Checkpoints/model_0.5L1_0.5SSIM_epoch_39.pth
!mv /content/model_L2_SSIM_epoch_39.pth /content/Checkpoints/model_L2_SSIM_epoch_39.pth

Downloading...
From (original): https://drive.google.com/uc?id=1mmDswcNnt6vu3jbO7hgyhV_-ts8Rtm6i
From (redirected): https://drive.google.com/uc?id=1mmDswcNnt6vu3jbO7hgyhV_-ts8Rtm6i&confirm=t&uuid=0349480c-22fe-4e9c-a71b-ddd35b3a6a03
To: /content/model_baseline_epoch_39.pth
100% 504M/504M [00:29<00:00, 16.9MB/s]
Downloading...
From (original): https://drive.google.com/uc?id=1e1NRuB2urJKHQS8pQLizTJ-t8gymnj4R
From (redirected): https://drive.google.com/uc?id=1e1NRuB2urJKHQS8pQLizTJ-t8gymnj4R&confirm=t&uuid=aa721664-afbc-4811-a7d0-4037622ff0cc
To: /content/model_0.5L1_0.5SSIM_epoch_39.pth
100% 504M/504M [00:13<00:00, 38.0MB/s]
Downloading...
From (original): https://drive.google.com/uc?id=1z8S2KSPWHrqrexGD1DdHkD1g2cdlx_aM
From (redirected): https://drive.google.com/uc?id=1z8S2KSPWHrqrexGD1DdHkD1g2cdlx_aM&confirm=t&uuid=990a0076-a084-463b-8353-2c85f7bfee2a
To: /content/model_L2_SSIM_epoch_39.pth
100% 504M/504M [00:13<00:00, 37.6MB/s]


## Dataset Code
The following code is used to create a test dataloader for evaluating entire folders of images at once:

In [3]:
import glob
import random
import os
import numpy as np

from PIL import Image
from torch.utils.data import Dataset
import torchvision.transforms as transforms


class ImageDataset(Dataset):
    def __init__(self, root, transforms_=None, img_size=128, mask_size=64, mode="train"):
        self.transform = transforms.Compose(transforms_)
        self.img_size = img_size
        self.mask_size = mask_size
        self.mode = mode
        self.files = sorted(glob.glob("%s/**/*.jpg" % root, recursive=True))
        # self.files = sorted(glob.glob("{}/**/*.jpg".format(root), recursive=True))

    def apply_random_mask(self, img):
        """Randomly masks image"""
        y1, x1 = np.random.randint(0, self.img_size - self.mask_size, 2)
        y2, x2 = y1 + self.mask_size, x1 + self.mask_size
        masked_part = img[:, y1:y2, x1:x2]
        masked_img = img.clone()
        masked_img[:, y1:y2, x1:x2] = 1

        return masked_img, masked_part, x1, y1

    def apply_center_mask(self, img):
        """Mask center part of image"""
        # Get upper-left pixel coordinate
        i = (self.img_size - self.mask_size) // 2
        masked_img = img.clone()
        masked_img[:, i : i + self.mask_size, i : i + self.mask_size] = 1

        return masked_img, i

    def __getitem__(self, index):
        img = Image.open(self.files[index % len(self.files)])
        img = self.transform(img)

        if self.mode == "train":
            # For training data perform random mask
            masked_img, aux, x1, y1 = self.apply_random_mask(img)
            return img, masked_img, aux, x1, y1 # aux = masked_part
        else:
            # For test data mask the center of the image
            masked_img, aux = self.apply_center_mask(img)
            return img, masked_img, aux

    def __len__(self):
        return len(self.files)

## Context Encoders Model Code

In [4]:
import torch.nn as nn
import torch.nn.functional as F
import torch

# ContextGenerator:
class Generator(nn.Module):
    def __init__(self, channels=3, bottleneck_dim=4000): # bottleneck_dim = 2048
        super(Generator, self).__init__()

        # Basic Block:
        def downsample(in_feat, out_feat, normalize=True):
            layers = [nn.Conv2d(in_feat, out_feat, kernel_size=4, stride=2, padding=1)]
            if normalize:
                layers.append(nn.BatchNorm2d(out_feat, 0.8))
            layers.append(nn.LeakyReLU(0.2)) # inplace=True
            return layers

        # Transpose Block:
        def upsample(in_feat, out_feat, normalize=True):
            layers = [nn.ConvTranspose2d(in_feat, out_feat, kernel_size=4, stride=2, padding=1)]
            if normalize:
                layers.append(nn.BatchNorm2d(out_feat, 0.8))
            layers.append(nn.ReLU())
            return layers

        # Must add final layer *upsample(64, 64) to resize to 128
        self.ContextEncoder = nn.Sequential(
            *downsample(channels, 64, normalize=False),
            *downsample(64, 64),    # if img size = 128
            *downsample(64, 128),
            *downsample(128, 256),
            *downsample(256, 512),
            nn.Conv2d(512, bottleneck_dim, 1) # try kernel_size=4 (last arg)
        )
        self.ContextDecoder = nn.Sequential( # Add nn.BatchNorm2d(bottleneck_dim) and nn.leakyReLU(0.2) before upsamples?
            *upsample(bottleneck_dim, 512),
            *upsample(512, 256),
            *upsample(256, 128),
            *upsample(128, 64),
            *upsample(64, 64),    # needed to resize back to 128 x 128
            nn.Conv2d(64, channels, 3, 1, 1), # try kernel_size=4, stride=2, padding=1, bias=False
            nn.Tanh()
        )

    def forward(self, x):
        x = self.ContextEncoder(x)
        return self.ContextDecoder(x)

# ContextDiscriminator
class Discriminator(nn.Module):
    def __init__(self, channels=3):
        super(Discriminator, self).__init__()

        def discriminator_block(in_filters, out_filters, stride, normalize):
            """Returns layers of each discriminator block"""
            layers = [nn.Conv2d(in_filters, out_filters, 3, stride, 1)]
            if normalize:
                layers.append(nn.InstanceNorm2d(out_filters))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        layers = []
        in_filters = channels
        for out_filters, stride, normalize in [(64, 2, False), (128, 2, True), (256, 2, True), (512, 1, True)]:
            layers.extend(discriminator_block(in_filters, out_filters, stride, normalize))
            in_filters = out_filters

        layers.append(nn.Conv2d(out_filters, 1, 3, 1, 1))

        self.model = nn.Sequential(*layers)
        # Add final conv2d and sigmoid?

    def forward(self, img):
        return self.model(img)

## Install dependency for YOLOv8 pre-trained model

In [5]:
!pip install ultralytics

Collecting ultralytics
  Downloading ultralytics-8.2.5-py3-none-any.whl (754 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m755.0/755.0 kB[0m [31m7.4 MB/s[0m eta [36m0:00:00[0m
Collecting thop>=0.1.1 (from ultralytics)
  Downloading thop-0.1.1.post2209072238-py3-none-any.whl (15 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.1.105 (from torch>=1.8.0->ultralytics)
  Using cached nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (23.7 MB)
Collecting nvidia-cuda-runtime-cu12==12.1.105 (from torch>=1.8.0->ultralytics)
  Using cached nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (823 kB)
Collecting nvidia-cuda-cupti-cu12==12.1.105 (from torch>=1.8.0->ultralytics)
  Using cached nvidia_cuda_cupti_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (14.1 MB)
Collecting nvidia-cudnn-cu12==8.9.2.26 (from torch>=1.8.0->ultralytics)
  Using cached nvidia_cudnn_cu12-8.9.2.26-py3-none-manylinux1_x86_64.whl (731.7 MB)
Collecting nvidia-cublas-cu12==12.1.

## Model for Object Segmentation and Mask Extraction:

In [6]:
import cv2
import torch
import numpy as np
import argparse
import matplotlib.pyplot as plt
from PIL import Image

from ultralytics import YOLO

class Segmentor():
  def __init__(self, input_image_path):
    self.model = YOLO('yolov8m-seg.pt')
    self.input_image_path = input_image_path
    self.image = np.array(Image.open(input_image_path))
    self.load_config()

  def get_results(self, input_file_path):
    return self.model(input_file_path)

  def save_image(self, img, output_path='/', ):
    cv2.imwrite(output_path, img.cpu().numpy())

  def display_image(self, img, title=''):
    plt.figure()
    plt.title(title)
    plt.axis('off')
    plt.tight_layout()
    plt.imshow(img, interpolation='none')
    plt.show()

  def display_mask_overlay(self, img, mask, title=''):
    plt.figure()
    plt.title(title)
    plt.axis('off')
    plt.tight_layout()
    plt.imshow(img, 'gray', interpolation='none')
    mask = mask.cpu().numpy()
    mask = np.ma.masked_where(mask == 0, mask)
    plt.imshow(mask, 'jet', interpolation='none', alpha=0.5)
    plt.show()

  def display_and_save_mask_overlay(self, img, mask, title='', output_path='/'):
    plt.figure()
    plt.title(title)
    plt.axis('off')
    plt.tight_layout()
    plt.imshow(img, 'gray', interpolation='none')
    mask = mask.cpu().numpy()
    mask = np.ma.masked_where(mask == 0, mask)
    plt.imshow(mask, 'jet', interpolation='none', alpha=0.5)
    plt.savefig(output_path)
    plt.show()

  def display_and_save_image(self, img, title='', output_path='/'):
    plt.figure()
    plt.title(title)
    plt.axis('off')
    plt.tight_layout()
    plt.imshow(img, interpolation='none')
    plt.savefig(output_path)
    plt.show()

  #
  def run_detection_segmentation(self):
    self.display_image(self.image, title="Original Image")

    results = self.get_results(self.input_image_path)
    result = results[0] # since we only pass a single image to segment, there is only one result

    pred_array = result.plot()  # BGR numpy array of predictions
    pred_img = Image.fromarray(pred_array[..., ::-1])  # Convert to RGB PIL image

    #self.display_and_save_image(pred_img, output_path='/content/predictions.jpg')
    self.display_image(pred_img, title="Predictions")

    # Get the lists of masks and bounding boxes:
    if result.masks is None:
        return None, None, None
    masks = result.masks.data
    boxes = result.boxes.data

    # Save the masks and boxes:
    self.masks = masks
    self.boxes = boxes
    self.detected_objects = boxes[:, 5]

    detected_objects = boxes[:, 5]

    # Extract a mask with all detected objects:
    obj_indices = torch.where(detected_objects != -1)
    obj_masks = masks[obj_indices]
    obj_mask = torch.any(obj_masks, dim=0).int() * 255
    self.save_image(obj_mask, '/content/all-detected-objects-masks.jpg')

    return masks, detected_objects, pred_img


  # Get the mask that includes specified objects:
  def get_mask(self, objects=None):
    if objects is None:
        objects = self.OBJECTS

    masks, detected_objects, predictions_img = self.run_detection_segmentation()
    if masks is None:
        return None

    # Extract a single mask that contains all segmentations of specified object types:
    object_indices = []

    # Mask for all instances of an object type:
    # for i, seg_class in enumerate(seg_classes):
    for id in objects:

        obj_indices = torch.where(detected_objects == id) # clss - id
        object_indices.append(obj_indices[0])
        obj_masks = masks[obj_indices]
        obj_mask = torch.any(obj_masks, dim=0).int() * 255 # Tensor

        #self.save_image(obj_mask, str(f'/content/object_class{id}_mask.jpg'))

        # Resize mask to image size:
        # print("image shape", img.shape)
        # image_height, image_width = img.shape[:2]
        # obj_mask = cv2.resize(np.array(obj_mask, dtype='uint8'), (image_width, image_height), interpolation=cv2.INTER_CUBIC)
        # OR Resize image to mask size (works better?):
        mask_height, mask_width = obj_mask.shape[:2]
        resized_img = cv2.resize(self.image, (mask_width, mask_height), interpolation=cv2.INTER_CUBIC)

        # Convert into a mask that can be directly applied to an image:
        obj_mask = obj_mask.cpu().numpy()
        actual_mask = np.ma.masked_where(obj_mask == 0, obj_mask)

        # Plot the input image, the mask overlayed on the image, and the image after the mask is applied:
        fig, ax = plt.subplots(nrows=1, ncols=3, tight_layout=True)
        ax[0].set_title("Original Image")
        ax[0].axis('off')
        ax[0].imshow(self.image, 'gray', interpolation='none')

        # Mask overlay:
        ax[1].set_title(str(f'{id} Mask Overlay'))
        ax[1].axis('off')
        ax[1].imshow(resized_img, 'gray', interpolation='none')
        ax[1].imshow(obj_mask, 'jet', interpolation='none', alpha=0.5)

        # Mask applied to image:
        ax[2].set_title(str(f'{id} Mask Directly Applied'))
        ax[2].axis('off')
        ax[2].imshow(resized_img, 'gray', interpolation='none')
        ax[2].imshow(actual_mask, 'jet', interpolation='none', alpha=0.5)
        # plt.savefig(str(f'/content/object_class{id}_applied_mask.jpg'), bbox_inches='tight', pad_inches = 0)
        fig.show()

    # Extract a single mask that contains all segmentations of specified object types:
    object_indices = torch.cat(object_indices, dim=0)
    object_masks = masks[object_indices]
    object_mask = torch.any(object_masks, dim=0).int() * 255
    self.save_image(object_mask, '/content/objects-to-remove-masks.jpg')

    self.display_image(object_mask.cpu().numpy(), title="Mask of Objects to Remove")
    self.display_and_save_mask_overlay(resized_img, object_mask, title="Final Mask Applied to Image", output_path='/content/object_removal_segmentation.jpg')

    # Plot the input image, the object detection predictions, and the extracted mask:
    fig, ax = plt.subplots(nrows=1, ncols=3, tight_layout=True)
    ax[0].set_title("Original Image")
    ax[0].axis('off')
    ax[0].imshow(self.image, cmap='gray', interpolation='none')

    # Mask overlay:
    ax[1].set_title("Predictions")
    ax[1].axis('off')
    ax[1].imshow(predictions_img, cmap='gray', interpolation='none')

    # Mask applied to image:
    ax[2].set_title("Masks")
    ax[2].axis('off')
    ax[2].imshow(object_mask.cpu().numpy(), cmap='gray', interpolation='none')
    plt.savefig(str(f'/content/predictions_and_mask.jpg'), bbox_inches='tight', pad_inches = 0)
    fig.show()

    return object_mask

  def load_config(self):
    parser = argparse.ArgumentParser()
    parser.add_argument('--remove', nargs= '*' ,type=int, help='objects to remove')
    args = parser.parse_args(args=[])

    self.OBJECTS = args.remove if args.remove is not None else [0] # default to removing people

## Install dependency for SSIM and MS-SSIM metrics:

In [7]:
!pip install torchmetrics

Collecting torchmetrics
  Downloading torchmetrics-1.3.2-py3-none-any.whl (841 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m841.5/841.5 kB[0m [31m7.1 MB/s[0m eta [36m0:00:00[0m
Collecting lightning-utilities>=0.8.0 (from torchmetrics)
  Downloading lightning_utilities-0.11.2-py3-none-any.whl (26 kB)
Installing collected packages: lightning-utilities, torchmetrics
Successfully installed lightning-utilities-0.11.2 torchmetrics-1.3.2


## Evaluation Code:
Instructions:
1. Evaluate on folder of images:

Specify --image_folder_path in load_args() by replacing default=''.

2. Evaluate on single image:

Specify --image_path in load_args() by replacing default='' AND set --image_folder_path default to ''.

3. Evaluate single image with object removal:

Same as (2), but specify --remove in load_args by adding default=[list of object ids]. See object_classes.txt for a list of supported objects and their ids.

In [29]:
import os
import argparse
import torch

from torch.autograd import Variable
from torchvision import transforms
from torchvision.utils import save_image
from torch.utils.data import DataLoader
from PIL import Image

def load_args():
    parser = argparse.ArgumentParser()
    # Specify path to image folder:
    parser.add_argument("--image_folder_path", type=str, default='/content/sample_images', help="Path to folder of input images")
    # parser.add_argument("--image_folder_path", type=str, default='', help="Path to folder of input images")

    # Specify path to single image (make sure default='' for --image_folder_path):
    parser.add_argument("--image_path", type=str, default='/content/sample_images_object_removal/beach.jpg', help="Path to image")
    # parser.add_argument("--image_path", type=str, default='/content/sample_images_object_removal/cliff.jpg', help="Path to image")
    # parser.add_argument("--image_path", type=str, default='/content/sample_images_object_removal/giraffe_zebra.jpg', help="Path to image")
    # parser.add_argument("--image_path", type=str, default='/content/sample_images_object_removal/tree.jpg', help="Path to image")

    # Choose which model to use:
    # parser.add_argument("--model_checkpoint", type=str, default='/content/Checkpoints/model_baseline_epoch_39.pth', help="name of the model checkpoint file to use for evaluation")
    parser.add_argument("--model_checkpoint", type=str, default='/content/Checkpoints/model_0.5L1_0.5SSIM_epoch_39.pth', help="name of the model checkpoint file to use for evaluation")
    # parser.add_argument("--model_checkpoint", type=str, default='/content/Checkpoints/model_L2_SSIM_epoch_39.pth', help="name of the model checkpoint file to use for evaluation")

    # Specify which objects to automatically remove:
    #parser.add_argument('--remove', nargs= '*' ,type=int, help='objects to remove')
    parser.add_argument('--remove', default=[0, 2], nargs= '*' ,type=int, help='objects to remove')

    # parser.add_argument("--batch_size", type=int, default=64, help="size of the batches (grid when evaluating folder of images)")
    # parser.add_argument("--num_cols", type=int, default=8, help="Number of images per column in output grid (use with folder of images)")
    parser.add_argument("--batch_size", type=int, default=-1, help="size of the batches (grid when evaluating folder of images)")
    parser.add_argument("--num_cols", type=int, default=-1, help="Number of images per column in output grid (use with folder of images)")

    # Do not change img_size, mask_size, or channels:
    parser.add_argument("--img_size", type=int, default=128, help="size of each image dimension")
    parser.add_argument("--mask_size", type=int, default=64, help="size of random mask")
    parser.add_argument("--channels", type=int, default=3, help="Number of image channels")
    args = parser.parse_args(args=[])
    print(args)
    return args

args = load_args()

os.makedirs("inpainting_results", exist_ok=True)

cuda = True if torch.cuda.is_available() else False
Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor

# Load the generator from the given checkpoint:
generator = Generator(channels=args.channels)
if cuda:
    generator.cuda()
    checkpoint = torch.load(args.model_checkpoint)
    generator.load_state_dict(checkpoint["generator_state_dict"])
else:
    # If using the cpu, specify map_location to load tensors in cpu form:
    checkpoint = torch.load(args.model_checkpoint, map_location=torch.device('cpu'))
    generator.load_state_dict(checkpoint["generator_state_dict"])
generator.eval()

# Create image and mask input transformations:
mask_transforms_ = [
    transforms.Resize((args.img_size, args.img_size), Image.BICUBIC),
    transforms.ToTensor()
]
mask_transforms = transforms.Compose(mask_transforms_) # Callable transformation
eval_transforms_ = [
    transforms.Resize((args.img_size, args.img_size), Image.BICUBIC),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]
eval_transforms = transforms.Compose(eval_transforms_) # Callable transformation

# Define loss functions:
pixelwise_loss_L2 = torch.nn.MSELoss() # L2 Loss
pixelwise_loss_L1 = torch.nn.L1Loss()  # L1 loss
loss_functions = [("L2 Loss", torch.nn.MSELoss()), ("L1 Loss", torch.nn.L1Loss())]

# PSNR = 10 * log_10( (max_dtype_value)^2 / MSE(gen_img, img))
def get_psnr(gen_img, img):
    mse = torch.mean((gen_img - img) ** 2)
    if(mse == 0):  # Perfect reconstruction
        return 100
    psnr = 10 * torch.log10(1.0 / mse)
    return psnr

# SSIM:
def get_ssim(gen_img, img):
    from torchmetrics.image import StructuralSimilarityIndexMeasure
    SSIM = StructuralSimilarityIndexMeasure(data_range=1.0) # images are normalized
    return SSIM(gen_img, img)

# Multi-scale SSIM (currently gives NaN, so it is not in use):
def get_ms_ssim(gen_img, img):
    from torchmetrics.image import MultiScaleStructuralSimilarityIndexMeasure
    # Default kernel_size=11 is too big for 128x128 images
    MS_SSIM = MultiScaleStructuralSimilarityIndexMeasure(data_range=1.0, kernel_size=7, normalize=None)
    return MS_SSIM(gen_img, img)

from torchmetrics.image import StructuralSimilarityIndexMeasure
SSIM_LossFn = StructuralSimilarityIndexMeasure(data_range=1.0) # images are normalized
if cuda:
    SSIM_LossFn.cuda()

# By default, apply center masking. If objects to remove are specified, detect the given objects and delete them
if args.image_folder_path: # Works for folder of jpgs (also recursively takes images in subfolders)
    # Get number of images in input folder:
    num_images = len(glob.glob("%s/**/*.jpg" % args.image_folder_path, recursive=True))
    print("Number of input images:", num_images)

    if args.batch_size == -1:
        args.batch_size = num_images
        print("==> Batch size not specified. Using total number of images: {}".format(num_images))

    # If the number of rows is not specified, take the sqrt of the number of images
    # to produce a square grid
    if args.num_cols == -1:
        import math
        args.num_cols = math.ceil(math.sqrt(args.batch_size))
        print("==> Columns not specified. Using ceil(sqrt(num_images)) = {}".format(args.num_cols))

    # Load the input image folder:
    test_dataloader = DataLoader(
      ImageDataset(args.image_folder_path, transforms_=eval_transforms_, mode="val"),
      batch_size=args.batch_size,
      shuffle=False,
      num_workers=1,
    )

    # Initialize loss counters
    total_L2 = 0
    total_L1 = 0
    total_PSNR = 0
    total_SSIM = 0

    for batch_idx, (samples, masked_samples, i) in enumerate(test_dataloader):
        samples = Variable(samples.type(Tensor))
        masked_samples = Variable(masked_samples.type(Tensor))
        i = i[0].item()  # Upper-left coordinate of mask

        # Generate inpainted images:
        gen_img = generator(masked_samples)
        gen_mask = gen_img[:, :, i : i + args.mask_size, i : i + args.mask_size]
        filled_samples = masked_samples.clone()
        filled_samples[:, :, i : i + args.mask_size, i : i + args.mask_size] = gen_mask

        '''Save Output Images (currently, for first batch):'''
        if batch_idx == 0:
            ''' Save
              (1) Grid of input, masked input, and inpainted output
              (2) Grid of input, masked input, generated images, and inpainted output
              (3) Grid of input images
              (4) Grid of masked input images
              (5) Grid of the resulting inpainted images
            '''
            # Save grid of input, masked input, and inpainted output (originally, nrow=8):
            sample = torch.cat((samples.data, masked_samples.data, filled_samples.data), -2)
            save_image(sample, "inpainting_results/inpainted_grid.png", nrow=args.num_cols, normalize=True)
            # Save grid of input, masked input, generated image, and inpainted output:
            sample = torch.cat((samples.data, masked_samples.data, gen_img.data, filled_samples.data), -2)
            save_image(sample, "inpainting_results/inpainted_grid_with_gen_img.png", nrow=args.num_cols, normalize=True)
            # Save grid of input images:
            save_image(samples, "inpainting_results/original_grid.png", nrow=args.num_cols, normalize=True)
            # Save grid of masked input images:
            save_image(masked_samples, "inpainting_results/masked_grid.png", nrow=args.num_cols, normalize=True)
            # Save grid of inpainted images:
            save_image(filled_samples, "inpainting_results/inpainted_images_grid.png", nrow=args.num_cols, normalize=True)

        # Get losses for each individual image (top to bottom, left to right):
        for idx, (sample, filled_sample) in enumerate(zip(samples, filled_samples)):
            image_idx = batch_idx * args.batch_size + idx

            # Reshape into BxCxHxW (B = 1 since the batch size is a single image here)
            # This is needed for torch SSIM loss
            sample = sample.unsqueeze(0)
            filled_sample = filled_sample.unsqueeze(0)

            # Calculate losses for the current image:
            L2_Loss = pixelwise_loss_L2(filled_sample, sample)
            L1_Loss = pixelwise_loss_L1(filled_sample, sample)
            PSNR_Loss = get_psnr(filled_sample, sample)
            SSIM_Loss = SSIM_LossFn(filled_sample, sample) # get_ssim(filled_sample, sample)
            #MS_SSIM_Loss = get_ms_ssim(filled_sample, sample)
            print("==> Calculating Reconstruction Loss {}:".format(image_idx))
            print("L1: {0:.4f}".format(L1_Loss.item()))
            print("L2: {0:.4f}".format(L2_Loss.item()))
            print("PSNR: {0:.4f}".format(PSNR_Loss.item())) # NOTE THAT PSNR AND L2 ARE RELATED
            print("SSIM: {0:.4f}".format(SSIM_Loss.item()))
            #print("MS-SSIM: {0:.4f}".format(MS_SSIM_Loss.item()))

            total_L2 += L2_Loss.item()
            total_L1 += L1_Loss.item()
            total_PSNR += PSNR_Loss.item()
            total_SSIM += SSIM_Loss.item()

    # Calculate average loss:
    mean_L2 = total_L2 / num_images
    mean_L1 = total_L1 / num_images
    mean_PSNR = total_PSNR / num_images
    mean_SSIM = total_SSIM / num_images
    print("==> Calculating average loss statistics:")
    print("Mean L1: {0:.4f}".format(mean_L1))
    print("Mean L2: {0:.4f}".format(mean_L2))
    print("Mean PSNR: {0:.4f}".format(mean_PSNR))
    print("Mean SSIM: {0:.4f}".format(mean_SSIM))

elif args.remove is None:
    img = eval_transforms(img)
    """Mask center part of image"""
    # Get upper-left pixel coordinate
    i = (args.img_size - args.mask_size) // 2
    masked_img = img.clone()
    masked_img[:, i : i + args.mask_size, i : i + args.mask_size] = 1

    # unsqueeze 0 to add extra dimension at position 0 (needed for generator)
    # (Reshape to BxCxHxW)
    img = Variable(img.type(Tensor)).unsqueeze(0)
    masked_img = Variable(masked_img.type(Tensor)).unsqueeze(0)

    # Generated the inpainted image:
    gen_img = generator(masked_img)
    gen_mask = gen_img[:, :, i : i + args.mask_size, i : i + args.mask_size]
    inpainted_img = masked_img.clone()
    inpainted_img[:, :, i : i + args.mask_size, i : i + args.mask_size] = gen_mask

    # Calculate loss metrics between original image and output image:
    L2_Loss = pixelwise_loss_L2(inpainted_img, img)
    L1_Loss = pixelwise_loss_L1(inpainted_img, img)
    PSNR_Loss = get_psnr(inpainted_img, img)
    SSIM_Loss = SSIM_LossFn(inpainted_img, img)
    print("==> Calculating Reconstruction Loss:")
    print("L1: {0:.4f}".format(L1_Loss.item()))
    print("L2: {0:.4f}".format(L2_Loss.item()))
    print("PSNR: {0:.4f}".format(PSNR_Loss.item())) # NOTE THAT PSNR AND L2 ARE RELATED
    print("SSIM: {0:.4f}".format(SSIM_Loss.item()))

    # print("==> Calculating Reconstruction Loss:")
    # for (name, pixelwise_loss) in loss_functions:
    #     loss = pixelwise_loss(inpainted_img, img)
    #     print("{}: {}".format(name, loss))

    # Save sample
    sample = torch.cat((img.data, masked_img.data, inpainted_img.data), 0)
    output_file = args.image_path.split("/")[-1]
    save_image(sample, "inpainting_results/inpainted_%s" % output_file, nrow=3, normalize=True) #nrow=3 since the sample has 3 images

else: # If we are detecting / removing objects:

    # Detect objects and extract mask:
    img = Image.open(args.image_path)
    img = np.array(img)
    seg = Segmentor(args.image_path)
    mask = seg.get_mask(args.remove)
    if mask is None:
        print("==> No objects detected. Exiting program.")
    else:
        mask = mask.detach().cpu().numpy().astype(float)

        # Resize image to mask size:
        import cv2
        mask_height, mask_width = mask.shape[:2]
        resized_img = cv2.resize(img, (mask_width, mask_height), interpolation=cv2.INTER_CUBIC)
        masked_img = resized_img.copy()
        masked_img[mask != 0] = 255

        # Convert back into PIL Images before applying torch transforms:
        img = Image.fromarray(img)
        mask = Image.fromarray(mask)
        masked_img = Image.fromarray(masked_img)

        # Apply transforms to input image and mask:
        img = eval_transforms(img)
        mask = mask_transforms(mask)
        masked_img = eval_transforms(masked_img)

        # Convert to Tensors and reshape from CxHxW to BxCxHxW:
        img = Variable(img.type(Tensor)).unsqueeze(0)
        mask = Variable(mask.type(Tensor)).unsqueeze(0)
        masked_img = Variable(masked_img.type(Tensor)).unsqueeze(0)

        # Generate image on masked image:
        gen_img = generator(masked_img)

        # Convert mask to have 3D channels:
        mask = torch.cat((mask, mask, mask), dim=1)

        # Find generated mask and inpaint this area onto a cloned version of the masked image:
        mask_indices = (mask != 0).nonzero(as_tuple=True) # get new indices since they are tensors with extra dim?
        gen_mask = gen_img[mask_indices]
        inpainted_img = masked_img.clone()
        inpainted_img[mask_indices] = gen_mask

        # NOTE: Evaluating loss and reconstruction metrics won't work here
        # since we have no ground truth image.

        # Save sample:
        output_file = args.image_path.split("/")[-1]
        sample = torch.cat((img.data, masked_img.data, inpainted_img.data), 0)
        save_image(sample, "inpainting_results/inpainted_%s" % output_file, nrow=3, normalize=True)
        sample = torch.cat((img.data, masked_img.data, gen_img.data, inpainted_img.data), 0)
        save_image(sample, "inpainting_results/inpainted_with_gen_img_%s" % output_file, nrow=4, normalize=True)
        save_image(mask.data, "inpainting_results/inpainted_mask%s" % output_file, nrow=1, normalize=True)


Namespace(image_folder_path='/content/sample_images', image_path='/content/sample_images_object_removal/beach.jpg', model_checkpoint='/content/Checkpoints/model_0.5L1_0.5SSIM_epoch_39.pth', remove=[0, 2], batch_size=-1, num_cols=-1, img_size=128, mask_size=64, channels=3)
Number of input images: 64
==> Batch size not specified. Using total number of images: 64
==> Columns not specified. Using ceil(sqrt(num_images)) = 8
==> Calculating Reconstruction Loss 0:
L1: 0.0455
L2: 0.0148
PSNR: 18.2918
SSIM: 0.7542
==> Calculating Reconstruction Loss 1:
L1: 0.0560
L2: 0.0213
PSNR: 16.7069
SSIM: 0.7128
==> Calculating Reconstruction Loss 2:
L1: 0.0572
L2: 0.0288
PSNR: 15.4038
SSIM: 0.7749
==> Calculating Reconstruction Loss 3:
L1: 0.0498
L2: 0.0207
PSNR: 16.8306
SSIM: 0.7852
==> Calculating Reconstruction Loss 4:
L1: 0.0842
L2: 0.0573
PSNR: 12.4159
SSIM: 0.7280
==> Calculating Reconstruction Loss 5:
L1: 0.0859
L2: 0.0474
PSNR: 13.2443
SSIM: 0.7118
==> Calculating Reconstruction Loss 6:
L1: 0.0670