![MLComp](https://raw.githubusercontent.com/catalyst-team/catalyst-pics/master/pics/MLcomp.png)
![Catalyst](https://raw.githubusercontent.com/catalyst-team/catalyst-pics/master/pics/catalyst_logo.png)

This kernel demonstrates:

1. Results of training models with [the training kernel](https://www.kaggle.com/lightforever/severstal-mlcomp-catalyst-train-0-90672-offline) and achieves 0.90672 score on public LB

2. Useful code in MLComp library: TtaWrapp, ImageDataset, ChannelTranspose, rle utilities

3. Output statistics and basic visualization

Approach descripton:

1. Segmentation via 3 Unet networks. The predictions are being averaged. 

2. Thresholding and removeing small areas. This method gives 0.90672 on public LB.

**Improving**:

1. As many participations have seen, that is the key to remove false positives from your predictions.

2. To cope with that, a classification network may be used. 

3. Heng CherKeng posted a classifier here: https://www.kaggle.com/c/severstal-steel-defect-detection/discussion/106462#latest-634450 resent34_cls_01, **if you remove false positives with it you should get 0.9117 on LB**

About the libraries:

1. [MLComp](https://github.com/catalyst-team/mlcomp) is a distributed DAG  (Directed acyclic graph)  framework for machine learning with UI. It helps to train, manipulate, and visualize. All models in this kernel were trained offline via MLComp + Catalyst libraries. 

You can control an execution process via Web-site

Dags
![Dags](https://github.com/catalyst-team/mlcomp/blob/master/docs/imgs/dags.png?raw=true)

Computers
![Computers](https://github.com/catalyst-team/mlcomp/blob/master/docs/imgs/computers.png?raw=true)

Reports
![Reports](https://github.com/catalyst-team/mlcomp/blob/master/docs/imgs/reports.png?raw=true)

Code
![Code](https://github.com/catalyst-team/mlcomp/blob/master/docs/imgs/code.png?raw=true)

Please follow [the web site](https://github.com/catalyst-team/mlcomp) to get the details.

https://github.com/catalyst-team/mlcomp

2. Catalys: High-level utils for PyTorch DL & RL research. It was developed with a focus on reproducibility, fast experimentation and code/ideas reusing. Being able to research/develop something new, rather then write another regular train loop. Break the cycle - use the Catalyst!

https://github.com/catalyst-team/catalyst

Docs and examples
- Detailed [classification tutorial](https://github.com/catalyst-team/catalyst/blob/master/examples/notebooks/classification-tutorial.ipynb) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/catalyst-team/catalyst/blob/master/examples/notebooks/classification-tutorial.ipynb)
- Comprehensive [classification pipeline](https://github.com/catalyst-team/classification).

API documentation and an overview of the library can be found here
[![Docs](https://img.shields.io/badge/dynamic/json.svg?label=docs&url=https%3A%2F%2Fpypi.org%2Fpypi%2Fcatalyst%2Fjson&query=%24.info.version&colorB=brightgreen&prefix=v)](https://catalyst-team.github.io/catalyst/index.html)

### Install MLComp library(offline version):

As the competition does not allow commit with the kernel that uses internet connection, we use offline installation

### Get heng's classification

In [None]:
import time
time_start = time.time()

In [None]:
!pip install ../input/pretrainedmodels/pretrainedmodels-0.7.4/pretrainedmodels-0.7.4/ > /dev/null

In [None]:
! python ../input/mlcomp/mlcomp/mlcomp/setup.py

In [None]:
# ! pip install ../input/pytorch-toolbelt/pytorch-toolbelt-develop

In [None]:
package_path = '../input/senetunetmodelcode' # add unet script dataset
import sys
sys.path.append(package_path)

# Get necessary Imports
import pdb
import os
import cv2
import torch
import pandas as pd
import numpy as np
from tqdm import tqdm
import torch.backends.cudnn as cudnn
from torch.utils.data import DataLoader, Dataset
from albumentations import (Normalize, Compose)
from albumentations.pytorch import ToTensor
import torch.utils.data as data
import torchvision.models as models
import torch.nn as nn
from torch.nn import functional as F
from senet_unet_model_code import Unet

##########################################

import matplotlib.pyplot as plt

import albumentations as A
from tqdm import tqdm_notebook
from torch.jit import load

from mlcomp.contrib.transform.albumentations import ChannelTranspose
from mlcomp.contrib.dataset.classify import ImageDataset
from mlcomp.contrib.transform.rle import rle2mask, mask2rle
from mlcomp.contrib.transform.tta import TtaWrap

###########################################

# from pathlib import Path

# from pytorch_toolbelt.inference import tta
# from pytorch_toolbelt.utils.rle import rle_encode
# from pytorch_toolbelt.utils.rle import rle_decode
# from pytorch_toolbelt.utils.rle import rle_to_string
# from pytorch_toolbelt.inference.functional import pad_image_tensor, unpad_image_tensor
# from pytorch_toolbelt.modules import decoders as D
# from pytorch_toolbelt.modules import encoders as E
# from pytorch_toolbelt.modules.fpn import *

# from models import PretrainedUNet

###########################################
import fastai
from fastai.vision import *
from PIL import Image
import zipfile
import io

import gc

In [None]:
def load_hengs_clf_model():
    # Codes from Heng's baseline
    # This code is for classifcation model

    BatchNorm2d = nn.BatchNorm2d

    ###############################################################################
    CONVERSION=[
     'block0.0.weight',	(64, 3, 7, 7),	 'conv1.weight',	(64, 3, 7, 7),
     'block0.1.weight',	(64,),	 'bn1.weight',	(64,),
     'block0.1.bias',	(64,),	 'bn1.bias',	(64,),
     'block0.1.running_mean',	(64,),	 'bn1.running_mean',	(64,),
     'block0.1.running_var',	(64,),	 'bn1.running_var',	(64,),
     'block1.1.conv_bn1.conv.weight',	(64, 64, 3, 3),	 'layer1.0.conv1.weight',	(64, 64, 3, 3),
     'block1.1.conv_bn1.bn.weight',	(64,),	 'layer1.0.bn1.weight',	(64,),
     'block1.1.conv_bn1.bn.bias',	(64,),	 'layer1.0.bn1.bias',	(64,),
     'block1.1.conv_bn1.bn.running_mean',	(64,),	 'layer1.0.bn1.running_mean',	(64,),
     'block1.1.conv_bn1.bn.running_var',	(64,),	 'layer1.0.bn1.running_var',	(64,),
     'block1.1.conv_bn2.conv.weight',	(64, 64, 3, 3),	 'layer1.0.conv2.weight',	(64, 64, 3, 3),
     'block1.1.conv_bn2.bn.weight',	(64,),	 'layer1.0.bn2.weight',	(64,),
     'block1.1.conv_bn2.bn.bias',	(64,),	 'layer1.0.bn2.bias',	(64,),
     'block1.1.conv_bn2.bn.running_mean',	(64,),	 'layer1.0.bn2.running_mean',	(64,),
     'block1.1.conv_bn2.bn.running_var',	(64,),	 'layer1.0.bn2.running_var',	(64,),
     'block1.2.conv_bn1.conv.weight',	(64, 64, 3, 3),	 'layer1.1.conv1.weight',	(64, 64, 3, 3),
     'block1.2.conv_bn1.bn.weight',	(64,),	 'layer1.1.bn1.weight',	(64,),
     'block1.2.conv_bn1.bn.bias',	(64,),	 'layer1.1.bn1.bias',	(64,),
     'block1.2.conv_bn1.bn.running_mean',	(64,),	 'layer1.1.bn1.running_mean',	(64,),
     'block1.2.conv_bn1.bn.running_var',	(64,),	 'layer1.1.bn1.running_var',	(64,),
     'block1.2.conv_bn2.conv.weight',	(64, 64, 3, 3),	 'layer1.1.conv2.weight',	(64, 64, 3, 3),
     'block1.2.conv_bn2.bn.weight',	(64,),	 'layer1.1.bn2.weight',	(64,),
     'block1.2.conv_bn2.bn.bias',	(64,),	 'layer1.1.bn2.bias',	(64,),
     'block1.2.conv_bn2.bn.running_mean',	(64,),	 'layer1.1.bn2.running_mean',	(64,),
     'block1.2.conv_bn2.bn.running_var',	(64,),	 'layer1.1.bn2.running_var',	(64,),
     'block1.3.conv_bn1.conv.weight',	(64, 64, 3, 3),	 'layer1.2.conv1.weight',	(64, 64, 3, 3),
     'block1.3.conv_bn1.bn.weight',	(64,),	 'layer1.2.bn1.weight',	(64,),
     'block1.3.conv_bn1.bn.bias',	(64,),	 'layer1.2.bn1.bias',	(64,),
     'block1.3.conv_bn1.bn.running_mean',	(64,),	 'layer1.2.bn1.running_mean',	(64,),
     'block1.3.conv_bn1.bn.running_var',	(64,),	 'layer1.2.bn1.running_var',	(64,),
     'block1.3.conv_bn2.conv.weight',	(64, 64, 3, 3),	 'layer1.2.conv2.weight',	(64, 64, 3, 3),
     'block1.3.conv_bn2.bn.weight',	(64,),	 'layer1.2.bn2.weight',	(64,),
     'block1.3.conv_bn2.bn.bias',	(64,),	 'layer1.2.bn2.bias',	(64,),
     'block1.3.conv_bn2.bn.running_mean',	(64,),	 'layer1.2.bn2.running_mean',	(64,),
     'block1.3.conv_bn2.bn.running_var',	(64,),	 'layer1.2.bn2.running_var',	(64,),
     'block2.0.conv_bn1.conv.weight',	(128, 64, 3, 3),	 'layer2.0.conv1.weight',	(128, 64, 3, 3),
     'block2.0.conv_bn1.bn.weight',	(128,),	 'layer2.0.bn1.weight',	(128,),
     'block2.0.conv_bn1.bn.bias',	(128,),	 'layer2.0.bn1.bias',	(128,),
     'block2.0.conv_bn1.bn.running_mean',	(128,),	 'layer2.0.bn1.running_mean',	(128,),
     'block2.0.conv_bn1.bn.running_var',	(128,),	 'layer2.0.bn1.running_var',	(128,),
     'block2.0.conv_bn2.conv.weight',	(128, 128, 3, 3),	 'layer2.0.conv2.weight',	(128, 128, 3, 3),
     'block2.0.conv_bn2.bn.weight',	(128,),	 'layer2.0.bn2.weight',	(128,),
     'block2.0.conv_bn2.bn.bias',	(128,),	 'layer2.0.bn2.bias',	(128,),
     'block2.0.conv_bn2.bn.running_mean',	(128,),	 'layer2.0.bn2.running_mean',	(128,),
     'block2.0.conv_bn2.bn.running_var',	(128,),	 'layer2.0.bn2.running_var',	(128,),
     'block2.0.shortcut.conv.weight',	(128, 64, 1, 1),	 'layer2.0.downsample.0.weight',	(128, 64, 1, 1),
     'block2.0.shortcut.bn.weight',	(128,),	 'layer2.0.downsample.1.weight',	(128,),
     'block2.0.shortcut.bn.bias',	(128,),	 'layer2.0.downsample.1.bias',	(128,),
     'block2.0.shortcut.bn.running_mean',	(128,),	 'layer2.0.downsample.1.running_mean',	(128,),
     'block2.0.shortcut.bn.running_var',	(128,),	 'layer2.0.downsample.1.running_var',	(128,),
     'block2.1.conv_bn1.conv.weight',	(128, 128, 3, 3),	 'layer2.1.conv1.weight',	(128, 128, 3, 3),
     'block2.1.conv_bn1.bn.weight',	(128,),	 'layer2.1.bn1.weight',	(128,),
     'block2.1.conv_bn1.bn.bias',	(128,),	 'layer2.1.bn1.bias',	(128,),
     'block2.1.conv_bn1.bn.running_mean',	(128,),	 'layer2.1.bn1.running_mean',	(128,),
     'block2.1.conv_bn1.bn.running_var',	(128,),	 'layer2.1.bn1.running_var',	(128,),
     'block2.1.conv_bn2.conv.weight',	(128, 128, 3, 3),	 'layer2.1.conv2.weight',	(128, 128, 3, 3),
     'block2.1.conv_bn2.bn.weight',	(128,),	 'layer2.1.bn2.weight',	(128,),
     'block2.1.conv_bn2.bn.bias',	(128,),	 'layer2.1.bn2.bias',	(128,),
     'block2.1.conv_bn2.bn.running_mean',	(128,),	 'layer2.1.bn2.running_mean',	(128,),
     'block2.1.conv_bn2.bn.running_var',	(128,),	 'layer2.1.bn2.running_var',	(128,),
     'block2.2.conv_bn1.conv.weight',	(128, 128, 3, 3),	 'layer2.2.conv1.weight',	(128, 128, 3, 3),
     'block2.2.conv_bn1.bn.weight',	(128,),	 'layer2.2.bn1.weight',	(128,),
     'block2.2.conv_bn1.bn.bias',	(128,),	 'layer2.2.bn1.bias',	(128,),
     'block2.2.conv_bn1.bn.running_mean',	(128,),	 'layer2.2.bn1.running_mean',	(128,),
     'block2.2.conv_bn1.bn.running_var',	(128,),	 'layer2.2.bn1.running_var',	(128,),
     'block2.2.conv_bn2.conv.weight',	(128, 128, 3, 3),	 'layer2.2.conv2.weight',	(128, 128, 3, 3),
     'block2.2.conv_bn2.bn.weight',	(128,),	 'layer2.2.bn2.weight',	(128,),
     'block2.2.conv_bn2.bn.bias',	(128,),	 'layer2.2.bn2.bias',	(128,),
     'block2.2.conv_bn2.bn.running_mean',	(128,),	 'layer2.2.bn2.running_mean',	(128,),
     'block2.2.conv_bn2.bn.running_var',	(128,),	 'layer2.2.bn2.running_var',	(128,),
     'block2.3.conv_bn1.conv.weight',	(128, 128, 3, 3),	 'layer2.3.conv1.weight',	(128, 128, 3, 3),
     'block2.3.conv_bn1.bn.weight',	(128,),	 'layer2.3.bn1.weight',	(128,),
     'block2.3.conv_bn1.bn.bias',	(128,),	 'layer2.3.bn1.bias',	(128,),
     'block2.3.conv_bn1.bn.running_mean',	(128,),	 'layer2.3.bn1.running_mean',	(128,),
     'block2.3.conv_bn1.bn.running_var',	(128,),	 'layer2.3.bn1.running_var',	(128,),
     'block2.3.conv_bn2.conv.weight',	(128, 128, 3, 3),	 'layer2.3.conv2.weight',	(128, 128, 3, 3),
     'block2.3.conv_bn2.bn.weight',	(128,),	 'layer2.3.bn2.weight',	(128,),
     'block2.3.conv_bn2.bn.bias',	(128,),	 'layer2.3.bn2.bias',	(128,),
     'block2.3.conv_bn2.bn.running_mean',	(128,),	 'layer2.3.bn2.running_mean',	(128,),
     'block2.3.conv_bn2.bn.running_var',	(128,),	 'layer2.3.bn2.running_var',	(128,),
     'block3.0.conv_bn1.conv.weight',	(256, 128, 3, 3),	 'layer3.0.conv1.weight',	(256, 128, 3, 3),
     'block3.0.conv_bn1.bn.weight',	(256,),	 'layer3.0.bn1.weight',	(256,),
     'block3.0.conv_bn1.bn.bias',	(256,),	 'layer3.0.bn1.bias',	(256,),
     'block3.0.conv_bn1.bn.running_mean',	(256,),	 'layer3.0.bn1.running_mean',	(256,),
     'block3.0.conv_bn1.bn.running_var',	(256,),	 'layer3.0.bn1.running_var',	(256,),
     'block3.0.conv_bn2.conv.weight',	(256, 256, 3, 3),	 'layer3.0.conv2.weight',	(256, 256, 3, 3),
     'block3.0.conv_bn2.bn.weight',	(256,),	 'layer3.0.bn2.weight',	(256,),
     'block3.0.conv_bn2.bn.bias',	(256,),	 'layer3.0.bn2.bias',	(256,),
     'block3.0.conv_bn2.bn.running_mean',	(256,),	 'layer3.0.bn2.running_mean',	(256,),
     'block3.0.conv_bn2.bn.running_var',	(256,),	 'layer3.0.bn2.running_var',	(256,),
     'block3.0.shortcut.conv.weight',	(256, 128, 1, 1),	 'layer3.0.downsample.0.weight',	(256, 128, 1, 1),
     'block3.0.shortcut.bn.weight',	(256,),	 'layer3.0.downsample.1.weight',	(256,),
     'block3.0.shortcut.bn.bias',	(256,),	 'layer3.0.downsample.1.bias',	(256,),
     'block3.0.shortcut.bn.running_mean',	(256,),	 'layer3.0.downsample.1.running_mean',	(256,),
     'block3.0.shortcut.bn.running_var',	(256,),	 'layer3.0.downsample.1.running_var',	(256,),
     'block3.1.conv_bn1.conv.weight',	(256, 256, 3, 3),	 'layer3.1.conv1.weight',	(256, 256, 3, 3),
     'block3.1.conv_bn1.bn.weight',	(256,),	 'layer3.1.bn1.weight',	(256,),
     'block3.1.conv_bn1.bn.bias',	(256,),	 'layer3.1.bn1.bias',	(256,),
     'block3.1.conv_bn1.bn.running_mean',	(256,),	 'layer3.1.bn1.running_mean',	(256,),
     'block3.1.conv_bn1.bn.running_var',	(256,),	 'layer3.1.bn1.running_var',	(256,),
     'block3.1.conv_bn2.conv.weight',	(256, 256, 3, 3),	 'layer3.1.conv2.weight',	(256, 256, 3, 3),
     'block3.1.conv_bn2.bn.weight',	(256,),	 'layer3.1.bn2.weight',	(256,),
     'block3.1.conv_bn2.bn.bias',	(256,),	 'layer3.1.bn2.bias',	(256,),
     'block3.1.conv_bn2.bn.running_mean',	(256,),	 'layer3.1.bn2.running_mean',	(256,),
     'block3.1.conv_bn2.bn.running_var',	(256,),	 'layer3.1.bn2.running_var',	(256,),
     'block3.2.conv_bn1.conv.weight',	(256, 256, 3, 3),	 'layer3.2.conv1.weight',	(256, 256, 3, 3),
     'block3.2.conv_bn1.bn.weight',	(256,),	 'layer3.2.bn1.weight',	(256,),
     'block3.2.conv_bn1.bn.bias',	(256,),	 'layer3.2.bn1.bias',	(256,),
     'block3.2.conv_bn1.bn.running_mean',	(256,),	 'layer3.2.bn1.running_mean',	(256,),
     'block3.2.conv_bn1.bn.running_var',	(256,),	 'layer3.2.bn1.running_var',	(256,),
     'block3.2.conv_bn2.conv.weight',	(256, 256, 3, 3),	 'layer3.2.conv2.weight',	(256, 256, 3, 3),
     'block3.2.conv_bn2.bn.weight',	(256,),	 'layer3.2.bn2.weight',	(256,),
     'block3.2.conv_bn2.bn.bias',	(256,),	 'layer3.2.bn2.bias',	(256,),
     'block3.2.conv_bn2.bn.running_mean',	(256,),	 'layer3.2.bn2.running_mean',	(256,),
     'block3.2.conv_bn2.bn.running_var',	(256,),	 'layer3.2.bn2.running_var',	(256,),
     'block3.3.conv_bn1.conv.weight',	(256, 256, 3, 3),	 'layer3.3.conv1.weight',	(256, 256, 3, 3),
     'block3.3.conv_bn1.bn.weight',	(256,),	 'layer3.3.bn1.weight',	(256,),
     'block3.3.conv_bn1.bn.bias',	(256,),	 'layer3.3.bn1.bias',	(256,),
     'block3.3.conv_bn1.bn.running_mean',	(256,),	 'layer3.3.bn1.running_mean',	(256,),
     'block3.3.conv_bn1.bn.running_var',	(256,),	 'layer3.3.bn1.running_var',	(256,),
     'block3.3.conv_bn2.conv.weight',	(256, 256, 3, 3),	 'layer3.3.conv2.weight',	(256, 256, 3, 3),
     'block3.3.conv_bn2.bn.weight',	(256,),	 'layer3.3.bn2.weight',	(256,),
     'block3.3.conv_bn2.bn.bias',	(256,),	 'layer3.3.bn2.bias',	(256,),
     'block3.3.conv_bn2.bn.running_mean',	(256,),	 'layer3.3.bn2.running_mean',	(256,),
     'block3.3.conv_bn2.bn.running_var',	(256,),	 'layer3.3.bn2.running_var',	(256,),
     'block3.4.conv_bn1.conv.weight',	(256, 256, 3, 3),	 'layer3.4.conv1.weight',	(256, 256, 3, 3),
     'block3.4.conv_bn1.bn.weight',	(256,),	 'layer3.4.bn1.weight',	(256,),
     'block3.4.conv_bn1.bn.bias',	(256,),	 'layer3.4.bn1.bias',	(256,),
     'block3.4.conv_bn1.bn.running_mean',	(256,),	 'layer3.4.bn1.running_mean',	(256,),
     'block3.4.conv_bn1.bn.running_var',	(256,),	 'layer3.4.bn1.running_var',	(256,),
     'block3.4.conv_bn2.conv.weight',	(256, 256, 3, 3),	 'layer3.4.conv2.weight',	(256, 256, 3, 3),
     'block3.4.conv_bn2.bn.weight',	(256,),	 'layer3.4.bn2.weight',	(256,),
     'block3.4.conv_bn2.bn.bias',	(256,),	 'layer3.4.bn2.bias',	(256,),
     'block3.4.conv_bn2.bn.running_mean',	(256,),	 'layer3.4.bn2.running_mean',	(256,),
     'block3.4.conv_bn2.bn.running_var',	(256,),	 'layer3.4.bn2.running_var',	(256,),
     'block3.5.conv_bn1.conv.weight',	(256, 256, 3, 3),	 'layer3.5.conv1.weight',	(256, 256, 3, 3),
     'block3.5.conv_bn1.bn.weight',	(256,),	 'layer3.5.bn1.weight',	(256,),
     'block3.5.conv_bn1.bn.bias',	(256,),	 'layer3.5.bn1.bias',	(256,),
     'block3.5.conv_bn1.bn.running_mean',	(256,),	 'layer3.5.bn1.running_mean',	(256,),
     'block3.5.conv_bn1.bn.running_var',	(256,),	 'layer3.5.bn1.running_var',	(256,),
     'block3.5.conv_bn2.conv.weight',	(256, 256, 3, 3),	 'layer3.5.conv2.weight',	(256, 256, 3, 3),
     'block3.5.conv_bn2.bn.weight',	(256,),	 'layer3.5.bn2.weight',	(256,),
     'block3.5.conv_bn2.bn.bias',	(256,),	 'layer3.5.bn2.bias',	(256,),
     'block3.5.conv_bn2.bn.running_mean',	(256,),	 'layer3.5.bn2.running_mean',	(256,),
     'block3.5.conv_bn2.bn.running_var',	(256,),	 'layer3.5.bn2.running_var',	(256,),
     'block4.0.conv_bn1.conv.weight',	(512, 256, 3, 3),	 'layer4.0.conv1.weight',	(512, 256, 3, 3),
     'block4.0.conv_bn1.bn.weight',	(512,),	 'layer4.0.bn1.weight',	(512,),
     'block4.0.conv_bn1.bn.bias',	(512,),	 'layer4.0.bn1.bias',	(512,),
     'block4.0.conv_bn1.bn.running_mean',	(512,),	 'layer4.0.bn1.running_mean',	(512,),
     'block4.0.conv_bn1.bn.running_var',	(512,),	 'layer4.0.bn1.running_var',	(512,),
     'block4.0.conv_bn2.conv.weight',	(512, 512, 3, 3),	 'layer4.0.conv2.weight',	(512, 512, 3, 3),
     'block4.0.conv_bn2.bn.weight',	(512,),	 'layer4.0.bn2.weight',	(512,),
     'block4.0.conv_bn2.bn.bias',	(512,),	 'layer4.0.bn2.bias',	(512,),
     'block4.0.conv_bn2.bn.running_mean',	(512,),	 'layer4.0.bn2.running_mean',	(512,),
     'block4.0.conv_bn2.bn.running_var',	(512,),	 'layer4.0.bn2.running_var',	(512,),
     'block4.0.shortcut.conv.weight',	(512, 256, 1, 1),	 'layer4.0.downsample.0.weight',	(512, 256, 1, 1),
     'block4.0.shortcut.bn.weight',	(512,),	 'layer4.0.downsample.1.weight',	(512,),
     'block4.0.shortcut.bn.bias',	(512,),	 'layer4.0.downsample.1.bias',	(512,),
     'block4.0.shortcut.bn.running_mean',	(512,),	 'layer4.0.downsample.1.running_mean',	(512,),
     'block4.0.shortcut.bn.running_var',	(512,),	 'layer4.0.downsample.1.running_var',	(512,),
     'block4.1.conv_bn1.conv.weight',	(512, 512, 3, 3),	 'layer4.1.conv1.weight',	(512, 512, 3, 3),
     'block4.1.conv_bn1.bn.weight',	(512,),	 'layer4.1.bn1.weight',	(512,),
     'block4.1.conv_bn1.bn.bias',	(512,),	 'layer4.1.bn1.bias',	(512,),
     'block4.1.conv_bn1.bn.running_mean',	(512,),	 'layer4.1.bn1.running_mean',	(512,),
     'block4.1.conv_bn1.bn.running_var',	(512,),	 'layer4.1.bn1.running_var',	(512,),
     'block4.1.conv_bn2.conv.weight',	(512, 512, 3, 3),	 'layer4.1.conv2.weight',	(512, 512, 3, 3),
     'block4.1.conv_bn2.bn.weight',	(512,),	 'layer4.1.bn2.weight',	(512,),
     'block4.1.conv_bn2.bn.bias',	(512,),	 'layer4.1.bn2.bias',	(512,),
     'block4.1.conv_bn2.bn.running_mean',	(512,),	 'layer4.1.bn2.running_mean',	(512,),
     'block4.1.conv_bn2.bn.running_var',	(512,),	 'layer4.1.bn2.running_var',	(512,),
     'block4.2.conv_bn1.conv.weight',	(512, 512, 3, 3),	 'layer4.2.conv1.weight',	(512, 512, 3, 3),
     'block4.2.conv_bn1.bn.weight',	(512,),	 'layer4.2.bn1.weight',	(512,),
     'block4.2.conv_bn1.bn.bias',	(512,),	 'layer4.2.bn1.bias',	(512,),
     'block4.2.conv_bn1.bn.running_mean',	(512,),	 'layer4.2.bn1.running_mean',	(512,),
     'block4.2.conv_bn1.bn.running_var',	(512,),	 'layer4.2.bn1.running_var',	(512,),
     'block4.2.conv_bn2.conv.weight',	(512, 512, 3, 3),	 'layer4.2.conv2.weight',	(512, 512, 3, 3),
     'block4.2.conv_bn2.bn.weight',	(512,),	 'layer4.2.bn2.weight',	(512,),
     'block4.2.conv_bn2.bn.bias',	(512,),	 'layer4.2.bn2.bias',	(512,),
     'block4.2.conv_bn2.bn.running_mean',	(512,),	 'layer4.2.bn2.running_mean',	(512,),
     'block4.2.conv_bn2.bn.running_var',	(512,),	 'layer4.2.bn2.running_var',	(512,),
     'logit.weight',	(1000, 512),	 'fc.weight',	(1000, 512),
     'logit.bias',	(1000,),	 'fc.bias',	(1000,),

    ]

    ###############################################################################
    class ConvBn2d(nn.Module):

        def __init__(self, in_channel, out_channel, kernel_size=3, padding=1, stride=1):
            super(ConvBn2d, self).__init__()
            self.conv = nn.Conv2d(in_channel, out_channel, kernel_size=kernel_size, padding=padding, stride=stride, bias=False)
            self.bn   = nn.BatchNorm2d(out_channel, eps=1e-5)

        def forward(self,x):
            x = self.conv(x)
            x = self.bn(x)
            return x




    #############  resnext50 pyramid feature net #######################################
    # https://github.com/Hsuxu/ResNeXt/blob/master/models.py
    # https://github.com/D-X-Y/ResNeXt-DenseNet/blob/master/models/resnext.py
    # https://github.com/miraclewkf/ResNeXt-PyTorch/blob/master/resnext.py


    # bottleneck type C
    class BasicBlock(nn.Module):
        def __init__(self, in_channel, channel, out_channel, stride=1, is_shortcut=False):
            super(BasicBlock, self).__init__()
            self.is_shortcut = is_shortcut

            self.conv_bn1 = ConvBn2d(in_channel,    channel, kernel_size=3, padding=1, stride=stride)
            self.conv_bn2 = ConvBn2d(   channel,out_channel, kernel_size=3, padding=1, stride=1)

            if is_shortcut:
                self.shortcut = ConvBn2d(in_channel, out_channel, kernel_size=1, padding=0, stride=stride)


        def forward(self, x):
            z = F.relu(self.conv_bn1(x),inplace=True)
            z = self.conv_bn2(z)

            if self.is_shortcut:
                x = self.shortcut(x)

            z += x
            z = F.relu(z,inplace=True)
            return z


    #resnet18
    class ResNet18(nn.Module):

        def __init__(self, num_class=1000 ):
            super(ResNet18, self).__init__()


            self.block0  = nn.Sequential(
                nn.Conv2d(3, 64, kernel_size=7, padding=3, stride=2, bias=False),
                BatchNorm2d(64),
                nn.ReLU(inplace=True),
            )

            self.block1  = nn.Sequential(
                 nn.MaxPool2d(kernel_size=3, padding=1, stride=2),
                 BasicBlock( 64, 64, 64, stride=1, is_shortcut=False,),
              * [BasicBlock( 64, 64, 64, stride=1, is_shortcut=False,) for i in range(1,2)],
            )
            self.block2  = nn.Sequential(
                 BasicBlock( 64,128,128, stride=2, is_shortcut=True, ),
              * [BasicBlock(128,128,128, stride=1, is_shortcut=False,) for i in range(1,2)],
            )
            self.block3  = nn.Sequential(
                 BasicBlock(128,256,256, stride=2, is_shortcut=True, ),
              * [BasicBlock(256,256,256, stride=1, is_shortcut=False,) for i in range(1,2)],
            )
            self.block4 = nn.Sequential(
                 BasicBlock(256,512,512, stride=2, is_shortcut=True, ),
              * [BasicBlock(512,512,512, stride=1, is_shortcut=False,) for i in range(1,2)],
            )
            self.logit = nn.Linear(512,num_class)



        def forward(self, x):
            batch_size = len(x)

            x = self.block0(x)
            x = F.max_pool2d(x,kernel_size=3, padding=1, stride=2, ceil_mode=False)

            x = self.block1(x)
            x = self.block2(x)
            x = self.block3(x)
            x = self.block4(x)
            x = F.adaptive_avg_pool2d(x,1).reshape(batch_size,-1)
            logit = self.logit(x)
            return logit


    ####################################################################################################
    def upsize(x,scale_factor=2):
        #x = F.interpolate(x, size=e.shape[2:], mode='nearest')
        x = F.interpolate(x, scale_factor=scale_factor, mode='nearest')
        return x

    class Swish(nn.Module):
        def forward(self, x):
            return x * torch.sigmoid(x)

    class Decode(nn.Module):
        def __init__(self, in_channel, out_channel):
            super(Decode, self).__init__()

            self.top = nn.Sequential(
                nn.Conv2d(in_channel, out_channel//2, kernel_size=3, stride=1, padding=1, bias=False),
                BatchNorm2d( out_channel//2),
                nn.ReLU(inplace=True),
                #nn.Dropout(0.1),

                nn.Conv2d(out_channel//2, out_channel//2, kernel_size=3, stride=1, padding=1, bias=False),
                BatchNorm2d(out_channel//2),
                nn.ReLU(inplace=True),
                #nn.Dropout(0.1),

                nn.Conv2d(out_channel//2, out_channel, kernel_size=1, stride=1, padding=0, bias=False),
                BatchNorm2d(out_channel),
                nn.ReLU(inplace=True), #Swish(), #
            )

        def forward(self, x):
            x = self.top(torch.cat(x, 1))
            return x



    class Net(nn.Module):

        def load_pretrain(self, skip, is_print=True):
            conversion=copy.copy(CONVERSION)
            for i in range(0,len(conversion)-8,4):
                conversion[i] = 'block.' + conversion[i][5:]
            load_pretrain(self, skip, pretrain_file=PRETRAIN_FILE, conversion=conversion, is_print=is_print)

        def __init__(self, num_class=5, drop_connect_rate=0.2):
            super(Net, self).__init__()

            e = ResNet18()
            self.block = nn.ModuleList([
               e.block0,
               e.block1,
               e.block2,
               e.block3,
               e.block4
            ])
            e = None  #dropped

            self.decode1 =  Decode(512,     128)
            self.decode2 =  Decode(128+256, 128)
            self.decode3 =  Decode(128+128, 128)
            self.decode4 =  Decode(128+ 64, 128)
            self.decode5 =  Decode(128+ 64, 128)
            self.logit = nn.Conv2d(128,num_class, kernel_size=1)

        def forward(self, x):
            batch_size,C,H,W = x.shape

            #----------------------------------
            backbone = []
            for i in range( len(self.block)):
                x = self.block[i](x)
                #print(i, x.shape)

                if i in [0,1,2,3,4]:
                    backbone.append(x)

            #----------------------------------
            x = self.decode1([backbone[-1], ])                   #; print('d1',d1.size())
            x = self.decode2([backbone[-2], upsize(x)])          #; print('d2',d2.size())
            x = self.decode3([backbone[-3], upsize(x)])          #; print('d3',d3.size())
            x = self.decode4([backbone[-4], upsize(x)])          #; print('d4',d4.size())
            x = self.decode5([backbone[-5], upsize(x)])          #; print('d5',d5.size())

            logit = self.logit(x)
            logit = F.interpolate(logit, size=(H,W), mode='bilinear', align_corners=False)
            return logit

    class ResNet34(nn.Module):

        def __init__(self, num_class=1000 ):
            super(ResNet34, self).__init__()


            self.block0  = nn.Sequential(
                nn.Conv2d(3, 64, kernel_size=7, padding=3, stride=2, bias=False),
                BatchNorm2d(64),
                nn.ReLU(inplace=True),
            )
            self.block1  = nn.Sequential(
                 nn.MaxPool2d(kernel_size=3, padding=1, stride=2),
                 BasicBlock( 64, 64, 64, stride=1, is_shortcut=False,),
              * [BasicBlock( 64, 64, 64, stride=1, is_shortcut=False,) for i in range(1,3)],
            )
            self.block2  = nn.Sequential(
                 BasicBlock( 64,128,128, stride=2, is_shortcut=True, ),
              * [BasicBlock(128,128,128, stride=1, is_shortcut=False,) for i in range(1,4)],
            )
            self.block3  = nn.Sequential(
                 BasicBlock(128,256,256, stride=2, is_shortcut=True, ),
              * [BasicBlock(256,256,256, stride=1, is_shortcut=False,) for i in range(1,6)],
            )
            self.block4 = nn.Sequential(
                 BasicBlock(256,512,512, stride=2, is_shortcut=True, ),
              * [BasicBlock(512,512,512, stride=1, is_shortcut=False,) for i in range(1,3)],
            )
            self.logit = nn.Linear(512,num_class)



        def forward(self, x):
            batch_size = len(x)

            x = self.block0(x)
            x = self.block1(x)
            x = self.block2(x)
            x = self.block3(x)
            x = self.block4(x)
            x = F.adaptive_avg_pool2d(x,1).reshape(batch_size,-1)
            logit = self.logit(x)
            return logit


    class Resnet34_classification(nn.Module):
        def __init__(self,num_class=4):
            super(Resnet34_classification, self).__init__()
            e = ResNet34()
            self.block = nn.ModuleList([
                e.block0,
                e.block1,
                e.block2,
                e.block3,
                e.block4,
            ])
            e = None  #dropped
            self.feature = nn.Conv2d(512,32, kernel_size=1) #dummy conv for dim reduction
            self.logit = nn.Conv2d(32,num_class, kernel_size=1)

        def forward(self, x):
            batch_size,C,H,W = x.shape

            for i in range( len(self.block)):
                x = self.block[i](x)
                #print(i, x.shape)

            x = F.dropout(x,0.5,training=self.training)
            x = F.adaptive_avg_pool2d(x, 1)
            x = self.feature(x)
            logit = self.logit(x)
            return logit

    model_classification = Resnet34_classification()
    model_classification.load_state_dict(torch.load("../input/hengs-models-20190910/00007500_model.pth",
                                                    map_location=lambda storage, loc: storage), strict=True)

    return model_classification

In [None]:
model_classification = load_hengs_clf_model()

In [None]:
def get_hengs_clf_model_preds(model_classification):
    # Dataset setup
    class TestDataset(Dataset):
        '''Dataset for test prediction'''
        def __init__(self, root, df, mean, std):
            self.root = root
            df['ImageId'] = df['ImageId_ClassId'].apply(lambda x: x.split('_')[0])
            self.fnames = df['ImageId'].unique().tolist()
            self.num_samples = len(self.fnames)
            self.transform = Compose(
                [
                    Normalize(mean=mean, std=std, p=1),
                    ToTensor(),
                ]
            )

        def __getitem__(self, idx):
            fname = self.fnames[idx]
            path = os.path.join(self.root, fname)
            image = cv2.imread(path)
            images = self.transform(image=image)["image"]
            return fname, images

        def __len__(self):
            return self.num_samples

    sample_submission_path = '../input/severstal-steel-defect-detection/sample_submission.csv'
    test_data_folder = "../input/severstal-steel-defect-detection/test_images"

    # hyperparameters
    batch_size = 1

    # mean and std
    mean = (0.485, 0.456, 0.406)
    std = (0.229, 0.224, 0.225)

    df = pd.read_csv(sample_submission_path)

    # dataloader
    testset = DataLoader(
        TestDataset(test_data_folder, df, mean, std),
        batch_size=batch_size,
        shuffle=False,
        num_workers=0,
        pin_memory=True
    )


    # useful functions for setting up inference

    def sharpen(p,t=0.5):
            if t!=0:
                return p**t
            else:
                return p

    def get_classification_preds(net,test_loader):
        test_probability_label = []
        test_id   = []

        net = net.cuda()
        for t, (fnames, images) in enumerate(tqdm(test_loader)):
            batch_size,C,H,W = images.shape
            images = images.cuda()

            with torch.no_grad():
                net.eval()

                num_augment = 0
                if 1: #  null
                    logit =  net(images)
                    probability = torch.sigmoid(logit)

                    probability_label = sharpen(probability,0)
                    num_augment+=1

                if 'flip_lr' in augment:
                    logit = net(torch.flip(images,dims=[3]))
                    probability  = torch.sigmoid(logit)

                    probability_label += sharpen(probability)
                    num_augment+=1

                if 'flip_ud' in augment:
                    logit = net(torch.flip(images,dims=[2]))
                    probability = torch.sigmoid(logit)

                    probability_label += sharpen(probability)
                    num_augment+=1

                probability_label = probability_label/num_augment

            probability_label = probability_label.data.cpu().numpy()

            test_probability_label.append(probability_label)
            test_id.extend([i for i in fnames])


        test_probability_label = np.concatenate(test_probability_label)
        return test_probability_label, test_id

    # threshold for classification
    threshold_label = [0.50,0.50,0.50,0.50,]

    augment = ['null'] #['null', 'flip_lr','flip_ud','5crop'] # ['null', 'flip_lr','flip_ud'] # # #

    # Get prediction for classification model

    probability_label, image_id = get_classification_preds(model_classification, testset)
    predict_label = probability_label>np.array(threshold_label).reshape(1,4,1,1)

    image_id_class_id = []
    encoded_pixel = []
    for b in range(len(image_id)):
        for c in range(4):
            image_id_class_id.append(image_id[b]+'_%d'%(c+1))
            if predict_label[b,c]==0:
                rle=''
            else:
                rle ='1 1'
            encoded_pixel.append(rle)

    df_classification = pd.DataFrame(zip(image_id_class_id, encoded_pixel), columns=['ImageId_ClassId', 'EncodedPixels'])

    return df_classification


In [None]:
df_classification = get_hengs_clf_model_preds(model_classification)

In [None]:
del model_classification
torch.cuda.empty_cache()
gc.collect()

In [None]:
df_classification.head()

In [None]:
def load_denis_gontcharov_segmentation_model():
    # https://www.kaggle.com/gontcharovd/unet-pytorch-inference-kernel-extended-0-89648
    #Model from https://www.kaggle.com/gontcharovd/unet-pytorch-inference-kernel-extended-0-89648/comments
    # ckpt_path = "../input/resnetmodels/resnet18_20_epochs.pth"
    # ckpt_path = "../input/senetmodels/senet50_20_epochs.pth"
    ckpt_path = "../input/senetmodels/senext50_30_epochs_high_threshold.pth"
    device = torch.device("cuda")
    # change the encoder name in the Unet() call.
    model_segmentation = Unet('se_resnext50_32x4d', encoder_weights=None, classes=4, activation=None)
    model_segmentation.to(device)
    model_segmentation.eval()
    state = torch.load(ckpt_path, map_location=lambda storage, loc: storage)
    model_segmentation.load_state_dict(state["state_dict"])
    model_segmentation = model_segmentation.cuda()
    
    return model_segmentation

In [None]:
denis_gontcharov_model = load_denis_gontcharov_segmentation_model()

In [None]:
# def load_ilya_model():
#     device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 

#     models_folder = Path("/kaggle/input/segmentation-model")
# #     empty_model_folder = Path("/kaggle/input/empty-model")

#     class PretrainedModel(torch.nn.Module):
#         def __init__(self, output_features, pretrained=True):
#             super().__init__()
#             model = torchvision.models.resnet34(pretrained=pretrained)
#             num_ftrs = model.fc.in_features
#             model.fc = torch.nn.Linear(num_ftrs, output_features)
#             self.model = model

#         def forward(self, x):
#             return self.model(x)
# #     empty_model = PretrainedModel(2, False)
    
# #     pretrained_model_name = "best_model.pt"
# #     empty_model.load_state_dict(torch.load(
# #         empty_model_folder / pretrained_model_name,
# #         map_location=torch.device("cpu")
# #     ))

# #     empty_model.to(device)
# #     empty_model.eval();
    
#     segmentation_model = PretrainedUNet(
#     in_channels=3,
#     out_channels=4, 
#     batch_norm=True, 
#     upscale_mode="bilinear",
#     pretrained=False
#     )
    
#     pretrained_model_name = "severstal-unet-v51.pt"

#     if pretrained_model_name is not None:
#         segmentation_model.load_state_dict(torch.load(
#             models_folder / pretrained_model_name,
#             map_location=torch.device("cpu")
#         ))

#     segmentation_model.to(device)
#     segmentation_model.eval();
    
#     return segmentation_model.cuda()

# ilya_model = load_ilya_model()

In [None]:
nfolds = 1#4
bs = 2
n_cls = 4
noise_th = 2000 #predicted masks must be larger than noise_th
TEST = '../input/severstal-steel-defect-detection/test_images/'
BASE = '../input/severstal-fast-ai-256x256-crops/'

torch.backends.cudnn.benchmark = True


# def get_fast_ai_learn():

from fastai.vision.learner import create_head, cnn_config, num_features_model, create_head
from fastai.callbacks.hooks import model_sizes, hook_outputs, dummy_eval, Hook, _hook_inner
from fastai.vision.models.unet import _get_sfs_idxs, UnetBlock

class Hcolumns(nn.Module):
    def __init__(self, hooks:Collection[Hook], nc:Collection[int]=None):
        super(Hcolumns,self).__init__()
        self.hooks = hooks
        self.n = len(self.hooks)
        self.factorization = None 
        if nc is not None:
            self.factorization = nn.ModuleList()
            for i in range(self.n):
                self.factorization.append(nn.Sequential(
                    conv2d(nc[i],nc[-1],3,padding=1,bias=True),
                    conv2d(nc[-1],nc[-1],3,padding=1,bias=True)))
                #self.factorization.append(conv2d(nc[i],nc[-1],3,padding=1,bias=True))
        
    def forward(self, x:Tensor):
        n = len(self.hooks)
        out = [F.interpolate(self.hooks[i].stored if self.factorization is None
            else self.factorization[i](self.hooks[i].stored), scale_factor=2**(self.n-i),
            mode='bilinear',align_corners=False) for i in range(self.n)] + [x]
        return torch.cat(out, dim=1)

class DynamicUnet_Hcolumns(SequentialEx):
    "Create a U-Net from a given architecture."
    def __init__(self, encoder:nn.Module, n_classes:int, blur:bool=False, blur_final=True, 
                 self_attention:bool=False,
                 y_range:Optional[Tuple[float,float]]=None,
                 last_cross:bool=True, bottle:bool=False, **kwargs):
        imsize = (256,256)
        sfs_szs = model_sizes(encoder, size=imsize)
        sfs_idxs = list(reversed(_get_sfs_idxs(sfs_szs)))
        self.sfs = hook_outputs([encoder[i] for i in sfs_idxs])
        x = dummy_eval(encoder, imsize).detach()

        ni = sfs_szs[-1][1]
        middle_conv = nn.Sequential(conv_layer(ni, ni*2, **kwargs),
                                    conv_layer(ni*2, ni, **kwargs)).eval()
        x = middle_conv(x)
        layers = [encoder, batchnorm_2d(ni), nn.ReLU(), middle_conv]

        self.hc_hooks = [Hook(layers[-1], _hook_inner, detach=False)]
        hc_c = [x.shape[1]]
        
        for i,idx in enumerate(sfs_idxs):
            not_final = i!=len(sfs_idxs)-1
            up_in_c, x_in_c = int(x.shape[1]), int(sfs_szs[idx][1])
            do_blur = blur and (not_final or blur_final)
            sa = self_attention and (i==len(sfs_idxs)-3)
            unet_block = UnetBlock(up_in_c, x_in_c, self.sfs[i], final_div=not_final, 
                blur=blur, self_attention=sa, **kwargs).eval()
            layers.append(unet_block)
            x = unet_block(x)
            self.hc_hooks.append(Hook(layers[-1], _hook_inner, detach=False))
            hc_c.append(x.shape[1])

        ni = x.shape[1]
        if imsize != sfs_szs[0][-2:]: layers.append(PixelShuffle_ICNR(ni, **kwargs))
        if last_cross:
            layers.append(MergeLayer(dense=True))
            ni += in_channels(encoder)
            layers.append(res_block(ni, bottle=bottle, **kwargs))
        hc_c.append(ni)
        layers.append(Hcolumns(self.hc_hooks, hc_c))
        layers += [conv_layer(ni*len(hc_c), n_classes, ks=1, use_activ=False, **kwargs)]
        if y_range is not None: layers.append(SigmoidRange(*y_range))
        super().__init__(*layers)

    def __del__(self):
        if hasattr(self, "sfs"): self.sfs.remove()
            
def unet_learner(data:DataBunch, arch:Callable, pretrained:bool=True, blur_final:bool=True,
        norm_type:Optional[NormType]=NormType, split_on:Optional[SplitFuncOrIdxList]=None, 
        blur:bool=False, self_attention:bool=False, y_range:Optional[Tuple[float,float]]=None, 
        last_cross:bool=True, bottle:bool=False, cut:Union[int,Callable]=None, 
        hypercolumns=True, **learn_kwargs:Any)->Learner:
    "Build Unet learner from `data` and `arch`."
    meta = cnn_config(arch)
    body = create_body(arch, pretrained, cut)
    M = DynamicUnet_Hcolumns if hypercolumns else DynamicUnet
    model = to_device(M(body, n_classes=data.c, blur=blur, blur_final=blur_final,
        self_attention=self_attention, y_range=y_range, norm_type=norm_type, 
        last_cross=last_cross, bottle=bottle), data.device)
    learn = Learner(data, model, **learn_kwargs)
    learn.split(ifnone(split_on, meta['split']))
    if pretrained: learn.freeze()
    apply_init(model[2], nn.init.kaiming_normal_)
    return learn
class SegmentationLabelList(SegmentationLabelList):
    def open(self, fn): return open_mask(fn, div=True)
    
class SegmentationItemList(SegmentationItemList):
    _label_cls = SegmentationLabelList

# Setting transformations on masks to False on test set
def transform(self, tfms:Optional[Tuple[TfmList,TfmList]]=(None,None), **kwargs):
    if not tfms: tfms=(None,None)
    assert is_listy(tfms) and len(tfms) == 2
    self.train.transform(tfms[0], **kwargs)
    self.valid.transform(tfms[1], **kwargs)
    kwargs['tfm_y'] = False # Test data has no labels
    if self.test: self.test.transform(tfms[1], **kwargs)
    return self
fastai.data_block.ItemLists.transform = transform

def open_mask(fn:PathOrStr, div:bool=True, convert_mode:str='L', cls:type=ImageSegment,
        after_open:Callable=None)->ImageSegment:
    with warnings.catch_warnings():
        warnings.simplefilter("ignore", UserWarning)
        x = PIL.Image.open(fn).convert(convert_mode)
    if after_open: x = after_open(x)
    x = pil2tensor(x,np.float32)
    return cls(x)


def get_fast_ai_learn():
    stats = ([0.396,0.396,0.396], [0.179,0.179,0.179])
    #check https://www.kaggle.com/iafoss/256x256-images-with-defects for stats

    data = (SegmentationItemList.from_folder(TEST)
            .split_by_idx([0])
            .label_from_func(lambda x : str(x), classes=[0,1,2,3,4])
            .add_test(Path(TEST).ls(), label=None)
            .databunch(path=Path('.'), bs=bs)
            .normalize(stats))


    learn = unet_learner(data, models.resnet34, pretrained=False)
    learn.model.load_state_dict(torch.load(Path(BASE)/f'models/fold0.pth')['model'])
    learn.model.eval()
    return learn

learn = get_fast_ai_learn()

n_cls = 4
tta = True
noise_th = 2000 #predicted masks must be larger than noise_th

def post_proc(yp):
    yp = np.argmax(yp, axis=-1)
    for i in range(n_cls):
        idxs = yp == i+1
        if idxs.sum() < noise_th: 
            yp[idxs] = 0
    return yp

def argmax_mask_to_binary_masks(mask):
    return [(mask == i).astype(np.int8)for i in range(1,n_cls+1)]

# def pred_batch_fast_ai(x):
#     x = x.cuda()
#     py = torch.softmax(learn.model(x),dim=1).permute(0,2,3,1).detach()
#     if tta:
#         flips = [[-1],[-2],[-2,-1]]
#         for f in flips:
#             py += torch.softmax(torch.flip(learn.model(torch.flip(x,f)),f),dim=1).permute(0,2,3,1).detach()
#         py /= len(flips) + 1
#     py = py.cpu().numpy() 
#     argmax_masks = [post_proc(yp) for yp in py]
#     binary_masks = np.array([argmax_mask_to_binary_masks(argmax_mask) for argmax_mask in argmax_masks])
#     return binary_masks

def pred_batch_fast_ai(x):
    x = x.cuda()
    py = torch.softmax(learn.model(x),dim=1).detach()
    if tta:
        flips = [[-1],[-2],[-2,-1]]
        for f in flips:
            py += torch.softmax(torch.flip(learn.model(torch.flip(x,f)),f),dim=1).detach()
        py /= len(flips) + 1
        
#     print(py[:,1:].cpu().numpy().shape)
    return py[:,1:]

In [None]:
unet_se_resnext50_32x4d = \
    load('/kaggle/input/severstalmodels/unet_se_resnext50_32x4d.pth').cuda()
unet_mobilenet2 = load('/kaggle/input/severstalmodels/unet_mobilenet2.pth').cuda()
unet_resnet34 = load('/kaggle/input/severstalmodels/unet_resnet34.pth').cuda()

### Models' mean aggregator

In [None]:
class Model:
    def __init__(self, models, weights=None):
        self.models = models
        self.weights = weights or np.ones(len(models))
    
    def __call__(self, x):
        res = []
        x = x.cuda()
        with torch.no_grad():
            for model, weight in zip(self.models, self.weights):
                res.append(torch.sigmoid(model(x))*weight)
        res = torch.stack(res)
        return torch.sum(res, dim=0) /sum(self.weights)

model = Model([unet_se_resnext50_32x4d, unet_mobilenet2, unet_resnet34,
               denis_gontcharov_model])

### Create TTA transforms, datasets, loaders

In [None]:
def create_transforms(additional):
    res = list(additional)
    # add necessary transformations
    res.extend([
        ChannelTranspose()
    ])
    res = A.Compose(res)
    return res

img_folder = '/kaggle/input/severstal-steel-defect-detection/test_images'
batch_size = 2
num_workers = 0
# [0.388,0.390,0.394], [0.178,0.181,0.175]
# Different transforms for TTA wrapper
transforms = [
    [A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))],
    [A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),A.HorizontalFlip(p=1)],
    [A.Normalize(mean=(0.388,0.390,0.394), std=(0.178,0.181,0.175))],#dataset for fast ai
]

transforms = [create_transforms(t) for t in transforms]
datasets = [TtaWrap(ImageDataset(img_folder=img_folder, transforms=t), tfms=t) for t in transforms]
loaders = [DataLoader(d, num_workers=num_workers, batch_size=batch_size, shuffle=False) for d in datasets]

### Loaders' mean aggregator

In [None]:
thresholds = [0.5, 0.5, 0.5, 0.5]
min_area = [600, 600, 1000, 2000]

res = []
# Iterate over all TTA loaders
total = len(datasets[0])//batch_size
for loaders_batch in tqdm_notebook(zip(*loaders), total=total):

    ############## Get preds with tta for fasi ai model ################# 
    fast_ai_batch = loaders_batch[-1]
    fast_ai_features = fast_ai_batch['features'].cuda()
    fast_ai_pred = pred_batch_fast_ai(fast_ai_features).cpu().numpy()
    
    ############## Get preds with tta for ensemble ################# 
    preds = []
    
    loaders_batch = loaders_batch[:-1]
    for i, batch in enumerate(loaders_batch):
        features = batch['features'].cuda()
        p = model(features)
        # inverse operations for TTA
        p = datasets[i].inverse(p)
        preds.append(p)
    # TTA mean
    preds = torch.stack(preds)
    preds = torch.mean(preds, dim=0)
    preds = preds.detach().cpu().numpy()
    
    ############## Combine preds with weights ################# 
    preds = np.average([preds, fast_ai_pred], weights=(4.0, 1.0), axis=0)

    
    # Batch post processing
    for p, file in zip(preds, loaders_batch[0]['image_file']):
        file = os.path.basename(file)
        # Image postprocessing
        for i in range(4):
            p_channel = p[i]
            imageid_classid = file+'_'+str(i+1)
            p_channel = (p_channel>thresholds[i]).astype(np.uint8)
            if p_channel.sum() < min_area[i]:
                p_channel = np.zeros(p_channel.shape, dtype=p_channel.dtype)

            res.append({
                'ImageId_ClassId': imageid_classid,
                'EncodedPixels': mask2rle(p_channel)
            })        	

In [None]:
df = pd.DataFrame(res)
df = df.fillna('')

In [None]:
(df['EncodedPixels'] == '').mean()

In [None]:
(df_classification['EncodedPixels'] == '').mean()

In [None]:
((df_classification['EncodedPixels'] == '')&(df['EncodedPixels'] == '')).mean()

## Delete false positives

In [None]:
df.loc[df_classification['EncodedPixels'] == '', 'EncodedPixels'] = ''

Save predictions

In [None]:
# a = pd.read_csv('submission.csv')
# a[~a['EncodedPixels'].isnull()]

In [None]:
df.to_csv('submission.csv', index=False)

Histogram of predictions

In [None]:
time_end = time.time()
print((time_end - time_start))

In [None]:
# df['Image'] = df['ImageId_ClassId'].map(lambda x: x.split('_')[0])
# df['Class'] = df['ImageId_ClassId'].map(lambda x: x.split('_')[1])
# df['empty'] = df['EncodedPixels'].map(lambda x: not x)
# df[df['empty'] == False]['Class'].value_counts()

### Visualization

In [None]:
# %matplotlib inline

# df = pd.read_csv('submission.csv')[:40]
# df['Image'] = df['ImageId_ClassId'].map(lambda x: x.split('_')[0])
# df['Class'] = df['ImageId_ClassId'].map(lambda x: x.split('_')[1])

# for row in df.itertuples():
#     img_path = os.path.join(img_folder, row.Image)
#     img = cv2.imread(img_path)
#     mask = rle2mask(row.EncodedPixels, (1600, 256)) \
#         if isinstance(row.EncodedPixels, str) else np.zeros((256, 1600))
#     if mask.sum() == 0:
#         continue
    
#     fig, axes = plt.subplots(1, 2, figsize=(20, 60))
#     axes[0].imshow(img/255)
#     axes[1].imshow(mask*60)
#     axes[0].set_title(row.Image)
#     axes[1].set_title(row.Class)
#     plt.show()