In [2]:
# %load evaluate.py
from __future__ import division

import os
import numpy as np
from sklearn.metrics import classification_report
from sklearn.metrics import f1_score

import torch
from torch.nn import *
import torch.nn.functional as F
from torch import optim
import torchvision.models as models
from torch.utils.data import DataLoader

from models.model import ImgSpecModel
from dataloaderraw import ImageSpecDataset

# combined model
imgspec_model = ImgSpecModel()

# change model path to the path of saved model you want to verify
model_path = 'save_50_epoch_f1_score\\model_best_acc_33_0.586962471491.pth'
imgspec_model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))

test_imgspec_dataset = ImageSpecDataset(root_dir='small_dataset_objects', split='test')
test_dataloader = DataLoader(test_imgspec_dataset, batch_size=10, shuffle=True, num_workers=4)

class_weight = torch.from_numpy(np.array([1.0, 3.0])).float()
cce_loss = CrossEntropyLoss(weight=class_weight)

# test model performance
imgspec_model.eval()
correct_class, incorrect_class = 0, 0
all_preds, all_org = [], []
for i_batch, batch in enumerate(test_dataloader):
    img_feats = batch['img_feats'].float()
    spec_imgs = batch['spec'].unsqueeze(1).float()
    scores = batch['score'].float().numpy()
    preds = F.softmax(imgspec_model(img_feats, spec_imgs).detach(), dim=1)
    _, preds = torch.max(preds, 1)
    preds = preds.cpu().numpy()
    #preds = np.where(preds > 0.5, 1, 0)
    for org, pred in zip(scores, preds):
        correct_class += int(org == pred)
        incorrect_class += int(org != pred)
        all_org.append(org)
        all_preds.append(pred)
        
print('\n\n')
print(classification_report(all_org, all_preds))
acc = f1_score(all_org, all_preds, average='macro')
print('TestF1: {}'.format(acc))
print('\n\n')





              precision    recall  f1-score   support

         0.0       0.78      0.80      0.79       952
         1.0       0.41      0.38      0.39       352

    accuracy                           0.68      1304
   macro avg       0.59      0.59      0.59      1304
weighted avg       0.68      0.68      0.68      1304

TestF1: 0.5888260035379691



