## Overview

In [None]:
!cd models; ls

In [None]:
from fastai.conv_learner import *
from fastai.dataset import *

import pandas as pd
import numpy as np
import os
from PIL import Image
from sklearn.model_selection import train_test_split

### Data + Pretrained code

In [None]:
PATH = './'
TRAIN = '../input/airbus-ship-detection/train_v2/'
TEST = '../input/airbus-ship-detection/test_v2/'
SEGMENTATION = '../input/airbus-ship-detection/train_ship_segmentations_v2.csv'
# PRETRAINED = '../input/fine-tuning-resnet34-on-ship-detection/models/Resnet34_lable_256_1.h5'
exclude_list = ['6384c3e78.jpg','13703f040.jpg', '14715c06d.jpg',  '33e0ff2d5.jpg',
                '4d4e09f2a.jpg', '877691df8.jpg', '8b909bb20.jpg', 'a8d99130e.jpg', 
                'ad55c3143.jpg', 'c8260c541.jpg', 'd6c7f17c7.jpg', 'dc3e7c901.jpg',
                'e44dffe88.jpg', 'ef87bad36.jpg', 'f083256d8.jpg'] #corrupted images

In [None]:
tr_arr = np.array(pd.read_csv('../input/traintestval-airbus/train_df.csv')['0'].reset_index(drop=True))
val_arr = np.array(pd.read_csv('../input/traintestval-airbus/val_df.csv')['0'].reset_index(drop=True))
test_arr = np.array(pd.read_csv('../input/traintestval-airbus/test_df.csv')['0'].reset_index(drop=True))

In [None]:
nw = 2   #number of workers for data loader
arch = resnet34 #specify target architecture

In [None]:
for el in exclude_list:
    if(el in tr_arr): tr_arr.remove(el)
    if(el in val_arr): val_arr.remove(el)
    if(el in test_arr): test_arr.remove(el)
segmentation_df = pd.read_csv(os.path.join(PATH, SEGMENTATION)).set_index('ImageId')

In [None]:
tr_n = tr_arr
val_n = val_arr
test_n = test_arr

In [None]:
def cut_empty(names):
    return [name for name in names 
            if(type(segmentation_df.loc[name]['EncodedPixels']) != float)]

tr_n = cut_empty(tr_n)
val_n = cut_empty(val_n)
test_n = cut_empty(test_n)

In [None]:
total_n = tr_n + val_n

In [None]:
def get_mask(img_id, df):
    shape = (768,768)
    img = np.zeros(shape[0]*shape[1], dtype=np.uint8)
    masks = df.loc[img_id]['EncodedPixels']
    if(type(masks) == float): return img.reshape(shape)
    if(type(masks) == str): masks = [masks]
    for mask in masks:
        s = mask.split()
        for i in range(len(s)//2):
            start = int(s[2*i]) - 1
            length = int(s[2*i+1])
            img[start:start+length] = 1
    return img.reshape(shape).T

def Show_images(x,yp,yt):
    columns = 3
    rows = min(bs,8)
    fig=plt.figure(figsize=(columns*4, rows*4))
    for i in range(rows):
        fig.add_subplot(rows, columns, 3*i+1)
        plt.axis('off')
        plt.imshow(x[i])
        fig.add_subplot(rows, columns, 3*i+2)
        plt.axis('off')
        plt.imshow(yp[i])
        fig.add_subplot(rows, columns, 3*i+3)
        plt.axis('off')
        plt.imshow(yt[i])
    plt.show()

In [None]:
class pdFilesDataset(FilesDataset):
    def __init__(self, fnames, path, transform):
        self.segmentation_df = pd.read_csv(SEGMENTATION).set_index('ImageId')
        super().__init__(fnames, transform, path)
    
    def get_x(self, i):
        img = open_image(os.path.join(self.path, self.fnames[i]))
        if self.sz == 768: return img 
        else: return cv2.resize(img, (self.sz, self.sz))
    
    def get_y(self, i):
        mask = np.zeros((768,768), dtype=np.uint8) if (self.path == TEST) \
            else get_mask(self.fnames[i], self.segmentation_df)
        img = Image.fromarray(mask).resize((self.sz, self.sz)).convert('RGB')
        return np.array(img).astype(np.float32)
    
    def get_c(self): return 0
     

def get_data(sz,bs):
    #data augmentation
    aug_tfms = [RandomRotate(45,p=0.75, tfm_y=TfmType.CLASS)]
    tfms = tfms_from_model(arch, sz, crop_type=CropType.NO, tfm_y=TfmType.CLASS,aug_tfms=aug_tfms)
    tr_names = tr_n if (len(tr_n)%bs == 0) else tr_n[:-(len(tr_n)%bs)] #cut incomplete batch
    ds = ImageData.get_ds(pdFilesDataset, (tr_names,TRAIN), (val_n,TRAIN), tfms, test=(test_n,TRAIN))
    md = ImageData(PATH, ds, bs, num_workers=nw, classes=None)
    #md.is_multi = False
    return md

def get_data_no_aug(sz,bs):
    #data augmentation
    tfms = tfms_from_model(arch, sz, crop_type=CropType.NO, tfm_y=TfmType.CLASS)
    tr_names = tr_n if (len(tr_n)%bs == 0) else tr_n[:-(len(tr_n)%bs)] #cut incomplete batch
    ds = ImageData.get_ds(pdFilesDataset, (tr_names,TRAIN), (val_n,TRAIN), tfms, test=(test_n,TRAIN))
    md = ImageData(PATH, ds, bs, num_workers=nw, classes=None)
    #md.is_multi = False
    return md

# Use test set as val
def get_data_test_no_aug(sz,bs):
    tfms = tfms_from_model(arch, sz, crop_type=CropType.NO, tfm_y=TfmType.CLASS)
    tr_names = tr_n if (len(tr_n)%bs == 0) else tr_n[:-(len(tr_n)%bs)] #cut incomplete batch
    ds = ImageData.get_ds(pdFilesDataset, (tr_names,TRAIN), (test_n,TRAIN), tfms, test=(test_n,TRAIN))
    md = ImageData(PATH, ds, bs, num_workers=nw, classes=None)
    #md.is_multi = False
    return md

def get_data_test_aug(sz,bs):
    #data augmentation
    aug_tfms = [RandomRotate(45,p=0.75, tfm_y=TfmType.CLASS)]
    tfms = tfms_from_model(arch, sz, crop_type=CropType.NO, tfm_y=TfmType.CLASS,aug_tfms=aug_tfms)
    tr_names = tr_n if (len(tr_n)%bs == 0) else tr_n[:-(len(tr_n)%bs)] #cut incomplete batch
    ds = ImageData.get_ds(pdFilesDataset, (test_n,TRAIN), (val_n,TRAIN), tfms, test=(test_n,TRAIN))
    md = ImageData(PATH, ds, bs, num_workers=nw, classes=None)
    #md.is_multi = False
    return md

### Model

In [None]:
cut,lr_cut = model_meta[arch]

In [None]:
def get_base():                   #load ResNet34 model
    layers = cut_model(arch(True), cut)
    return nn.Sequential(*layers)

def get_base(pre=True):              #load ResNet34 model
    layers = cut_model(arch(pre), cut)
    return nn.Sequential(*layers)

In [None]:
class UnetBlock(nn.Module):
    def __init__(self, up_in, x_in, n_out):
        super().__init__()
        up_out = x_out = n_out//2
        self.x_conv  = nn.Conv2d(x_in,  x_out,  1)
        self.tr_conv = nn.ConvTranspose2d(up_in, up_out, 2, stride=2)
        self.bn = nn.BatchNorm2d(n_out)
        
    def forward(self, up_p, x_p):
        up_p = self.tr_conv(up_p)
        x_p = self.x_conv(x_p)
        cat_p = torch.cat([up_p,x_p], dim=1)
        return self.bn(F.relu(cat_p))

class SaveFeatures():
    features=None
    def __init__(self, m): self.hook = m.register_forward_hook(self.hook_fn)
    def hook_fn(self, module, input, output): self.features = output
    def remove(self): self.hook.remove()
    
class Unet34(nn.Module):
    def __init__(self, rn):
        super().__init__()
        self.rn = rn
        self.sfs = [SaveFeatures(rn[i]) for i in [2,4,5,6]]
        self.up1 = UnetBlock(512,256,256)
        self.up2 = UnetBlock(256,128,256)
        self.up3 = UnetBlock(256,64,256)
        self.up4 = UnetBlock(256,64,256)
        self.up5 = nn.ConvTranspose2d(256, 1, 2, stride=2)
        
    def forward(self,x):
        x = F.relu(self.rn(x))
        x = self.up1(x, self.sfs[3].features)
        x = self.up2(x, self.sfs[2].features)
        x = self.up3(x, self.sfs[1].features)
        x = self.up4(x, self.sfs[0].features)
        x = self.up5(x)
        return x[:,0]
    
    def close(self):
        for sf in self.sfs: sf.remove()
            
class UnetModel():
    def __init__(self,model,name='Unet'):
        self.model,self.name = model,name

    def get_layer_groups(self, precompute):
        lgs = list(split_by_idxs(children(self.model.rn), [lr_cut]))
        return lgs + [children(self.model)[1:]]

### Loss function

In [None]:
def dice_loss(input, target):
    input = torch.sigmoid(input)
    smooth = 1.0

    iflat = input.view(-1)
    tflat = target.view(-1)
    intersection = (iflat * tflat).sum()
    
    return ((2.0 * intersection + smooth) / (iflat.sum() + tflat.sum() + smooth))

In [None]:
class FocalLoss(nn.Module):
    def __init__(self, gamma):
        super().__init__()
        self.gamma = gamma
        
    def forward(self, input, target):
        if not (target.size() == input.size()):
            raise ValueError("Target size ({}) must be the same as input size ({})"
                             .format(target.size(), input.size()))

        max_val = (-input).clamp(min=0)
        loss = input - input * target + max_val + \
            ((-max_val).exp() + (-input - max_val).exp()).log()

        invprobs = F.logsigmoid(-input * (target * 2.0 - 1.0))
        loss = (invprobs * self.gamma).exp() * loss
        
        return loss.mean()

In [None]:
class MixedLoss(nn.Module):
    def __init__(self, alpha, gamma):
        super().__init__()
        self.alpha = alpha
        self.focal = FocalLoss(gamma)
        
    def forward(self, input, target):
        loss = self.alpha*self.focal(input, target) - torch.log(dice_loss(input, target))
        return loss.mean()

In [None]:
def dice(pred, targs):
    pred = (pred>0).float()
    return 2.0 * (pred*targs).sum() / ((pred+targs).sum() + 1.0)

def IoU(pred, targs):
    pred = (pred>0).float()
    intersection = (pred*targs).sum()
    return intersection / ((pred+targs).sum() - intersection + 1.0)

### Train Augmented Learner

In [None]:
sz = 768 #image size
bs = 8  #batch size

md = get_data(sz,bs)
no_aug_md = get_data_no_aug(sz,bs)

Build model

In [None]:
!cp ../input/traintestval-airbus/Unet34_768_aug_0.h5 ./models/Unet34_768_aug_0.h5

In [None]:
m_base = get_base()
m = to_gpu(Unet34(m_base))
models = UnetModel(m)

In [None]:
learn = ConvLearner(md, models)
learn.load('Unet34_768_aug_0')
learn.opt_fn=optim.Adam
learn.crit = MixedLoss(10.0, 2.0)
learn.metrics=[accuracy_thresh(0.5),dice,IoU]
wd=1e-7
lr = 1e-2

# donnt run this after

In [None]:
learn.freeze_to(1)
learn.fit(lr,1,wds=wd,cycle_len=1,use_clr=(5,8))

In [None]:
learn.save('Unet34_768_aug_0')

 The lr of the head part is still 1e-3, while the middle layers of the model are trained with 1e-4 lr, and the base is trained with even smaller lr, 1e-5, since low level detectors do not vary much from one image data set to another.

# Run this

In [None]:
# Unfreeze and train with differential learning rate
lrs = np.array([lr/100,lr/10,lr])
learn.unfreeze() #unfreeze the encoder
learn.bn_freeze(True)

learn.fit(lrs,2,wds=wd,cycle_len=1,use_clr=(20,8))

In [None]:
learn.sched.plot_lr()

In [None]:
learn.save('Unet34_768_aug_full')

In [None]:
!cd models; ls

## Training non-augmented learner

In [None]:
sz = 768 #image size
bs = 8  #batch size

no_aug_md = get_data_no_aug(sz,bs)

m_base = get_base()
state_na = torch.load('../input/traintestval-airbus/Unet34_768_no_aug_0.h5')
m_base.load_state_dict(state_na, strict=False)
m = to_gpu(Unet34(m_base))

learn_na = ConvLearner(no_aug_md, models)
learn_na.opt_fn=optim.Adam
learn_na.crit = MixedLoss(10.0, 2.0)
learn_na.metrics=[accuracy_thresh(0.5),dice,IoU]
wd=1e-7
lr = 1e-2

In [None]:
# Deprecated
learn_na.freeze_to(1)
learn_na.fit(lr,1,wds=wd,cycle_len=1,use_clr=(5,8))
learn_na.save('Unet34_768_no_aug_0')

In [None]:
# Unfreeze and train with differential learning rate
lrs = np.array([lr/100,lr/10,lr])
learn_na.unfreeze() #unfreeze the encoder
learn_na.bn_freeze(True)

learn_na.fit(lrs,2,wds=wd,cycle_len=1,use_clr=(20,8))

learn_na.save('Unet34_768_no_aug_1')

### Visualization

In [None]:
def Show_images(x,yp,yt):
    columns = 3
    rows = min(bs,8)
    fig=plt.figure(figsize=(columns*4, rows*4))
    for i in range(rows):
        fig.add_subplot(rows, columns, 3*i+1)
        plt.axis('off')
        plt.imshow(x[i])
        fig.add_subplot(rows, columns, 3*i+2)
        plt.axis('off')
        plt.imshow(yp[i])
        fig.add_subplot(rows, columns, 3*i+3)
        plt.axis('off')
        plt.imshow(yt[i])
    plt.show()
    
    
# def Show_images(x,y,x1,y1):
#     columns = 4
#     rows = min(bs,8)
#     fig=plt.figure(figsize=(columns*4, rows*4))
#     for i in range(rows):
#         fig.add_subplot(rows, columns, 4*i+1)
#         plt.axis('off')
#         plt.imshow(x[i])
#         fig.add_subplot(rows, columns, 4*i+2)
#         plt.axis('off')
#         plt.imshow(y[i])
#         fig.add_subplot(rows, columns, 4*i+3)
#         plt.axis('off')
#         plt.imshow(x1[i])
#         fig.add_subplot(rows, columns, 4*i+4)
#         plt.axis('off')
#         plt.imshow(y1[i])
#     plt.show()

In [None]:
sz = 768 #image size
bs = 8  #batch size

md = get_data(sz,bs)
no_aug_md = get_data_no_aug(sz,bs)

In [None]:
x,y= md[0].trn_ds[2]
ox,oy = md[1].trn_ds[2]

In [None]:
plt.imshow(md[0].trn_ds.denorm(x)[0])

In [None]:
plt.imshow(y)

In [None]:
plt.imshow(md[1].trn_ds.denorm(ox)[0])

In [None]:
plt.imshow(oy)

## Test Model

### Evaluate scoring

In [None]:
from scipy import ndimage

In [None]:
def IoU(pred, targs):
    pred = (pred > 0.5).astype(float)
    intersection = (pred*targs).sum()
    return intersection / ((pred+targs).sum() - intersection + 1.0)

def get_iou(pred, true):
    n_masks = len(true)
    n_pred = len(pred)
    tot_iou = []
    
    for mask in true:
        for p in pred: 
            temp_iou = IoU(p,mask)
            tot_iou.append(temp_iou)
    
    return np.mean(np.array(tot_iou))

def get_score(pred, true):
    n_th = 10
    b = 4
    thresholds = [0.5 + 0.05*i for i in range(n_th)]
    n_masks = len(true)
    n_pred = len(pred)
    
    ious = []
    score = 0
    for mask in true:
        buf = []
        for p in pred: 
            buf.append(IoU(p,mask))
        ious.append(buf)
        
    for t in thresholds:   
        tp, fp, fn = 0, 0, 0
        for i in range(n_masks):
            match = False
            for j in range(n_pred):
                if ious[i][j] > t: match = True
            if not match: fn += 1
        
        for j in range(n_pred):
            match = False
            for i in range(n_masks):
                if ious[i][j] > t: match = True
            if match: tp += 1
            else: fp += 1
        score += ((b+1)*tp)/((b+1)*tp + b*fn + fp)
    
    return score/n_th

def split_mask(mask):
    threshold = 0.5
    threshold_obj = 30 #ignor predictions composed of "threshold_obj" pixels or less
    labled,n_objs = ndimage.label(mask > threshold)
    result = []
    for i in range(n_objs):
        obj = (labled == i + 1).astype(int)
        if(obj.sum() > threshold_obj): result.append(obj)
    return result

def get_mask_ind(img_id, df, shape = (768,768)): #return mask for each ship
    masks = df.loc[img_id]['EncodedPixels']
    if(type(masks) == float): return []
    if(type(masks) == str): masks = [masks]
    result = []
    for mask in masks:
        img = np.zeros(shape[0]*shape[1], dtype=np.uint8)
        s = mask.split()
        for i in range(len(s)//2):
            start = int(s[2*i]) - 1
            length = int(s[2*i+1])
            img[start:start+length] = 1
        result.append(img.reshape(shape).T)
    return result

class IoU_eval():
    def __init__(self):
        self.segmentation_df = pd.read_csv(SEGMENTATION).set_index('ImageId')
        self.count = 0
        self.iou = 0.0
        
    def put(self,pred,name):
        true = get_mask_ind(name, self.segmentation_df)
        self.iou += get_iou(pred,true)        
        self.count += 1
        
    def evaluate(self):
        return (self.iou / self.count)

class Score_eval():
    def __init__(self):
        self.segmentation_df = pd.read_csv(SEGMENTATION).set_index('ImageId')
        self.score, self.count = 0.0, 0
        
    def put(self,pred,name):
        true = get_mask_ind(name, self.segmentation_df)
        self.score +=  get_score(pred,true)
        self.count += 1
        
    def evaluate(self):
        return (self.score/self.count)

### TTA

In [None]:
def aug_unit(x,fwd=True,mask=False):
    return x

def aug_flipV(x,fwd=True,mask=False):
    return x.flip(2) if mask else x.flip(3)

def aug_flipH(x,fwd=True,mask=False):
    return x.flip(1) if mask else x.flip(2)

def aug_T(x,fwd=True,mask=False):
    return torch.transpose(x,1,2) if mask else torch.transpose(x,2,3)

def aug_rot_2(x,fwd=True,mask=False): #rotate pi/2
    return aug_flipV(aug_flipH(x,fwd,mask),fwd,mask)

def aug_rot_4cr(x,fwd=True,mask=False): #rotate pi/4 counterclockwise
    return aug_flipV(aug_T(x,fwd,mask),fwd,mask) if fwd else \
        aug_T(aug_flipV(x,fwd,mask),fwd,mask)

def aug_rot_4cw(x,fwd=True,mask=False): #rotate pi/4 clockwise
    return aug_flipH(aug_T(x,fwd,mask),fwd,mask) if fwd else \
        aug_T(aug_flipH(x,fwd,mask),fwd,mask)

def aug_rot_2T(x,fwd=True,mask=False): #transpose and rotate pi/2
    return aug_rot_2(aug_T(x,fwd,mask),fwd,mask)

trms_side_on = [aug_unit,aug_flipH]
trms_top_down = [aug_unit,aug_flipV]
trms_dihedral = [aug_unit,aug_flipH,aug_flipV,aug_T,aug_rot_2,aug_rot_2T,
                 aug_rot_4cw,aug_rot_4cr]
def enc_img(img):
    return torch.transpose(torch.tensor(img),0,2).unsqueeze(0)

def dec_img(img):
    return to_np(torch.transpose(img.squeeze(0),0,2))

def display_augs(x,augs=aug_unit):
    columns = 4
    n = len(augs)
    rows = n//4 + 1
    fig=plt.figure(figsize=(columns*4, rows*4))
    img = enc_img(x)
    for i in range(rows):
        for j in range(columns):
            idx = j+i*columns
            if idx >= n: break
            fig.add_subplot(rows, columns, idx+1)
            plt.axis('off')
            plt.imshow(dec_img(augs[idx](img)))
    plt.show()

In [None]:
def model_pred(learner, dl, F_save): #if use train dl, disable shuffling
    learner.model.eval();
    name_list = dl.dataset.fnames
    num_batchs = len(dl)
    t = tqdm(iter(dl), leave=False, total=num_batchs)
    count = 0
    for x,y in t:
        py = to_np(torch.sigmoid(learn.model(V(x))))
        batch_size = len(py)
        for i in range(batch_size):
            F_save(py[i],to_np(y[i]),name_list[count])
            count += 1

### Test augmented model

In [None]:
sz = 768 #image size
bs = 8  #batch size
test_md = get_data_test_no_aug(sz,bs)

test_aug_md = get_data_test_aug(sz,bs)

In [None]:
!cp ../input/traintestval-airbus/Unet34_768_no_aug_1.h5 ./models/Unet34_768_no_aug_1.h5

In [None]:
# Test Aug on no aug

m = to_gpu(Unet34(get_base(False)))
models = UnetModel(m)

learn = ConvLearner(test_md, models)

learn.load('Unet34_768_no_aug_1')

In [None]:
score = Score_eval()
process_pred = lambda yp, y, name : score.put(split_mask(yp),name)

iou_score = IoU_eval()
process_iou_pred = lambda yp, y, name : iou_score.put(split_mask(yp),name)

In [None]:
model_pred(learn, test_md.val_dl, process_pred)
print('\n',score.evaluate())

In [None]:
def IoU(pred, targs):
    pred = (pred>0).float()
    intersection = (pred*targs).sum()
    return intersection / ((pred+targs).sum() - intersection + 1.0)

### Mean iou pixle pred starts here

In [None]:
sz = 768 #image size
bs = 8  #batch size
test_md = get_data_test_no_aug(sz,bs)

test_aug_md = get_data_test_aug(sz,bs)

# Test Aug on no aug
m = to_gpu(Unet34(get_base(False)))
models = UnetModel(m)

learn = ConvLearner(test_md, models)

learn.load('Unet34_768_aug_2')

In [None]:
def get_metrics(y_pred, gt):
    y_pred = y_pred.astype(int)
    y_pred = y_pred.squeeze()
    
    y_true = np.array(gt)
    y_true = y_true.squeeze()
        
    tp = np.sum((y_true == 1) & (y_pred == 1))
    tn = np.sum((y_true == 0) & (y_pred == 0))
    fp = np.sum((y_true == 0) & (y_pred == 1))
    fn = np.sum((y_true == 1) & (y_pred == 0))
    
#     print(tp,tn,fp,fn,y_true.size,tp + tn + fp + fn)
    
    assert tp + tn + fp + fn == y_true.size

    n = tp + fp + tn + fn

    acc = (tp+tn)/n

    if (tp + fn == 0 or tp + fn + fp == 0):
        recall = 1
        precision = 1
        iou = 1
    else:
        recall = (tp/(tp+fn))
        precision = tp / (tp + fp)
        iou = tp /(tp + fn + fp)
    return iou

In [None]:
# RES DICT
res_dict = []

# RUN IT UP TURBO
for r,(img, mask) in tqdm(enumerate(test_aug_md.trn_dl)):

    learn.model.eval();
    yp = to_np(F.sigmoid(learn.model(V(img))))
    yp[yp >= 0.5] = 1
    yp[yp < 0.5] = 0
    
    metrics = get_metrics(yp, to_np(mask))

    res_dict.append(metrics)

In [None]:
np.mean(np.array(res_dict))

In [None]:
m_base = get_base()
m = to_gpu(Unet34(m_base))
models = UnetModel(m)
learn = ConvLearner(md, models)

In [None]:
!cd models; ls

In [None]:
learn.load('Unet34_768_no_aug_1')

In [None]:
learn.model.eval();
x,y = next(iter(test_md.val_dl))
yp = to_np(F.sigmoid(learn.model(V(x))))
yp[yp >= 0.5] = 1
yp[yp < 0.5] = 0

print(get_metrics(yp, to_np(y)))

Show_images(np.asarray(test_md.val_ds.denorm(x)), yp, y)


# RotEqNet

In [None]:
def ntuple(n):
    """ Ensure that input has the correct number of elements """
    def parse(x):
        if isinstance(x, collections.Iterable):
            return x
        return tuple(itertools.repeat(x, n))
    return parse

def getGrid(siz):
    """ Returns grid with coordinates from -siz[0]/2 : siz[0]/2, -siz[1]/2 : siz[1]/2, ...."""
    space = [np.linspace( -(N/2), (N/2), N ) for N in siz]
    mesh = np.meshgrid( *space, indexing='ij' )
    mesh = [np.expand_dims( ax.ravel(), 0) for ax in mesh]

    return np.concatenate(mesh)

def rotate_grid_2D(grid, theta):
    """ Rotate grid """
    theta = np.deg2rad(theta)

    x0 = grid[0, :] * np.cos(theta) - grid[1, :] * np.sin(theta)
    x1 = grid[0, :] * np.sin(theta) + grid[1, :] * np.cos(theta)

    grid[0, :] = x0
    grid[1, :] = x1
    return grid

def rotate_grid_3D(theta, axis, grid):
    """ Rotate grid """
    theta = np.deg2rad(theta)
    axis = np.array(axis)
    rot_mat = expm(np.cross(np.eye(3), axis / norm(axis) * theta))
    rot_mat  =np.expand_dims(rot_mat,2)
    grid = np.transpose( np.expand_dims(grid,2), [0,2,1])

    return np.einsum('ijk,jik->ik',rot_mat,grid)


def get_filter_rotation_transforms(kernel_dims, angles):
    """ Return the interpolation variables needed to transform a filter by a given number of degrees """

    dim = len(kernel_dims)

    # Make grid (centered around filter-center)
    grid = getGrid(kernel_dims)

    # Rotate grid
    if dim == 2:
        grid = rotate_grid_2D(grid, angles)
    elif dim == 3:
        grid = rotate_grid_3D(angles[0], [1, 0, 0], grid)
        grid = rotate_grid_3D(angles[1], [0, 0, 1], grid)


    # Radius of filter
    radius = np.min((np.array(kernel_dims)-1) / 2.)

    #Mask out samples outside circle
    radius = np.expand_dims(radius,-1)
    dist_to_center = np.sqrt(np.sum(grid**2,axis=0))
    mask = dist_to_center>=radius+.0001
    mask = 1-mask

    # Move grid to center
    grid += radius

    return compute_interpolation_grids(grid, kernel_dims, mask)

def compute_interpolation_grids(grid, kernel_dims, mask):

    #######################################################
    # The following part is part of nd-linear interpolation

    #Add a small eps to grid so that floor and ceil operations become more stable
    grid += 0.000000001

    # Make list where each element represents a dimension
    grid = [grid[i, :] for i in range(grid.shape[0])]

    # Get left and right index (integers)
    inds_0 = [ind.astype(np.integer) for ind in grid]
    inds_1 = [ind + 1 for ind in inds_0]

    # Get weights
    weights = [float_ind - int_ind for float_ind, int_ind in zip(grid, inds_0)]

    # Special case for when ind_1 == size (while ind_0 == siz)
    # In that case we select ind_0
    ind_1_out_of_bounds = np.logical_or.reduce([ind == siz for ind, siz in zip(inds_1, kernel_dims)])
    for i in range(len(inds_1)):
        inds_1[i][ind_1_out_of_bounds] = 0


    # Get samples that are out of bounds or outside mask
    inds_out_of_bounds = np.logical_or.reduce([ind < 0 for ind in itertools.chain(inds_0, inds_1)] + \
                                              [ind >= siz for ind, siz in zip(inds_0, kernel_dims)] + \
                                              [ind >= siz for ind, siz in zip(inds_1, kernel_dims)] +
                                              (1-mask).astype('bool')
                                              )


    # Set these samples to zero get data from upper-left-corner (which will be put to zero)
    for i in range(len(inds_0)):
        inds_0[i][inds_out_of_bounds] = 0
        inds_1[i][inds_out_of_bounds] = 0

    #Reshape
    inds_0 = [np.reshape(ind,[1,1]+kernel_dims) for ind in inds_0]
    inds_1 = [np.reshape(ind,[1,1]+kernel_dims) for ind in inds_1]
    weights = [np.reshape(weight,[1,1]+kernel_dims)for weight in weights]

    #Make pytorch-tensors of the interpolation variables
    inds_0 = [Variable(torch.LongTensor(ind)) for ind in inds_0]
    inds_1 = [Variable(torch.LongTensor(ind)) for ind in inds_1]
    weights = [Variable(torch.FloatTensor(weight)) for weight in weights]

    #Make mask pytorch tensor
    mask = mask.reshape(kernel_dims)
    mask = mask.astype('float32')
    mask = np.expand_dims(mask, 0)
    mask = np.expand_dims(mask, 0)
    mask = torch.FloatTensor(mask)

    # Uncomment for nearest interpolation (for debugging)
    #inds_1 = [ind*0 for ind in inds_1]
    #weights  = [weight*0 for weight in weights]

    return inds_0, inds_1, weights, mask

def apply_transform(filter, interp_vars, filters_size, old_bilinear_interpolation=True):
    """ Apply a transform specified by the interpolation_variables to a filter """

    dim = 2 if len(filter.size())==4 else 3

    if dim == 2:


        if old_bilinear_interpolation:
            [x0_0, x1_0], [x0_1, x1_1], [w0, w1] = interp_vars
            rotated_filter = (filter[:, :, x0_0, x1_0] * (1 - w0) * (1 - w1) +
                          filter[:, :, x0_1, x1_0] * w0 * (1 - w1) +
                          filter[:, :, x0_0, x1_1] * (1 - w0) * w1 +
                          filter[:, :, x0_1, x1_1] * w0 * w1)
        else:

            # Expand dimmentions to fit filter
            interp_vars = [[inner_el.expand_as(filter) for inner_el in outer_el] for outer_el in interp_vars]

            [x0_0, x1_0], [x0_1, x1_1], [w0, w1] = interp_vars

            a = torch.gather(torch.gather(filter, 2, x0_0), 3, x1_0) * (1 - w0) * (1 - w1)
            b = torch.gather(torch.gather(filter, 2, x0_1), 3, x1_0)* w0 * (1 - w1)
            c = torch.gather(torch.gather(filter, 2, x0_0), 3, x1_1)* (1 - w0) * w1
            d = torch.gather(torch.gather(filter, 2, x0_1), 3, x1_1)* w0 * w1
            rotated_filter = a+b+c+d

        rotated_filter = rotated_filter.view(filter.size()[0],filter.size()[1],filters_size[0],filters_size[1])

    elif dim == 3:
        [x0_0, x1_0, x2_0], [x0_1, x1_1, x2_1], [w0, w1, w2] = interp_vars

        rotated_filter = (filter[x0_0, x1_0, x2_0] * (1 - w0) * (1 - w1)* (1 - w2) +
                          filter[x0_1, x1_0, x2_0] * w0       * (1 - w1)* (1 - w2) +
                          filter[x0_0, x1_1, x2_0] * (1 - w0) * w1      * (1 - w2) +
                          filter[x0_1, x1_1, x2_0] * w0       * w1      * (1 - w2) +
                          filter[x0_0, x1_0, x2_1] * (1 - w0) * (1 - w1)* w2 +
                          filter[x0_1, x1_0, x2_1] * w0       * (1 - w1)* w2 +
                          filter[x0_0, x1_1, x2_1] * (1 - w0) * w1      * w2 +
                          filter[x0_1, x1_1, x2_1] * w0       * w1      * w2)

        rotated_filter = rotated_filter.view(filter.size()[0], filter.size()[1], filters_size[0], filters_size[1], filters_size[2])

    return rotated_filter


In [None]:
class RotConv(nn.Module):

    def __init__(self, in_channels, out_channels, kernel_size, stride=1,
                 padding=0, dilation=1, n_angles = 8, mode=1):
        super(RotConv, self).__init__()

        kernel_size = ntuple(2)(kernel_size)
        stride = ntuple(2)(stride)
        padding = ntuple(2)(padding)
        dilation = ntuple(2)(dilation)

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        self.dilation = dilation

        self.mode = mode

        #Angles
        self.angles = np.linspace(0,360,n_angles, endpoint=False)
        self.angle_tensors = []

        #Get interpolation variables
        self.interp_vars = []
        for angle in self.angles:
            out = get_filter_rotation_transforms(list(self.kernel_size), angle)
            self.interp_vars.append(out[:-1])
            self.mask = out[-1]

            self.angle_tensors.append( Variable(torch.FloatTensor( np.array([angle/ 180. * np.pi]) )) )

        self.weight1 = Parameter(torch.Tensor( out_channels, in_channels , *kernel_size))
        #If input is vector field, we have two filters (one for each component)
        if self.mode == 2:
            self.weight2 = Parameter(torch.Tensor( out_channels, in_channels, *kernel_size))

        self.reset_parameters()

    def reset_parameters(self):
        n = self.in_channels
        for k in self.kernel_size:
            n *= k
        stdv = 1. / math.sqrt(n)
        self.weight1.data.uniform_(-stdv, stdv)
        if self.mode == 2:
            self.weight2.data.uniform_(-stdv, stdv)

    def mask_filters(self):
        self.weight1.data[self.mask.expand_as(self.weight1) == 0] = 1e-8
        if self.mode == 2:
            self.weight2.data[self.mask.expand_as(self.weight1) == 0] = 1e-8

    def _apply(self, func):
        # This is called whenever user calls model.cuda()
        # We intersect to replace tensors and variables with cuda-versions
        self.mask = func(self.mask)
        self.interp_vars = [[[func(el2) for el2 in el1] for el1 in el0] for el0 in self.interp_vars]
        self.angle_tensors = [func(el) for el in self.angle_tensors]

        return super(RotConv, self)._apply(func)


    def forward(self,input):
        #Uncomment this to turn on filter-masking
        #Todo: fix broken convergence when filter-masking is on
        #self.mask_filters()

        if self.mode == 1:
            outputs = []

            #Loop through the different filter-transformations
            for ind, interp_vars in enumerate(self.interp_vars):
                #Apply rotation
                weight = apply_transform(self.weight1, interp_vars, self.kernel_size)

                #Do convolution
                out = F.conv2d(input, weight, None, self.stride, self.padding, self.dilation)
                outputs.append(out.unsqueeze(-1))

        if self.mode == 2:
            u = input[0]
            v = input[1]

            outputs = []
            # Loop through the different filter-transformations
            for ind, interp_vars in enumerate(self.interp_vars):
                angle = self.angle_tensors[ind]
                # Apply rotation
                wu = apply_transform(self.weight1, interp_vars, self.kernel_size)
                wv = apply_transform(self.weight2, interp_vars, self.kernel_size)

                # Do convolution for u
                wru = torch.cos(angle) * wu - torch.sin(angle ) * wv
                u_out = F.conv2d(u, wru, None, self.stride, self.padding, self.dilation)

                # Do convolution for v
                wrv = torch.sin(angle) * wu + torch.cos(angle) * wv
                v_out = F.conv2d(v, wrv, None, self.stride, self.padding, self.dilation)

                #Compute magnitude (p)
                outputs.append(  (u_out + v_out).unsqueeze(-1) )
                

        # Get the maximum direction (Orientation Pooling)
        strength, max_ind = torch.max(torch.cat(outputs, -1), -1)

        # Convert from polar representation q
        angle_map = max_ind.float() * (360. / len(self.angles) / 180. * np.pi)
        u = F.relu(strength) * torch.cos(angle_map)
        v = F.relu(strength) * torch.sin(angle_map)


        return u, v

In [None]:
class RotUnetBlock(nn.Module):
    def __init__(self, up_in, x_in, n_out):
        super().__init__()
        up_out = x_out = n_out//2
        self.x_conv  = nn.RotConv(x_in,  x_out,  1)
        self.tr_conv = nn.ConvTranspose2d(up_in, up_out, 2, stride=2)
        self.bn = nn.BatchNorm2d(n_out)
        
    def forward(self, up_p, x_p):
        up_p = self.tr_conv(up_p)
        x_p = self.x_conv(x_p)
        cat_p = torch.cat([up_p,x_p], dim=1)
        return self.bn(F.relu(cat_p))
    
class RotUnet34(nn.Module):
    def __init__(self, rn):
        super().__init__()
        self.rn = rn
        self.sfs = [SaveFeatures(rn[i]) for i in [2,4,5,6]]
        self.up1 = RotUnetBlock(512,256,256)
        self.up2 = RotUnetBlock(256,128,256)
        self.up3 = RotUnetBlock(256,64,256)
        self.up4 = RotUnetBlock(256,64,256)
        self.up5 = nn.ConvTranspose2d(256, 1, 2, stride=2)
        
    def forward(self,x):
        x = F.relu(self.rn(x))
        x = self.up1(x, self.sfs[3].features)
        x = self.up2(x, self.sfs[2].features)
        x = self.up3(x, self.sfs[1].features)
        x = self.up4(x, self.sfs[0].features)
        x = self.up5(x)
        return x[:,0]
    
    def close(self):
        for sf in self.sfs: sf.remove()
            
class RotUnetModel():
    def __init__(self,model,name='Rot_Unet'):
        self.model,self.name = model,name

    def get_layer_groups(self, precompute):
        lgs = list(split_by_idxs(children(self.model.rn), [lr_cut]))
        return lgs + [children(self.model)[1:]]

In [None]:
sz = 768 #image size
bs = 8  #batch size

no_aug_md = get_data_no_aug(sz,bs)

In [None]:
m_base = get_base()
rot_m = to_gpu(RotUnet34(m_base))
rot_models = RotUnetModel(rot_m)


learn = ConvLearner(md, rot_models)
learn.opt_fn=optim.Adam
learn.crit = MixedLoss(10.0, 2.0)
learn.metrics=[accuracy_thresh(0.5),dice,IoU]
wd=1e-7
lr = 1e-2
learn.freeze_to(1)
learn.fit(lr,1,wds=wd,cycle_len=1,use_clr=(5,8))