# Libraries

In [None]:
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader

import os, cv2
import numpy as np
import pandas as pd
from tqdm import tqdm
from time import time
from collections import defaultdict
from matplotlib import pyplot as plt

import sys
sys.path.append("..")
!ls ../input
from input.script.submission import *

# Setup

In [None]:
# Build model

model = torch.jit.load("../input/mixmocrjit/_ckpt_epoch_64_jit.pth")
model.eval();
model.cuda();

In [None]:
# Build dataloader
test_data_folder = "../input/severstal-steel-defect-detection/test_images"
dataset = TestSegDataset(
    imgdir=test_data_folder,
    color_channel="RGB",
    normalize=True,
    use_tta=False,
    ignore_background=True,
    input_size=(256,1600),
)
print(len(dataset))

dataloader = DataLoader(
    dataset=dataset,
    batch_size=14,
    shuffle=False,
    num_workers=2,
    pin_memory=True,
    collate_fn=dataset.collate_fn,
)
print(len(dataloader))

# Run inference

In [None]:
def fuse_mask_score(cls_score, obj_score,segm_scores, use_sigmoid=True, cls_thresh_prob=0.45):
    # fuse the prediction from both heads
    masks_scores = segm_scores*obj_score
    if not use_sigmoid:
        #P(Back_ground) = P(obj=Background) + P(obj=foreground)*P(segm=Background)
        masks_scores[:,0,:,:] +=(1-obj_score[:,0,:,:])

    # Suppress background samples
    background_proposal_cls = cls_score.squeeze() < cls_thresh_prob
    if use_sigmoid:
        masks_scores[background_proposal_cls] = 0.0
    else:
        masks_scores[background_proposal_cls,0,:,:]=1.0
        masks_scores[background_proposal_cls,1:,:,:]=0.0
    return masks_scores


In [None]:
# Aggregate results

predictions = []
with torch.no_grad():    
#     start_time = time()
    for batch in tqdm(dataloader, total=len(dataloader)):
            # Get data
            filenames = batch['fnames']
            imgs = batch['images']
            imgs = imgs.cuda()

            #Perform inference with TTA
            cls_score, obj_score, segm_scores = model(imgs)
            mask_score = fuse_mask_score(cls_score, obj_score,segm_scores, use_sigmoid=False, cls_thresh_prob=0.45)
            for dim_tta in [(2,),(3,),(2,3)]:
                imgs_tta = imgs.flip(dims=dim_tta)
                cls_score, obj_score, segm_scores = model(imgs_tta)
                mask_score_tta = fuse_mask_score(cls_score, obj_score,segm_scores, use_sigmoid=False, cls_thresh_prob=0.45)
                mask_score_tta = mask_score_tta.flip(dim_tta)
                mask_score +=mask_score_tta
            mask_score /=4

            # Convert to Binary Mask
            N,num_class,H,W = mask_score.shape
            masks = torch.argmax(mask_score,dim=1)
            masks = masks.squeeze()
            
            masks_binary = F.one_hot(masks,num_class)
            masks_binary = masks_binary.permute(0,3,1,2)[:,1:,:,:]

            # Convert to RLE
            masks_binary = masks_binary.cpu().numpy()
            for i in range(N):
                fname = filenames[i].split('/')[-1]
                mask_i = masks_binary[i]
#                 mask_i = remove_small_regions(mask_i.astype(bool), size=36)

                for cls_idx in range(num_class-1):
                    rle = mask2rle(mask_i[cls_idx].astype('uint8'))
                    name = f'{fname}_{cls_idx+1}'
                    predictions.append((name, rle))

In [None]:
print(len(predictions))

In [None]:
# Write output file

sample_submission_path = 'submission.csv'
df = pd.DataFrame(predictions, columns=['ImageId_ClassId', 'EncodedPixels'])
df.to_csv(sample_submission_path, index=False)

# Verify output

In [None]:
# Read csv file
df_results = pd.read_csv("submission.csv")
num_samples = len(df_results)
num_images = int(num_samples / 4)

print(df_results.shape)
print(df_results.head(18))

In [None]:
# Random check

idx = 4*np.random.choice(range(num_images))
filename = df_results.iloc[idx, 0][:-2]
print(idx, filename)

file = os.path.join(test_data_folder, filename)
image = cv2.imread(file)[...,::-1]
background = get_background(image)

masks = build_masks(idx, df_results)
predictions = [(cls_idx, (masks==cls_idx).astype('uint8')) for cls_idx in range(1,5)]
overlaid_image = draw_overlay(image, predictions)
overlaid_image[background==1, :] = (128,0,128)

plt.figure(figsize=(20,8))
plt.subplot(2,1,1); plt.imshow(image); plt.axis('off'); plt.title("image")
plt.subplot(2,1,2); plt.imshow(overlaid_image); plt.axis('off'); plt.title("result")
plt.show()