# Generate trimaps from U-2-Net alpha matte

This is the 3rd step in our workflow to remove the background from an image:

1. Use U-2-Net pre-trained model to generate a first alpha matte
2. **Use the U-2-Net alpha matte to generate a trimap
3. **Train MODNet model with the original image, the trimap and ground truth image from DUTS dataset** (the current colab notebook)

## Sources:
* 


# Import

In [None]:
# import modules to handle files
import os
import glob
import shutil
from google.colab import drive
from PIL import Image

# import modules to train models
import torch
import torch.nn.functional as F
import torchvision.transforms as transforms

import pandas as pd
import numpy as np

# Mount Google Drive

In [None]:
drive.mount('/content/drive/')

Mounted at /content/drive/


# Clone MODNet repo & download pre-trained model

In [None]:
# clone the repository
%cd /content
if not os.path.exists('MODNet'):
  !git clone https://github.com/ZHKKKe/MODNet
%cd MODNet/

/content
Cloning into 'MODNet'...
remote: Enumerating objects: 213, done.[K
remote: Counting objects: 100% (7/7), done.[K
remote: Compressing objects: 100% (7/7), done.[K
remote: Total 213 (delta 1), reused 0 (delta 0), pack-reused 206[K
Receiving objects: 100% (213/213), 37.62 MiB | 32.84 MiB/s, done.
Resolving deltas: 100% (62/62), done.
/content/MODNet


In [None]:
# copy pre-trained model to directory
current_path = "/content/drive/MyDrive/Faktion/exploration/pretrained_models/mobilenetv2_human_seg.ckpt"
dst_path = "/content/MODNet/pretrained/mobilenetv2_human_seg.ckpt"
shutil.copy(current_path, dst_path)

'/content/MODNet/pretrained/mobilenetv2_human_seg.ckpt'

# Train model to predict alpha matte

## Functions

In [None]:
def process_image(image_path):
  """
  Function to process image into the input format required
  for model
  """
  
  # read image
  im = Image.open(image_path)


  # define image to tensor transform
  im_transform = transforms.Compose(
      [
          transforms.ToTensor(),
          transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
      ]
  )

  # unify image channels to 3
  im = np.asarray(im)
  if len(im.shape) == 2:
      im = im[:, :, None]
  if im.shape[2] == 1:
      im = np.repeat(im, 3, axis=2)
  elif im.shape[2] == 4:
      im = im[:, :, 0:3]

  # convert image to PyTorch tensor
  im = Image.fromarray(im)
  im = im_transform(im).cuda()

  # add mini-batch dim
  im = im[None, :, :, :]

  # resize image for input
  im_b, im_c, im_h, im_w = im.shape

  if max(im_h, im_w) < ref_size or min(im_h, im_w) > ref_size:
      if im_w >= im_h:
          im_rh = ref_size
          im_rw = int(im_w / im_h * ref_size)
      elif im_w < im_h:
          im_rw = ref_size
          im_rh = int(im_h / im_w * ref_size)
  else:
      im_rh = im_h
      im_rw = im_w

  im_rw = im_rw - im_rw % 32
  im_rh = im_rh - im_rh % 32
  im = F.interpolate(im, size=(im_rh, im_rw), mode='area')

  return im

In [19]:
def get_image_paths(image_dir_list: list) -> pd.DataFrame:
  """
  Function to get a dataframe with the different paths for each image
  """
  # create empty list
  image_df_list = [] # will contain a dataframe for each directory

  # create dataframes with list of image paths per directory
  for image_dir in image_dir_list:
    image_path_list = glob.glob(image_dir + os.sep + '*')
    image_df = pd.DataFrame(image_path_list, columns=[image_dir])

    # create a filename column by extracting the image filename without extension
    image_df["filename"] = image_df[image_dir].str.split(os.sep).str[-1].str.split(".").str[0]

    # append dataframe to dataframe list
    image_df_list.append(image_df)
    print(image_df.shape)

  # merge dataframes on the image filename
  df = pd.merge(image_df_list[0], image_df_list[1], on="filename")
  print(df.shape)
  df = pd.merge(df, image_df_list[2], on="filename")
  print(df.shape)

  return df

In [42]:
def dataloader():
  """
  Function to load the model with the paths of:
  * original images
  * trimaps
  * ground truth images
  """
  df = get_image_paths(paths_list)
  paths_list = []
  for column_name in df.columns:
    if column_name != "filename":
      paths = df[column_name].to_list()
      paths_list.append(paths)
    
    image_list, trimap_list, gt_path_list = tuple(paths_list)
  return image_list, trimap_list, gt_path_list

In [44]:
paths_list = []
for column_name in df.columns:
  if column_name != "filename":
    print(column_name)
    paths = df[column_name].to_list()
    paths_list.append(paths)

len(paths_list)


/content/drive/MyDrive/Faktion/DUTS/DUTS-TR/DUTS-TR-Image
/content/drive/MyDrive/Faktion/DUTS/DUTS-TR/DUTS-TR-Trimap
/content/drive/MyDrive/Faktion/DUTS/DUTS-TR/DUTS-TR-Mask


3

In [46]:
zip(paths_list)

<zip at 0x7fbbff0aacd0>

In [None]:
# get list of image directory paths
src_dir = "/content/drive/MyDrive/Faktion/DUTS/DUTS-TR"
image_dir_list = ["DUTS-TR-Image", "DUTS-TR-Trimap", "DUTS-TR-Mask"]

## convert paths into absolute paths
image_dir_list = [os.path.join(src_dir, image_dir) for image_dir in image_dir_list]

['/content/drive/MyDrive/Faktion/DUTS/DUTS-TR/DUTS-TR-Image',
 '/content/drive/MyDrive/Faktion/DUTS/DUTS-TR/DUTS-TR-Trimap',
 '/content/drive/MyDrive/Faktion/DUTS/DUTS-TR/DUTS-TR-Mask']

In [25]:
# df = get_image_paths(image_dir_list)
df.loc.head()

TypeError: ignored

In [43]:
for idx, (image, trimap, gt_matte) in enumerate(dataloader(df)):
  print(idx, image, trimap, gt_matte)

ValueError: ignored

In [None]:
from natsort import natsorted, ns

def dataloader(file_location='/content/drive/MyDrive/Faktion/DUTS/DUTS-TR/'):
    path = file_location
    image_path = os.listdir(path + 'DUTS-TR-Image')
    trimap_path = os.listdir(path + 'DUTS-TR-Trimap')
    gt_path = os.listdir(path + 'DUTS-TR-Mask')

    try:
      image_path.remove(".ipynb_checkpoints")
    except:
      pass
    try:
      trimap_path.remove(".ipynb_checkpoints")
    except:
      pass
    try:
      gt_path.remove(".ipynb_checkpoints")
    except:
      pass


    # sort in an order
    image_path = natsorted(image_path, alg=ns.IGNORECASE)
    trimap_path = natsorted(trimap_path, alg=ns.IGNORECASE)
    gt_path = natsorted(gt_path, alg=ns.IGNORECASE)


    return zip(image_path[:10], trimap_path[:10], gt_path[:10])

## Train model

In [10]:
# copy modified modnet script that trains the model
current_path = "/content/drive/MyDrive/Faktion/exploration/modnet_trainer_modified.py"
dst_path = "/content/MODNet/src/modnet_trainer_modified.py"
shutil.copy(current_path, dst_path)


# cd to repository
%cd ..
%cd /content/MODNet/

# import local modules
from src.models.modnet import MODNet
from src.modnet_trainer_modified import supervised_training_iter

[Errno 2] No such file or directory: 'MODNet/'
/content/MODNet


In [None]:
bs = 16         # batch size
lr = 0.001       # learn rate
epochs = 50     # total epochs

modnet = torch.nn.DataParallel(MODNet()).cuda()
optimizer = torch.optim.SGD(modnet.parameters(), lr=lr, momentum=0.9)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=int(0.25 * epochs), gamma=0.1)

# define hyper-parameters
ref_size = 512

for epoch in range(0, epochs):

  for idx in df.index[:10]:
  image_list = []
  image_path_list = df.iloc[idx, [0, 2, 3]].to_list()

  for image_path in image_path_list:
    im = process_image(image_path)
    image_list.append(im)

  for idx, (image, trimap, gt_matte) in enumerate(dataloader):
    semantic_loss, detail_loss, matte_loss = \
    supervised_training_iter(modnet, optimizer, image, trimap, gt_matte)
  lr_scheduler.step()

## Save & load model
Parameters for PyTorch networks stored in the model's `state_dict`

[Source](https://colab.research.google.com/github/agungsantoso/deep-learning-v2-pytorch/blob/master/intro-to-pytorch/Part%206%20-%20Saving%20and%20Loading%20Models.ipynb#scrollTo=lBqSgQCNpCX4)

In [None]:
print("The trained model: \n\n", modnet, '\n')
print("The state_dict keys: \n\n", modnet.state_dict().keys())

In [None]:
from google.colab import files

# # create dictionary with all information necessary to rebuild the model.
# checkpoint = {'input_size': 784,
#               'output_size': 10,
#               'hidden_layers': [each.out_features for each in model.hidden_layers],
#               'state_dict': modnet.state_dict()}

# # save model's architecture and state_dict
# torch.save(checkpoint, 'model_checkpoint.pth')

# save model's state_dict
torch.save(modnet.state_dict(), 'model_checkpoint.pth')

# download checkpoint file
files.download('model_checkpoint.pth')

In [None]:
# load model's state_dict
state_dict = torch.load('model_checkpoint.pth')
print(state_dict.keys())

# load state_dict into the network (works only if model architecture is the same as checkpoint architecture)
modnet.load_state_dict(state_dict)

In [None]:
# def load_checkpoint(filepath):
#   """
#   Function to load saved model and rebuild it.
#   """
#     checkpoint = torch.load(filepath)
#     model = modnet.Network(checkpoint['input_size'],
#                              checkpoint['output_size'],
#                              checkpoint['hidden_layers'])
#     model.load_state_dict(checkpoint['state_dict'])
    
#     return model
model = load_checkpoint('checkpoint.pth')
print(model)

# Predict alpha matte with trained model

In [None]:
image_path = "/content/U-2-Net/test_data/test_images/0002-01.jpg"
trimap_path = "/content/drive/MyDrive/trimaps/0002-01.png"
im = process_image(image_path)
trimap = process_image(trimap_path)

_, _, matte = modnet(im, True)

# resize and save matte
matte = F.interpolate(matte, size=(im_h, im_w), mode='area')
matte = matte[0][0].data.cpu().numpy()
matte_name = 'test_b_v2.png'

Image.fromarray(((matte * 255).astype('uint8')), mode='L').save(os.path.join("/content", matte_name))

# Performance metrics

## Functions

In [None]:
# compute the MSE error given a prediction, a ground truth and a trimap.
# pred: the predicted alpha matte
# target: the ground truth alpha matte
# trimap: the given trimap
#
def compute_mse(pred, alpha, trimap):
    num_pixels = float((trimap == 127).sum())
    return ((pred - alpha) ** 2).sum() / num_pixels


# compute the SAD error given a prediction and a ground truth.
#
def compute_sad(pred, alpha):
    diff = np.abs(pred - alpha)
    return np.sum(diff) / 1000

In [None]:
# import cv2

gt_matte = cv2.imread("/content/drive/MyDrive/Faktion/exploration/dataset/ground_truths/ILSVRC2012_test_00000018.png")
trimap = cv2.imread("/content/drive/MyDrive/Faktion/exploration/dataset/trimaps/ILSVRC2012_test_00000018.png")
pred = cv2.imread("/content/test.png")

print(compute_mse(pred, gt_matte, trimap))
print(compute_sad(pred, gt_matte))