##  Severstal: Steel Defect Detection. Inference Notebook

<br>

More details about this submission can be found in this [GitHub Repo](https://github.com/reyvaz/steel-defect-segmentation).

<br>

### The solution in this notebook consists on a two-step approach.

1. The first step is to run the images through an ensemble of binary image neural network classifiers to determine whether the piece of steel in the image presents a defect.

2. The second step, runs the same image through an ensemble of segmentation neural networks to identify the location of the defect and identify the type of the defect.

All 1st and 2nd step neural networks were trained on a K = 5, K-Fold cross-validation distribution of the data, all with the same image size. Random data augmentations were applied to the training partition data for all networks and folds in both stages.

#### Binary Classification

The ensemble for the binary classification step consists of EfficientNet (Tan & Le 2020) based classifiers versions B0-B5. The classifiers were selected according to their out-of-fold performance during training.

#### Defect Type Classification and Segmentation

The ensemble for the 2nd step consists of UNet++ (Zhou et al., 2019) based CNNs, all with EfficientNet backbones versions B0-B5. The networks included in the ensemble were selected according to their out-of-fold performance on validation data on images with defects only. 

#### Acknowledgements:
- Thanks to [PAO Severstal](https://www.severstal.com/) for providing the dataset.

- This solution was inspired by the Kaggle kernels [Severstal: U-Net++ with EfficientNetB4]( https://www.kaggle.com/xhlulu/severstal-u-net-with-efficientnetb4/) by user [xhlulu](https://www.kaggle.com/xhlulu), and [Unet Plus Plus with EfficientNet Encoder]( https://www.kaggle.com/meaninglesslives/nested-unet-with-efficientnet-encoder?scriptVersionId=0) by [Siddhartha](https://sidml.github.io/).


#### References:

- Tan, M., & Le, Q. V. (2020). EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks. [arXiv:1905.11946v5](https://arxiv.org/abs/1905.11946v5).

- Zhou, Z., Siddiquee, M., Tajbakhsh, N., & Liang, J. (2019). UNet++: Redesigning Skip Connections to Exploit Multiscale Features in Image Segmentation. [arXiv:1912.05074v2](https://arxiv.org/abs/1912.05074v2).

In [None]:
input_path = '/kaggle/input/severstal-steel-defect-detection/'
base = '/kaggle/input/severstal-inference-base'
requirements_dir = base + '/requirements/'

In [None]:
!pip -q config set global.disable-pip-version-check true
!pip -q install {requirements_dir}Keras_Applications-1.0.8-py3-none-any.whl
!pip -q install {requirements_dir}efficientnet-1.1.1-py3-none-any.whl

In [None]:
import os, sys, re, gc
import numpy as np
import pandas as pd

import tensorflow as tf
import tensorflow.keras.layers as L
import tensorflow.keras.backend as K

!cp -r {base}/tpu_segmentation ./
!cp -r {base}/*.py ./
from tpu_segmentation import *
from severstal_utils import *
!rm -r tpu_segmentation *.py

AUTO = tf.data.experimental.AUTOTUNE 
strategy = tf.distribute.get_strategy()

start_notebook = time()
print('Notebook started at: ', current_time_str())
print('Tensorflow version: ', tf.__version__)

## Constants

In [None]:
IMAGE_SIZE = (256, 1600) # original image size
target_size = (128, 800) # size used for CNN inputs, same for all CNNs.
input_shape = (*target_size, 3) 
N_CLASSES = 4 # types of defects

## Test Dataset

In [None]:
test_fnames = tf.io.gfile.glob(input_path + 'test_images/*')
test_ids = [x.split('/')[-1].split('.')[0] for x in test_fnames]
get_test_path = lambda x: input_path + 'test_images/' + x + '.jpg'

In [None]:
def normalize_and_reshape(img, target_size): 
    img = tf.image.resize(img, target_size)
    img = tf.cast(img, tf.float32) / 255.0  
    img = tf.reshape(img, [*target_size, 3]) 
    return img

def get_image_and_id(file_name, target_size): 
    img = tf.io.read_file(file_name) 
    img = tf.image.decode_jpeg(img, channels=3) 
    img = normalize_and_reshape(img, target_size)
    img_id = tf.strings.split(file_name, os.path.sep)[-1]
    img_id = tf.strings.split(img_id, '.')[0]
    return img, img_id

def get_test_dataset(fnames, target_size, batch_size):
    dataset = tf.data.Dataset.from_tensor_slices(fnames)
    dataset = dataset.map(lambda file_name: get_image_and_id(file_name, target_size), num_parallel_calls=AUTO)
    dataset = dataset.batch(batch_size=batch_size, drop_remainder=False)
    dataset = dataset.prefetch(AUTO) 
    return dataset

## Weight Data and Extraction

In [None]:
df = pd.read_csv(base + '/weights_meta.csv')
df1 = df[df.source == 1]
df2 = df[df.source == 2]

bin1 = df1[df1.type == 'bin']
bin1 = get_best_weights(bin1, 1)

seg1 = df1[df1.type != 'bin']
seg1 = get_best_weights(seg1, 1)

bin2 = df2[df2.type == 'bin']
seg2 = df2[df2.type != 'bin']

bin_weights = list(bin2.filename) + list(bin1.filename)
seg_weights = list(seg2.filename) + list(seg1.filename)

In [None]:
!mkdir -p weights
!unzip -q {base}/mixed_weights.zip -d weights
!unzip -q {base}/binary_weights.zip -d weights
!unzip -q {base}/segmentation_weights.zip -d weights

# Binary Predictions

### Build test dataset for binary classification

using all test images. i.e. all listed in `test_ids`

In [None]:
fnames_bin = [get_test_path(i) for i in test_ids]

test_dataset_bin = get_test_dataset(fnames_bin, target_size=(128, 800), batch_size=16)
num_batches = tf.data.experimental.cardinality(test_dataset_bin); 
print('num of batches', num_batches.numpy())

### Assemble Binary Ensemble

In [None]:
ensemble_outputs = []
with strategy.scope():
    X = L.Input(shape=input_shape)
    for i, w in enumerate(bin_weights):
        base_name = w.split('-bin')[0]
        model = build_classifier(base_name, n_classes = 1, input_shape=input_shape, weights = None, name_suffix='-M{}'.format(i+1))
        model.load_weights('weights/' + w)
        model_output = model(X)
        ensemble_outputs.append(model_output)
    
    Y = L.Average()(ensemble_outputs)
    binary_ensemble = tf.keras.Model(inputs=X, outputs=Y, name='Binary_Classification_Ensemble')
    binary_ensemble.compile(optimizer='adam', loss='binary_crossentropy', metrics=[])

del model, ensemble_outputs, model_output

### Predictions (binary)

In [None]:
start_preds = time()
binary_predictions = binary_ensemble.predict(test_dataset_bin)

del binary_ensemble
K.clear_session()
gc.collect()

print('Elapsed time (binary predictions) {}'.format(time_passed(start_preds)))

# Mask Predictions

### Assemble Segmentation Ensemble

In [None]:
ensemble_outputs = []
with strategy.scope():
    X = L.Input(shape=input_shape)
    for i, w in enumerate(seg_weights):
        backbone_name = w.split('-unetpp')[0]
        model = xnet(backbone_name, num_classes = 4, input_shape=input_shape, weights = None)
        model._name = '{}-M{}'.format(model.name, i+1)
        model.load_weights('weights/' + w)
        model_output = model(X)
        ensemble_outputs.append(model_output)
    
    Y = L.Average()(ensemble_outputs)
    seg_ensemble = tf.keras.Model(inputs=X, outputs=Y, name='Mask_Segmentation_Ensemble')
    seg_ensemble.compile(optimizer='adam', loss='binary_crossentropy', metrics=[])

del model, ensemble_outputs, model_output
!rm -r weights

### Build test dataset for mask predictions

In [None]:
THRESHOLD = 0.80
masked_indexes = np.where(binary_predictions>=THRESHOLD)[0]
unmasked_indexes = np.where(binary_predictions<THRESHOLD)[0]

seg_ids = list(np.array(test_ids)[masked_indexes])
no_seg_ids = list(np.array(test_ids)[unmasked_indexes])
print(len(seg_ids), len(no_seg_ids), len(seg_ids) + len(no_seg_ids), len(test_ids))

In [None]:
fnames_seg = [get_test_path(i) for i in seg_ids]
batch_size = 8
test_dataset_seg = get_test_dataset(fnames_seg, target_size=target_size, batch_size=batch_size)
num_batches = tf.data.experimental.cardinality(test_dataset_seg); 
print('num of batches', num_batches.numpy())

### Visualize Prediction Examples

In [None]:
n_batches = 1
sample_preds = seg_ensemble.predict(test_dataset_seg.take(n_batches))
examples = retrieve_examples(test_dataset_seg, batch_size*n_batches)
idx = -1

mask_rgb = [(230, 184, 0), (0, 128, 0), (102, 0, 204), (204, 0, 102)]

In [None]:
# convert to RLE, then back to MASK to verify functions. 
idx += 1
print('Original Prediction, mask shape:', sample_preds[idx].shape)
plot_image_mask((examples[idx][0], sample_preds[idx]), mask_rgb = mask_rgb)

rle_example = create_rles(sample_preds[idx], IMAGE_SIZE)
masks_example = build_mask_array(rle_example, IMAGE_SIZE, n_classes=4)
print('Reconverted Prediction, subtle differences due to repeated resizing, mask shape: ', masks_example.shape)
plot_image_mask((tf.image.resize(examples[idx][0], IMAGE_SIZE), tf.cast(masks_example, tf.float32)), mask_rgb = mask_rgb)

In [None]:
for i in range(1, min(6, len(sample_preds))):
    plot_image_mask((examples[i][0], sample_preds[i]), mask_rgb = mask_rgb)

### Predictions + post-process (masks)

In [None]:
thresh_upper = [0.7,0.7,0.7,0.7]
thresh_lower = [0.4,0.5,0.4,0.5]
min_area = [180, 260, 200, 500]

empty_mask = np.zeros(target_size, int) 

rles_dict = {}
# Fill in all the entries for images with no masks
for img_prefix in no_seg_ids:
    for c in range(N_CLASSES): 
            row_name = '{}.jpg_{}'.format(img_prefix, c+1)
            rles_dict[row_name]  = ''

# predict, postprocess, convert to rle batch by batch.          
start_preds = time()
for item in test_dataset_seg:
    mask_predictions = seg_ensemble.predict(item[0])
    
    for k, p in enumerate(mask_predictions):
        for ch in range(N_CLASSES):
            ch_probs = p[..., ch]
            ch_pred = (ch_probs > thresh_upper[ch])
            if ch_pred.sum() < min_area[ch]:
                ch_pred = empty_mask.copy()
            else:
                ch_pred = (ch_probs > thresh_lower[ch])
            mask_predictions[k,:,:,ch] = ch_pred
    
    img_ids = item[1]
    ids  = [l.decode('utf-8') for l in img_ids.numpy()]
    rles = [create_rles(p, IMAGE_SIZE) for p in mask_predictions]
    
    for i in range(len(ids)):        
        rle = rles[i]
        
        img_prefix = ids[i]
        for c in range(N_CLASSES): 
            row_name = '{}.jpg_{}'.format(img_prefix, c+1)
            rles_dict[row_name]  = rle[c].numpy().decode('utf-8')
            
print('Elapsed time (mask predictions) {}'.format(time_passed(start_preds)))

In [None]:
df = pd.DataFrame.from_dict(rles_dict, orient='index')
df.reset_index(level=0, inplace=True)
df.columns = ['ImageId_ClassId', 'EncodedPixels']

df.to_csv('submission.csv', index=False)

In [None]:
print('Notebook ended at: ', current_time_str())
print('Elapsed time (notebook) {}'.format(time_passed(start_notebook)))