# Import Packages

In [None]:
!nvidia-smi

In [None]:
import os
import glob
import numpy as np
import random
import cv2
import matplotlib.pyplot as plt
import PIL
from PIL import Image
import json
import torch
import shutil

from torch.utils.data import DataLoader
from torch.utils.data import Dataset as BaseDataset
import torch.nn.functional as F

import torch
import numpy as np
import segmentation_models_pytorch as smp
from torch import nn

# Dataset

In [None]:
class TestDataset(BaseDataset):
    """CamVid Dataset. Read images, apply augmentation and preprocessing transformations.
    
    Args:
        images_dir (str): path to images folder
        masks_dir (str): path to segmentation masks folder
        class_values (list): values of classes to extract from segmentation mask
        augmentation (albumentations.Compose): data transfromation pipeline 
            (e.g. flip, scale, etc.)
        preprocessing (albumentations.Compose): data preprocessing 
            (e.g. noralization, shape manipulation, etc.)
    
    """
    
    CLASSES = ['bg', 'stas']
    
    def __init__(
            self, 
            images_dir,
            classes=None,
            preprocessing=None,
    ):
        self.ids = os.listdir(images_dir)
        self.images_fps = [os.path.join(images_dir, image_id) for image_id in self.ids]
        
        # convert str names to class values on masks
        self.class_values = [self.CLASSES.index(cls.lower()) for cls in classes]
        
        self.preprocessing = preprocessing
    
    def __getitem__(self, i):
        
        # read data
        image = cv2.imread(self.images_fps[i])
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        if image.shape != (1024, 512, 3):
            image = cv2.resize(image, (1024, 512), interpolation=cv2.INTER_LANCZOS4)
        
        # apply preprocessing
        if self.preprocessing:
            sample = self.preprocessing(image=image)
            image = sample['image']
            
        return image
        
    def __len__(self):
        return len(self.ids)

In [None]:
import albumentations as albu

def to_tensor(x, **kwargs):
    return x.transpose(2, 0, 1).astype('float32')


def get_preprocessing():
    """Construct preprocessing transform
    
    Args:
        preprocessing_fn (callbale): data normalization function 
            (can be specific for each pretrained neural network)
    Return:
        transform: albumentations.Compose
    
    """
    
    _transform = [
        albu.Lambda(image=to_tensor, mask=to_tensor),
    ]
    return albu.Compose(_transform)

## Public Image

In [None]:
CLASSES = ['stas']
DEVICE = 'cuda'

public_test = TestDataset(
    './Data/Public_Image',
    preprocessing=get_preprocessing(),
    classes=CLASSES,
)

public_test_dataloader = DataLoader(public_test)

In [None]:
public_test_vis = TestDataset(
    './Data/Public_Image',
    classes=CLASSES,
)

## Private Image

In [None]:
private_test = TestDataset(
    './Data/Image',
    preprocessing=get_preprocessing(),
    classes=CLASSES,
)

private_test_dataloader = DataLoader(private_test)

In [None]:
private_test_vis = TestDataset(
    './Data/Image',
    classes=CLASSES,
)

In [None]:
out_path='./Best/Image'

# Load weight

In [None]:
best_model_1 = torch.load('./model_weight/best_model_1.pth')
best_model_2 = torch.load('./model_weight/best_model_2.pth')
best_model_3 = torch.load('./model_weight/best_model_3.pth')
best_model_4 = torch.load('./model_weight/best_model_4.pth')
best_model_5 = torch.load('./model_weight/best_model_5.pth')

# Predict

In [None]:
for i in range(len(public_test)):
    name = os.path.basename(public_test_vis.images_fps[i])
    print(name)
    image_vis = public_test_vis[0].astype('uint8')
    image = public_test[i]
    x_tensor = torch.from_numpy(image).to(DEVICE).unsqueeze(0)
    print(x_tensor.shape)
    pr_mask_1 = best_model_1.predict(x_tensor)
    pr_mask_2 = best_model_2.predict(x_tensor)
    pr_mask_3 = best_model_3.predict(x_tensor)
    pr_mask_4 = best_model_4.predict(x_tensor)
    pr_mask_5 = best_model_5.predict(x_tensor)
    preds = (pr_mask_1 + pr_mask_2 + pr_mask_3 + pr_mask_4 + pr_mask_5) / 5.
    pr_mask = (preds[:, 0, ...] > 0.5)
    pr_mask = (pr_mask.squeeze().cpu().numpy().round())
    len(pr_mask)
    zeros = np.zeros((512, 1024))
    zeros[...] = pr_mask * 255
    zeros = cv2.resize(zeros, (1716, 942))
    print(i)

    image = cv2.resize(image[0], (1716, 942))

    plt.figure(figsize=(24,8))
    f, axarr = plt.subplots(2, figsize=(15,8))
    axarr[0].imshow(image, cmap='gray')
    axarr[1].imshow(zeros, cmap='gray')
    #plt.imshow(zeros, cmap='gray')
    plt.show()
    print(os.path.join(out_path, name.replace('.jpg','.png')))
    plt.imsave(os.path.join(out_path, name.replace('.jpg','.png')), zeros, cmap='gray')

In [None]:
for i in range(len(private_test)):
    #if i == 10:break
    name = os.path.basename(private_test_vis.images_fps[i])
    print(name)
    image_vis = private_test_vis[0].astype('uint8')
    image = private_test[i]
    x_tensor = torch.from_numpy(image).to(DEVICE).unsqueeze(0)
    print(x_tensor.shape)
    pr_mask_1 = best_model_1.predict(x_tensor)
    pr_mask_2 = best_model_2.predict(x_tensor)
    pr_mask_3 = best_model_3.predict(x_tensor)
    pr_mask_4 = best_model_4.predict(x_tensor)
    pr_mask_5 = best_model_5.predict(x_tensor)
    preds = (pr_mask_1 + pr_mask_2 + pr_mask_3 + pr_mask_4 + pr_mask_5) / 5.
    pr_mask = (preds[:, 0, ...] > 0.5)
    pr_mask = (pr_mask.squeeze().cpu().numpy().round())
    len(pr_mask)
    zeros = np.zeros((512, 1024))
    zeros[...] = pr_mask * 255
    zeros = cv2.resize(zeros, (1716, 942))
    print(i)

    image = cv2.resize(image[0], (1716, 942))

    plt.figure(figsize=(24,8))
    f, axarr = plt.subplots(2, figsize=(15,8))
    axarr[0].imshow(image, cmap='gray')
    axarr[1].imshow(zeros, cmap='gray')
    #plt.imshow(zeros, cmap='gray')
    plt.show()
    print(os.path.join(out_path, name.replace('.jpg','.png')))
    plt.imsave(os.path.join(out_path, name.replace('.jpg','.png')), zeros, cmap='gray')
  

# Post process

In [None]:
import cv2
import os
import numpy as np

imaPath = './Best/Image'
output = './Best/Image_Postprocess'
imaList = os.listdir(imaPath)

for files in imaList:
    path_ima = os.path.join(imaPath, files)
    path_processed = os.path.join(output, files)
    img = cv2.imread (path_ima, 0)
    img = cv2.blur(img, (7, 7))
    mask = np.zeros_like(img)
    #print(np.shape (img))

    ret, img = cv2.threshold(img, 127, 255, cv2.THRESH_BINARY)
    _, contours,_= cv2.findContours(img, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
    n = len(contours)
    cv_contours = []
    for contour in contours:
        area = cv2.contourArea(contour)
        if area <= 30000:
            cv_contours.append(contour)
        else:
            continue
    cv2.fillPoly(img, cv_contours, (255, 255, 255))
    cv2.imwrite(path_processed, img)