Thanks to the great work finished by **REMEK KINAS** !

https://www.kaggle.com/code/remekkinas/remove-background-salient-object-detection/notebook

This notbook is ***based on*** the notebook above. Change a little parameters and **refactor** to get  **effecient** detection result.

**Changes**

+ **CLAHE** : use CLAHE in opencv to make image brighter and a higer constrast

+ **THRESHOLD** : change from 0.3 to 0.9

+ **BoundingRec** : import from cv2, which used to find the bounding of whale(after background remove) and resize all the image to the same size.

+ **Other changes** : fix latent exception

+ **Something New** : there are about 5% images which has a bad result after u-net background remove. I use the a area threshold to sift the valid images. Those invalid remain the original images. 


## 1. SETUP NOTEBOOK (MODULES)
Let's clone U-2-Net repository

In [None]:
!git clone https://github.com/shreyas-bk/U-2-Net
    
import sys
sys.path.append('./U-2-Net')

In [None]:
import cv2 as cv
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch

from data_loader import RescaleT
from data_loader import ToTensor
from data_loader import ToTensorLab
from data_loader import SalObjDataset
from model import U2NET 
from model import U2NETP 

from IPython.display import display
from PIL import Image as Img

from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from torch.autograd import Variable

## 2. TAKE CANDIDATES FOR BACKGROUND REMOVAL

For this demo I use cropped images from dataset provided by @phalanx [cropped&resized(512x512) dataset using detic](https://www.kaggle.com/c/happy-whale-and-dolphin/discussion/305503). Thank you for contributing in this competition.

In [None]:
train_df = pd.read_csv("../input/whale2-cropped-dataset/train2.csv")

input_path = "../input/whale2-cropped-dataset/cropped_train_images/cropped_train_images"

root_in = '../input/whale2-cropped-dataset'
root_out = './'# root_out give the folder direciton of after-processed images, csv..

In [None]:
import os 
os.makedirs('./subtraction_train')
os.makedirs('./subtraction_test')

Let's take one of TOP10 individual. You can experiment with other individuals. Quality of prediction depends on photo but I will work on improving prediction (I will probably train this model on custom data).

In [None]:
train_df.individual_id.value_counts().head(10)

For experiments we use only one individual ID=281504409737... just to check solution performance and quality of mask.

In [None]:
paths = [file for file in train_df.image]
img_to_draw = [input_path + '/' + file for file in train_df.image]

In [None]:
fig, axes = plt.subplots(5, 5, figsize=(20,20))

for idx, img in enumerate(img_to_draw[0:25]):
    i = idx % 5 
    j = idx // 5
    image = Img.open(img)
    iar_shp = np.array(image).shape
    axes[i, j].axis('off')
    axes[i, j].imshow(image)
    
plt.subplots_adjust(wspace=0.05, hspace=0.05)
plt.show()

## 3. Use CLAHE

**Here we use GPU to process the images**

In [None]:
THRESHOLD = 0.9
BATCH_SIZE = 64
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
def CLAHE_Convert(origin_input):
    imidx = origin_input['imidx']
    label = origin_input['label']
    clahe = cv.createCLAHE(clipLimit=3, tileGridSize=(24,32))
    img = np.asarray(origin_input['image'])
    img = cv.cvtColor(img, cv.COLOR_BGR2HSV)
    img[:,:,-1] = clahe.apply(img[:,:,-1])
    img = cv.cvtColor(img, cv.COLOR_HSV2BGR)
    
    return {'imidx':imidx, 'image':img,'label':label}

def BilateralFilter_Convert(origin_input):
    imidx = origin_input['imidx']
    label = origin_input['label']
    img = np.asarray(origin_input['image'])
    img = cv.bilateralFilter(img,5,75,75)
    
    return {'imidx':imidx, 'image':img,'label':label}

def EqualizeHist_Convert(origin_input):
    imidx = origin_input['imidx']
    label = origin_input['label']
    img = np.asarray(origin_input['image'])
    img = cv.cvtColor(img, cv.COLOR_BGR2HSV)
    img[:,:,-1] = cv.equalizeHist(img[:,:,-1])
    img = cv.cvtColor(img, cv.COLOR_HSV2BGR)
    
    return {'imidx':imidx, 'image':img,'label':label}

In [None]:
def normPRED(d):
    ma = torch.max(d)
    mi = torch.min(d)
    dn = (d-mi)/(ma-mi)
    return dn

In [None]:
def pred_unet(model, imgs):
    '''The old version of pred_unet!
        Used just to show the difference'''
    test_salobj_dataset = SalObjDataset(img_name_list = imgs, lbl_name_list = [], transform = transforms.Compose([RescaleT(320),ToTensorLab(flag=0)]))
    test_salobj_dataloader = DataLoader(test_salobj_dataset, batch_size=1, shuffle=False, num_workers = 0)
    
    for i_test, data_test in enumerate(test_salobj_dataloader):
        
        inputs_test = data_test['image']
        inputs_test = inputs_test.type(torch.FloatTensor)

        if torch.cuda.is_available():
            inputs_test = Variable(inputs_test.cuda())
        else:
            inputs_test = Variable(inputs_test)

        d1, d2, d3, d4, d5, d6, d7 = net(inputs_test)

        predict = d5[:,0,:,:]
        predict = normPRED(predict)
        
        del d1, d2, d3, d4, d5, d6, d7

        predict = predict.squeeze()
        predict_np = predict.cpu().data.numpy()

        # Masked image - using threshold you can soften/sharpen mask boundaries
        predict_np[predict_np > THRESHOLD] = 1
        predict_np[predict_np <= THRESHOLD] = 0
        mask = Img.fromarray(predict_np*255).convert('RGB')
        image = Img.open(imgs[0])
        imask = mask.resize((image.width, image.height), resample=Img.BILINEAR)
        back = Img.new("RGB", (image.width, image.height), (255, 255, 255))
        mask = imask.convert('L')
        im_out = Img.composite(image, back, mask)
        
        # Sailient mask 
        salient_mask = np.array(image)
        mask_layer = np.array(imask)        
        mask_layer[mask_layer == 255] = 50 # offest on RED channel
        salient_mask[:,:,0] += mask_layer[:,:, 0]
        salient_mask = np.clip(salient_mask, 0, 255) 
    
    return np.array(im_out), np.array(image), np.array(salient_mask), np.array(mask)

In [None]:
def pred_unet_CLAHE_show(model, imgs):
    '''Used just to show the difference,
        Do not use it to write image'''
    test_salobj_dataset = SalObjDataset(img_name_list = imgs, lbl_name_list = [], transform = transforms.Compose([CLAHE_Convert, RescaleT(320),ToTensorLab(flag=0)]))
    test_salobj_dataloader = DataLoader(test_salobj_dataset, batch_size=1, shuffle=False, num_workers = 0)
    
    for i_test, data_test in enumerate(test_salobj_dataloader):
        
        inputs_test = data_test['image']
        inputs_test = inputs_test.type(torch.FloatTensor)

        if torch.cuda.is_available():
            inputs_test = Variable(inputs_test.cuda())
        else:
            inputs_test = Variable(inputs_test)

        d1, d2, d3, d4, d5, d6, d7 = net(inputs_test)

        predict = d5[:,0,:,:]
        predict = normPRED(predict)
        
        del d1, d2, d3, d4, d5, d6, d7

        predict = predict.squeeze()
        predict_np = predict.cpu().data.numpy()

        # Masked image - using threshold you can soften/sharpen mask boundaries
        predict_np[predict_np > THRESHOLD] = 1
        predict_np[predict_np <= THRESHOLD] = 0
        mask = Img.fromarray(predict_np*255).convert('RGB')
        image = Img.open(imgs[0])
        imask = mask.resize((image.width, image.height), resample=Img.BILINEAR)
        back = Img.new("RGB", (image.width, image.height), (255, 255, 255))
        mask = imask.convert('L')
        im_out = Img.composite(image, back, mask)
        
        # Sailient mask 
        salient_mask = np.array(image)
        mask_layer = np.array(imask)        
        mask_layer[mask_layer == 255] = 50 # offest on RED channel
        salient_mask[:,:,0] += mask_layer[:,:, 0]
        salient_mask = np.clip(salient_mask, 0, 255) 
    
    return np.array(im_out), np.array(image), np.array(salient_mask), np.array(mask)

In [None]:
VALID_THERESHOLD = 0.2 # use the a area threshold to sift the valid images

In [None]:
def pred_unet_CLAHE(model, imgs, paths, train=False):
    salobj_dataset = SalObjDataset(img_name_list = imgs, lbl_name_list = [], transform = transforms.Compose([CLAHE_Convert, RescaleT(320), ToTensorLab(flag=0)]))
    salobj_dataloader = DataLoader(salobj_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers = 1)
    
    model.to(DEVICE)
    
    error_log = []
    for i, batch in enumerate(salobj_dataloader):
        model.eval()
        img_batch = batch['image'] # (batch size, channel, row, columns)
        images = img_batch.type(torch.float32) # change torch.double -> torch.float
        images = images.to(DEVICE) 
        with torch.no_grad():
            d1, d2, d3, d4, d5, d6, d7 = model(images)
        
        predict = d5[:,0,:,:]# (batch size, 1, row, columns)
        predict = normPRED(predict)

        del d1, d2, d3, d4, d5, d6, d7

        predict = predict.squeeze()# (batch size, row, columns)
        predict_np = predict.cpu().numpy()

        # Masked image - using threshold you can soften/sharpen mask boundaries
        predict_np[predict_np > THRESHOLD] = 1
        predict_np[predict_np <= THRESHOLD] = 0

        for j in range(len(predict_np)):
            file = paths[i * BATCH_SIZE + j]
            mask_np = predict_np[j]
            mask = Img.fromarray(mask_np*255).convert('RGB')
            image = Img.open(imgs[i * BATCH_SIZE + j])
            mask = mask.resize((image.width, image.height), resample=Img.BILINEAR)
            back = Img.new("RGB", (image.width, image.height), (255, 255, 255)) # WHITE Backgroud
            mask = mask.convert('L')
            im_out = Img.composite(image, back, mask)
            
            mask_rs = np.array(mask)
            
            x,y,w,h = cv.boundingRect (mask_rs)
            ymin = y
            ymax = y + h
            xmin = x
            xmax = x + w
            
            #used to sift invalid image by the "1"(white) area / rectangle area(white and black)
            if (x,y,w,h) == (0,0,0,0) or (((mask_rs != 0).sum()) / (w * h)) < a:
                crop_img = np.array(image)
                error_log.append(file)
                print('Failed:\t', file)
            else:
                im_out_np = np.array(im_out)
                crop_img = im_out_np[ymin:ymax, xmin:xmax]
                crop_img = cv.resize(crop_img, (im_out_np.shape[0], im_out_np.shape[1]), interpolation = cv.INTER_AREA)
            if train:
                cv.imwrite(os.path.join(root_out, 'subtraction_train', file), crop_img)
            else:
                cv.imwrite(os.path.join(root_out, 'subtraction_test', file), crop_img)
    if train:
        pd.DataFrame({'image':error_log}).to_csv(os.path.join(root_out, 'Ignorance_Train_Img.csv'))
    else:
        pd.DataFrame({'image':error_log}).to_csv(os.path.join(root_out, 'Ignorance_Test_Img.csv'))

In [None]:
UNET2_SMALL = False

In [None]:
%%capture

if UNET2_SMALL:
    model_dir = "./U-2-Net/u2netp.pth"  # Faster ... a lot (!) but less accurate
    net = U2NETP(3,1) 
else:
    model_dir = "../input/u-square-net-model/u2net.pth"
    net = U2NET(3,1) 


if torch.cuda.is_available():
    net.load_state_dict(torch.load(model_dir))
    net.cuda()
else:        
    net.load_state_dict(torch.load(model_dir, map_location=torch.device('cpu')))

net.eval()

## 4. U2-Net RESULT VISUALIZATION

In [None]:
fig, axes = plt.subplots(2, 1, figsize=(30,20))
i_m_t = img_to_draw[0:1]
for idx, img in enumerate(i_m_t):
    #CLAHE
    image, im_oryg, sal_map, mask = pred_unet(net, [i_m_t[idx]]) 
    result = np.concatenate((im_oryg, sal_map, image), axis=1)
    result_img = Img.fromarray(result)
    axes[idx].axis('off')
    axes[idx].set_title('Original', fontsize=32)
    axes[idx].imshow(result_img)
    
    
    image, im_oryg, sal_map, mask = pred_unet_CLAHE_show(net, [i_m_t[idx]]) 
    result = np.concatenate((im_oryg, sal_map, image), axis=1)
    result_img = Img.fromarray(result)
    axes[idx+1].axis('off')
    axes[idx+1].set_title('After CLAHE', fontsize=32)
    axes[idx+1].imshow(result_img)
    


plt.subplots_adjust(wspace=0.05, hspace=0.1)
plt.show()

## 5. TRANSFORM & WRITE IMAGES

Let's use parameter ***train*** to indicate that we are select and process the images from train or test images

In [None]:
# Transform and write Train images
train_df = pd.read_csv("../input/whale2-cropped-dataset/train2.csv")

input_path = "../input/whale2-cropped-dataset/cropped_train_images/cropped_train_images"
paths = [file for file in train_df.image]
img_to_draw = [input_path + '/' + file for file in train_df.image]
pred_unet_CLAHE(net, img_to_draw, paths, train=True)

In [None]:
# Transform and write Test images
test_df = pd.read_csv(os.path.join(root_in, 'test2.csv'))

input_path = os.path.join(root_in, "cropped_test_images/cropped_test_images")
paths = [file for file in test_df.image]
img_to_draw = [input_path + '/' + file for file in test_df.image]
pred_unet_CLAHE(net, img_to_draw, paths, train=False)
