# <div style="color:white;display:fill;border-radius:5px;background-color:#75B7BF;letter-spacing:0.1px;overflow:hidden"><p style="padding:20px;color:white;overflow:hidden;margin:0;font-size:100%;text-align:center">SETUP</p></div>

In [None]:
!git clone https://github.com/shreyas-bk/U-2-Net
    
import sys
sys.path.append('./U-2-Net')

# <div style="color:white;display:fill;border-radius:5px;background-color:#75B7BF;letter-spacing:0.1px;overflow:hidden"><p style="padding:20px;color:white;overflow:hidden;margin:0;font-size:100%;text-align:center">IMPORT</p></div>

In [None]:
import os
import cv2
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import json
from datetime import datetime

from data_loader import RescaleT
from data_loader import ToTensor
from data_loader import ToTensorLab
from data_loader import SalObjDataset
from model import U2NET 
from model import U2NETP 

from IPython.display import display
from PIL import Image as Img

from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from torch.autograd import Variable

# <div style="color:white;display:fill;border-radius:5px;background-color:#75B7BF;letter-spacing:0.1px;overflow:hidden"><p style="padding:20px;color:white;overflow:hidden;margin:0;font-size:100%;text-align:center">CONFIGURATION</p></div>

In [None]:
class CFG:
    # U2-Net
    THRESHOLD=0.5
    UNET2_SMALL = False

# <div style="color:white;display:fill;border-radius:5px;background-color:#75B7BF;letter-spacing:0.1px;overflow:hidden"><p style="padding:20px;color:white;overflow:hidden;margin:0;font-size:100%;text-align:center">HELPING FUNCTIONS</p></div>

In [None]:
# SOD- U2 Net prediction
################################
def normPRED(d):
    ma = torch.max(d)
    mi = torch.min(d)
    dn = (d-mi)/(ma-mi)
    return dn


def pred_unet(model, imgs):
    
    test_salobj_dataset = SalObjDataset(img_name_list = imgs, lbl_name_list = [], transform = transforms.Compose([RescaleT(320),ToTensorLab(flag=0)]))
    test_salobj_dataloader = DataLoader(test_salobj_dataset, batch_size=1, shuffle=False, num_workers = 1)
    
    for i_test, data_test in enumerate(test_salobj_dataloader):
        
        inputs_test = data_test['image']
        inputs_test = inputs_test.type(torch.FloatTensor)

        if torch.cuda.is_available():
            inputs_test = Variable(inputs_test.cuda())
        else:
            inputs_test = Variable(inputs_test)

        d1, d2, d3, d4, d5, d6, d7 = net(inputs_test)

        predict = d5[:,0,:,:]
        predict = normPRED(predict)
        
        del d1, d2, d3, d4, d5, d6, d7

        predict = predict.squeeze()
        predict_np = predict.cpu().data.numpy()

        # Masked image - using threshold you can soften/sharpen mask boundaries
        predict_np[predict_np > CFG.THRESHOLD] = 1
        predict_np[predict_np <= CFG.THRESHOLD] = 0
        mask = Img.fromarray(predict_np*255).convert('RGB')
        image = Img.open(imgs[0])
        imask = mask.resize((image.width, image.height), resample=Img.BILINEAR)
        back = Img.new("RGB", (image.width, image.height), (255, 255, 255))
        mask = imask.convert('L')
        im_out = Img.composite(image, back, mask)
        
        # Sailient mask 
        salient_mask = np.array(image)
        mask_layer = np.array(imask)        
        mask_layer[mask_layer == 255] = 50 # offest on RED channel
        salient_mask[:,:,0] += mask_layer[:,:, 0]
        salient_mask = np.clip(salient_mask, 0, 255) 
    
    return np.array(im_out), np.array(image), np.array(salient_mask), np.array(mask)

# <div style="color:white;display:fill;border-radius:5px;background-color:#75B7BF;letter-spacing:0.1px;overflow:hidden"><p style="padding:20px;color:white;overflow:hidden;margin:0;font-size:100%;text-align:center">LOADING DATA</p></div>

In [None]:
train_df = pd.read_csv("../input/whale2-cropped-dataset/train2.csv")
test_df = pd.read_csv("../input/whale2-cropped-dataset/test2.csv")

train_path = "../input/whale2-cropped-dataset/cropped_train_images/cropped_train_images"
test_path = "../input/whale2-cropped-dataset/cropped_test_images/cropped_test_images"

In [None]:
train_directory = [train_path + '/' + file for file in train_df.image]
test_directory = [test_path + '/' + file for file in test_df.image]

In [None]:
# Check NaN value in Test Files
for filename in test_directory:
    value = plt.imread(filename)
    if np.isnan(value).any() == True: print(filename)

In [None]:
for filename in test_directory:
    value = plt.imread(filename)
    if np.isfinite(value).any() == False: print(filename)

In [None]:
%%time

### Create Kaggle Dataset if not exists 

DATASET_NAME = f'happywhale-cropped-removeBackground-v1'
TRAINING_NAME = f'removedBackground_train_images'
TESTING_NAME = f'removedBackground_test_image'

!rm -r /tmp/{DATASET_NAME} # remove folder

os.makedirs(f'/tmp/{DATASET_NAME}', exist_ok=True)
os.makedirs(f'/tmp/{DATASET_NAME}/{TRAINING_NAME}', exist_ok=True)
os.makedirs(f'/tmp/{DATASET_NAME}/{TESTING_NAME}', exist_ok=True)

with open('../input/kaggle-json-file/kaggle.json') as f:
    kaggle_creds = json.load(f)
    
os.environ['KAGGLE_USERNAME'] = kaggle_creds['username']
os.environ['KAGGLE_KEY'] = kaggle_creds['key']

!kaggle datasets init -p /tmp/{DATASET_NAME}

with open(f'/tmp/{DATASET_NAME}/dataset-metadata.json') as f:
    dataset_meta = json.load(f)
dataset_meta['id'] = f'phanttan/{DATASET_NAME}'
dataset_meta['title'] = DATASET_NAME
with open(f'/tmp/{DATASET_NAME}/dataset-metadata.json', "w") as outfile:
    json.dump(dataset_meta, outfile)
print(dataset_meta)

!cp /tmp/{DATASET_NAME}/dataset-metadata.json /tmp/{DATASET_NAME}/meta.json
!ls /tmp/{DATASET_NAME}

!kaggle datasets create -u -p /tmp/{DATASET_NAME} 

# <div style="color:white;display:fill;border-radius:5px;background-color:#75B7BF;letter-spacing:0.1px;overflow:hidden"><p style="padding:20px;color:white;overflow:hidden;margin:0;font-size:100%;text-align:center">SOD-U2-Net Prediction</p></div>

In [None]:
%%capture

if CFG.UNET2_SMALL:
    model_dir = "./U-2-Net/u2netp.pth"  # Faster ... a lot (!) but less accurate
    net = U2NETP(3,1) 
else:
    model_dir = "../input/u-square-net-model/u2net.pth"
    net = U2NET(3,1) 


if torch.cuda.is_available():
    net.load_state_dict(torch.load(model_dir))
    net.cuda()
else:        
    net.load_state_dict(torch.load(model_dir, map_location=torch.device('cpu')))

net.eval()

In [None]:
for idx, img in enumerate(train_directory):    
    image, im_orig, sal_map, mask = pred_unet(net, [train_directory[idx]])
#     print(idx)
    try: 
        ymin = np.nonzero((mask[:] != 0).argmax(axis = 1))[0][0]
        ymax = np.nonzero((mask[:] != 0).argmax(axis = 1))[0][-1]
        xmin = np.nonzero((mask[:] != 0).argmax(axis = 0))[0][0]
        xmax = np.nonzero((mask[:] != 0).argmax(axis = 0))[0][-1]
        if (ymin != ymax) & (xmin != xmax):
            crop_img = image[ymin:ymax, xmin:xmax]
            crop_img = cv2.resize(crop_img, (image.shape[0], image.shape[1]), interpolation = cv2.INTER_AREA)
        else:
            crop_img = im_orig
    except IndexError:
        crop_img = im_orig
    im = Img.fromarray(crop_img)
    im.save(f'/tmp/{DATASET_NAME}/{TRAINING_NAME}/{train_df.image[idx]}')


In [None]:
for idx, img in enumerate(test_directory):    
    image, im_orig, sal_map, mask = pred_unet(net, [test_directory[idx]])
    try: 
        ymin = np.nonzero((mask[:] != 0).argmax(axis = 1))[0][0]
        ymax = np.nonzero((mask[:] != 0).argmax(axis = 1))[0][-1]
        xmin = np.nonzero((mask[:] != 0).argmax(axis = 0))[0][0]
        xmax = np.nonzero((mask[:] != 0).argmax(axis = 0))[0][-1]
        if (ymin != ymax) & (xmin != xmax):
            crop_img = image[ymin:ymax, xmin:xmax]
            crop_img = cv2.resize(crop_img, (image.shape[0], image.shape[1]), interpolation = cv2.INTER_AREA)
        else:
            crop_img = im_orig
    except IndexError:
        crop_img = im_orig
    im = Img.fromarray(crop_img)
    if (np.isnan(im).any() == False) & (np.isfinite(im).any()==True): continue
#         im.save(f'/tmp/{DATASET_NAME}/{TESTING_NAME}/{test_df.image[idx]}')
    else: print(idx)

In [None]:
version_name = datetime.now().strftime("%Y%m%d-%H%M%S")
!kaggle datasets version -m {version_name} -p /tmp/{DATASET_NAME} -r zip -q