In [None]:
from fastai.vision.all import *
import gc

In [None]:
datapath = Path("/kaggle/input/hubmap-kidney-segmentation/")

In [None]:
img_files = get_files(datapath/'train', extensions=['.tiff'])
test_img_files = get_files(datapath/'test', extensions=['.tiff'])

In [None]:
trn_map = dict(zip(img_files.map(lambda o:o.name), ['train']*len(img_files)))
test_map = dict(zip(test_img_files.map(lambda o:o.name), ['test']*len(img_files)))
trn_test_map = {**trn_map, **test_map}

In [None]:
unique_ids = img_files.map(lambda o: o.stem.split("_")[0]).unique(); unique_ids

In [None]:
train_df = pd.read_csv(datapath/'train.csv')
meta_df = pd.read_csv(datapath/'HuBMAP-20-dataset_information.csv')

In [None]:
meta_df['split'] = meta_df['image_file'].map(trn_test_map)

In [None]:
meta_df.sort_values('patient_number')

In [None]:
meta_df.groupby(['patient_number','image_file', 'split'])[['split']].count()

### Utils

In [None]:
def enc2mask(encs, shape):
    img = np.zeros(shape[0]*shape[1], dtype=np.uint8)
    for m,enc in enumerate(encs):
        if isinstance(enc,np.float) and np.isnan(enc): continue
        s = enc.split()
        for i in range(len(s)//2):
            start = int(s[2*i]) - 1
            length = int(s[2*i+1])
            img[start:start+length] = 1 + m
    return img.reshape(shape).T

def rle_encode_less_memory(img):
    pixels = img.T.flatten()
    pixels[0] = 0
    pixels[-1] = 0
    runs = np.where(pixels[1:] != pixels[:-1])[0] + 2
    runs[1::2] -= runs[::2]
    return ' '.join(str(x) for x in runs)

def make_grid(shape, window=1024, min_overlap=32):
    """
        Return Array of size (N,4), where N - number of tiles,
        2nd axis represente slices: x1,x2,y1,y2 
    """
#     import pdb; pdb.set_trace()
    x, y = shape
    nx = x // (window - min_overlap) + 1
    x1 = np.linspace(0, x, num=nx, endpoint=False, dtype=np.int64)
    x1[-1] = x - window
    x2 = (x1 + window).clip(0, x)
    
    ny = y // (window - min_overlap) + 1
    y1 = np.linspace(0, y, num=ny, endpoint=False, dtype=np.int64)
    y1[-1] = y - window
    y2 = (y1 + window).clip(0, y)
    
    slices = np.zeros((nx,ny, 4), dtype=np.int64)
    
    for i in range(nx):
        for j in range(ny):
            slices[i,j] = x1[i], x2[i], y1[j], y2[j]    
    return slices.reshape(nx*ny,4)

### Datasets

In [None]:
import rasterio
from rasterio.windows import Window
import cv2

WINDOW = 1536
MIN_OVERLAP = 128
NEW_SIZE = 512

In [None]:
# image datasets
identity = rasterio.Affine(1, 0, 0, 0, 1, 0)
id2dataset = {_id : rasterio.open(datapath/'train'/f"{_id}.tiff", transform=identity) for _id in unique_ids}

# image masks
id2rle = dict(zip(train_df['id'], train_df['encoding']))
id2mask = {_id:enc2mask([rle], id2dataset[_id].shape[::-1]) for _id,rle in id2rle.items()}

# (dataset id, slices array)
id_slices = []
for _id, dataset in id2dataset.items():
    slices = make_grid(dataset.shape, window=WINDOW, min_overlap=MIN_OVERLAP)
    id_slices += list(zip([_id]*len(slices), slices))

In [None]:
plt.imshow(id2mask['2f6ecfcdf'][7000:8500, 15000:16500])

In [None]:
image = id2dataset['2f6ecfcdf'].read([1,2,3], window=Window.from_slices((7000,8500), (15000,16500)))
TensorImage(tensor(image)).show()

In [None]:
id_slices[:10], len(id_slices)

In [None]:
# tfms
def read_tile(i, id_slices):
    _id, (x1,x2,y1,y2) = id_slices[i]
    image = id2dataset[_id].read([1,2,3], window=Window.from_slices((x1,x2),(y1,y2)))
    image = np.moveaxis(image, 0, -1)
    image = cv2.resize(image, (NEW_SIZE, NEW_SIZE),interpolation = cv2.INTER_AREA)
    image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
    image = (image2tensor(image)/255.)
    return TensorImage(image)


def read_mask(i, id_slices):
    _id, (x1,x2,y1,y2) = id_slices[i]
    mask = id2mask[_id][x1:x2, y1:y2]
    mask = cv2.resize(mask, (NEW_SIZE, NEW_SIZE), interpolation=cv2.INTER_NEAREST)
    return TensorMask(mask)

In [None]:
items = range(len(id_slices))
dsets = Datasets(items, 
                 tfms=[[partial(read_tile, id_slices=id_slices)], 
                       [partial(read_mask, id_slices=id_slices)]]
                )
len(dsets)

In [None]:
dls = dsets.dataloaders(bs=4,     
                        batch_tfms=[Dihedral(p=0.5), 
                            Rotate(p=0.5, max_deg=30), 
                            Brightness(p=0.5, max_lighting=0.3, batch=False)],
                        splits=RandomSplitter(0.1)(items))

In [None]:
xb,yb = dls.one_batch()

In [None]:
xb.shape, yb.shape

### Model

In [None]:
# https://github.com/fastai/fastai/issues/3041
def flatten_check(inp, targ):
    "Check that `out` and `targ` have the same number of elements and flatten them."
    inp,targ = inp.contiguous().view(-1),targ.contiguous().view(-1)
    test_eq(len(inp), len(targ))
    return inp,targ
    
class Dice(Metric):
    "Dice coefficient metric for binary target in segmentation"
    def __init__(self, thresh=0.5): store_attr()
    def reset(self): self.inter,self.union = 0,0
    def accumulate(self, learn):
        pred,targ = flatten_check(learn.pred.sigmoid().squeeze(1)>self.thresh, learn.y)
        pred, targ = TensorBase(pred), TensorBase(targ)
        self.inter += (pred*targ).float().sum().item()
        self.union += (pred+targ).float().sum().item()

    @property
    def value(self): return 2. * self.inter/self.union if self.union > 0 else None

Here I preferred softmax since using argmax is easier than setting a threhsold after sigmoid. To my knowledge and the papers I have seen in medical domain says that ImageNet transfer learning help close to nothing, so I will ignore it here to keep things clean and simple.

In [None]:
loss_func = BCEWithLogitsLossFlat()

In [None]:
sqrmom=0.99
mom=0.95
beta=0.
eps=1e-4
opt_func = partial(ranger, mom=mom, sqr_mom=sqrmom, eps=eps, beta=beta)

In [None]:
learner = unet_learner(dls,
                       xresnet34,
                       loss_func=loss_func,
                       opt_func=opt_func,
                       metrics=[Dice(thresh=0.5)], 
                       normalize=False, 
                       pretrained=False,
                       n_out=1)
learner.to_native_fp16(); # little bit faster compared to fp_16() - thanks to ilovescience's experiments

In [None]:
learner.fit_flat_cos(1)

In [None]:
%debug

In [None]:
ds = rasterio.open("/kaggle/input/hubmap-kidney-segmentation/train/cb2d976f4.tiff")

In [None]:
ds.read([1,2,3], window=((0,100), (0,100)))

In [None]:
for FOLD in range(8):
    dls = get_dls(FOLD)
    learner = get_learner(dls)
    learner.fit_flat_cos(30, lr=1e-3, cbs=[SaveModelCallback("dice", fname=f'xresunet34_fold{FOLD}')])
    del learner
    gc.collect()