<a href="https://colab.research.google.com/github/peace-and-harmony/image-matting/blob/main/notebooks/modnet_quick_inference.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# MODNet quick inference

## 1. Preparation

In [1]:
%%capture
!git clone https://github.com/peace-and-harmony/image-matting.git
!git clone https://github.com/ZHKKKe/MODNet.git
!pip install --upgrade pillow

## 2. Upload Images

<p align="justify">Upload clothing images to extract foreground object (only PNG and JPG format are supported):</p>

In [None]:
import shutil
from google.colab import files
import os

# clean and rebuild the image folders
input_folder = '/content/input'
if os.path.exists(input_folder):
  shutil.rmtree(input_folder)
os.makedirs(input_folder)

output_folder = '/content/output/'
if os.path.exists(output_folder):
  shutil.rmtree(output_folder)
os.makedirs(output_folder)

# upload images (PNG or JPG)
image_names = list(files.upload().keys())
for image_name in image_names:
  shutil.move(image_name, os.path.join(input_folder, image_name))

## 3. Quick inference

In [None]:
import numpy as np
from PIL import Image

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms

from MODNet.src.models.modnet import MODNet

if __name__ == '__main__':
  # define image to tensor transform
  im_transform = transforms.Compose(
      [
          transforms.ToTensor(),
          transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
      ]
  )
  if torch.cuda.is_available():
    device = torch.device('cuda')
    print('using gpu!')
  else:
    device = torch.device('cpu')

  # create MODNet and load the pre-trained ckpt
  state = torch.load('/content/image-matting/pretrained_weights/checkpoint.pth', map_location=device)
  modnet = MODNet(backbone_pretrained=False)
  modnet = nn.DataParallel(modnet)
  modnet.load_state_dict(state['state_dict'])

  # inference images
  im_names = os.listdir(input_folder)
  for im_name in im_names:
    print('Process image: {0}'.format(im_name))
    # read image
    im = Image.open(os.path.join(input_folder, im_name))
    # unify image channels to 3
    im = np.asarray(im)
    if len(im.shape) == 2:
        im = im[:, :, None]
    if im.shape[2] == 1:
        im = np.repeat(im, 3, axis=2)
    elif im.shape[2] == 4:
        im = im[:, :, 0:3]
    # convert image to PyTorch tensor
    im = Image.fromarray(im)
    im = im_transform(im)

    # add mini-batch dim
    im = im[None, :, :, :]
    im_b, im_c, im_h, im_w = im.shape

    # resize image for input
    im_rh, im_rw = (512, 512)
    im = F.interpolate(im, size=(im_rh, im_rw), mode='area')

    # inference
    # _, _, matte = modnet(im.cuda(), False)
    _, _, matte = modnet(im, False)

    # resize and save matte
    matte = F.interpolate(matte, size=(im_h, im_w), mode='area')
    matte = matte[0][0].data.cpu().numpy()
    matte_name = im_name.split('.')[0] + '.png'
    Image.fromarray(((matte * 255).astype('uint8')), mode='L').save(os.path.join(output_folder, matte_name))

def combined_display(image, matte):
  # calculate display resolution
  w, h = image.width, image.height
  rw, rh = 800, int(h * 800 / (3 * w))
  
  # obtain predicted foreground
  image = np.asarray(image)
  if len(image.shape) == 2:
    image = image[:, :, None]
  if image.shape[2] == 1:
    image = np.repeat(image, 3, axis=2)
  elif image.shape[2] == 4:
    image = image[:, :, 0:3]
  matte = np.repeat(np.asarray(matte)[:, :, None], 3, axis=2) / 255
  foreground = image * matte + np.full(image.shape, 255) * (1 - matte)
  
  # combine image, foreground, and alpha into one line
  combined = np.concatenate((foreground, matte * 255, image), axis=1)
  combined = Image.fromarray(np.uint8(combined)).resize((rw, rh))
  return combined

# visualize all images
image_names = os.listdir(input_folder)
for image_name in image_names:
  matte_name = image_name.split('.')[0] + '.png'
  image = Image.open(os.path.join(input_folder, image_name))
  matte = Image.open(os.path.join(output_folder, matte_name))
  display(combined_display(image, matte))
  print(image_name, '\n')
# When inference using GPU, need to retart runtime of this cell.