In [1]:
import sys

sys.path.insert(0, "..")
import os
import cocpit.config as config
import cocpit

import ipywidgets as widgets
import cv2
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
from natsort import natsorted
from PIL import ImagePath
from torch.utils.data import Dataset
from torchvision import transforms

%load_ext autoreload
%autoreload 2

COMET INFO: Experiment is live on comet.ml https://www.comet.ml/vprzybylo/cocpit/6a535b2f39a742b5b3eb138c31afb435



In [2]:
torch.cuda.empty_cache()

In [3]:
plt_params = {
    "axes.labelsize": "xx-large",
    "axes.titlesize": "xx-large",
    "xtick.labelsize": "xx-large",
    "ytick.labelsize": "xx-large",
    "legend.title_fontsize": 12,
}
plt.rcParams["font.family"] = "serif"
plt.rcParams.update(plt_params)

### check classifications from specific model and validation dataloader

In [141]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = torch.load(
    f"/data/data/saved_models/no_mask/{config.TAG}/e15_bs64_1model(s).pt"
).cuda()
val_data = torch.load(
    f"/data/data/saved_val_loaders/no_mask/{config.TAG}/e15_bs64_1model(s).pt"
)

val_loader = torch.utils.data.DataLoader(
    val_data, batch_size=11, shuffle=True, num_workers=20, pin_memory=True
)

In [157]:
# use the above model and validation dataloader to check predictions
# using this cell to move incorrect predictions that I may have labeled wrong
import shutil
from torch import nn
import torch.nn.functional as F

model.eval()
criterion = nn.CrossEntropyLoss()  # Loss function
count = 0
top_k_preds = 9

all_labels = []
all_paths = []
all_topk_probs = []
all_topk_classes = []
all_max_preds = []

for batch_idx, ((imgs, labels, paths), index) in enumerate(val_loader):        
    
    imgs = imgs.to(config.DEVICE)
    labels = labels.to(config.DEVICE)

    logits = model(imgs)
    # dimension 1 because taking the prediction
    # with the highest probability
    # from all classes across each index in the batch
    _, max_preds = torch.max(logits, dim = 1)  

    #convert back to lists from being on gpus
    max_preds = max_preds.cpu().tolist()
    labels = labels.cpu().tolist()

    wrong_idx = [index for index, elem in enumerate(max_preds)
                           if elem != labels[index]]
    
    # make sure there is an incorrect prediction in this batch otherwise skip appending      
    if not wrong:   
        # get top k predictions for each index in the batch for bar chart
        predictions = F.softmax(logits, dim=1)
        topk = predictions.cpu().topk(top_k_preds)  # top k predictions
        probs, classes = [e.data.numpy().squeeze().tolist() for e in topk]

        # human label and image path
        all_labels.append([labels[i] for i in wrong_idx])
        all_paths.append([paths[i] for i in wrong_idx])
        
        # model top k predicted  probability and classes per image
        all_topk_probs.append([probs[i] for i in wrong_idx])
        all_topk_classes.append([classes[i] for i in wrong_idx])
        
        # top predicted class from model
        all_max_preds.append([max_preds[i] for i in wrong_idx])

all_labels = np.asarray(list(itertools.chain(*all_labels)))
all_paths = np.asarray(list(itertools.chain(*all_paths)))
all_topk_probs = np.asarray(list(itertools.chain(*all_topk_probs)))
all_topk_classes = np.asarray(list(itertools.chain(*all_topk_classes)))
all_max_preds = np.asarray(list(itertools.chain(*all_max_preds)))

            
# print('DONE FINDING INCORRECT PREDICTIONS!')
# print(f'There are {count} images to check!')

KeyboardInterrupt: 

In [168]:
from IPython.display import display, clear_output

class GUI():
    
    def __init__(self, all_labels, all_paths, all_topk_probs, all_topk_classes, all_max_preds):
        self.index = 0
        self.all_labels = all_labels
        self.all_paths = all_paths
        self.all_topk_probs = all_topk_probs
        self.all_topk_classes = all_topk_classes
        self.all_max_preds = all_max_preds        
        
    def make_buttons(self):
        self.center = ipywidgets.Output() # center image with predictions

        self.menu = widgets.Dropdown(
                    options= ["agg",
                        "budding",
                        "bullets",
                        "columns",
                        "compact_irregs",
                        "fragments",
                        "planar_polycrsytals",
                        "rimed",
                        "spheres"],
                   description='Category:')
        self.menu.observe(self.on_change)

        # create button that progresses through incorrect predictions
        self.forward = Button(description = 'Next')
        self.forward.on_click(self.on_button_next)
        
    def on_change(self):
        self.save_image()
        
    def on_button_next(self, b):
        self.bar_chart()
        self.index += 1 

    def bar_chart(self):
        
        self.label = self.all_labels[self.index]
        self.path = self.all_paths[self.index]
        self.topk_probs = self.all_topk_probs[self.index]
        self.topk_classes = self.all_topk_classes[self.index]
        self.max_pred = self.all_max_preds[self.index]

        #puts class names in order based on probabilty of prediction
        crystal_names = [
            config.CLASS_NAMES[e] for e in self.topk_classes]

        # add chart to ipywidgets.Output()
        with self.center:
            self.view_classifications(crystal_names)

    def view_classifications(self, crystal_names):
        
        clear_output()   # so that the next fig doesnt display below
        fig, (ax1, ax2) = plt.subplots(constrained_layout=True, figsize=(5, 7), ncols=1, nrows=2)
        image = Image.open(self.path)
        ax1.imshow(image)
        ax1.set_title(f'Human Labeled as: {config.CLASS_NAMES[self.label]}\n'\
                      f'Model Labeled as: {crystal_names[0]}')
        ax1.axis("off")

        y_pos = np.arange(len(self.topk_probs))
        ax2.barh(y_pos, self.topk_probs, align="center")
        ax2.set_yticks(y_pos)
        ax2.set_yticklabels(crystal_names)
        ax2.tick_params(axis="y", rotation=45)
        ax2.invert_yaxis()  # labels read top-to-bottom
        ax2.set_title("Class Probability")
        plt.show()

    def save_image(self):
        filename = self.path.split('/')[-1]
        #print(f'move {path} to {config.DATA_DIR}{crystal_names[0]}/{filename}')

        try:
            print(path, f"{config.DATA_DIR}{self.menu.value}/{filename}")
            #shutil.move(path, f"{config.DATA_DIR}{category}/{filename}")
        except FileNotFoundError:
            pass                    

In [169]:
from ipywidgets import AppLayout, Button
import ipywidgets
from PIL import Image


gui = GUI(all_labels, all_paths, all_topk_probs, all_topk_classes, all_max_preds)
gui.make_buttons()
display(ipywidgets.VBox([gui.menu, gui.forward, gui.center]))

# AppLayout(right_sidebar=right,
#           center=center,
#           footer= menu,
#           grid_gap='20px',
#           justify_items='center',
#           align_items='center')


HBox(children=(Dropdown(description='Category:', options=('agg', 'budding', 'bullets', 'columns', 'compact_irr…