In [None]:
from fastai.vision import *
import cv2

In [None]:
import mish_torch
from mish_torch import MishFunction,Mish
import bundled_gen_efficientnet
import gen_efficientnet

In [None]:
BS = 8
DROP_RATE = 0.0
STATS = (tensor([0.3436]), tensor([0.1961]))

INPUTS = Path('/kaggle/input')
STATE  = INPUTS/'severstal-eur-mish'
DATA   = INPUTS/'severstal-steel-defect-detection'

CLASSES=['1','2','3','4']

# PEr-class thresholds from analysis of valid
THRESHOLDS = [0.29,0.552,0.447,0.330]
MIN_PXS = 2000 # Minimum contiguous pixels to not ignore
TTA_WEIGHTS = tensor([0.6, 0.2, 0.07, 0.07, 0.05]) # Weight full and non-TTA blocks higher

# Create Model and Data

In [None]:
# efficientnet expects inplace argument for activation function, ignore
def mish_fn(x,inplace=False): return MishFunction.apply(x)

class EfficientUnetEncoder(nn.Module):
    def __init__(self, eff_net, first_skip=False):
        super().__init__()
        self.first_skip = first_skip
        self.mod_names = ['conv_stem', 'bn1', 'act_fn', 'blocks']
        for n in self.mod_names: setattr(self, n, getattr(eff_net, n))
        self.blk_strides  = [
            max((l.stride[0] if hasattr(l, 'stride')
                 else 0
                 for l in flatten_model(blk)))
            for blk in self.blocks]
        self.blk_channels = [
            [ l.out_channels
              for l in flatten_model(blk)
              if isinstance(l, nn.Conv2d)
            ][-1]
            for blk in self.blocks]

    def forward(self, x, hook=False):
        x = self.conv_stem(x)
        x = self.bn1(x)
        x = self.act_fn(x)
        if hook:
            hooks = []
            for i,(blk,stride) in enumerate(zip(self.blocks,self.blk_strides)):
                if stride > 1 and (i > 0 or self.first_skip):
                    hooks.append(x)
                x = blk(x)
        else: x = self.blocks(x)
        return x if not hook else (x,hooks)
    
    def load_state(self, state_dict):
        sd = {n: v for n,v in state_dict.items() if n.split('.')[0] in self.mod_names}
        return self.load_state_dict(sd)

class EfficientUnetClassifier(nn.Module):
    def __init__(self, enc, eff_net, drop_rate=0.):
        super().__init__()
        self.enc,self.drop_rate = enc,drop_rate
        self.efficient_head,self.act_fn = eff_net.efficient_head, enc.act_fn
        self.mod_names = ['conv_head','classifier']
        if not self.efficient_head: self.mod_names.insert(1, 'bn2')
        for n in self.mod_names: setattr(self, n, getattr(eff_net, n))
        
    def forward(self, x):
        x = self.enc(x)
        if self.efficient_head:
            x = F.adaptive_avg_pool2d(x, 1)
            x = self.conv_head(x)
            # no BN
            x = self.act_fn(x, inplace=True)
        else:
            x = self.conv_head(x)
            x = self.bn2(x)
            x = self.act_fn(x, inplace=True)
            x = F.adaptive_avg_pool2d(x, 1)
        x = x.flatten(1)
        if self.drop_rate > 0.:
            x = F.dropout(x, p=self.drop_rate, training=self.training)
        return self.classifier(x)
        return x
    
    def load_state(self, state_dict):
        res_enc = self.enc.load_state(state_dict)
        sd = {n: v for n,v in state_dict.items() if n.split('.')[0] in self.mod_names}
        res = self.load_state_dict(sd)
        res.missing_keys.extend(res_enc.missing_keys)
        res.unexpected_keys.extend(res_enc.unexpected_keys)
        return res

class ResBlock(nn.Module):
    def __init__(self, ni:int, nf:int, ks:int=3, stride:int=1, padding:int=None,
                 identity:bool=True, pre_activ:bool=True, bias:bool=True, act_fn:Callable[[None],nn.Module]=None,
                 leaky:float=None,init:Callable=nn.init.kaiming_normal_, norm_type:NormType=NormType.Batch):
        super().__init__()
        
        act = ifnone(act_fn, partial(relu, inplace=True, leaky=leaky))
        if type(act) == type and issubclass(act, nn.Module): act = act()
        if padding is None: padding = (ks-1)//2
        if pre_activ: self.act1,self.bn1 = act,batchnorm_2d(ni, norm_type=norm_type)
        else:         self.act1,self.bn1 = noop,noop
        self.conv1 = conv2d(ni, nf, ks, stride=stride, padding=padding, bias=(bias and not pre_activ), init=init)
        #norm2 = NormType.BatchZero if norm_type == NormType.Batch else norm_type
        self.bn2   = batchnorm_2d(nf, norm_type=norm_type)
        self.act2  = act
        self.conv2 = conv2d(nf, nf, ks, padding=padding, bias=False, init=init)
        if identity:
            l = [batchnorm_2d(ni, norm_type=norm_type)]
            if ni != nf: l.append(conv2d(ni, nf, ks=1, bias=False, init=init))
            if stride != 1: l.append(nn.AvgPool2d(2, ceil_mode=True))
            self.identity = nn.Sequential(*l)
        else: self.identity = None
        
    def forward(self, x):
        res = self.act1(self.bn1(x))
        res = self.conv1(res)
        res = self.act2(self.bn2(res))
        res = self.conv2(res)
        if self.identity is not None: res += self.identity(x)
        return res
    
from gen_efficientnet.gen_efficientnet import GenEfficientNet, decode_arch_def, round_channels, _resolve_bn_args
def gen_efficientnet(c_out=1000, c_in=1, act_fn=mish_fn, **kwargs):
    # Modified arch to not have an immeadiate stride 2 conv and to reduce channels
    channel_multiplier=1.0
    depth_multiplier=1.0

    arch_def = [
        ['ds_r1_k3_s2_e2_c16_se0.50'],
        ['ir_r2_k3_s1_e4_c24_se0.50'],
        ['ir_r2_k5_s2_e4_c40_se0.50'],
        ['ir_r3_k3_s1_e3_c80_se0.50'],
        ['ir_r3_k5_s2_e2_c112_se0.50'],
        ['ir_r3_k5_s1_e3_c160_se0.50'],
        ['ir_r1_k3_s2_e2_c192_se0.50'],
    ]
    model = GenEfficientNet(
        decode_arch_def(arch_def, depth_multiplier),
        num_classes=c_out,
        in_chans=c_in,
        stem_size=32,
        channel_multiplier=channel_multiplier,
        channel_divisor=8,
        channel_min=None,
        num_features=round_channels(1280, channel_multiplier, 8, None),
        bn_args=_resolve_bn_args(kwargs),
        act_fn=act_fn,
        **kwargs
    )
    model.conv_stem.stride = (1,1)
    return model

    
class EfficientUResnet(nn.Module):
    def __init__(self, c_out, c_in=1, drop_rate=0.0, resize_mode='nearest',
                 act_fn=mish_fn, res_act_fn=None, leaky=False, final_act=None,  identity=True,
                 with_cfn=False, cfn_c_out=None, cfn_drop_rate=None, **kwargs):
        super().__init__()
        if type(act_fn) is type and issubclass(act_fn, nn.Module):
            # EfficientNet expects inplace parameter on activation
            act = act_fn()
            def _act_fn(self, x, inplace=False): return act(x)
            act_fn = _act_fn
        resize_mode = {'mode': resize_mode}
        if resize_mode == 'bilinear': resize_mode['align_corners'] = False
        eff_net = gen_efficientnet(ifnone(cfn_c_out, c_out), c_in, act_fn=act_fn, **kwargs)
        #gen_efficientnet.efficientnet_b0(in_chans=c_in, act_fn=act_fn, drop_connect_rate=drop_rate)
        self.enc = EfficientUnetEncoder(eff_net)
        if with_cfn:
            self.classifier = EfficientUnetClassifier(self.enc, eff_net, drop_rate=ifnone(cfn_drop_rate, drop_rate))
        # Bridge and decoder
        self.resize = nn.Upsample(scale_factor=2, **resize_mode)
        res_args = dict(identity=identity, act_fn=res_act_fn, leaky=leaky)
        skip_nfs = [c for s,c in zip(self.enc.blk_strides[-1:1:-1], self.enc.blk_channels[-2::-1])
                    if s == 2]
        self.bridge = ResBlock(self.enc.blk_channels[-1], skip_nfs[0], ks=3, stride=1, **res_args)
        dec_nfs = skip_nfs[1:] + [self.enc.blk_channels[0]]
        #print(f"skip_nfs: {skip_nfs}; dec_nfs: {dec_nfs}")
        self.dec = nn.ModuleList([
            # Input also includes concatenated skip connection of size nf
            ResBlock(ni*2, nf, **res_args)
            for ni,nf in zip(skip_nfs,dec_nfs) ])
        self.final_conv = conv2d(dec_nfs[-1], c_out, ks=3, padding=1)
        self.final_act = final_act

        
    def forward(self, inp):
        x,skips = self.enc(inp, hook=True)
        x = self.bridge(x)
        for dec_blk,skip in zip(self.dec,reversed(skips)):
            x = self.resize(x)
            x = torch.cat((x, skip), dim=1) # Concatenate channels
            x = dec_blk(x)
        x = self.resize(x)
        x = self.final_conv(x)
        if self.final_act: x = self.final_act(x)
        return x
    
    def load_classifier_state(self, state_dict):
        if 'model' in state_dict: state_dict = state_dict['model']
        return self.enc.load_state(state_dict)

In [None]:
#https://www.kaggle.com/paulorzp/rle-functions-run-lenght-encode-decode
def mask2rle(img):
    '''
    img: numpy array, 1 - mask, 0 - background
    Returns run length as string formated
    '''
    pixels= img.T.flatten()
    pixels = np.concatenate([[0], pixels, [0]])
    runs = np.where(pixels[1:] != pixels[:-1])[0] + 1
    runs[1::2] -= runs[::2]
    return ' '.join(str(x) for x in runs)

In [None]:
# Load model
mdl = EfficientUResnet(c_out=4, act_fn=mish_fn, res_act_fn=Mish, drop_rate=DROP_RATE, with_cfn=False)
state = torch.load(STATE/'seg.pth', map_location='cpu')
if 'model' in state: state = state['model'] # Handle both fastai saves and torch saves
res = mdl.load_state_dict(state)
mdl = mdl.cuda().eval()
res

In [None]:
# Get test data
df_samp_subm = pd.read_csv(DATA/'sample_submission.csv')
df_samp_subm.head()

In [None]:
images = ImageId=df_samp_subm.ImageId_ClassId.str.slice(0, -2).unique()
df = pd.DataFrame.from_dict({'ImageId':images})
df.head()

In [None]:
data = (SegmentationItemList([], convert_mode='L', ignore_empty=True).split_none().label_empty()
         .databunch(bs=BS, num_workers=2)
         .normalize(STATS))
data.add_test(SegmentationItemList.from_df(df, path=DATA, folder='test_images', cols='ImageId', convert_mode='L'))
data

## Collect Predictions

In [None]:
thresh = torch.tensor(THRESHOLDS, dtype=torch.float32, device='cuda:0')[None,:,None,None]
tta_wgts = TTA_WEIGHTS.view(5, *([1]*4)).cuda()
batch_images = np.pad(images, (0,np.abs(len(images) % -BS)), 'constant', constant_values='').reshape(-1, BS)

def get_predictions(xb:Tensor):
    """Get predictions with TTA, and block processing"""
    pred_raws = []
    # Original prediction
    with torch.no_grad(): pred_raws.append(mdl(xb).sigmoid())
    # Block predictions with TTA
    blks = (xb .view(*xb.shape[:-1], 4, -1)
               .permute(0, 3, 1, 2, 4))
    blks = blks.reshape(-1, *blks.shape[2:])
    for flip in range(4):
        # Apply flip to input
        if flip == 1: blks = blks.flip(2) # hflip
        if flip == 2: blks = blks.flip(3) # hflip + vflip
        if flip == 3: blks = blks.flip(2) # vflip (reversed hflip)

        # Predict - (BSxBlk)xCxHxW
        with torch.no_grad(): blk_raw = mdl(blks)
        # Revert flip on prediction
        if flip == 1: blk_raw = blk_raw.flip(2) # hflip
        if flip == 2: blk_raw = blk_raw.flip((2,3)) # hflip + vflip
        if flip == 3: blk_raw = blk_raw.flip(3) # vflip
        # Reorder blocks - to BxCxHxBlkxW
        blk_raw = (blk_raw.view(-1, 4, *blk_raw.shape[1:])
                          .permute(0, 2, 3, 1, 4))
        # to Now BxCxHxW
        blk_raw =  blk_raw.reshape(*blk_raw.shape[:-2], -1)
        pred_raws.append(blk_raw.sigmoid())
    # Multiply by weights and sum, then theshold
    pred_raw = (torch.stack(pred_raws) * tta_wgts).sum(0)
    return (pred_raw > thresh).to(dtype=torch.uint8, device='cpu')

In [None]:
#it = iter(zip(data.test_dl,batch_images))
# (xb,_),imgs = next(it)
# xb.shape, xb.device, xb.dtype, imgs

In [None]:
predictions = []
for (xb,_),imgs in zip(progress_bar(data.test_dl),batch_images):
    preds = get_predictions(xb)
    for img,img_pred in zip(imgs,preds):
        for c,pred in zip(CLASSES,img_pred):
            n_comp, comps = cv2.connectedComponents(pred.to(dtype=torch.uint8, device='cpu').numpy())
            res = np.zeros((256, 1600), np.uint8)
            for comp in range(1, n_comp):
                p = (comps == comp)
                if p.sum() > MIN_PXS:
                    res[p] = 1
            rle = mask2rle(res)
            predictions.append([f"{img}_{c}", rle])

In [None]:
df_subm = pd.DataFrame(predictions, columns=['ImageId_ClassId', 'EncodedPixels'])
df_subm.to_csv("submission.csv", index=False)

In [None]:
df_subm.head()

In [None]:
(df_subm.EncodedPixels!="").sum()