An ipywidget interface to move incorrect predictions within the training dataset should they be labeled wrong

In [1]:
import sys

sys.path.insert(0, "..")
import itertools

import ipywidgets
import matplotlib.pyplot as plt
import numpy as np
import torch

import cocpit
import cocpit.config as config
import cocpit.gui as gui

%load_ext autoreload
%autoreload 2

In [2]:
plt_params = {
    "axes.labelsize": "xx-large",
    "axes.titlesize": "xx-large",
    "xtick.labelsize": "xx-large",
    "ytick.labelsize": "xx-large",
    "legend.title_fontsize": 12,
}
plt.rcParams["font.family"] = "serif"
plt.rcParams.update(plt_params)

### check classifications from specific model and validation dataloader

In [7]:
"""
In gui.py, uses a model and validation dataloader to check predictions.
Loops through all validation loader predictions but only saves incorrect predictions.
The incorrect predictions are loaded into a gui so that a user can decide 
whether the label is wrong (i.e., the model is right)
The image is automatically moved upon choosing a class from the dropdown menu.
"""

# loop over folds and model to get incorrect predictions
# includes a validation dataloader per fold to get all of the labeled images
# (i.e., not just 20% at once with a random shuffle)
top_k_preds = 8  # the top k predictions will be displayed in bar chart
folds = 1  # if kfold cross validation was not used, change to 0

all_labels = []
all_paths = []
all_topk_probs = []
all_topk_classes = []
all_max_preds = []
for fold in range(folds):
    val_data = torch.load(
        "/data/data/saved_val_loaders/no_mask/v1.4.0/e[30]_val_loader20_bs[64]_k0_vgg16.pt"
    )
    val_loader = cocpit.data_loaders.create_loader(val_data, batch_size=100, sampler=None)

    model = torch.load(
        f"{config.MODEL_SAVE_DIR}e{config.MAX_EPOCHS}"
        f"_bs{config.BATCH_SIZE}"
        f"_k{str(fold)}_vgg16.pt"
    ).cuda()
    model.eval()
    
    for batch_idx, ((imgs, labels, paths), index) in enumerate(val_loader):
        all_labels.append(np.array(labels))
        all_paths.append(paths)
        b = cocpit.predictions.BatchPredictions(imgs, model)
        b.find_max_preds()
        b.top_k_preds(top_k_preds)
        all_topk_probs.append(b.probs)
        all_topk_classes.append(b.classes)
        all_max_preds.append(b.max_preds)
all_b = cocpit.predictions.LoaderPredictions(all_labels, all_paths, all_topk_probs, all_topk_classes, all_max_preds)  # combines predictions from all batches
all_b.concatenate_loader_vars()
all_b.find_wrong_indices()

print("DONE FINDING INCORRECT PREDICTIONS!")
print(f"There are {len(all_b.all_labels)} total images in the dataloader.")
print(f"There are {len(all_b.all_labels[all_b.wrong_idx])} wrong predictions to check in the dataloader!")


DONE FINDING INCORRECT PREDICTIONS!
There are 4308 total images in the dataloader.
There are 209 wrong predictions to check in the dataloader!


In [10]:
# to only look between specific categories run this cell
label_list = dict(zip(config.CLASS_NAMES, np.arange(0, len(config.CLASS_NAMES))))
print(label_list)
# change these two lines to focus on wrong predictions from a specific category
# or box within the confusion matrix
# makes labeling easier focusing on two at a time
human_label = label_list["agg"]
model_label = label_list["budding"]

all_b.hone_incorrect_predictions(label_list, human_label, model_label)

{'agg': 0, 'budding': 1, 'bullet': 2, 'column': 3, 'compact_irreg': 4, 'fragment': 5, 'planar_polycrystal': 6, 'rimed': 7, 'sphere': 8}
3 wrong predictions between agg and budding


In [18]:
"""
code for ipywidget buttons called cocpit/gui.py

this cell displays a bar chart of predictions that the model outputs
a dropdown menu is available to move the image if you think the model got the label right 
when you choose an option from the dropdown list, the image will be moved to that category in the training dataset
if you don't want to move the image and the human labeled correctly, simply click "Next"
"""
display_gui = gui.GUI(b=all_b)
display_gui.make_buttons()
display(ipywidgets.HBox([display_gui.center, display_gui.menu, display_gui.forward]))

HBox(children=(Output(), Dropdown(description='Category:', options=('agg', 'budding', 'bullet', 'column', 'com…

In [19]:
import ipywidgets as widgets
from IPython.display import display, clear_output
vardict = ["var1","var2"]
select_variable = widgets.Dropdown(
    options=vardict,
    value=vardict[0],
    description='Select variable:',
    disabled=False,
    button_style=''
)
def get_and_plot(b):
    clear_output
    print(select_variable.value)

display(select_variable)
select_variable.observe(get_and_plot, names='value')

Dropdown(description='Select variable:', options=('var1', 'var2'), value='var1')