In [None]:
# Install deepflash2 and dependencies
import sys
sys.path.append("../input/zarrkaggleinstall")
sys.path.append("../input/segmentation-models-pytorch-install")
!pip install -q --no-deps ../input/deepflash2-lfs
import cv2, torch, zarr, tifffile, pandas as pd, gc
from fastai.vision.all import *
from deepflash2.all import *
import segmentation_models_pytorch as smp

In [None]:
#https://www.kaggle.com/bguberfain/memory-aware-rle-encoding
#with transposed mask
def rle_encode_less_memory(img):
    #the image should be transposed
    pixels = img.T.flatten()
    
    # This simplified method requires first and last pixel to be zero
    pixels[0] = 0
    pixels[-1] = 0
    runs = np.where(pixels[1:] != pixels[:-1])[0] + 2
    runs[1::2] -= runs[::2]
    
    return ' '.join(str(x) for x in runs)

def load_model_weights(model, file, strict=True):
    state = torch.load(file, map_location='cpu')
    stats = state['stats']
    model_state = state['model']
    model.load_state_dict(model_state, strict=strict)
    return model, stats

In [None]:
# https://matjesg.github.io/deepflash2/data.html#BaseDataset
# Handling of different input shapes
@patch
def read_img(self:BaseDataset, *args, **kwargs):
    image = tifffile.imread(args[0])
    if len(image.shape) == 5:
        image = image.squeeze().transpose(1, 2, 0)
    elif image.shape[0] == 3:
        image = image.transpose(1, 2, 0)
    return image

# https://matjesg.github.io/deepflash2/data.html#DeformationField
# Adding normalization (divide by 255)
@patch
def apply(self:DeformationField, data, offset=(0, 0), pad=(0, 0), order=1):
    "Apply deformation field to image using interpolation"
    outshape = tuple(int(s - p) for (s, p) in zip(self.shape, pad))
    coords = [np.squeeze(d).astype('float32').reshape(*outshape) for d in self.get(offset, pad)]
    # Get slices to avoid loading all data (.zarr files)
    sl = []
    for i in range(len(coords)):
        cmin, cmax = int(coords[i].min()), int(coords[i].max())
        dmax = data.shape[i]
        if cmin<0: 
            cmax = max(-cmin, cmax)
            cmin = 0 
        elif cmax>dmax:
            cmin = min(cmin, 2*dmax-cmax)
            cmax = dmax
            coords[i] -= cmin
        else: coords[i] -= cmin
        sl.append(slice(cmin, cmax))    
    if len(data.shape) == len(self.shape) + 1:
        
        ## Channel order change in V12
        tile = np.empty((*outshape, data.shape[-1]))
        for c in range(data.shape[-1]):
            # Adding divide
            tile[..., c] = cv2.remap(data[sl[0],sl[1], c]/255, coords[1],coords[0], interpolation=order, borderMode=cv2.BORDER_REFLECT)
    else:
        tile = cv2.remap(data[sl[0], sl[1]], coords[1], coords[0], interpolation=order, borderMode=cv2.BORDER_REFLECT)
    return tile

In [None]:
class CONFIG():
    
    # data paths
    data_path = Path('../input/hubmap-kidney-segmentation')
    models_path = Path('../input/hubmap-efficient-sampling-deepflash2-train')
    models_file = np.array([x for x in models_path.iterdir() if x.name.startswith('u')])
    
    # deepflash2 dataset (https://matjesg.github.io/deepflash2/data.html#TileDataset)
    scale = 3 # zoom facor (zoom out)
    tile_shape = (512, 512)
    padding = (100,100) # Border overlap for prediction

    # pytorch model (https://github.com/qubvel/segmentation_models.pytorch)
    encoder_name = "efficientnet-b4"
    encoder_weights = None
    in_channels = 3
    classes = 2
    
    # dataloader 
    batch_size = 16
    
    # prediction threshold
    threshold = 0.5
    
cfg = CONFIG()

In [None]:
print(cfg.models_file)
print(len(cfg.models_file))

# Sample submissions for ids
df_sample = pd.read_csv(cfg.data_path/'sample_submission.csv',  index_col='id')
names,preds = [],[]
sub = None

In [None]:
names,preds = [],[]


for idx, _ in df_sample.iterrows():
    print(f'###### File {idx} ######')
    f = cfg.data_path/'test'/f'{idx}.tiff'
    
    # Create deepflash2 dataset (including tiling and file conversion)
    ds = TileDataset([f], scale=cfg.scale, tile_shape=cfg.tile_shape, padding=cfg.padding)
    shape = ds.data[f.name].shape
    print('Shape:', shape)
    
    names.append(idx)
    msk = None
    
    print('Prediction')
    for model_path in cfg.models_file:
        model = smp.Unet(encoder_name=cfg.encoder_name, 
                         encoder_weights=cfg.encoder_weights, 
                         in_channels=cfg.in_channels, 
                         classes=cfg.classes)
        model, stats = load_model_weights(model, model_path)
        batch_tfms = [Normalize.from_stats(*stats)]
        
        dls = DataLoaders.from_dsets(ds, batch_size=cfg.batch_size, after_batch=batch_tfms, shuffle=False, drop_last=False)
        if torch.cuda.is_available(): dls.cuda(), model.cuda()
        
        learn = Learner(dls, model, loss_func='')
        
        # Predict tiles, see https://matjesg.github.io/deepflash2/learner.html#Learner.predict_tiles
        res = learn.predict_tiles(dl=dls.train, path='/kaggle/temp/', use_tta=False, uncertainty_estimates=False)
        if msk is None:
            msk = res[0][f.name][..., 1]
        else:
            msk += res[0][f.name][..., 1]
#         print(msk[:3, :3])
        
        del model, stats, learn
    
    msk = msk/len(cfg.models_file)
    msk = (msk > cfg.threshold).astype(np.uint8)
    
    # Resize image and create RLE
    print('Rezising')
    msk = cv2.resize(msk, (shape[1], shape[0]))
    rle = rle_encode_less_memory(msk)
    preds.append(rle)
        
    # Plot Result
    print('Plotting')
    fig, ax = plt.subplots(figsize=(15,15))
    ax.imshow(cv2.resize(msk, (1024, 1024)))
    plt.show()
    
    # Overwrite store (reduce disk usage)
    _ = [shutil.rmtree(p, ignore_errors=True) for p in Path('/kaggle/temp/').iterdir()]
    _ = [shutil.rmtree(p, ignore_errors=True) for p in Path('/tmp/').iterdir() if p.name.startswith('zarr')]

In [None]:
df = pd.DataFrame({'id':names,'predicted':preds}).set_index('id')
df_sample.loc[df.index.values] = df.values  
df_sample.to_csv('submission.csv')