### Pseudo Labeling (Mask)

In [None]:
import os
import random
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib import patches, text, patheffects
from tqdm import tqdm
from glob import glob

import torch
from module.image import get_boundary_points
from module.label_map import get_label_map
from module.config import config_from_yaml
from module.model import MASK_MODEL, CLF_MODEL
from module.helper import get_batch_file_list, get_ts
from module.predict import batch_prediction_tta

In [None]:
CFG = config_from_yaml('config.yaml')

mask_model = MASK_MODEL(num_channel = CFG.DATA.N_CHANNEL, 
                        num_class   = CFG.DATA.N_CLASS,
                        min_size    = CFG.TRAIN.IMG_SIZE, 
                        max_size    = CFG.TRAIN.IMG_SIZE, 
                        image_mean  = CFG.DATA.N_CHANNEL * [CFG.TRAIN.IMG_MEAN],
                        image_std   = CFG.DATA.N_CHANNEL * [CFG.TRAIN.IMG_STD],
                        pretrained  = CFG.TRAIN.MASK.PRETRAINED)
mask_model.load_state_dict(torch.load(CFG.TEST.MASK_WEIGHT, map_location=CFG.TEST.DEVICE))
mask_model.to(CFG.TEST.DEVICE)
mask_model.eval()

clf_model = CLF_MODEL(name        = CFG.TRAIN.CLF.NAME,
                      num_channel = CFG.DATA.N_CHANNEL, 
                      num_class   = CFG.DATA.N_CLASS,
                      image_mean  = CFG.DATA.N_CHANNEL * [CFG.TRAIN.IMG_MEAN],
                      image_std   = CFG.DATA.N_CHANNEL * [CFG.TRAIN.IMG_STD],
                      smoothing   = CFG.TRAIN.CLF.SMOOTHING,
                      pretrained  = CFG.TRAIN.CLF.PRETRAINED)
clf_model.load_state_dict(torch.load(CFG.TEST.CLF_WEIGHT, map_location=CFG.TEST.DEVICE))
clf_model.to(CFG.TEST.DEVICE)
clf_model.eval()

print('Model Loaded')

In [None]:
!rm -rf df_pseudo_labeling
os.makedirs('df_pseudo_labeling', exist_ok=True)

df        = pd.read_csv(CFG.DATA.TRAIN_FILE)
img_names = df['img_name'].values.astype('object')
img_names = np.unique(img_names)
group     = df.groupby('img_name')

for img_name in tqdm(img_names):
    file_path = f'df_pseudo_labeling/{img_name.replace("jpg", "csv")}'
    if not os.path.exists(file_path):
        # Get Group
        df_tmp = group.get_group(img_name).copy()

        # Defect Number = 1
        if len(df_tmp) == 1:
            true_label = df_tmp['defect'].item()

            # If not Pass
            if true_label != 0:
                # BBOX
                x_center = df_tmp['x_center'].item()
                y_center = df_tmp['y_center'].item()
                width    = df_tmp['width'].item()
                height   = df_tmp['height'].item()

                x1   = np.clip(x_center - width  / 2, 0, 100)
                y1   = np.clip(y_center - height / 2, 0, 100)
                x2   = np.clip(x_center + width  / 2, 0, 100)
                y2   = np.clip(y_center + height / 2, 0, 100)
                bbox = [x1, y1, x2, y2]

                # Prediction
                path  = df_tmp['path'].item()
                batch = [path]            
                tta_images, tta_preds = batch_prediction_tta(mask_model, clf_model, batch, 
                                                             input_img_size  = CFG.TRAIN.IMG_SIZE, 
                                                             output_img_size = CFG.TEST.IMG_SIZE, 
                                                             ts              = get_ts(3), 
                                                             iou_thr         = CFG.TEST.IOU_THR,
                                                             skip_box_thr    = CFG.TEST.SKIP_BOX_THR, 
                                                             device          = CFG.TEST.DEVICE)
                pred   = tta_preds[0]
                pred_labels = pred['labels'].astype(int)
                pred_mask   = pred['masks']
                pred_scores = pred['t_scores']
                index  = np.where(pred_scores > CFG.TEST.CONF)[0]
                if len(index) != 0:
                    pred_labels = pred_labels[index]
                    pred_scores = pred_scores[index]

                # If Prediction = True                
                if (len(pred_labels)==1) and (pred_labels[0]==true_label):
                    mask = np.where(pred_mask > 0.5, 1, 0)
                    if np.sum(mask) > 25:
                        boundary_points = get_boundary_points(mask, bbox, return_bbox=False)
                        boundary_points = [str(item) for item in boundary_points]
                        boundary_points = ' '.join(boundary_points)

                        # Update Boundary Points
                        df_tmp['boundary_points'] = [boundary_points]

        # Append df_tmp
        df_tmp.to_csv(file_path, index=False, encoding='utf-8-sig')

In [None]:
files = np.sort(glob('df_pseudo_labeling/*.csv'))
df_out = pd.concat([pd.read_csv(file) for file in tqdm(files)])
df_out.to_csv(CFG.DATA.PSEUDO_LABEL_FILE, index=False, encoding='utf-8-sig')