# CPTAC-PDA  - deepflash2 Cross Validation

> Cross validation for models trained with efficient region based sampling. 

***

## Overview

1. Installation and package loading
2. Functions and classes for prediction
3. Configuration
4. Prediction
5. Submission

#### Related Kernels

- Train Notebook: 
- Sampling Notebook: 

### Installation and package loading

In [None]:
# Install deepflash2 and dependencies
import sys
sys.path.append("../input/segmentation-models-pytorch-install")
!pip install -q ../input/deepflash2-lfs
import cv2, torch, gc, zarr
import torch.nn.functional as F
import deepflash2.tta as tta
import matplotlib.pyplot as plt
import pandas as pd, numpy as np
import segmentation_models_pytorch as smp
from pathlib import Path
from torch.utils.data import Dataset, DataLoader
from scipy .ndimage.filters import gaussian_filter
from tqdm.notebook import tqdm
from sklearn.model_selection import KFold
from fastcore.all import *
from skimage.measure import regionprops_table
from deepflash2.utils import iou
from skimage.segmentation import clear_border
from scipy.ndimage.morphology import binary_fill_holes
import warnings
warnings.filterwarnings("ignore")

### Functions and classes for prediction

In [None]:
def load_model_weights(model, file, strict=True):
    state = torch.load(file, map_location='cpu')
    stats = state['stats']
    model_state = state['model']
    model.load_state_dict(model_state, strict=strict)
    return model, stats

# from https://github.com/MIC-DKFZ/nnUNet/blob/2fade8f32607220f8598544f0d5b5e5fa73768e5/nnunet/network_architecture/neural_network.py#L250
def _get_gaussian(patch_size, sigma_scale=1. / 8) -> np.ndarray:
    tmp = np.zeros(patch_size)
    center_coords = [i // 2 for i in patch_size]
    sigmas = [i * sigma_scale for i in patch_size]
    tmp[tuple(center_coords)] = 1
    gaussian_importance_map = gaussian_filter(tmp, sigmas, 0, mode='constant', cval=0)
    gaussian_importance_map = gaussian_importance_map / np.max(gaussian_importance_map) * 1
    gaussian_importance_map = gaussian_importance_map.astype(np.float32)

    # gaussian_importance_map cannot be 0, otherwise we may end up with nans!
    gaussian_importance_map[gaussian_importance_map == 0] = np.min(
        gaussian_importance_map[gaussian_importance_map != 0])

    return gaussian_importance_map

def postproc(pred):
    pred = (pred*255).astype('uint8')
    _ ,msk = cv2.threshold(pred,0,1,cv2.THRESH_BINARY+cv2.THRESH_OTSU)  
    msk = clear_border(msk)
    msk = binary_fill_holes(msk)
    return msk

def expand_slice(slices, n_pix=100):
    slc_list = []
    for s in slices:
        slc_list.append(slice(max(s.start-n_pix,0), s.stop+n_pix))
    return tuple(slc_list)

In [None]:
# Some code adapted from https://www.kaggle.com/iafoss/hubmap-pytorch-fast-ai-starter-sub
class HubmapDataset(Dataset):
    'HubmapDataset class that does not load the full tiff files.'
    def __init__(self, file, stats, scale=3, shift=.8, output_shape=(512,512), s_th = 40):
        
        self.mean, self.std = stats
        self.scale = scale
        self.shift = shift
        self.output_shape = output_shape
        self.input_shape = tuple(int(t*scale) for t in self.output_shape)      
        self.s_th = s_th #saturation blancking threshold
        self.p_th = 1000*(self.output_shape[0]//256)**2 #threshold for the minimum number of pixels
        self.data = zarr.open(file.as_posix())
            
        # Tiling
        self.slices = []
        self.out_slices = []
        self.out_data_shape = tuple(int(x//self.scale) for x in self.data.shape[:-1])
        start_points = [o//2 for o in self.output_shape]
        end_points = [(s - st) for s, st in zip(self.out_data_shape, start_points)]
        n_points = [int(s//(o*self.shift))+1 for s, o in zip(self.out_data_shape, self.output_shape)]
        center_points = [np.linspace(st, e, num=n, endpoint=True, dtype=np.int64) for st, e, n in zip(start_points, end_points, n_points)]
        for cx in center_points[1]:
            for cy in center_points[0]:
                # Calculate output slices for whole image
                slices = tuple(slice(int((c*self.scale - o/2).clip(0, s)), int((c*self.scale + o/2).clip(max=s)))
                                 for (c, o, s) in zip((cy, cx), self.input_shape, self.data.shape))
                self.slices.append(slices)
                
                out_slices = tuple(slice(int((c - o/2).clip(0, s)), int((c + o/2).clip(max=s)))
                                 for (c, o, s) in zip((cy, cx), self.output_shape, self.out_data_shape))
                self.out_slices.append(out_slices)
                

    def __len__(self):
        return len(self.slices)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
            
        slices = self.slices[idx]
        img = self.data[slices]

        if self.scale!=1:
            img = cv2.resize(img, self.output_shape, interpolation = cv2.INTER_AREA)
        
        #check for empty imges
        hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)
        h,s,v = cv2.split(hsv)
        if (s>self.s_th).sum() <= self.p_th or img.sum() <= self.p_th:
            # Remove if idx=-1
            idx = -1
        
        img = (img/255.0 - self.mean)/self.std
        img = img.transpose(2, 0, 1).astype('float32')
        
        return torch.from_numpy(img), idx
    
class Model_pred:
    'Class for prediction with multiple models'
    def __init__(self, models, use_tta=True, batch_size=32, energy_T=1):
        self.model = model
        self.bs = batch_size
        self.energy_T = energy_T
        self.tfms = [tta.HorizontalFlip(), tta.VerticalFlip()] if use_tta else []
        
    def predict(self, ds):
        dl = DataLoader(ds, self.bs, num_workers=0, shuffle=False, pin_memory=True)
        
        # Create zero arrays
        softmax = np.zeros(ds.out_data_shape, dtype='float32')
        energy = np.zeros(ds.out_data_shape, dtype='float32')
        st_deviation = np.zeros(ds.out_data_shape, dtype='float32')
        merge_map = np.zeros(ds.out_data_shape, dtype='float32')
        
        # Gaussian weights
        gw_numpy = _get_gaussian(ds.output_shape)
        gw = torch.from_numpy(gw_numpy).to(device)
        gw = gw.view(1,*gw.shape)
        
        with torch.no_grad():
            for images, idxs in tqdm(iter(dl), total=len(dl)):
                if ((idxs>=0).sum() > 0): #exclude empty images
                    images = images[idxs>=0].to(device)
                    idxs = idxs[idxs>=0]
                    m_softmax, m_energy = tta.Merger(), tta.Merger()
                    for t in tta.Compose(self.tfms):
                        aug_images = t.augment_image(images)
                        out = self.model(aug_images)
                        out = t.deaugment_mask(out)
                        m_softmax.append(F.softmax(out, dim=1))
                        # https://proceedings.neurips.cc/paper/2020/file/f5496252609c43eb8a3d147ab9b9c006-Paper.pdf
                        e = (self.energy_T*torch.logsumexp(out/self.energy_T, dim=1)) #negative energy score
                        m_energy.append(e)
            
                    # Merge preds and apply gaussian weigthing 
                    merge_ll = []
                    batch_smx = m_softmax.result()*gw.view(1,*gw.shape)
                    merge_ll.append([x for x in batch_smx.permute(0,2,3,1).cpu().numpy()])
                    
                    batch_std = torch.mean(m_softmax.result('std'), 1)*gw
                    merge_ll.append([x for x in batch_std.cpu().numpy()])
                    
                    batch_energy =  m_energy.result()*gw
                    merge_ll.append([x for x in batch_energy.cpu().numpy()])

                    for idx, smx, std, eng  in zip(idxs, *merge_ll):
                        slcs = ds.out_slices[idx]
                        # Only using positive class here
                        softmax[slcs] += smx[...,1]
                        energy[slcs] += eng
                        st_deviation[slcs] += std
                        merge_map[slcs] += gw_numpy

        softmax /= merge_map
        st_deviation /= merge_map
        energy /= merge_map
        return softmax, st_deviation, energy

### Configuration

In [None]:
class CONFIG():
    
    # data paths
    path = Path('../input/cptacpda')
    data_path = Path('../input/cptac-pda-to-zarr/images_scale2')
    annotations_path = Path('../input/cptac-pda-pancreas-efficient-sampling-deepflash2/masks_scale2')
    model_path = Path('../input/cptac-pda-train-deepflash2')
    
    # scale factor, data is already downscaled to 2, so absulute downscale with 1.5 is 3
    scale = 1.5 
    # tile shift for prediction
    shift = 0.8 
    tile_shape = (512, 512)

    # pytorch model (https://github.com/qubvel/segmentation_models.pytorch)
    encoder_name = "efficientnet-b2"
    encoder_weights = None
    in_channels = 3
    classes = 2
    
    # dataloader 
    batch_size = 32
    
    # test time augmentation
    tta = True
    # prediction threshold
    threshold = 0.5
    
    # Postprocessing
    min_pixel = 1000 # needs to be further evaluated 
    area_threshold = 10000
    
cfg = CONFIG()

In [None]:
# Train annotations for ids
df_train = pd.read_csv(cfg.path/'dataset_information.csv')
files = L([cfg.data_path/x for x in df_train.image_file])

g_mask = zarr.open((cfg.annotations_path/'labels').as_posix())
root = zarr.group('cv_preds', overwrite=True)
g_pred, g_smx, g_std, g_eng = root.create_groups('pred', 'smx', 'std', 'eng', overwrite=True)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Models (see https://github.com/qubvel/segmentation_models.pytorch)
MODELS = sorted([f for f in cfg.model_path.iterdir() if f.suffix=='.pth'])
print(f'Found {len(MODELS)} models', *MODELS)

models = []
for i, m_path in enumerate(MODELS):
    #state_dict = torch.load(path,map_location=torch.device('cpu'))
    model = smp.Unet(encoder_name=cfg.encoder_name, 
                     encoder_weights=cfg.encoder_weights, 
                     in_channels=cfg.in_channels, 
                     classes=cfg.classes)
    model, stats = load_model_weights(model, m_path)
    model.float()
    model.eval()
    model.to(device)
    models.append(model)
assert len(models)==len(MODELS)

### Cross validation

- Calculation metrics for predictions from k-fold cross validation
- Plotting some expamples

In [None]:
def plot_selection(row):
    fig, ax = plt.subplots(ncols=7, figsize=(30,5))
    slc = expand_slice(row.slice, 20)
    
    # Image
    ax[0].imshow(ds.data[slc])
    ax[0].set_title(row.idx)
    
    # Annotation
    ax[1].imshow(g_mask[row.idx][slc])
    ax[1].set_title('Annotation')
    
    # Softmax
    ax[2].imshow(g_smx[row.idx][slc].astype(np.single))
    ax[2].set_title('Softmax')
    
    # Threshold
    ax[3].imshow(g_smx[row.idx][slc]>0.5)
    ax[3].set_title(f'Threshold at 0.5')
    
    # Postproc
    ax[4].imshow(g_pred[row.idx][slc])
    ax[4].set_title(f'Postprocessing \n IoU {row.iou:.2f}')
    
    # Energy
    ax[5].imshow(g_eng[row.idx][slc].astype(np.single))
    ax[5].set_title(f'Energy \n Mean {row.mean_intensity:.2f}')
    
    # Energy
    ax[6].imshow(g_std[row.idx][slc].astype(np.single))
    ax[6].set_title(f'TTA Standard Deviation \n Mean {row.mean_intensity_std:.2f}')
    
    plt.show()

In [None]:
# Use validation split from training
kf = KFold(len(MODELS), shuffle=True, random_state=42)
metric_list = []
recall_list = []
precision_list = []
PLOT_OVERVIEW = True

for i, (train_idx, val_idx) in enumerate(kf.split(files)):
    print('Split', i)
    files_val = files[val_idx]
    mp = Model_pred([models[i]], use_tta=cfg.tta, batch_size=cfg.batch_size)
    print('Using model', MODELS[i])
    
    for f in files_val:
        print('Validating', f.name)
        
        ds = HubmapDataset(f, stats, scale=cfg.scale, shift=cfg.shift, output_shape=cfg.tile_shape)
    
        # print('Predicting...')   
        softmax, st_deviation, energy = mp.predict(ds)
        
        #print('Rezising...')
        msk = g_mask[f.name][:].astype('uint8')
        shape = msk.shape
        g_smx[f.name] = cv2.resize(softmax, (shape[1], shape[0])).astype(np.half)
        del softmax
        gc.collect()
        
        g_std[f.name] = cv2.resize(st_deviation, (shape[1], shape[0])).astype(np.half)
        del st_deviation
        gc.collect
        
        g_eng[f.name] = cv2.resize(energy, (shape[1], shape[0])).astype(np.half)
        del energy
        gc.collect()
        
        # Counting prediction
        pred = (g_smx[f.name][:]>cfg.threshold).astype('uint8')
        _, comps = cv2.connectedComponents(pred, connectivity=4)
        props = regionprops_table(comps, properties=('area', 'slice'))
        df_props = pd.DataFrame(props)

        # Postprocessing
        pred_pp = np.zeros_like(pred)
        for _, row in df_props.iterrows():
            slc = row.slice
            if row.area>cfg.area_threshold:
                tile = binary_fill_holes(pred[slc])
                pred_pp[slc][tile>0] = 1
            else:
                if any([(s.stop-s.start)<100 for s in slc]):
                    slc = expand_slice(slc, 10)
                tile = 0
                while np.sum(tile)==0:
                    slc = expand_slice(slc, 10)
                    tile = postproc(g_smx[f.name][slc])
                if np.sum(tile)>2000:
                    pred_pp[slc][tile>0] = 1
        
        # Recounting prediction
        g_pred[f.name] = pred_pp
        _, comps = cv2.connectedComponents(pred_pp, connectivity=4)
        props_energy = regionprops_table(comps, g_eng[f.name][:],  properties=('mean_intensity', 'area', 'bbox', 'solidity', 'slice'))
        props_std = regionprops_table(comps, g_std[f.name][:],  properties=('mean_intensity',))
        df_props =  pd.DataFrame(props_energy).join(pd.DataFrame(props_std)[['mean_intensity']], rsuffix='_std')
        
        for _, row in df_props.iterrows():
            df_props.loc[_, 'iou'] = iou(msk[row.slice], pred_pp[row.slice])
        
        df_props['idx'] = f.name
        df_props['fold'] = i
        df_props['model'] = MODELS[i]
        
        precision_list.append(df_props)
        
        # Mask vs. Prediction
        _, comps_msk = cv2.connectedComponents(msk, connectivity=4)
        props_msk = regionprops_table(comps_msk, properties=('area', 'bbox', 'slice'))
        df_msk =  pd.DataFrame(props_msk)
        
        for _, row in df_msk.iterrows():
            df_msk.loc[_, 'iou'] = iou(msk[row.slice], pred_pp[row.slice])
        
        df_msk['idx'] = f.name
        df_msk['fold'] = i
        df_msk['model'] = MODELS[i]
        recall_list.append(df_props)
        
        # Calc prediction metrics 
        iou_score_pp = iou(pred_pp,msk)
        dice_score_pp = 2*iou_score_pp/(iou_score_pp+1)
        
        iou_score = iou(pred,msk)
        dice_score = 2*iou_score/(iou_score+1)
        
        precision =  len(df_props[df_props.iou>.5])/len(df_props)
        recall = len(df_msk[df_msk.iou>.5])/len(df_msk)
        f1_score = 2 * (precision * recall) / (precision + recall)
        
        tmp = pd.Series({'idx' : f.name,
                         'fold' : i,
                         'model': MODELS[i],
                         'iou': iou_score,
                         'dice': dice_score,
                         'iou_pp': iou_score_pp,
                         'dice_pp': dice_score_pp,
                         'precision' : precision, 
                         'recall' : recall,
                         'f1_score' : f1_score})
        
        metric_list.append(tmp)
        
        del msk, pred, pred_pp
        gc.collect()
        
        if PLOT_OVERVIEW:
            #print('#######################################')
            print('Example with highest energy')
            row = df_props.sort_values('mean_intensity', ascending=False).iloc[0]
            plot_selection(row)
            
            #print('#######################################')
            try: 
                row = df_props[df_props.iou<0.5].sort_values('mean_intensity', ascending=False).iloc[0]
                print('Higher energy but no annotation')
                plot_selection(row)
            except:
                pass
            
            #print('#######################################')
            print('Difficult example:  Low energy with annotation')
            row = df_props[df_props.iou>0.5].sort_values('mean_intensity', ascending=True).iloc[0]
            plot_selection(row)

In [None]:
pd.DataFrame(metric_list).to_csv('cv_metrics.csv', index=False)
pd.concat(precision_list).to_csv('cv_precision.csv', index=False)
pd.concat(recall_list).to_csv('cv_recall.csv', index=False)