In [1]:
import sys

sys.path.insert(0, "..")
import itertools
import os

import cocpit.config as config
import cocpit

import ipywidgets
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn.functional as F
from torch import nn


%load_ext autoreload
%autoreload 2

In [2]:
plt_params = {
    "axes.labelsize": "xx-large",
    "axes.titlesize": "xx-large",
    "xtick.labelsize": "xx-large",
    "ytick.labelsize": "xx-large",
    "legend.title_fontsize": 12,
}
plt.rcParams["font.family"] = "serif"
plt.rcParams.update(plt_params)

### check classifications from specific model and validation dataloader

In [3]:
def get_dataloader(fold):

    data = torch.load(
            f"{config.VAL_LOADER_SAVE_DIR}e{config.MAX_EPOCHS}"\
            f"_val_loader20_bs{config.BATCH_SIZE}"\
            f"_k{str(fold)}_vgg16.pt"
    )

    val_loader = torch.utils.data.DataLoader(
            data,
            batch_size=int(config.BATCH_SIZE[0]),
            shuffle=False,
            num_workers=config.NUM_WORKERS,
            pin_memory=True,
    )
    return val_loader

In [4]:
def get_model(fold):
    model = torch.load(
            f"{config.MODEL_SAVE_DIR}e{config.MAX_EPOCHS}"
            f"_bs{config.BATCH_SIZE}"
            f"_k{str(fold)}_vgg16.pt"
        ).cuda()
    model.eval()
    return model

In [5]:
torch.cuda.empty_cache() 

In [6]:
'''
Use the above model and validation dataloader to check predictions.
Loops through all validation loader predictions but only saves incorrect predictions.
The incorrect predictions are loaded into a gui so that a user can decide 
whether the label was wrong or the model was right
The image is automatically moved upon choosing a class from the dropdown
'''

top_k_preds = 9  # the top k predictions will be displayed in bar chart

all_labels = []
all_paths = []
all_topk_probs = []
all_topk_classes = []
all_max_preds = []

try:

# loop over folds and model to get incorrect predictions
# includes a validation dataloader per fold to get all of the labeled images
# (i.e., not just 20% at once with a random shuffle)
    for fold in range(5):
        val_loader = get_dataloader(fold)
        model = get_model(fold)
        for batch_idx, ((imgs, labels, paths), index) in enumerate(val_loader):
            imgs = imgs.to(config.DEVICE)
            labels = labels.to(config.DEVICE)

            logits = model(imgs)
            # dimension 1 because taking the prediction
            # with the highest probability
            # from all classes across each index in the batch
            _, max_preds = torch.max(logits, dim = 1) 

            #convert back to lists from being on gpus
            max_preds = max_preds.cpu().tolist()
            labels = labels.cpu().tolist()        

            wrong_idx = [index for index, elem in enumerate(max_preds)
                                   if elem != labels[index]] #and 
                         #labels[index]==actual and elem == model_label]


            # make sure there is an incorrect prediction in this batch otherwise skip appending      
            if len(wrong_idx)!=0: 
                # get top k predictions for each index in the batch for bar chart
                predictions = F.softmax(logits, dim=1)
                topk = predictions.cpu().topk(top_k_preds)  # top k predictions
                probs, classes = [e.data.numpy().squeeze().tolist() for e in topk]

                # human label and image path
                all_labels.append([labels[i] for i in wrong_idx])
                all_paths.append([paths[i] for i in wrong_idx])

                # model top k predicted  probability and classes per image
                all_topk_probs.append([probs[i] for i in wrong_idx])
                all_topk_classes.append([classes[i] for i in wrong_idx])

                # top predicted class from model
                all_max_preds.append([max_preds[i] for i in wrong_idx])

except FileNotFoundError:
    print("There are files in the dataloader that have already moved and cannot be found.")
    print("This is likely due to running an old model that has not captured the updated file movement.")
    print("Try rerunning the model to update the validation dataloaders.  Stopping prematurely.")
    pass 

all_labels = np.asarray(list(itertools.chain(*all_labels)))
all_paths = np.asarray(list(itertools.chain(*all_paths)))
all_topk_probs = np.asarray(list(itertools.chain(*all_topk_probs)))
all_topk_classes = np.asarray(list(itertools.chain(*all_topk_classes)))
all_max_preds = np.asarray(list(itertools.chain(*all_max_preds)))

            
print('DONE FINDING INCORRECT PREDICTIONS!')
print(f'There are {len(all_labels)} images to check!')

DONE FINDING INCORRECT PREDICTIONS!
There are 6061 images to check!


In [None]:
'''if you stopped the previous cell early because there are
so many wrong predictions run this to capture labels up until you stopped waiting'''
all_labels = np.asarray(list(itertools.chain(*all_labels)))
all_paths = np.asarray(list(itertools.chain(*all_paths)))
all_topk_probs = np.asarray(list(itertools.chain(*all_topk_probs)))
all_topk_classes = np.asarray(list(itertools.chain(*all_topk_classes)))
all_max_preds = np.asarray(list(itertools.chain(*all_max_preds)))

In [7]:
# to only look between specific categories run this cell
label_list = dict(zip(config.CLASS_NAMES, np.arange(0,len(config.CLASS_NAMES))))

# change these two lines to focus on wrong predictions from a specific category
# or box within the confusion matrix
# makes labeling easier focusing on two at a time
human_label = label_list["budding"]
model_label = label_list["compact_irregs"]

idx_human = np.where(all_labels == human_label)
idx_model = np.where(all_max_preds == model_label)

# find indices where human labeled as one thing and model labeled as another
# according to human_label and model_label above
wrong_trunc = [] 
for i in idx_human[0]:
    if all_max_preds[i] == model_label:
        wrong_trunc.append(i)
cat1 = list(label_list.keys())[list(label_list.values()).index(human_label)]
cat2 = list(label_list.keys())[list(label_list.values()).index(model_label)]
f"{len(wrong_trunc)} wrong predictions between {cat1} and {cat2}"

'167 wrong predictions between budding and compact_irregs'

In [17]:
"""
code for ipywidget buttons called cocpit/gui.py

this cell displays a bar chart of predictions that the model outputs
a dropdown menu is available to move the image if you think the model got the label right 
when you choose an option from the dropdown list, the image will be moved to that category in the training dataset
if you don't want to move the image and the human labeled correctly, simply click "Next"
"""
gui = cocpit.gui.GUI(
    all_labels[wrong_trunc],
    all_paths[wrong_trunc],
    all_topk_probs[wrong_trunc],
    all_topk_classes[wrong_trunc],
    all_max_preds[wrong_trunc],
)
gui.make_buttons()
display(ipywidgets.HBox([gui.center, gui.menu, gui.forward]))

HBox(children=(Output(), Dropdown(description='Category:', index=1, options=('agg', 'budding', 'bullets', 'col…

In [9]:
gui.index

0