# HuBMAP - Efficient Sampling Baseline (deepflash2, pytorch, fastai) [sub]

> Submission kernel for model trained with efficient region based sampling. 

# Acknowledgements

- Train Notebook: https://www.kaggle.com/matjes/hubmap-efficient-sampling-deepflash2-train
- Sampling Notebook: https://www.kaggle.com/matjes/hubmap-labels-pdf-0-5-0-25-0-01
- Original Inference Notebook: https://www.kaggle.com/matjes/hubmap-efficient-sampling-deepflash2-sub

Requires deepflash2 (git version), zarr, and segmentation-models-pytorch

In [None]:
#!ls ../input/d/khoongweihao/

In [None]:
!ls ../input/

### Installation and package loading

In [None]:
# Install deepflash2 and dependencies
import sys
sys.path.append("../input/zarrkaggleinstall")
sys.path.append("../input/segmentation-models-pytorch-install")
!pip install -q --no-deps ../input/deepflash2-lfs
import cv2, torch, zarr, tifffile, pandas as pd, gc
from fastai.vision.all import *
from deepflash2.all import *
import deepflash2.tta as tta
import segmentation_models_pytorch as smp

### Helper functions and patches

In [None]:
#https://www.kaggle.com/bguberfain/memory-aware-rle-encoding
#with transposed mask
def rle_encode_less_memory(img):
    #the image should be transposed
    pixels = img.T.flatten()
    
    # This simplified method requires first and last pixel to be zero
    pixels[0] = 0
    pixels[-1] = 0
    runs = np.where(pixels[1:] != pixels[:-1])[0] + 2
    runs[1::2] -= runs[::2]
    
    return ' '.join(str(x) for x in runs)

def load_model_weights(model, file, strict=True):
    state = torch.load(file, map_location='cpu')
    stats = state['stats']
    model_state = state['model']
    model.load_state_dict(model_state, strict=strict)
    return model, stats

Patches for deepflash2 classes, see https://fastcore.fast.ai/basics.html#patch

In [None]:
# https://matjesg.github.io/deepflash2/data.html#BaseDataset
# Handling of different input shapes
@patch
def read_img(self:BaseDataset, *args, **kwargs):
    image = tifffile.imread(args[0])
    if len(image.shape) == 5:
        image = image.squeeze().transpose(1, 2, 0)
    elif image.shape[0] == 3:
        image = image.transpose(1, 2, 0)
    return image

# https://matjesg.github.io/deepflash2/data.html#DeformationField
# Adding normalization (divide by 255)
@patch
def apply(self:DeformationField, data, offset=(0, 0), pad=(0, 0), order=1):
    "Apply deformation field to image using interpolation"
    outshape = tuple(int(s - p) for (s, p) in zip(self.shape, pad))
    coords = [np.squeeze(d).astype('float32').reshape(*outshape) for d in self.get(offset, pad)]
    # Get slices to avoid loading all data (.zarr files)
    sl = []
    for i in range(len(coords)):
        cmin, cmax = int(coords[i].min()), int(coords[i].max())
        dmax = data.shape[i]
        if cmin<0: 
            cmax = max(-cmin, cmax)
            cmin = 0 
        elif cmax>dmax:
            cmin = min(cmin, 2*dmax-cmax)
            cmax = dmax
            coords[i] -= cmin
        else: coords[i] -= cmin
        sl.append(slice(cmin, cmax))    
    if len(data.shape) == len(self.shape) + 1:
        
        ## Channel order change in V12
        tile = np.empty((*outshape, data.shape[-1]))
        for c in range(data.shape[-1]):
            # Adding divide
            tile[..., c] = cv2.remap(data[sl[0],sl[1], c]/255, coords[1],coords[0], interpolation=order, borderMode=cv2.BORDER_REFLECT)
    else:
        tile = cv2.remap(data[sl[0], sl[1]], coords[1], coords[0], interpolation=order, borderMode=cv2.BORDER_REFLECT)
    return tile

### Configuration

In [None]:
class CONFIG():
    
    # data paths
    data_path = Path('../input/hubmap-kidney-segmentation')
    model_file1 = '../input/hubmap-deepflash-weights/unet_efficientnet-b0-morph-iou0.9015-dice0.9482.pth'
    model_file2 = '../input/hubmap-deepflash-weights/unet_efficientnet-b2-morph-iou0.8993-dice0.9470.pth'
    model_file3 = '../input/hubmap-deepflash-weights/unet_efficientnet-b5-morph-iou0.9021-dice0.9485.pth'
    model_file4 = '../input/hubmap-deepflash-weights/unet_timm-resnest101e_iou0.9051_dice0.9502.pth'
    model_file5 = '../input/hubmap-efficient-sampling-deepflash2-train/unet_efficientnet-b4.pth'
    model_file6 = '../input/hubmap-deepflash-weights/unet_efficientnet-b1-morph-iou0.8981-dice0.9463.pth'
    model_file7 = '../input/hubmap-deepflash-weights/unet_efficientnet-b3-morph-iou0.9038-dice0.9494.pth'
    
    # deepflash2 dataset (https://matjesg.github.io/deepflash2/data.html#TileDataset)
    scale = 3 # zoom facor (zoom out)
    tile_shape = (512, 512)
    padding = (100,100) # Border overlap for prediction

    # pytorch model (https://github.com/qubvel/segmentation_models.pytorch)
    encoder_name1 = "efficientnet-b0"
    encoder_name2 = "efficientnet-b2"
    encoder_name3 = "efficientnet-b5"
    encoder_name4 = "timm-resnest101e"
    encoder_name5 = "efficientnet-b4"
    encoder_name6 = "efficientnet-b1"
    encoder_name7 = "efficientnet-b3"
    encoder_weights = None
    in_channels = 3
    classes = 2
    
    # dataloader 
    batch_size = 16 #16
    
    # prediction threshold
    threshold = 0.4
    
cfg = CONFIG()

In [None]:
# Sample submissions for ids
df_sample = pd.read_csv(cfg.data_path/'sample_submission.csv',  index_col='id')

# Model (see https://github.com/qubvel/segmentation_models.pytorch)
model1 = smp.Unet(encoder_name=cfg.encoder_name1, 
                 encoder_weights=cfg.encoder_weights, 
                 in_channels=cfg.in_channels, 
                 classes=cfg.classes)
model2 = smp.Unet(encoder_name=cfg.encoder_name2, 
                 encoder_weights=cfg.encoder_weights, 
                 in_channels=cfg.in_channels, 
                 classes=cfg.classes)
model3 = smp.Unet(encoder_name=cfg.encoder_name3, 
                 encoder_weights=cfg.encoder_weights, 
                 in_channels=cfg.in_channels, 
                 classes=cfg.classes)
model4 = smp.Unet(encoder_name=cfg.encoder_name4, 
                 encoder_weights=cfg.encoder_weights, 
                 in_channels=cfg.in_channels, 
                 classes=cfg.classes)
model5 = smp.Unet(encoder_name=cfg.encoder_name5, 
                 encoder_weights=cfg.encoder_weights, 
                 in_channels=cfg.in_channels, 
                 classes=cfg.classes)
model6 = smp.Unet(encoder_name=cfg.encoder_name6, 
                 encoder_weights=cfg.encoder_weights, 
                 in_channels=cfg.in_channels, 
                 classes=cfg.classes)
model7 = smp.Unet(encoder_name=cfg.encoder_name7, 
                 encoder_weights=cfg.encoder_weights, 
                 in_channels=cfg.in_channels, 
                 classes=cfg.classes)
model1, stats1 = load_model_weights(model1, cfg.model_file1)
model2, stats2 = load_model_weights(model2, cfg.model_file2)
model3, stats3 = load_model_weights(model3, cfg.model_file3)
model4, stats4 = load_model_weights(model4, cfg.model_file4)
model5, stats5 = load_model_weights(model5, cfg.model_file5)
model6, stats6 = load_model_weights(model6, cfg.model_file6)
model7, stats7 = load_model_weights(model7, cfg.model_file7)
batch_tfms1 = [Normalize.from_stats(*stats1)]
batch_tfms2 = [Normalize.from_stats(*stats2)]
batch_tfms3 = [Normalize.from_stats(*stats3)]
batch_tfms4 = [Normalize.from_stats(*stats4)]
batch_tfms5 = [Normalize.from_stats(*stats5)]
batch_tfms6 = [Normalize.from_stats(*stats6)]
batch_tfms7 = [Normalize.from_stats(*stats7)]

models = [
    (model1, batch_tfms1),
    (model2, batch_tfms2),
    (model3, batch_tfms3),
    #(model4, batch_tfms4),
    (model5, batch_tfms5),
    (model6, batch_tfms6),
    (model7, batch_tfms7),
]

In [None]:
print(len(models))

In [None]:
#!ls ../input/d/khoongweihao

### Prediction

In [None]:
names,preds = [],[]


for idx, _ in df_sample.iterrows():
    print(f'###### File {idx} ######')
    f = cfg.data_path/'test'/f'{idx}.tiff'
    
    # Create deepflash2 dataset (including tiling and file conversion)
    ds = TileDataset([f], scale=cfg.scale, tile_shape=cfg.tile_shape, padding=cfg.padding)
    shape = ds.data[f.name].shape
    print('Shape:', shape)
    
    msk = None
    
    for i, m in enumerate(models):
        model = m[0]
        batch_tfms = m[1]
        
        # Create fastai dataloader and learner
        dls = DataLoaders.from_dsets(ds, batch_size=cfg.batch_size, after_batch=batch_tfms, shuffle=False, drop_last=False)
        if torch.cuda.is_available(): dls.cuda(), model.cuda()
        learn = Learner(dls, model, loss_func='')

        # Predict tiles, see https://matjesg.github.io/deepflash2/learner.html#Learner.predict_tiles
        print('Prediction')
        res = learn.predict_tiles(
            dl=dls.train, 
            path='/kaggle/temp/', 
            n_times=2,
            use_tta=True, 
            tta_merge='mean',
            tta_tfms=[
                tta.HorizontalFlip(), 
                tta.Rotate90(angles=[90,180,270]),
            ],
            uncertainty_estimates=False
        )

        # Load mask from softmax prediction > threshold
        th = 0.2 if idx=='d488c759a' else cfg.threshold
        print(th)
        
        if i == 0:
            msk = res[0][f.name][..., 1]/len(models)
        else:
            msk += res[0][f.name][..., 1]/len(models)
        print(f'Model {i} done!')
        
    msk = (msk>th).astype(np.uint8)
        
    print('Rezising')
    msk = cv2.resize(msk, (shape[1], shape[0]))
    rle = rle_encode_less_memory(msk)
    names.append(idx)
    preds.append(rle)
    
    # Plot Result
    print('Plotting')
    #fig, ax = plt.subplots(figsize=(15,15))
    #ax.imshow(cv2.resize(res[1][f.name][:].astype(np.uint8), (1024, 1024)))
    #plt.show()

    # Overwrite store (reduce disk usage)
    _ = [shutil.rmtree(p, ignore_errors=True) for p in Path('/kaggle/temp/').iterdir()]
    _ = [shutil.rmtree(p, ignore_errors=True) for p in Path('/tmp/').iterdir() if p.name.startswith('zarr')]

### Submission

In [None]:
df = pd.DataFrame({'id':names,'predicted':preds}).set_index('id')
df_sample.loc[df.index.values] = df.values  
df_sample.to_csv('submission.csv')
display(df_sample)