# Configuration 

In [20]:
import sys 
import easydict
sys.path.append('../')
from utils.data import name_to_class

args = easydict.EasyDict(
    {
        'model_path': '../artifacts/pmg_baseline/defect_exp1',
        'arch': 'pmg', # 'pmg', 'resnet'
        'method': 'baseline', # 'uni', 'multi', 'baseline'
        'data_root_path': '../../../../dataset/99_ext_car_defect_recognition_v2',
        'train_class': 'outer_normal,outer_damage',
        'test_class': 'outer_normal,outer_damage',

        'num_workers': 4, 
        'batch_size': 128 ,
        'ce_label': False, 
        'show_img': False, # to show result imgs 
    }
)
args.train_class_name = [item for item in args.train_class.split(',')]
args.test_class_name = [item for item in args.test_class.split(',')]

args.train_class = [name_to_class[item] for item in args.train_class.split(',')]
args.test_class = [name_to_class[item] for item in args.test_class.split(',')]

# Dataloader 

In [2]:
import cv2
import glob 
from PIL import Image, ImageFile

from torch.utils.data.dataset import Dataset
from torch.utils.data import DataLoader
from torchvision import transforms

ImageFile.LOAD_TRUNCATED_IMAGES = True

normalize = transforms.Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225]
        )

test_transform = transforms.Compose([
            transforms.Resize((448, 448)),
            transforms.ToTensor(),
            normalize
        ])

class SofarExtDataset(Dataset):
    def __init__(self, root, transform=None):
        self.root = root
        self.data = glob.glob(root + '/*/*.jpg')
        
        self.transform = transform
        
    def __getitem__(self, index):
        path  = self.data[index]
        filename = path.split('/')[-1]
        # x = Image.open(path).convert("RGB")
        img = cv2.imread(path)[:,:,::-1]
        x = Image.fromarray(img)
        
        if self.transform is not None:
            x = self.transform(x)
        if 'normal' in path:
            label = 0 
        elif 'weak' in path or 'strong' in path:
            label = 1
        return x, label, path

    def __len__(self):
        return len(self.data)

# Model

In [21]:
import os 
import torch 
import torch.nn as nn 

from model.set_model import set_model

model = set_model(args)

state_dict = torch.load(os.path.join(args.model_path, 'last.pth'), map_location='cpu')
model.load_state_dict(state_dict)
print('pre-trained v2 model is loaded')

pre-trained v2 model is loaded


# Utils

In [4]:
from sklearn.metrics import *

def calculate_metrics(trues, preds):
    accuracy = accuracy_score(trues, preds)
    f1 = f1_score(trues, preds, average='macro')
    precision = precision_score(trues, preds, average='macro')
    recall = recall_score(trues, preds, average='macro')

    return accuracy, f1, precision, recall

# Test

In [23]:
from tqdm import tqdm 
import numpy as np 

@torch.no_grad() 
def test(dataloader, model):
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    
    model.to(device)
    model.eval()

    test_preds = []
    test_labels =[] 

    for (img, label, _) in tqdm(dataloader):
        img = img.to(device)

        # forward
        if args.arch == 'resnet':
            out = model(img)
            
        elif args.arch == 'pmg':
            out = model._forward(img)
            out = out[-1]
        
        # pred 
        if args.method == 'uni':
            pass 
        elif args.method == 'baseline':
            _, pred = torch.max(out, 1)
            pred = pred.view(-1).cpu().detach().numpy()
            pred = [args.train_class_name[v] for v in pred]
            pred = [name_to_class[v] for v in pred]
            test_preds.extend(pred)
        elif args.method == 'multi':
            pass
        test_labels.extend(label.view(-1).cpu().numpy().tolist())

    
    return test_labels, test_preds


def parse_result(test_labels, test_preds, method, task):
    if task == 'dirt':
        target = [0,2]
    elif task == 'defect':
        target = [0,1]
    task_idx = np.isin(test_labels, target)
    
    test_labels = np.array(test_labels)[task_idx]
    test_preds = np.array(test_preds)[task_idx]
    
    if method == 'uni':
        if task == 'defect':
            target = 1
        elif task == 'dirt':
            target = 2 

        _test_preds = []
        for v in test_preds:
            if v == target:
                _test_preds.append(target)
            else:
                _test_preds.append(0) 
        test_preds = _test_preds

    return test_labels, test_preds

In [24]:
cases = ['2021-02-04-the-day-snowed', '2021-02-05-the-day-after-snowed', 
         '2021-04-03-the-day-rained', '2021-04-04-the-day-after-rained']
for case in cases:
    path = os.path.join(args.data_root_path, case)
    
    # external validation dataloader 
    ext_dataset = SofarExtDataset(path, 
                                    transform=test_transform)
    print(len(ext_dataset))
    ext_loader = DataLoader(
                ext_dataset,
                batch_size=args.batch_size,
                shuffle=False,
                num_workers=args.num_workers,
                pin_memory=True,
                drop_last=False,
            ) 
    labels, preds = test(ext_loader, model)
   
   
    task_labels, task_preds = parse_result(labels, preds, args.method, task='defect') 
    acc, f1, prec, rec = calculate_metrics(task_labels, task_preds)
    print(case)
    print('[Results] Acc || Prec. || Rec. || F1')
    print('{:.3f} {:.3f} {:.3f} {:.3f}'.format(acc, prec, rec, f1))
    print(confusion_matrix(task_labels, task_preds))

1286


100%|██████████| 11/11 [00:40<00:00,  3.70s/it]


[TEST] Acc: 0.564 || precision 0.543 || recall 0.709 || f1 0.441
[[664 552]
 [  9  61]]
913


100%|██████████| 8/8 [00:29<00:00,  3.74s/it]


[TEST] Acc: 0.647 || precision 0.584 || recall 0.730 || f1 0.540
[[516 307]
 [ 15  75]]
1145


 11%|█         | 1/9 [00:13<01:49, 13.72s/it]Premature end of JPEG file
 33%|███▎      | 3/9 [00:16<00:24,  4.03s/it]Premature end of JPEG file
100%|██████████| 9/9 [00:37<00:00,  4.17s/it]


[TEST] Acc: 0.703 || precision 0.625 || recall 0.784 || f1 0.608
[[684 325]
 [ 15 121]]
1066


100%|██████████| 9/9 [00:31<00:00,  3.49s/it]

[TEST] Acc: 0.763 || precision 0.662 || recall 0.808 || f1 0.673
[[686 234]
 [ 19 127]]



