In [1]:
import torch
import pandas as pd
import numpy as np
import time
import os
import random
import matplotlib.pyplot as plt

from datasets.utils import get_dataloaders
from models.model import get_model, eval_model
from models.utils import set_random_seed

## Set parameters

In [2]:
val_dir = '/media/aiffel0042/SSD256/temp/AVIDNet/data/CXR/ori/test/'
model_dir = 'trained_models/'
out_dir = 'results'

In [3]:
batch_size = 16
random_seed = 0

In [4]:
set_random_seed(random_seed)

## Get data loaders

In [5]:
dataloaders = get_dataloaders(val_dir, batch_size)

Validation dataset size: 300
['covid-19', 'normal', 'pneumonia']
######### Validation Dataset #########
covid-19 size: 100
normal size: 100
pneumonia size: 100


## Evaluate the model

In [12]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
mtype = 'MOVR'

if mtype == 'OVR':
    covid_ft, normal_ft, pneumonia_ft, criterion = get_model(model_dir, device, mtype)
    preds_dict = eval_model(dataloaders=dataloaders,
                            covid_model=covid_ft,
                            normal_model=normal_ft,
                            pneumonia_model=pneumonia_ft,
                            criterion=criterion,
                            batch_size=batch_size,
                            device=device,
                            out_dir=out_dir,
                            mtype=mtype)
else: # Previous OVR
    covid_ft, nor_pneu_ft, criterion = get_model(model_dir, device, mtype)
    preds_dict = eval_model(dataloaders=dataloaders,
                            covid_model=covid_ft,
                            nor_pneu_model=nor_pneu_ft,
                            criterion=criterion,
                            batch_size=batch_size,
                            device=device,
                            out_dir=out_dir,
                            mtype=mtype)

Evaluation Result
----------
1 1 [0.859614372253418, 0.8777777552604675, 0.12222222238779068]
0 0 [0.9997851252555847, 0.0020214642863720655, 0.9979785084724426]
2 2 [0.9242223501205444, 0.003657880239188671, 0.996342122554779]
0 0 [0.9872565269470215, 0.06248454749584198, 0.9375154376029968]
0 0 [0.999962568283081, 0.00043891146196983755, 0.9995611310005188]
1 1 [0.3353807032108307, 0.9695004224777222, 0.03049965761601925]
0 0 [0.9995794892311096, 0.7782212495803833, 0.2217787653207779]
1 1 [0.052955713123083115, 0.9939749836921692, 0.006025043781846762]
1 0 [0.9053738117218018, 0.8796023726463318, 0.12039757519960403]
2 2 [0.7140381932258606, 0.0005989843630231917, 0.9994009733200073]
2 2 [0.7906617522239685, 0.28012824058532715, 0.7198717594146729]
0 0 [0.9907492995262146, 0.25417375564575195, 0.7458261847496033]
1 1 [0.016811557114124298, 0.9967058300971985, 0.0032942225225269794]
0 2 [0.9903312921524048, 0.5288403034210205, 0.4711596965789795]
1 1 [0.3451266586780548, 0.9813905954

## Verify prediction results

In [13]:
len(preds_dict[(0, 0)]), len(preds_dict[(1, 0)]), len(preds_dict[(2, 0)])

(89, 7, 4)

In [14]:
len(preds_dict[(0, 1)]), len(preds_dict[(1, 1)]), len(preds_dict[(2, 1)])

(4, 89, 7)

In [15]:
len(preds_dict[(0, 2)]), len(preds_dict[(1, 2)]), len(preds_dict[(2, 2)])

(11, 5, 84)

In [16]:
cases = [(i, j) for i in range(3) for j in range(3)]
preds_cols = ('covid', 'normal', 'pneumonia')
preds_stat_cols = ('covid-mean', 'covid-std', 'normal-mean', 'normal-std',
                   'pneumonia-mean', 'pneumonia-std', 'count')

preds_dfs = []
preds_stat_df = pd.DataFrame([], columns=preds_stat_cols)
for case in cases:
    preds_df = pd.DataFrame(preds_dict[case], columns=preds_cols)

    preds_stat = []
    for preds_col in preds_cols:
        preds_stat.extend((preds_df[preds_col].mean(), preds_df[preds_col].std()))
    preds_stat.append(preds_df[preds_col].shape[0])

    if np.nan not in preds_stat:
        preds_stat_df = preds_stat_df.append(pd.DataFrame([preds_stat], columns=preds_stat_cols, index=(str(case),)))
    
preds_stat_df

Unnamed: 0,covid-mean,covid-std,normal-mean,normal-std,pneumonia-mean,pneumonia-std,count
"(0, 0)",0.993615,0.009938,0.136975,0.250335,0.863025,0.250335,89
"(0, 1)",0.974884,0.016631,0.949424,0.041899,0.050576,0.041899,4
"(0, 2)",0.980647,0.015768,0.123333,0.275923,0.876667,0.275923,11
"(1, 0)",0.785527,0.128689,0.891772,0.090246,0.108228,0.090246,7
"(1, 1)",0.295585,0.285146,0.945086,0.089307,0.054914,0.089307,89
"(1, 2)",0.109765,0.079181,0.909751,0.049221,0.090249,0.049221,5
"(2, 0)",0.851098,0.104509,0.065582,0.092385,0.934418,0.092385,4
"(2, 1)",0.558903,0.272959,0.33678,0.12381,0.66322,0.12381,7
"(2, 2)",0.427008,0.27624,0.032291,0.079106,0.967709,0.079106,84
