# Salient Object Detection (SOD) - how to remove background from whales images

Salient Object Detection (SOD) aims at segmenting the most visually attractive objects in an image. It is widely
used in many fields, such as visual tracking and image segmentation. Recently, with the development of deep convolutional neural networks (CNNs), especially the rise of Fully Convolutional Networks (FCN) in image segmentation, the salient object detection has been improved significantly. (Source: https://arxiv.org/pdf/2005.09007.pdf)

During my reearch I found SOD survey [RGB-D Salient Object Detection: A Survey](https://github.com/taozh2017/RGBD-SODsurvey). This is a survey to review related RGB-D SOD models along with benchmark datasets, and provide a comprehensive evaluation for these models. Authors also collect related review papers for SOD and light field SOD models.

<div align="center"><img src="https://github.com/taozh2017/RGBD-SODsurvey/raw/master/figures/Fig0.jpg" width=640/></div>
In this experiment I use U2-Net: U Square Net whis is described in <a href="https://arxiv.org/pdf/2005.09007.pdf">U2-Net: Going Deeper with Nested U-Structure for Salient Object Detection</a>
<p>&nbsp;</p>
U2-Net Architecture:
<div align="center"><img src="https://github.com/xuebinqin/U-2-Net/raw/master/figures/U2NETPR.png" width=480/></div>

I use this for solution with LoFTR features matching which you can see in my notebook [Whales feature matching LoFTR - Kornia](https://www.kaggle.com/remekkinas/whales-feature-matching-loftr-kornia) 
    
<div align="center"><img src="https://i.ibb.co/yQzhzb8/Lo-FTR-BR1.jpg"/></div>

<div class="alert alert-warning">Note: My goal was to implement and share tool for experimentations  - I was not creating full dataset using this. This is part of your journey. Enjoy experimenting and progressing!</div>

## 1. SETUP NOTEBOOK (MODULES)
Let's clone U-2-Net repository

In [None]:
!git clone https://github.com/shreyas-bk/U-2-Net
    
import sys
sys.path.append('./U-2-Net')

In [None]:
import cv2
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch

from data_loader import RescaleT
from data_loader import ToTensor
from data_loader import ToTensorLab
from data_loader import SalObjDataset
from model import U2NET 
from model import U2NETP 

from IPython.display import display
from PIL import Image as Img

from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from torch.autograd import Variable

## 2. TAKE CANDIDATES FOR BACKGROUND REMOVAL

For this demo I use cropped images from dataset provided by @phalanx [cropped&resized(512x512) dataset using detic](https://www.kaggle.com/c/happy-whale-and-dolphin/discussion/305503). Thank you for contributing in this competition.

In [None]:
train_df = pd.read_csv("../input/happy-whale-and-dolphin/train.csv")

input_path = "../input/whale2-cropped-dataset/cropped_train_images/cropped_train_images"

Let's take one of TOP10 individual. You can experiment with other individuals. Quality of prediction depends on photo but I will work on improving prediction (I will probably train this model on custom data).

In [None]:
train_df.individual_id.value_counts().head(10)

For experiments we use only one individual ID=281504409737... just to check solution performance and quality of mask.

In [None]:
img_to_draw = [ input_path + '/' + file for file in train_df.query("individual_id == '281504409737'").sample(25).image]

In [None]:
fig, axes = plt.subplots(5, 5, figsize=(20,20))

for idx, img in enumerate(img_to_draw):
    i = idx % 5 
    j = idx // 5
    image = Img.open(img)
    iar_shp = np.array(image).shape
    axes[i, j].axis('off')
    axes[i, j].imshow(image)
    
plt.subplots_adjust(wspace=0.05, hspace=0.05)
plt.show()

## 3. SOD - U2-Net PREDICTION

In [None]:
THRESHOLD = 0.3

In [None]:
def normPRED(d):
    ma = torch.max(d)
    mi = torch.min(d)
    dn = (d-mi)/(ma-mi)
    return dn


def pred_unet(model, imgs):
    
    test_salobj_dataset = SalObjDataset(img_name_list = imgs, lbl_name_list = [], transform = transforms.Compose([RescaleT(320),ToTensorLab(flag=0)]))
    test_salobj_dataloader = DataLoader(test_salobj_dataset, batch_size=1, shuffle=False, num_workers = 1)
    
    for i_test, data_test in enumerate(test_salobj_dataloader):
        
        inputs_test = data_test['image']
        inputs_test = inputs_test.type(torch.FloatTensor)

        if torch.cuda.is_available():
            inputs_test = Variable(inputs_test.cuda())
        else:
            inputs_test = Variable(inputs_test)

        d1, d2, d3, d4, d5, d6, d7 = net(inputs_test)

        predict = d5[:,0,:,:]
        predict = normPRED(predict)
        
        del d1, d2, d3, d4, d5, d6, d7

        predict = predict.squeeze()
        predict_np = predict.cpu().data.numpy()

        # Masked image - using threshold you can soften/sharpen mask boundaries
        predict_np[predict_np > THRESHOLD] = 1
        predict_np[predict_np <= THRESHOLD] = 0
        mask = Img.fromarray(predict_np*255).convert('RGB')
        image = Img.open(imgs[0])
        imask = mask.resize((image.width, image.height), resample=Img.BILINEAR)
        back = Img.new("RGB", (image.width, image.height), (255, 255, 255))
        mask = imask.convert('L')
        im_out = Img.composite(image, back, mask)
        
        # Sailient mask 
        salient_mask = np.array(image)
        mask_layer = np.array(imask)        
        mask_layer[mask_layer == 255] = 50 # offest on RED channel
        salient_mask[:,:,0] += mask_layer[:,:, 0]
        salient_mask = np.clip(salient_mask, 0, 255) 
    
    return np.array(im_out), np.array(image), np.array(salient_mask), np.array(mask)

In [None]:
UNET2_SMALL = False

In [None]:
%%capture

if UNET2_SMALL:
    model_dir = "./U-2-Net/u2netp.pth"  # Faster ... a lot (!) but less accurate
    net = U2NETP(3,1) 
else:
    model_dir = "../input/u-square-net-model/u2net.pth"
    net = U2NET(3,1) 


if torch.cuda.is_available():
    net.load_state_dict(torch.load(model_dir))
    net.cuda()
else:        
    net.load_state_dict(torch.load(model_dir, map_location=torch.device('cpu')))

net.eval()

## 4. U2-Net RESULT VISUALIZATION

In [None]:
fig, axes = plt.subplots(25, 1, figsize=(600,200))

for idx, img in enumerate(img_to_draw):
    image, im_oryg, sal_map, mask = pred_unet(net, [img_to_draw[idx]]) 
    result = np.concatenate((im_oryg, sal_map, image), axis=1)
    result_img = Img.fromarray(result)
    axes[idx].axis('off')
    axes[idx].imshow(result_img)

plt.subplots_adjust(wspace=0.05, hspace=0)
plt.show()

<div class="alert alert-success" role="alert">
    I really appreciate if you suport my work. <b>Voting is more then welcome. It motivates me a lot</b> for sharing part of solution/tools you can use in this competition.
    
My other work in this competition:
    <ul>
        <li> <a href="https://www.kaggle.com/remekkinas/whales-feature-matching-loftr-kornia">Whales feature matching LoFTR - Kornia</a></li>
    </ul>
    
</div>

## 5. IDEAS FOR IMPROVEMENT

### A. REMOVE WHITE AREA OUTSIDE SAILENT MAP - PROTOTYPE
Let's find bbox for Sailent Map and remove white area outside it. This is just prototype - some code refactoring is needed. 

In [None]:
fig, axes = plt.subplots(4, 3, figsize=(30,40))

for idx, img in enumerate(img_to_draw):
    
    if idx > 3:
        break
        
    image, im_oryg, sal_map, mask = pred_unet(net, [img_to_draw[idx]])
    
    ymin = np.nonzero((mask[:] != 0).argmax(axis = 1))[0][0]
    ymax = np.nonzero((mask[:] != 0).argmax(axis = 1))[0][-1]
    xmin = np.nonzero((mask[:] != 0).argmax(axis = 0))[0][0]
    xmax = np.nonzero((mask[:] != 0).argmax(axis = 0))[0][-1]
    
    img = cv2.rectangle(image.copy(), (xmin, ymin), (xmax, ymax), (255,0,0), 2)
    
    crop_img = image[ymin:ymax, xmin:xmax]
    crop_img = cv2.resize(crop_img, (image.shape[0], image.shape[1]), interpolation = cv2.INTER_AREA)
    
    result = np.concatenate((img, crop_img), axis=1)
    result_img = Img.fromarray(result)
    
    axes[idx, 0].imshow(Img.fromarray(sal_map))
    axes[idx, 1].imshow(Img.fromarray(img))
    axes[idx, 2].imshow(Img.fromarray(crop_img))

plt.subplots_adjust(wspace=0.1, hspace=0)
plt.show()

This notebook needs some improvement:
- code refactoring - I decided to share idea as fast as possible (I have done many research to find the best SOD solution)
- model training on custom data to improve object detection prediction