In [None]:
# https://www.kaggle.com/c/tgs-salt-identification-challenge/discussion/67762#399663
class scSELayer(nn.Module):
    def __init__(self, n_in, r = 16):
        assert n_in % r == 0, f'in channel count needs to be divisible by r == {r}'
        super().__init__()
        self.cSE = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(n_in,n_in//r,1),
            nn.ReLU(inplace=True),
            nn.Conv2d(n_in//r,n_in,1),
            nn.Sigmoid()
        )
        self.sSE = nn.Sequential(
            nn.Conv2d(n_in,1,1),
            nn.Sigmoid()
        )

    def forward(self, x):
        return x * self.cSE(x) + x * self.sSE(x)
    
class sSELayer(nn.Module):
    def __init__(self, n_in):
        super().__init__()
        self.sSE = nn.Sequential(
            nn.Conv2d(n_in,1,1),
            nn.Sigmoid()
        )

    def forward(self, x):
        return x * self.sSE(x)
    
class HCBlock(nn.Module):
    '''Hypercolumn block - reduces num of channels and interpolates'''
    def __init__(self, n_in, out_sz=256):
        super().__init__()
        self.conv = nn.Conv2d(n_in, 16, 1)
        self.bn = nn.BatchNorm2d(16)
        self.out_sz = out_sz

    def forward(self, x):
        x = F.relu(self.conv(x))
        x = self.bn(x)
        return interpolate(x, (self.out_sz, self.out_sz), mode='bilinear', align_corners=False)


In [None]:
class UnetBlock(nn.Module):
    def __init__(self, up_in, x_in, n_out, kernel_size=2, output_padding=0, padding=0, stride=2):
        super().__init__()
        up_out = x_out = n_out//2
        self.x_conv  = nn.Conv2d(x_in,  x_out,  1)
        self.tr_conv = nn.ConvTranspose2d(up_in, up_out, kernel_size, stride=stride, output_padding=output_padding, padding=padding)
        self.bn = nn.BatchNorm2d(n_out)
        self.out_channels = n_out
        
    def forward(self, up_p, x_p):
        up_p = self.tr_conv(up_p)
        x_p = self.x_conv(x_p)
        cat_p = torch.cat([up_p,x_p], dim=1)
        return self.bn(F.relu(cat_p))

In [None]:
class Res34UnetBlock(UnetBlock):
    def __init__(self, up_in, x_in, n_out, k):
        super().__init__(up_in, x_in, n_out, kernel_size=k, padding=(k-2)//2)
    
class BasicRes34Unet(nn.Module):        
    def __init__(self, mult=1, k=2):
        super().__init__()
        base = nn.Sequential(*list(pretrainedmodels.resnet34().children())[:-2])
        self.down1 = base[:3]
        self.down2 = base[3:5]
        self.down3 = base[5:6]
        self.down4 = base[6:7]
        self.down5 = base[7:]
        
        self.up1 = Res34UnetBlock(512,256,192 * mult,k)
        self.up2 = Res34UnetBlock(192 * mult,128,96 * mult,k)
        self.up3 = Res34UnetBlock(96 * mult,64,32 * mult,k)
        self.up4 = Res34UnetBlock(32 * mult,64,32 * mult,k)
        self.up5 = Res34UnetBlock(32 * mult,3,16 * mult,k)
        
        self.up6 = nn.ConvTranspose2d(16 * mult, 1, 1)
        
    def forward(self,x):
        inp = x
        d1 = self.down1(x)
        d2 = self.down2(d1)
        d3 = self.down3(d2)
        d4 = self.down4(d3)
        
        mid = F.relu(self.down5(d4))
        
        x = self.up1(mid, d4)
        x = self.up2(x, d3)
        x = self.up3(x, d2)
        x = self.up4(x, d1)
        x = self.up5(x, inp)
        x = self.up6(x)
  
        assert (x.shape[1] == 1)
        return x[:, 0]

class Res34UnetSE(BasicRes34Unet):
    def __init__(self, mult=1, k=2):
        super().__init__(mult, k)
        
        self.se1 = scSELayer(self.up1.out_channels)
        self.se2 = scSELayer(self.up2.out_channels)
        self.se3 = scSELayer(self.up3.out_channels)
        self.se4 = scSELayer(self.up4.out_channels)
        self.se5 = scSELayer(self.up5.out_channels)
        
        self.se_feat1 = scSELayer(64)
        self.se_feat2 = scSELayer(64)
        self.se_feat3 = scSELayer(128)
        self.se_feat4 = scSELayer(256)
        
    def forward(self,x):
        inp = x
        d1 = self.se_feat1(self.down1(x))
        d2 = self.se_feat2(self.down2(d1))
        d3 = self.se_feat3(self.down3(d2))
        d4 = self.se_feat4(self.down4(d3))
        
        mid = F.relu(self.down5(d4))
        
        x = self.se1(self.up1(mid, d4))
        x = self.se2(self.up2(x, d3))
        x = self.se3(self.up3(x, d2))
        x = self.se4(self.up4(x, d1))
        x = self.se5(self.up5(x, inp))
        x = self.up6(x)
        
        assert (x.shape[1] == 1)
        return x[:, 0]

class Res34UnetSE_HC(Res34UnetSE):
    def __init__(self, mult=1, k=2, sz=128, comb_hc_channels=8):
        super().__init__(mult, k)
        
        self.hc1 = HCBlock(self.up1.out_channels, out_sz=sz)
        self.hc2 = HCBlock(self.up2.out_channels, out_sz=sz)
        self.hc3 = HCBlock(self.up3.out_channels, out_sz=sz)
        self.hc4 = HCBlock(self.up4.out_channels, out_sz=sz)
        
        self.hc_comb = nn.Conv2d(64, comb_hc_channels, 3, padding=1)
        self.hc_bn = nn.BatchNorm2d(comb_hc_channels)
        
        self.up6 = nn.ConvTranspose2d(16 * mult + comb_hc_channels, 1, 1)
        
    def forward(self,x):
        inp = x
        d1 = self.down1(x)
        d2 = self.down2(d1)
        d3 = self.down3(d2)
        d4 = self.down4(d3)
        
        mid = F.relu(self.down5(d4))
        
        d1 = self.se_feat1(d1)
        d2 = self.se_feat2(d2)
        d3 = self.se_feat3(d3)
        d4 = self.se_feat4(d4)
        
        x = self.se1(self.up1(mid, d4))
        hc1 = self.hc1(x)
        
        x = self.se2(self.up2(x, d3))
        hc2 = self.hc2(x)
        
        x = self.se3(self.up3(x, d2))
        hc3 = self.hc3(x)
        
        x = self.se4(self.up4(x, d1))
        hc4 = self.hc4(x)
        
        x = self.se5(self.up5(x, inp))
        
        hc = self.hc_comb(torch.cat((hc1, hc2, hc3, hc4), dim=1))
        hc = self.hc_bn(F.relu(hc))
        
        x = torch.cat((x, hc), dim=1)
        x = self.up6(x)
        
        assert (x.shape[1] == 1)
        return x[:, 0]

In [None]:
class BasicUnetSEresnext50(nn.Module):        
    def __init__(self, mult=1, k=2):
        super().__init__()
        base = nn.Sequential(*list(pretrainedmodels.se_resnext50_32x4d().children())[:-2])

        self.down1 = base[0][:3] 
        self.down2 = nn.Sequential(base[0][3:], base[1:2])
        self.down3 = base[2:3]
        self.down4 = base[3:4]
        self.down5 = base[4:]
        
        self.up1 = Res34UnetBlock(2048,1024,768,k)
        self.up2 = Res34UnetBlock(768,512,320,k)
        self.up3 = Res34UnetBlock(320,256,160,k)
        self.up4 = Res34UnetBlock(160,64,80,k)
        self.up5 = Res34UnetBlock(80,3,32,k)
        
        self.up6 = nn.ConvTranspose2d(32, 1, 1)
        
    def forward(self,x):
        inp = x
        d1 = self.down1(x)
        d2 = self.down2(d1)
        d3 = self.down3(d2)
        d4 = self.down4(d3)
        
        mid = F.relu(self.down5(d4))
        
        x = self.up1(mid, d4)
        x = self.up2(x, d3)
        x = self.up3(x, d2)
        x = self.up4(x, d1)
        x = self.up5(x, inp)
        x = self.up6(x)
        assert (x.shape[1] == 1)
        return x[:, 0]
    
class UnetSEresnext50SE(BasicUnetSEresnext50):
    def __init__(self, mult=1, k=2):
        super().__init__(mult, k)
        
        self.se_down1 = sSELayer(64)
        self.se_down2 = sSELayer(256)
        self.se_down3 = sSELayer(512)
        self.se_down4 = sSELayer(1024)

        self.se_up1 = scSELayer(self.up1.out_channels)
        self.se_up2 = scSELayer(self.up2.out_channels)
        self.se_up3 = scSELayer(self.up3.out_channels)
        self.se_up4 = scSELayer(self.up4.out_channels)
        self.se_up5 = scSELayer(self.up5.out_channels)
        
    def forward(self,x):
        inp = x
        d1 = self.se_down1(self.down1(x))
        d2 = self.se_down2(self.down2(d1))
        d3 = self.se_down3(self.down3(d2))
        d4 = self.se_down4(self.down4(d3))
        
        mid = F.relu(self.down5(d4))
        
        x = self.se_up1(self.up1(mid, d4))
        x = self.se_up2(self.up2(x, d3))
        x = self.se_up3(self.up3(x, d2))
        x = self.se_up4(self.up4(x, d1))
        x = self.se_up5(self.up5(x, inp))
        x = self.up6(x)
        assert (x.shape[1] == 1)
        return x[:, 0]
    
class UnetSEresnext50SE_HC(UnetSEresnext50SE):
    def __init__(self, mult=1, k=2, sz=128, comb_hc_channels=16):
        super().__init__(mult, k)
                           
        self.hc1 = HCBlock(self.up1.out_channels, out_sz=sz)
        self.hc2 = HCBlock(self.up2.out_channels, out_sz=sz)
        self.hc3 = HCBlock(self.up3.out_channels, out_sz=sz)
        self.hc4 = HCBlock(self.up4.out_channels, out_sz=sz)
        
        self.hc_comb = nn.Conv2d(64, comb_hc_channels, 1)
        self.hc_bn = nn.BatchNorm2d(comb_hc_channels)

        self.up6 = nn.ConvTranspose2d(32+comb_hc_channels, 1, 1)
        
    def forward(self,x):
        inp = x
        d1 = self.se_down1(self.down1(x))
        d2 = self.se_down2(self.down2(d1))
        d3 = self.se_down3(self.down3(d2))
        d4 = self.se_down4(self.down4(d3))
        
        mid = F.relu(self.down5(d4))
        
        x = self.se_up1(self.up1(mid, d4))
        hc1 = self.hc1(x)
        x = self.se_up2(self.up2(x, d3))
        hc2 = self.hc2(x)
        x = self.se_up3(self.up3(x, d2))
        hc3 = self.hc3(x)
        x = self.se_up4(self.up4(x, d1))
        hc4 = self.hc4(x)
        
        hc = self.hc_comb(torch.cat((hc1, hc2, hc3, hc4), dim=1))
        hc = self.hc_bn(F.relu(hc))
                           
        x = self.se_up5(self.up5(x, inp))
        x = torch.cat((x, hc), dim=1)
        x = self.up6(x)
        assert (x.shape[1] == 1)
        return x[:, 0]
    
# this model has a couple of likely issues:
#   - the classifier part  should probably be way more complex
#   - it would probably be better to return apply loss to each of the hcXs instead of to hs_comb
#   - the losses (not visible here) that I implemented should probably only apply segmentation loss
#     to x and hot hcXs where mask is all zeros      
class UnetSEresnext50SE_HC_deep_sup(UnetSEresnext50SE_HC):
    def __init__(self, mult=1, k=2, sz=128, comb_hc_channels=16):
        super().__init__(mult, k, sz=sz, comb_hc_channels=comb_hc_channels)
        
        self.classifier = nn.Linear(2048, 1)
        self.hc_class = nn.Conv2d(comb_hc_channels, 1, 1)
                           
    def forward(self,x):
        inp = x
        d1 = self.se_down1(self.down1(x))
        d2 = self.se_down2(self.down2(d1))
        d3 = self.se_down3(self.down3(d2))
        d4 = self.se_down4(self.down4(d3))
        
        mid = F.relu(self.down5(d4))
        
        x = self.se_up1(self.up1(mid, d4))
        hc1 = self.hc1(x)
        x = self.se_up2(self.up2(x, d3))
        hc2 = self.hc2(x)
        x = self.se_up3(self.up3(x, d2))
        hc3 = self.hc3(x)
        x = self.se_up4(self.up4(x, d1))
        hc4 = self.hc4(x)
        
        hc = self.hc_comb(torch.cat((hc1, hc2, hc3, hc4), dim=1))
        hc = self.hc_bn(F.relu(hc))
                           
        x = self.se_up5(self.up5(x, inp))
        x = torch.cat((x, hc), dim=1)
        x = self.up6(x)
        
        ### classification
        avg_pool = mid.view((mid.shape[0], mid.shape[1], -1)).mean(2)
        cl = self.classifier(avg_pool)
        
        ### hc segmentation
        hc_out = self.hc_class(hc)
        
        assert (x.shape[1] == 1)
        assert (hc_out.shape[1] == 1)
        return x[:, 0], hc_out[:, 0], cl

In [None]:
# db = get_data_bunch()
# learn = get_learner(db)
# ims = torch.zeros(4,3,128,128).cuda()
# learn.model(ims).shape

In [None]:
# Improved callbacks
class SaveBest(Callback):
    def __init__(self):
        self.iou = 0
    def on_epoch_end(self, epoch, num_batch, smooth_loss, last_metrics, **kwargs): 
        iou = last_metrics[-1]
        if iou > self.iou:
            self.iou = iou
            learn.save(f'{name}_best_iou_fold{fold}')
            
class ReduceLROnPlateau(Callback):
    def __init__(self, learn, patience=5, div_factor=10, grace=0, delta=1e-4):
        self.learn = learn
        self.patience = patience
        self.div_factor = div_factor
        
        self.iou = 0
        self.epochs_without_improv = 0
        self.delta = delta
        self.grace = grace # number of epochs to remain inactive after train start
                           # useful for retraining starting with higher lr
    def on_epoch_end(self, epoch, num_batch, smooth_loss, last_metrics, **kwargs): 
        if self.grace > 0:
            self.grace -= 1
            return
        
        iou = last_metrics[-1]
        if iou - self.iou > self.delta:
            self.epochs_without_improv = 0
            self.iou = iou
        else:
            self.epochs_without_improv += 1
        if self.epochs_without_improv == self.patience:
            lr = self.learn.opt.read_val('lr')
            self.learn.opt.lr = np.array(lr) / self.div_factor
            print(f'Reducing lr to: {self.learn.opt.lr}')
            self.epochs_without_improv = 0
            
    
class StopTrain(Callback):
    def __init__(self, learn, patience=5, delta=1e-4):
        self.learn = learn
        self.patience = patience
        
        self.iou = 0
        self.delta = delta
        self.epochs_without_improv = 0
    def on_epoch_end(self, epoch, num_batch, smooth_loss, last_metrics, **kwargs): 
        iou = last_metrics[-1]
        if iou - self.iou > self.delta:
            self.epochs_without_improv = 0
            self.iou = iou
        else:
            self.epochs_without_improv += 1
        if self.epochs_without_improv == self.patience:
            lr = self.learn.opt.read_val('lr')
            print(f'Finishing training with lr: {lr}')
            return True