In [1]:
from sklearn.metrics import confusion_matrix, balanced_accuracy_score, classification_report
from sklearn.metrics import accuracy_score, cohen_kappa_score, f1_score, top_k_accuracy_score
from torch import nn, optim
import numpy as np
import sys
np.set_printoptions(threshold=sys.maxsize)

def calculate_metrics(preds, labels, score):
    preds, labels = np.asarray(preds), np.asarray(labels)
    score = np.asarray(score)
    
    cm = confusion_matrix(labels, preds)
    b_acc = balanced_accuracy_score(labels, preds)
    acc = accuracy_score(labels, preds)
    kappa = cohen_kappa_score(labels, preds)
    f1 = f1_score(labels, preds, average = 'weighted')  
    top2 = top_k_accuracy_score(labels, score, k=2)
    top3 = top_k_accuracy_score(labels, score, k=3)
    
    print ("\nAccuracy: " + str(acc))
    print ("Balanced_Accuracy: " + str(b_acc))
    print ("Kappa: " + str(kappa))
    print ("F1: " + str(f1))
    print ("Top-2: ", str(top2))
    print ("Top-3: ", str(top3))
    print (cm)
    print(classification_report(labels, preds))
    
    return b_acc


import torch.nn.functional as F
class FocalLoss(nn.CrossEntropyLoss):
    ''' Focal loss for classification tasks on imbalanced datasets '''

    def __init__(self, gamma, alpha=None, ignore_index=-100, reduction='none'):
        super().__init__(weight=alpha, ignore_index=ignore_index, reduction='none')
        self.reduction = reduction
        self.gamma = gamma

    def forward(self, input_, target):
        cross_entropy = super().forward(input_, target)
        # Temporarily mask out ignore index to '0' for valid gather-indices input.
        # This won't contribute final loss as the cross_entropy contribution
        # for these would be zero.
        target = target * (target != self.ignore_index).long()
        input_prob = torch.gather(F.softmax(input_, 1), 1, target.unsqueeze(1))
        loss = torch.pow(1 - input_prob, self.gamma) * cross_entropy
        return torch.mean(loss) if self.reduction == 'mean' \
               else torch.sum(loss) if self.reduction == 'sum' \
               else loss

In [24]:
from torch.utils.data import DataLoader, WeightedRandomSampler, random_split
from torchvision import datasets, transforms, models
import torchvision.transforms.functional as TF
import torch.nn.functional as tf
import pandas as pd
import numpy as np
import torch, pickle

from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
from sklearn.preprocessing import normalize
import matplotlib.pyplot as plt
%matplotlib inline

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

mean = [0.485, 0.456, 0.406] 
std = [0.229, 0.224, 0.225]

data_transforms =  transforms.Compose([
    transforms.Resize(384),
    transforms.CenterCrop(384),
    transforms.ToTensor(),
    transforms.Normalize(mean, std),
    ])

modelpath = '/mnt/DADOS_PARIS_1/laranjeira/Dengue-Fachada/models/'
with open(f'{modelpath}folds.dict', 'rb') as fp:
    folds = pickle.load(fp)
    
model = models.efficientnet_v2_s(pretrained=True)
num_ftrs = model.classifier[1].in_features
model.classifier = nn.Sequential(
                            nn.Dropout(p=0.2, inplace=True),
                            nn.Linear(in_features=num_ftrs, out_features=5, bias=True))

class_weights = torch.FloatTensor([4.7, 1.0, 0.54, 0.5, 1.2]).to(device)
criterion = FocalLoss(gamma=2, alpha = class_weights, reduction='mean').to(device)


dstpath = '/mnt/DADOS_PARIS_1/laranjeira/Datasets/ICM_streetview'
dataset = datasets.ImageFolder(dstpath, transform=data_transforms)

mae = [[],[],[],[],[]]
results = {'filename': [], 'lat': [], 'long': [], 'label': [], 
           'pred0': [], 'pred1': [], 'pred2': [], 'pred3': [], 'pred4': []}
for fold, test_ids in enumerate(folds['test']):
    print('-'*15, f'FOLD{fold}')
    model.load_state_dict(torch.load(f'{modelpath}best_fold0{fold+1}.pt'))
    model.to(device)
    model.eval()
    
    print('    Testing...')
    losses, errors, correct = [], [], 0.
    val_pred, val_score, val_labels = [], [], []
    with torch.no_grad():
        for it, data in enumerate(dataset):
            inps, labs = data
            
            filename, label = dataset.imgs[it]
            if filename not in results['filename']: 
                results['filename'].append(filename)
                results['lat']     .append(float(filename.split('_')[-2]))
                results['long']    .append(float(filename.split('_')[-1][:-4]))
                results['label']   .append(label)
                
                for i in range(5): results[f'pred{i}'].append(-1)
                
            idx = results['filename'].index(filename)

            inps = inps[None].to(device)
            labs = torch.LongTensor([labs]).to(device)
            output = model(inps)
            _, preds = torch.max(output, 1)
            val_labels += labs.cpu().numpy().tolist()
            val_pred += output.max(1)[1].cpu().numpy().tolist()
            val_score += output.detach().cpu().numpy().tolist()
            correct += torch.sum(preds == labs.data)
            errors += torch.abs(preds - labs.data)

            results[f'pred{fold}'][idx] = preds[0].item()
            
            loss = criterion(output, labs)
            losses.append(loss.data.item()) 
    
    
    for clas in range(5):
        idx = [k for k, v in enumerate(val_labels) if v == clas]
        mae[clas].extend(np.abs(np.array(val_pred)[idx] - np.array(val_labels)[idx]))
        print(f'MAE (ICM={clas+1})', np.mean(mae[clas]), np.std(mae[clas]))
    
    print(f'Loss: {np.mean(losses)}, Acc: {correct/len(dataset)}', 'MAE: ')
    calculate_metrics(val_pred, val_labels, val_score)
    
    ########## Figures ###############
#     plt.style.use('default')
#     plt.rcParams.update({'font.size': 14})
#     cm = confusion_matrix(val_labels, val_pred)
#     cm_val = normalize(cm, axis=1, norm='l1')

#     disp = ConfusionMatrixDisplay(cm_val)
#     disp.plot(cmap='Blues')

#     plt.title('Normalized confusion matrix')
#     plt.xticks(np.arange(5), np.arange(1,6))
#     plt.yticks(np.arange(5), np.arange(1,6))
#     plt.ylabel('True ICM')
#     plt.ylabel('Predicted ICM')
#     plt.savefig(f'cm_fold0{fold+1}.pdf', bbox_inches='tight', format='pdf', dpi=150)
#     plt.show()
    
#     plt.style.use('ggplot')
#     plt.rcParams.update({'font.size': 14})
#     for k in range(5):
#         idx = [i for i, v in enumerate(val_labels) if v == k]
#         scores = tf.softmax(torch.Tensor(val_score))[idx]

#         df = pd.DataFrame(scores.numpy())
#         fig, ax = plt.subplots(figsize=(7, 4))
#         ax.errorbar(np.arange(5), df.mean(), yerr=df.std(), 
#                 linestyle='dotted', marker='o', markersize=8)

#         plt.title(f'ICMNet activations when ICM={k+1}')
#         plt.savefig(f'activations_fold0{fold+1}_ICM0{k+1}.pdf', bbox_inches='tight', format='pdf', dpi=150)
#         plt.show()



  f"The parameter '{pretrained_param}' is deprecated since 0.13 and may be removed in the future, "


--------------- FOLD0
    Testing...
MAE (ICM=1) 1.08 0.844748483277715
MAE (ICM=2) 0.5625 0.7358996503970806
MAE (ICM=3) 0.8521560574948666 0.5463247301367518
MAE (ICM=4) 0.6666666666666666 0.8312473893025002
MAE (ICM=5) 1.3125 0.8305683295190107
Loss: 0.9902392926914126, Acc: 0.40463724732398987 MAE: 

Accuracy: 0.40463724756918473
Balanced_Accuracy: 0.3265293696720801
Kappa: 0.18018741839442598
F1: 0.40432876177558397
Top-2:  0.712789827973074
Top-3:  0.9102468212415856
[[  9  34   2   4   1]
 [ 70 309  71  78   0]
 [ 28 164 114 167  14]
 [  9  18  50 100  15]
 [  1   7  17  46   9]]
              precision    recall  f1-score   support

           0       0.08      0.18      0.11        50
           1       0.58      0.59      0.58       528
           2       0.45      0.23      0.31       487
           3       0.25      0.52      0.34       192
           4       0.23      0.11      0.15        80

    accuracy                           0.40      1337
   macro avg       0.32   

In [26]:
df = pd.DataFrame(results)
df.to_csv('inferences.csv', index=False)

In [28]:
df.iloc[65:70]

Unnamed: 0,filename,lat,long,label,pred0,pred1,pred2,pred3,pred4
65,/mnt/DADOS_PARIS_1/laranjeira/Datasets/ICM_str...,-22.990885,-47.12183,1,1,1,1,1,1
66,/mnt/DADOS_PARIS_1/laranjeira/Datasets/ICM_str...,-22.969423,-47.145522,1,3,3,2,2,3
67,/mnt/DADOS_PARIS_1/laranjeira/Datasets/ICM_str...,-22.992843,-47.120805,1,3,2,3,3,3
68,/mnt/DADOS_PARIS_1/laranjeira/Datasets/ICM_str...,-22.993735,-47.119886,1,1,1,1,3,1
69,/mnt/DADOS_PARIS_1/laranjeira/Datasets/ICM_str...,-22.945076,-47.099512,1,1,3,2,3,1
