Plotting the wrongly classified images in nice overview plots.

In [None]:
import os
import re
import sys
import json
from argparse import ArgumentParser

import numpy as np
import pandas as pd
from PIL import Image
import matplotlib.pyplot as plt
import torch
import torchvision

from show_failed_imgs import *
from download_imgs import *

In [None]:
csv_file = "../raw-data/humans/colour-experiment/colour-experiment_subject-01_session_1.csv"
img_location = "/mnt/qb/work/bethge/tklein16/brains_vs_dnns_out/wrong_imgs"
imagenet_path = "/imagenet/train/"

In [None]:
# read CSV
df = pd.read_csv(csv_file)

# get list of wrong images
wrongs = get_wrong_images(df)

In [None]:
# download all images where subjects gave wrong responses
download_wrong_images([img for _, _, img in wrongs], img_location, imagenet_path)

In [None]:
# make overview plot of wrongly labelled images

def make_overview_plot():
    ncols = 8
    nrows = int(np.ceil(len(wrongs) / ncols))

    print(f"Subject got {len(wrongs)} of {len(df)} images wrong.")

    scale = 2.5
    fig, ax = plt.subplots(nrows, ncols, sharex=True, sharey=True, figsize=(ncols*scale, nrows*scale))
    ax = ax.flatten()
    for idx, (response, label, img_path) in enumerate(wrongs):
        ax[idx].imshow(get_img(img_location, img_path))
        ax[idx].set_xlabel(f"{response} / {label}")

    plt.tight_layout()
    plt.savefig("confusions.png")
    plt.show()
    plt.close()

make_overview_plot()

In [None]:
def make_confusion_matrix():
    
    classes= [
        'airplane',
        'bear',
        'bicycle',
        'bird',
        'boat',
        'bottle',
        'car',
        'cat',
        'chair',
        'clock',
        'dog',
        'elephant',
        'keyboard',
        'knife',
        'oven',
        'truck',
        'na' # no response
    ]
    class_numbers = {cl:num for num, cl in enumerate(classes)}

    conf = np.zeros((len(classes), len(classes)))

    for response, label, _ in wrongs:
        conf[class_numbers[response], class_numbers[label]] += 1

    # cut off na-column
    conf = conf[:,:-1]
    
    plt.figure()
    plt.imshow(conf)
    plt.colorbar()
    plt.ylabel("People said...")
    plt.xticks(np.arange(0,len(classes)-1), classes[:-1], rotation=45) # ignore na
    plt.xlabel("When it was...")
    plt.yticks(np.arange(0,len(classes)), classes)
    plt.savefig("confusion_matrix.png")
    plt.show()
    plt.close()


make_confusion_matrix()

# Hypotheses

- cropping sometimes cuts off the relevant parts, did Robert crop?
- people have a preference for saying cat / dog instead of bear, because they are much more familiar with pets? (saying cat instead of bear seems to be big wild cats like lions, mostly)
- cats and dogs get confused because sometimes it's just hard to tell...
- oven and knife get confused a lot, seemingly because many pictures of ovens are just weird-looking (grills and stuff like that)
- frequently, the label is unclear, because both objects are in the image