In [None]:
# https://arxiv.org/pdf/1611.08323.pdf

In [None]:
%run ./utils.ipynb

In [None]:
sz = 256
max_crop_sz = int(sz * 0.9)

In [None]:
trn_tfms = albumentations.Compose([
    albumentations.HorizontalFlip(),
    albumentations.RandomSizedCrop((max_crop_sz, max_crop_sz), sz, sz, interpolation=1, p=0.5),
    albumentations.IAAAffine(rotate=10, p=0.5, mode='edge'),
    albumentations.Blur()
])

In [None]:
db = get_data_bunch(sz=sz, bs=32, part=4, trn_tfms=trn_tfms)

In [None]:
db.train_dl.dl.dataset.check_tfms(1)

In [None]:
db.valid_dl.dl.dataset.check_tfms(2)

In [None]:
class Conv2d(nn.Module):
    def __init__(self, n_in, n_out, kernel_size=3, pad=1, bn=True, nonlinearity=True, bias=False):
        super().__init__()
        self.conv = nn.Conv2d(n_in, n_out, kernel_size, padding=pad, bias=bias)
        if bn: self.bn = nn.BatchNorm2d(n_out)
        if nonlinearity: self.relu = nn.ReLU()
    def forward(self, x):
        for c in self.children(): x = c(x)
        return x

class ruBlock(nn.Module):
    def __init__(self, n_in, n_out):
        super().__init__()
        self.project = n_in != n_out
        if self.project:
            self.conv1x1 = Conv2d(n_in, n_out, 1, 0, nonlinearity=False, bn=False, bias=False)
        self.conv1 = Conv2d(n_out, n_out)
        self.conv2 = Conv2d(n_out, n_out, nonlinearity=False)
        
    def forward(self, x):
        if self.project: x = self.conv1x1(x)
        x_in = x
        x = self.conv1(x)
        x = self.conv2(x)
        return x + x_in
    
class frruBlock(nn.Module):
    def __init__(self, pooling, n_in, mult=None):
        super().__init__()
        if mult is None: mult = pooling
        self.mult = mult
        channels = int(base_channels * mult)
        self.pooling = pooling
        self.conv1 = Conv2d(n_in=n_in, n_out=channels)
        self.conv2 = Conv2d(n_in=channels, n_out=channels)
        self.res_conv1x1 = Conv2d(channels, lanes, 1, 0, False, False, True)
    
    def forward(self, pool_stream, res_stream):
        res_in = res_stream
        if self.pooling > 1:
            res_stream = F.max_pool2d(res_stream, self.pooling, self.pooling)
            
        pool_stream = torch.cat((pool_stream, res_stream), dim=1)
        pool_stream = self.conv1(pool_stream)
        pool_stream = self.conv2(pool_stream)
        
        residual = self.res_conv1x1(pool_stream)
        if self.pooling > 1:
            residual = F.interpolate(residual, scale_factor=self.pooling)
        res_stream = res_in + residual
        return pool_stream, res_stream
    

class FRRN(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.conv5x5 = Conv2d(3, base_channels, 5, pad=2)
        self.ru_3 = nn.Sequential(
            ruBlock(base_channels, base_channels),
            ruBlock(base_channels, base_channels),
            ruBlock(base_channels, base_channels)
        )
        self.conv1x1 = nn.Conv2d(base_channels, lanes, 1)
        
        # pooling stage / 2
        self.frru2_0 = frruBlock(multiplier**1, n_in=base_channels+lanes)
        self.frru2_1 = frruBlock(multiplier**1, n_in=base_channels*multiplier+lanes)
        self.frru2_2 = frruBlock(multiplier**1, n_in=base_channels*multiplier+lanes)
        
        # pooling stage / 4
        self.frru4_0 = frruBlock(multiplier**2, n_in=base_channels*multiplier+lanes)
        self.frru4_1 = frruBlock(multiplier**2, n_in=base_channels*multiplier**2+lanes)
        self.frru4_2 = frruBlock(multiplier**2, n_in=base_channels*multiplier**2+lanes)
        self.frru4_3 = frruBlock(multiplier**2, n_in=base_channels*multiplier**2+lanes)
        
        # pooling stage / 8
        self.frru8_0 = frruBlock(multiplier**3, n_in=base_channels*multiplier**2+lanes)
        self.frru8_1 = frruBlock(multiplier**3, n_in=base_channels*multiplier**3+lanes)
        
        # pooling stage / 16
        self.frru16_0 = frruBlock(multiplier**4, n_in=base_channels*multiplier**3+lanes, mult=multiplier**3)
        self.frru16_1 = frruBlock(multiplier**4, n_in=base_channels*multiplier**3+lanes, mult=multiplier**3)   
        
        # pooling stage / 8 up
        self.frru8_up_0 = frruBlock(multiplier**3, n_in=base_channels*multiplier**3+lanes, mult=multiplier**2)
        self.frru8_up_1 = frruBlock(multiplier**3, n_in=base_channels*multiplier**2+lanes, mult=multiplier**2)
        
        # Pooling stage / 4 up
        self.frru4_up_0 = frruBlock(multiplier**2, n_in=base_channels*multiplier**2+lanes)
        self.frru4_up_1 = frruBlock(multiplier**2, n_in=base_channels*multiplier**2+lanes)
                
        # Pooling stage / 2 up
        self.frru2_up_0 = frruBlock(multiplier**1, n_in=base_channels*multiplier**2+lanes)
        self.frru2_up_1 = frruBlock(multiplier**1, n_in=base_channels*multiplier**1+lanes)
        
        self.ru_3_up = nn.Sequential(
            ruBlock(base_channels * multiplier + lanes, base_channels),
            ruBlock(base_channels, base_channels),
            ruBlock(base_channels, base_channels)
        )
        
        self.out = Conv2d(base_channels, 1, 1, 0, False, False, True)

    def forward(self, x):
        pool_stream = self.conv5x5(x)
        
        pool_stream = self.ru_3(pool_stream)        
        res_stream = self.conv1x1(pool_stream)
        
        # pooling stage / 2
        pool_stream = F.max_pool2d(pool_stream, 2, 2)
        pool_stream, res_stream = self.frru2_0(pool_stream, res_stream)
        pool_stream, res_stream = self.frru2_1(pool_stream, res_stream)
        pool_stream, res_stream = self.frru2_2(pool_stream, res_stream)
        
        # pooling stage / 4
        pool_stream = F.max_pool2d(pool_stream, 2, 2)
        pool_stream, res_stream = self.frru4_0(pool_stream, res_stream)
        pool_stream, res_stream = self.frru4_1(pool_stream, res_stream)
        pool_stream, res_stream = self.frru4_2(pool_stream, res_stream)
        pool_stream, res_stream = self.frru4_3(pool_stream, res_stream)
        
        # pooling stage / 8
        pool_stream = F.max_pool2d(pool_stream, 2, 2)
        pool_stream, res_stream = self.frru8_0(pool_stream, res_stream)
        pool_stream, res_stream = self.frru8_1(pool_stream, res_stream)
        
        # pooling stage / 16
        pool_stream = F.max_pool2d(pool_stream, 2, 2)
        pool_stream, res_stream = self.frru16_0(pool_stream, res_stream)
        pool_stream, res_stream = self.frru16_1(pool_stream, res_stream)
        
        # pooling stage / 8 up
        pool_stream = F.interpolate(pool_stream, scale_factor=2, mode='bilinear', align_corners=False)
        pool_stream, res_stream = self.frru8_up_0(pool_stream, res_stream)
        pool_stream, res_stream = self.frru8_up_1(pool_stream, res_stream)
        
        # pooling stage / 4 up
        pool_stream = F.interpolate(pool_stream, scale_factor=2, mode='bilinear', align_corners=False)
        pool_stream, res_stream = self.frru4_up_0(pool_stream, res_stream)
        pool_stream, res_stream = self.frru4_up_1(pool_stream, res_stream)
                
        # pooling stage / 2 up
        pool_stream = F.interpolate(pool_stream, scale_factor=2, mode='bilinear', align_corners=False)
        pool_stream, res_stream = self.frru2_up_0(pool_stream, res_stream)
        pool_stream, res_stream = self.frru2_up_1(pool_stream, res_stream)
        
        # pooling stage / 1 <-> Full resolution
        pool_stream = F.interpolate(pool_stream, scale_factor=2, mode='bilinear', align_corners=False)
        
        x = torch.cat((pool_stream, res_stream), dim=1)
        x = self.ru_3_up(x)
            
        return self.out(x)[:,0]

In [None]:
base_channels = 32
lanes = 32
multiplier = 2
name = 'frrn_256_adam'
bs = 12
folds_to_train = range(n_splits)

In [None]:
db = get_data_bunch(sz=sz, bs=bs, part=4, trn_tfms=trn_tfms)

In [None]:
def bce_loss(preds, targs): return F.binary_cross_entropy_with_logits(preds, targs)

In [None]:
k = 512 * multiplier 
def bootstrapped_xentropy_with_logits(preds, targs):
    preds = preds.view(-1)
    targs = targs.view(-1)
    diff = (targs - preds.sigmoid()).abs()
    order = diff.sort(descending=True)[1]
    return F.binary_cross_entropy_with_logits(preds[order[:k]], targs[order[:k]])

In [None]:
def get_learner(db):
    m = FRRN()
    m.cuda(default_device)
#     learn = Learner(db, m, true_wd=True, loss_fn=boot, layer_groups=split_model_idx(m, [61, 91]), opt_fn=lambda x: optim.SGD(x))
#     learn = Learner(db, m, true_wd=True, loss_fn=loss, layer_groups=split_model_idx(m, [61, 91]), opt_fn=AdamW)
    learn = Learner(db, m, loss_fn=bce_loss, opt_fn=lambda x: optim.Adam(x))
    learn.metrics = [accuracy_thresh, dice, iou_pytorch]
    learn.callbacks = [SaveBest()]
    return learn

In [None]:
%%time

upside_down = False

for fold in folds_to_train:
    db = get_data_bunch(sz=202, bs=32, part=fold, trn_tfms=trn_tfms)
    learn = get_learner(db)
    
    learn.load(f'{name}_fold{fold}')
    val_preds, val_targs = predict_with_targs_and_TTA(learn.model, db.valid_dl, upside_down)
    test_preds = predict_with_TTA(learn.model, db.test_dl, upside_down)
    print(f'Fold {fold} val acc: {accuracy_np(val_preds, val_targs)}, iou: {iou_metric(val_targs, val_preds > 0.5)}')
    
    val_preds, test_preds = normalize_t(val_preds, val_targs, test_preds)
    
    np.save(f'/home/radek/db/salt/val_preds_{name}_fold{fold}', val_preds)
    np.save(f'/home/radek/db/salt/val_targs_{name}_fold{fold}', val_targs)
    np.save(f'/home/radek/db/salt/test_preds_{name}_fold{fold}', test_preds)
    del val_preds, val_targs, test_preds
    
    
    learn.load(f'{name}_best_iou_fold{fold}')
    val_preds, val_targs = predict_with_targs_and_TTA(learn.model, db.valid_dl, upside_down)
    test_preds = predict_with_TTA(learn.model, db.test_dl, upside_down)
    
    val_preds, test_preds = normalize_t(val_preds, val_targs, test_preds)
    print(f'Fold {fold} best iou val acc: {accuracy_np(val_preds, val_targs)}, iou: {iou_metric(val_targs, val_preds > 0.5)}')
    
    np.save(f'/home/radek/db/salt/val_preds_{name}_best_iou_fold{fold}', val_preds)
    np.save(f'/home/radek/db/salt/test_{name}_best_iou_fold{fold}', test_preds)
    del val_preds, val_targs, test_preds
    
    learn.model.close()
    del learn

In [None]:
%%time
ys = []
preds = []
test_preds = np.zeros((18000, 202, 202))
for fold in folds_to_train:
    y = np.load(f'/home/radek/db/salt/val_targs_{name}_fold{fold}.npy')
    val_preds = np.load(f'/home/radek/db/salt/val_preds_{name}_fold{fold}.npy')
    preds.append(val_preds)
    ys.append(y)
    test_pred = np.load(f'/home/radek/db/salt/test_preds_{name}_fold{fold}.npy')
    test_preds += test_pred / len(folds_to_train)

np.save(f'/home/radek/db/salt/val_preds_{name}.npy', np.concatenate(preds))
np.save(f'/home/radek/db/salt/val_targs_{name}.npy', np.concatenate(ys))
np.save(f'/home/radek/db/salt/test_preds_{name}.npy', test_preds)

In [None]:
%%time
for fold in folds_to_train:
    val_preds = np.load(f'/home/radek/db/salt/val_preds_{name}_fold{fold}.npy')
    val_targs = np.load(f'/home/radek/db/salt/val_targs_{name}_fold{fold}.npy')
    print(f'Part {fold}: {accuracy_np(val_preds, val_targs)}, {iou_metric(val_targs, val_preds > 0.5)}, {best_preds_t(val_preds, val_targs)}')

In [None]:
val_preds = np.load(f'/home/radek/db/salt/val_preds_{name}.npy')
val_targs = np.load(f'/home/radek/db/salt/val_targs_{name}.npy')

In [None]:
iou_metric(val_targs, val_preds > 0.5)

In [None]:
%run diagnostics.ipynb

In [None]:
test_preds = np.load(f'/home/radek/db/salt/test_preds_{name}.npy')

In [None]:
preds_to_sub(test_preds, db.test_dl.dl.dataset.x, 0.5, 120, name)

In [None]:
!kaggle competitions submit -c tgs-salt-identification-challenge -f ../subs/{name}.csv.gz  -m {name}