# Installation

In [None]:
!pip install segmentation-models-pytorch
!git clone https://github.com/NVIDIA/apex
!mv apex to_delete
!mv to_delete/* .
!rm -r to_delete
!pip install -v --disable-pip-version-check --no-cache-dir ./
!mkdir testing_output
!mkdir masked_output

# Imports

In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

import torch
from torch import nn
import segmentation_models_pytorch as smp
from torch.utils.data import DataLoader
from tqdm import tqdm
from apex import amp
from albumentations import Resize, Normalize, Compose
from albumentations.pytorch import ToTensorV2
from torch.utils.data import Dataset
import cv2
import matplotlib.pyplot as plt

# Crop function

In [None]:
def crop_image_only_outside(img,tol=0):
    # img is 2D image data
    # tol  is tolerance
    mask = img>tol
    m,n = img.shape
    mask0,mask1 = mask.any(0),mask.any(1)
    col_start,col_end = mask0.argmax(),n-mask0[::-1].argmax()
    row_start,row_end = mask1.argmax(),m-mask1[::-1].argmax()
    return row_start,row_end,col_start,col_end

# Dataloader

In [None]:
class AlbuAugment():
    def __init__(self):
        transformation = [
            Resize(256, 256),
            Normalize(),
            ToTensorV2()
        ]
        self.transform = Compose(transformation)
    
    def __call__(self, image):
        transformed = self.transform(image=image)
        return transformed['image']        

class SegmentationDataset(Dataset):
    def __init__(self, img_dir, names):
        self.images_src = img_dir
        self.names = names
        self.transform = AlbuAugment()
    def __len__(self):
        return len(self.names)

    def __getitem__(self, idx):
        image = cv2.imread(os.path.join(self.images_src, self.names[idx]))
        image = self.transform(image=image)
        return image, self.names[idx]

# Image folder    
fdir = "../input/vinbigdata-chest-xray-resized-png-1024x1024/test"
loader = DataLoader(SegmentationDataset(fdir, os.listdir(fdir)), batch_size=32, pin_memory=True, shuffle=False, num_workers=4)    

# Segmentation model

In [None]:
model = smp.UnetPlusPlus('resnet50',
    encoder_weights=None,
    classes=1,    
    ).cuda()
model.load_state_dict(torch.load("/kaggle/input/lungfield-segmetation/segmentation-checkpoint.pth", "cpu"))
model.eval();

# Make crops

In [None]:
tbar = tqdm(loader)
outputs = []
with torch.no_grad():
    for image, names in tbar:
        image = image.cuda()
        # Threshold 0.5, change as you wish
        output = (torch.sigmoid(model(image).cpu()).permute(0,2,3,1).numpy() > 0.5).astype(np.uint8)
        
        for crop, name in zip(output, names):
            src = cv2.imread(f"{fdir}/{name}")
            # Imagesize 1024x1024, change as you wish
            crop = cv2.resize(crop, (1024, 1024))
            
            src_w_mask = np.array(src)
            src_w_mask[:,:,0] += (crop * 255 * 0.3).astype(np.uint8).squeeze()
            cv2.imwrite(f"./testing_output/{name}", src_w_mask)
            
            rs, re, cs, ce = crop_image_only_outside(crop)
            # Padsize 100, change as you wish
            padsize = 100
            rs = max(0, rs-padsize)
            re = min(src.shape[0], re+padsize)
            cs = max(0, cs-padsize)
            ce = min(src.shape[1], ce+padsize)
            src_wo_mask = np.array(src)
            src_wo_mask = src[rs:re, cs:ce]
            cv2.imwrite(f"./masked_output/{name}", src_wo_mask)
        
        # To run on full set, remove this
        break

In [None]:
fname = os.listdir("./testing_output")

In [None]:
idxes = np.random.choice(fname, 5)
for f in idxes:
    plt.figure(figsize=(10,10))
    plt.subplot(121)
    plt.imshow(cv2.imread(f"./testing_output/{f}"))
    plt.subplot(122)
    plt.imshow(cv2.imread(f"./masked_output/{f}"))