In [1]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
from torchvision import transforms
import torch.nn as nn
from PIL import Image

class2num = {'broad leaves': 0, 'coniferous tree': 1, 'grass land': 2, 'hard ground': 3, 'leafless': 4, 'stalk': 5, 'stone': 6, 'trunk': 7, 'water': 8, 'wire': 9}
num2class = {0: 'broad leaves', 1: 'coniferous tree', 2: 'grass land', 3: 'hard ground', 4: 'leafless', 5: 'stalk', 6: 'stone', 7: 'trunk', 8: 'water', 9: 'wire'}
id2num = {
        '0' : class2num['broad leaves'],
        '1' : class2num['coniferous tree'],
        '2' : class2num['hard ground'],
        '3' : class2num['leafless'],
        '4' : class2num['stalk'],
        '5' : class2num['stone'],
        '6' : class2num['trunk'],
        '7' : class2num['grass land'],
        '8' : class2num['water'],
        '9' : class2num['wire'],
        }

In [2]:
wrongs = os.listdir('./pred_show_imags/wrongs')
df = pd.DataFrame(wrongs, columns=['original_fn'])
df['fn_without_exif'] = df['original_fn'].apply(lambda x: x.strip('.png'))
df['fn_id'] = df['fn_without_exif'].apply(lambda x: x.split('_')[0])
df['fn_model'] = df['fn_without_exif'].apply(lambda x: x.split('_')[1])
df['category'] = df['fn_id'].apply(lambda x: num2class[id2num[x[1]]])

In [4]:
# df.to_excel('wrongs.xls')

In [5]:
wrongs_id_list = df['fn_id'].unique().tolist()

In [6]:
s = 'gjghjg'
s.find('j')
s.startswith('g')

True

In [13]:
imgs = os.listdir('./samples/')
# imgs = [os.path.join('./samples', img) for img in imgs]

wrong_images = []
for fn in imgs:
    if fn.split('.')[0] in wrongs_id_list:
#         print(fn)
        wrong_images.append(os.path.join('./samples', fn))
len(wrong_images)

56

In [8]:
def pred_single_image(modelname, data_transforms, img_path):
    im = Image.open(img_path).convert("RGB")
    tsr = data_transforms(im).unsqueeze(0).to(device)
    outputs = models[modelname](tsr)
    probs = nn.functional.softmax(outputs, dim=1).squeeze()
    probs = probs.detach().numpy()
    pre_label_id = probs.argmax()
    true_label_id = imgName2num(img_path)
    true_label = num2class[true_label_id]
    pred_label = num2class[pre_label_id]
#     print(pre_label_id, true_label_id)
    is_correct = (pre_label_id==true_label_id)
#     print('Is correct:', is_correct, 'True Label:', true_label, 'Predicted Label:',pred_label)
    return is_correct, probs, true_label, pred_label

    
def imgName2num(image_path):
    baseName = os.path.basename(image_path)
    folder_id = baseName[1]
    id2num = {
        '0' : class2num['broad leaves'],
        '1' : class2num['coniferous tree'],
        '2' : class2num['hard ground'],
        '3' : class2num['leafless'],
        '4' : class2num['stalk'],
        '5' : class2num['stone'],
        '6' : class2num['trunk'],
        '7' : class2num['grass land'],
        '8' : class2num['water'],
        '9' : class2num['wire'],
    }
#     print(baseName, folder_id)
    return id2num[str(folder_id)]
    

In [9]:
alexnet = torch.load('alexnet-111_10_5_5.pkl', map_location=torch.device('cpu'))
vgg19 = torch.load('vgg19-111_8_5_5.pkl', map_location=torch.device('cpu'))
res50 = torch.load('resnet50-111_3_5_5.pkl', map_location=torch.device('cpu'))
res152 = torch.load('resnet152-111_5_5_5.pkl', map_location=torch.device('cpu'))

models = {
    'AlexNet'    :alexnet,
    'VGG19'     :vgg19,
    'ResNet50'  :res50,
    'ResNet152' :res152
}
resize = 224
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]
device = torch.device('cpu')

data_transforms = transforms.Compose([
    transforms.Resize((resize, resize)),
    transforms.ToTensor(),
    transforms.Normalize(mean, std)
    ])


model_name = 'ResNet152'

image_path = './samples/000019.jpg'
is_correct, probs, true_label, pred_label = pred_single_image(model_name, data_transforms, image_path)

In [16]:
columns = ['image', 'Ground Truth', 'AlexNet', 'VGG19', 'ResNet50', 'ResNet152']
misidentified_result = []
for i,img in enumerate(wrong_images):
    _,_,gt,pred_label_AlexNet = pred_single_image('AlexNet', data_transforms, img)
    _,_,_,pred_label_VGG19 = pred_single_image('VGG19', data_transforms, img)
    _,_,_,pred_label_ResNet50 = pred_single_image('ResNet50', data_transforms, img)
    _,_,_,pred_label_ResNet152 = pred_single_image('ResNet152', data_transforms, img)
    
    print(i, gt, pred_label_AlexNet, pred_label_VGG19, pred_label_ResNet50, pred_label_ResNet152)
    misidentified_result.append([img, gt, pred_label_AlexNet, pred_label_VGG19, pred_label_ResNet50, pred_label_ResNet152])

0 broad leaves coniferous tree broad leaves broad leaves broad leaves
1 broad leaves leafless leafless leafless leafless
2 broad leaves broad leaves stalk broad leaves broad leaves
3 broad leaves broad leaves stalk broad leaves broad leaves
4 coniferous tree broad leaves coniferous tree coniferous tree coniferous tree
5 coniferous tree broad leaves coniferous tree coniferous tree coniferous tree
6 coniferous tree wire wire wire wire
7 coniferous tree leafless coniferous tree coniferous tree coniferous tree
8 coniferous tree stalk coniferous tree coniferous tree coniferous tree
9 coniferous tree coniferous tree grass land coniferous tree coniferous tree
10 coniferous tree broad leaves coniferous tree coniferous tree coniferous tree
11 coniferous tree coniferous tree trunk trunk trunk
12 coniferous tree stone stone coniferous tree coniferous tree
13 coniferous tree grass land coniferous tree coniferous tree coniferous tree
14 hard ground stalk hard ground hard ground hard ground
15 hard 

In [24]:
df = pd.DataFrame(misidentified_result, columns=columns)
df.head()

Unnamed: 0,image,Ground Truth,AlexNet,VGG19,ResNet50,ResNet152
0,./samples\000009.jpg,broad leaves,coniferous tree,broad leaves,broad leaves,broad leaves
1,./samples\000019.jpg,broad leaves,leafless,leafless,leafless,leafless
2,./samples\000023.jpg,broad leaves,broad leaves,stalk,broad leaves,broad leaves
3,./samples\000090.jpg,broad leaves,broad leaves,stalk,broad leaves,broad leaves
4,./samples\010022.jpg,coniferous tree,broad leaves,coniferous tree,coniferous tree,coniferous tree


In [None]:
nodes = []
