In [9]:
import argparse
import cv2
import numpy as np
import torch


from backbones import get_model
from expression import *
from expression.datasets import RAFDBDataset
from utils.utils_config import get_config
from expression.models import SwinTransFER
from torchvision.transforms import Compose, Normalize, ToTensor

total_predicts = []
total_logits = []
swin = get_model('swin_t')
net = SwinTransFER(swin=swin, swin_num_features=768, num_classes=7, cam=True)

dict_checkpoint = torch.load('results/checkpoint_step_59999_gpu_0.pt')
net.load_state_dict(dict_checkpoint["state_dict_model"])


dataset_val = RAFDBDataset(choose="test",
                           data_path="dataset/RAF",
                         label_path="dataset/list_patition_label.txt",
                         img_size=112)
test_loader = torch.utils.data.DataLoader(dataset_val, batch_size=128,
                shuffle=False, num_workers=2, pin_memory=True, drop_last=False)


with torch.no_grad():
    net.cuda()
    net.eval()
    
    bingo_cnt = 0
    sample_cnt = 0
    for idx, (images, target, _) in enumerate(test_loader):
        img = images.cuda(non_blocking=True)
        target = target.cuda(non_blocking=True)

        outputs, _ = net(img)

        _, predicts = torch.max(outputs, 1)
        total_predicts.append(predicts.cpu().numpy())
        total_logits.append(outputs.cpu().numpy())
        
        correct_num = torch.eq(predicts, target)
        bingo_cnt += correct_num.sum().cpu()
        sample_cnt += outputs.size(0)
        
    acc = bingo_cnt.float() / float(sample_cnt)
    acc = np.around(acc.numpy(), 4)
    print("Validation accuracy:%.4f. " % (acc))

Validation accuracy:0.9270. 


In [10]:
new_total_predicts = np.concatenate(total_predicts)
total_labels = test_loader.dataset.labels

mean_acc = 0
for j in range(7):
    class_num = j
    class_idx = []
    for i in range(len(total_labels)):
        if total_labels[i] == class_num:
            class_idx.append(i)
    tempt = (new_total_predicts[np.array(class_idx)]==total_labels[np.array(class_idx)]).sum()/len(class_idx)
    mean_acc += tempt
    print(tempt)
print(mean_acc/7, 'mean')

0.9270516717325228
0.7432432432432432
0.85625
0.9687763713080169
0.899581589958159
0.8641975308641975
0.925
0.883442915300877 mean
