# Configuration 

In [7]:
import sys 
import easydict
sys.path.append('../')
from utils.data import name_to_class

args = easydict.EasyDict(
    {
        'model_path': '../artifacts/pmg_uni/exp1',
        'arch': 'pmg', # 'pmg', 'resnet'
        'method': 'uni', # 'uni', 'multi', 'baseline'
        'dataset': 'sofar_v3',
        'data_root_path': '../../../../dataset/',
        'train_class': 'outer_normal,outer_damage,outer_dirt,outer_wash,inner_wash,inner_dashboard,inner_cupholder,inner_glovebox,inner_washer_fluid,inner_rear_seat,inner_sheet_dirt', 
        'test_class': 'outer_normal,outer_damage,outer_dirt,outer_wash,inner_wash,inner_dashboard,inner_cupholder,inner_glovebox,inner_washer_fluid,inner_rear_seat,inner_sheet_dirt', 

        'num_workers': 4, 
        'batch_size': 128 ,
        'ce_label': False, 
        'show_img': False, # to show result imgs 
    }
)
args.train_class_name = [item for item in args.train_class.split(',')]
args.test_class_name = [item for item in args.test_class.split(',')]

args.train_class = [name_to_class[item] for item in args.train_class.split(',')]
args.test_class = [name_to_class[item] for item in args.test_class.split(',')]

# Dataloader 

In [2]:
from utils.data import create_dataloader

_, test_loader = create_dataloader(args)

# Model

In [3]:
import os 
import torch 
import torch.nn as nn 

from model.set_model import set_model

model = set_model(args)

state_dict = torch.load(os.path.join(args.model_path, 'last.pth'), map_location='cpu')
model.load_state_dict(state_dict)
print('pre-trained v2 model is loaded')

pre-trained v2 model is loaded


# Utils

In [4]:
from sklearn.metrics import *

def calculate_metrics(trues, preds):
    accuracy = accuracy_score(trues, preds)
    f1 = f1_score(trues, preds, average='macro')
    precision = precision_score(trues, preds, average='macro')
    recall = recall_score(trues, preds, average='macro')

    return accuracy, f1, precision, recall

# Test

In [5]:
from tqdm import tqdm 
import numpy as np 

@torch.no_grad() 
def test(dataloader, model):
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    
    model.to(device)
    model.eval()

    test_preds = []
    test_labels =[] 

    for (img, label) in tqdm(dataloader):
        img = img.to(device)

        # forward
        if args.arch == 'resnet':
            out = model(img)
            
        elif args.arch == 'pmg':
            out = model._forward(img)
            out = out[-1]
        
        # pred 
        if args.method == 'uni':
            _, pred = torch.max(out, 1)
            test_preds.extend(pred.view(-1).cpu().detach().numpy().tolist())
        
        elif args.method == 'baseline':
            _, pred = torch.max(out, 1)
            pred = pred.view(-1).cpu().detach().numpy()
            pred = [args.train_class_name[v] for v in pred]
            pred = [name_to_class[v] for v in pred]
            test_preds.extend(pred)
        
        elif args.method == 'multi':
            pass
        test_labels.extend(label.view(-1).cpu().numpy().tolist())
    
    return test_labels, test_preds

def parse_result(test_labels, test_preds, method, task):
    if task == 'dirt':
        target = [0,2]
    elif task == 'defect':
        target = [0,1]
    task_idx = np.isin(test_labels, target)
    
    test_labels = np.array(test_labels)[task_idx]
    test_preds = np.array(test_preds)[task_idx]
    
    if method == 'uni':
        if task == 'defect':
            target = 1
        elif task == 'dirt':
            target = 2 

        _test_preds = []
        for v in test_preds:
            if v == target:
                _test_preds.append(target)
            else:
                _test_preds.append(0) 
        test_preds = _test_preds

    return test_labels, test_preds

In [8]:
labels, preds = test(test_loader, model)

100%|██████████| 8/8 [00:30<00:00,  3.82s/it]

[TEST] Acc: 0.751 || precision 0.747 || recall 0.745 || f1 0.746





In [None]:
# dirt 
parsed_labels, parsed_preds = parse_result(labels, preds, args.method, 'dirt')
acc, f1, prec, rec = calculate_metrics(parsed_labels, parsed_preds)
print('[Dirt Results] Acc || Prec. || Rec. || F1')
print('{:.3f} {:.3f} {:.3f} {:.3f}'.format(acc, prec, rec, f1))
print(confusion_matrix(labels, preds))

In [None]:
# defect
parsed_labels, parsed_preds = parse_result(labels, preds, args.method, 'defect')
acc, f1, prec, rec = calculate_metrics(parsed_labels, parsed_preds)
print('[Defect Results] Acc || Prec. || Rec. || F1')
print('{:.3f} {:.3f} {:.3f} {:.3f}'.format(acc, prec, rec, f1))
print(confusion_matrix(labels, preds))