# Train MODNet model

This is the 3rd step in our workflow to remove the background from an image:

1. Use U-2-Net pre-trained model to generate a first alpha matte
2. **Use the U-2-Net alpha matte to generate a trimap
3. **Train MODNet model with the original image, the trimap and ground truth image from DUTS dataset** (the current colab notebook)

## Sources:
* 


# Import

In [2]:
# import modules to handle files
import os
import glob
import shutil
# from google.colab import drive
from PIL import Image

# import modules to train models
import torch
import torch.nn.functional as F
import torchvision.transforms as transforms
import pandas as pd
import numpy as np

# Mount Google Drive

In [None]:
# drive.mount('/content/drive/')

Mounted at /content/drive/


# Clone MODNet repo & download pre-trained model

In [3]:
# # clone the repository
# %cd /content
if not os.path.exists('MODNet'):
  !git clone https://github.com/ZHKKKe/MODNet
%cd MODNet/

[Errno 2] No such file or directory: '/content'
/Users/derrickvanfrausum/BeCode_AI/git-repos/Remove_Image_Background/core/exploration
Cloning into 'MODNet'...
remote: Enumerating objects: 213, done.[K
remote: Counting objects: 100% (7/7), done.[K
remote: Compressing objects: 100% (7/7), done.[K
remote: Total 213 (delta 1), reused 0 (delta 0), pack-reused 206[K
Receiving objects: 100% (213/213), 37.62 MiB | 3.11 MiB/s, done.
Resolving deltas: 100% (62/62), done.
/Users/derrickvanfrausum/BeCode_AI/git-repos/Remove_Image_Background/core/exploration/MODNet


In [None]:
# # copy pre-trained model to directory
# current_path = "/content/drive/MyDrive/BeCode/Projects/Faktion/exploration/pretrained_models/mobilenetv2_human_seg.ckpt"
# dst_path = "/content/MODNet/pretrained/mobilenetv2_human_seg.ckpt"
# shutil.copy(current_path, dst_path)

'/content/MODNet/pretrained/mobilenetv2_human_seg.ckpt'

In [4]:
# copy pre-trained model to directory
current_path = "/Users/derrickvanfrausum/BeCode_AI/git-repos/Remove_Image_Background/core/assets/pretrained_models/mobilenetv2_human_seg.ckpt"
dst_path = "/Users/derrickvanfrausum/BeCode_AI/git-repos/Remove_Image_Background/core/exploration/MODNet/pretrained/mobilenetv2_human_seg.ckpt"
shutil.move(current_path, dst_path)

'/Users/derrickvanfrausum/BeCode_AI/git-repos/Remove_Image_Background/core/exploration/MODNet/pretrained/mobilenetv2_human_seg.ckpt'

# Train model to predict alpha matte

In [66]:
global device

# define device
device = "cuda" if torch.cuda.is_available() else "cpu"

## Functions

In [71]:
def process_image(image_path, binary=False):
  """
  Function to process image into the input format required
  for model
  """
  
  # read image
  im = Image.open(image_path)


  # define image to tensor transform
  im_transform = transforms.Compose(
      [
          transforms.ToTensor(),
          transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
      ]
  )

  gt_transform = transforms.Compose(
      [
          transforms.ToTensor(),
          transforms.Normalize((0.5), (0.5))
      ]
  )

  # unify image channels to 3
  if binary == False:
      
    im = np.asarray(im)
    if len(im.shape) == 2:
        im = im[:, :, None]
    if im.shape[2] == 1:
        im = np.repeat(im, 3, axis=2)
    elif im.shape[2] == 4:
        im = im[:, :, 0:3]
        
    # convert image to PyTorch tensor
    im = Image.fromarray(im)
    im = im_transform(im).to(device)

  else:
    im = im.convert('1') 
    im = np.asarray(im)

    # convert image to PyTorch tensor
    im = Image.fromarray(im)
    im = gt_transform(im).to(device)


  # add mini-batch dim
  im = im[None, :, :, :]

  # resize image for input
  im_b, im_c, im_h, im_w = im.shape

  if max(im_h, im_w) < ref_size or min(im_h, im_w) > ref_size:
      if im_w >= im_h:
          im_rh = ref_size
          im_rw = int(im_w / im_h * ref_size)
      elif im_w < im_h:
          im_rw = ref_size
          im_rh = int(im_h / im_w * ref_size)
  else:
      im_rh = im_h
      im_rw = im_w

  im_rw = im_rw - im_rw % 32
  im_rh = im_rh - im_rh % 32
  im = F.interpolate(im, size=(im_rh, im_rw), mode='area')

  return im

In [6]:
def get_image_paths(image_dir_list: list) -> pd.DataFrame:
  """
  Function to get a dataframe with the different paths for each image
  """
  # create empty list
  image_df_list = [] # will contain a dataframe for each directory

  # create dataframes with list of image paths per directory
  for image_dir in image_dir_list:
    image_path_list = glob.glob(image_dir + os.sep + '*')
    image_df = pd.DataFrame(image_path_list, columns=[image_dir])

    # create a filename column by extracting the image filename without extension
    image_df["filename"] = image_df[image_dir].str.split(os.sep).str[-1].str.split(".").str[0]

    # append dataframe to dataframe list
    image_df_list.append(image_df)
    print(image_df.shape)

  # merge dataframes on the image filename
  df = pd.merge(image_df_list[0], image_df_list[1], on="filename")
  print(df.shape)
  df = pd.merge(df, image_df_list[2], on="filename")
  print(df.shape)

  return df

In [7]:
def dataloader(image_paths_df):
  """
  Function to load the model with the paths of:
  * original images
  * trimaps 
  * ground truth images
  """

  paths_list = []
  print(image_paths_df.columns)
  for column_name in image_paths_df.columns:
    if column_name != "filename":
      print(column_name)
      paths = image_paths_df[column_name].to_list()
      paths_list.append(paths)
    
  image_list, trimap_list, gt_path_list = paths_list[0], paths_list[1], paths_list[2]

  return zip(image_list, trimap_list, gt_path_list)

## Train model

In [63]:
## copy modified modnet script that trains the model
# current_path = "/content/drive/MyDrive/BeCode/Projects/Faktion/exploration/modnet_trainer_modified.py"
# dst_path = "/content/MODNet/src/modnet_trainer_modified.py"
# shutil.copy(current_path, dst_path)


# cd to repository
%cd /Users/derrickvanfrausum/BeCode_AI/git-repos/Remove_Image_Background/core/exploration/MODNet

# import local modules
from MODNet.src.models.modnet import MODNet
from MODNet.src.trainer_modified import supervised_training_iter

/Users/derrickvanfrausum/BeCode_AI/git-repos/Remove_Image_Background/core/exploration/MODNet


In [None]:
# importing and activating TensorBoard

from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter()

In [19]:
# get list of image directory paths
src_dir = "/Users/derrickvanfrausum/BeCode_AI/git-repos/Remove_Image_Background/core/assets/dataset/DUTS/DUTS-TR"
image_dir_list = ["DUTS-TR-Image", "DUTS-TR-Trimap", "DUTS-TR-Mask"]

## convert paths into absolute paths
image_dir_list = [os.path.join(src_dir, image_dir) for image_dir in image_dir_list]

df = get_image_paths(image_dir_list)

# save df to csv
df.to_csv("/Users/derrickvanfrausum/BeCode_AI/git-repos/Remove_Image_Background/core/assets/dataset/DUTS/image_paths.csv", index=False)

(10553, 2)
(10553, 2)
(10553, 2)
(10553, 3)
(10553, 4)


In [21]:
# # load csv into dataframe
# df = pd.read_csv("/Users/derrickvanfrausum/BeCode_AI/git-repos/Remove_Image_Background/core/assets/dataset/DUTS/image_paths.csv")

# hyperparameters
batch_size = 16
lr = 0.001      # learn rate
epochs = 40     # total epochs

# split dataframe into batches
batches = np.array_split(df, batch_size)

# define device
device = "cuda" if torch.cuda.is_available() else "cpu"

In [46]:
import re
from typing import Optional

In [47]:
def get_latest_model_chekpoint_path(saved_models_dir) -> Optional[str]:
    """
    Function to get the path of the latest saved 
    model checkpoint
    
    returns: None or path in string
    """
    model_checkpoint_path_list = glob.glob(saved_models_dir + os.sep + 'model_checkpoint_epoch' + '*')
    if len(model_checkpoint_path_list) > 0:
        model_idx_list = []
        for filepath in model_checkpoint_path_list:
            model_idx = int(re.findall(r'\d+', filepath)[0])
            model_idx_list.append(model_idx)

        # convert list to array & get latest idx
        latest_idx = np.argmax(np.array(model_idx_list))
        latest_checkpoint = model_checkpoint_path_list[latest_idx]
        print("latest model checkpoint: ", latest_checkpoint)
        return latest_checkpoint

In [75]:
# initialize model
modnet = torch.nn.DataParallel(MODNet()).to(device)

# model index: updated if training starts from saved models
model_idx = -1

# load latest saved trained model if available
saved_models_dir = "/Users/derrickvanfrausum/BeCode_AI/git-repos/Remove_Image_Background/core/assets/saved_models"
latest_model_chekpoint_path = get_latest_model_chekpoint_path(saved_models_dir)

if latest_model_chekpoint_path:
  # load model's state_dict
  state_dict = torch.load(
    latest_model_chekpoint_path,
    map_location=torch.device(device) # map to device
      )
  # print(state_dict.keys())

  # load state_dict into the network (works only if model architecture is the same as checkpoint architecture)
  modnet.load_state_dict(state_dict)

  # get path of loss metrics checkpoint
  model_idx = re.findall(r'\d+', path)[0]
  latest_loss_chekpoint_path = os.path.join(
    saved_models_dir,
    f"loss_checkpoint_epoch_{model_idx}.pth")
  print(latest_loss_chekpoint_path)

  # load loss metrics checkpoints
  loss_dict = torch.load(
    latest_loss_chekpoint_path,
    map_location=torch.device(device) # map to device
      )
  # print(loss_dict.keys())

else:
  # create empty dict
  loss_dict = {}
  loss_dict["semantic_loss"] = []
  loss_dict["detail_loss"] = []
  loss_dict["matte_loss"] = []

# initialize optimizer
optimizer = torch.optim.SGD(modnet.parameters(), lr=lr, momentum=0.9)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=int(0.25 * epochs), gamma=0.1)

# define hyper-parameters
ref_size = 512

# training loop starting from latest saved epoch
for epoch in range(int(model_idx) + 1, epochs + 1):
  print("\n\n*****************************************")
  print("epoch:",epoch)
  print("*****************************************")

  for batch_df in batches:
    print(type(batch_df))
    for idx, (image_path, trimap_path, gt_matte_path) in enumerate(dataloader(batch_df)):
      print(idx, image_path, trimap_path, gt_matte_path)
      image = process_image(image_path)
      trimap = process_image(trimap_path)
      gt_matte = process_image(gt_matte_path, True)

      semantic_loss, detail_loss, matte_loss = \
      supervised_training_iter(modnet, optimizer, image, trimap, gt_matte)

    lr_scheduler.step()

  # save model's state_dict
  torch.save(
    modnet.state_dict(),
    os.path.join(
      saved_models_dir,
      f"model_checkpoint_epoch_{epoch}.pth"
      )
      )

  # updates loss dict
  # --------------------------------------------------------------------
  # WE NEED TO SAVE THE LOSSES AFTER EVERY EPOCH TO PLOT THEM IN A GRAPH
  # --------------------------------------------------------------------
  loss_dict["semantic_loss"].append(semantic_loss)
  loss_dict["detail_loss"].append(detail_loss)
  loss_dict["matte_loss"].append(matte_loss)

  # --------------------------------------------------------------------
  # Saving the loss values for tensorboard
  # --------------------------------------------------------------------
  writer.add_scalar("semantic_loss/train", semantic_loss, epoch)
  writer.add_scalar("detail_loss/train", detail_loss, epoch)
  writer.add_scalar("matte_loss/train", matte_loss, epoch)

  # save loss metrics
  torch.save(
    loss_dict,
    os.path.join(
      saved_models_dir,
      f"loss_checkpoint_epoch_{epoch}.pth"
      )
      )

  print("\n\n---------------------------------------")
  print("saved losses of epoch",epoch)
  print("semantic_loss:",semantic_loss, "detail_loss:",detail_loss,"matte_loss:",matte_loss)
  print("---------------------------------------")
# --------------------------------------------------------------------
# conclude tensorboard
# --------------------------------------------------------------------
writer.flush()

#writer.close() "if you don't need it anymore" - quote of Derrick


latest model checkpoint:  /Users/derrickvanfrausum/BeCode_AI/git-repos/Remove_Image_Background/core/assets/saved_models/model_checkpoint_epoch_1.pth
/Users/derrickvanfrausum/BeCode_AI/git-repos/Remove_Image_Background/core/assets/saved_models/loss_checkpoint_epoch_1.pth


TypeError: can only concatenate str (not "int") to str

## Save & load model
Parameters for PyTorch networks stored in the model's `state_dict`

[Source](https://colab.research.google.com/github/agungsantoso/deep-learning-v2-pytorch/blob/master/intro-to-pytorch/Part%206%20-%20Saving%20and%20Loading%20Models.ipynb#scrollTo=lBqSgQCNpCX4)

In [None]:
print("The trained model: \n\n", modnet, '\n')
print("The state_dict keys: \n\n", modnet.state_dict().keys())

NameError: ignored

In [None]:
from google.colab import files

# # create dictionary with all information necessary to rebuild the model.
# checkpoint = {'input_size': 784,
#               'output_size': 10,
#               'hidden_layers': [each.out_features for each in model.hidden_layers],
#               'state_dict': modnet.state_dict()}

# # save model's architecture and state_dict
# torch.save(checkpoint, 'model_checkpoint.pth')

# save model's state_dict
torch.save(modnet.state_dict(), 'model_checkpoint.pth')

# download checkpoint file
files.download('model_checkpoint.pth')

In [None]:
# load model's state_dict
state_dict = torch.load('model_checkpoint.pth')
print(state_dict.keys())

# load state_dict into the network (works only if model architecture is the same as checkpoint architecture)
modnet.load_state_dict(state_dict)

In [None]:
# def load_checkpoint(filepath):
#   """
#   Function to load saved model and rebuild it.
#   """
#     checkpoint = torch.load(filepath)
#     model = modnet.Network(checkpoint['input_size'],
#                              checkpoint['output_size'],
#                              checkpoint['hidden_layers'])
#     model.load_state_dict(checkpoint['state_dict'])
    
#     return model
model = load_checkpoint('checkpoint.pth')
print(model)

# Predict alpha matte with trained model

In [None]:
image_path = "/content/U-2-Net/test_data/test_images/0002-01.jpg"
trimap_path = "/content/drive/MyDrive/trimaps/0002-01.png"
im = process_image(image_path)
trimap = process_image(trimap_path)

_, _, matte = modnet(im, True)

# resize and save matte
matte = F.interpolate(matte, size=(im_h, im_w), mode='area')
matte = matte[0][0].data.cpu().numpy()
matte_name = 'test_b_v2.png'

Image.fromarray(((matte * 255).astype('uint8')), mode='L').save(os.path.join("/content", matte_name))

# Performance metrics

## Functions

In [None]:
# compute the MSE error given a prediction, a ground truth and a trimap.
# pred: the predicted alpha matte
# target: the ground truth alpha matte
# trimap: the given trimap
#
def compute_mse(pred, alpha, trimap):
    num_pixels = float((trimap == 127).sum())
    return ((pred - alpha) ** 2).sum() / num_pixels


# compute the SAD error given a prediction and a ground truth.
#
def compute_sad(pred, alpha):
    diff = np.abs(pred - alpha)
    return np.sum(diff) / 1000

In [None]:
# import cv2

gt_matte = cv2.imread("/content/drive/MyDrive/Faktion/exploration/dataset/ground_truths/ILSVRC2012_test_00000018.png")
trimap = cv2.imread("/content/drive/MyDrive/Faktion/exploration/dataset/trimaps/ILSVRC2012_test_00000018.png")
pred = cv2.imread("/content/test.png")

print(compute_mse(pred, gt_matte, trimap))
print(compute_sad(pred, gt_matte))