# Draftsheet 

### Get the validation errors of a classification learner, plot them in different ways and store the resulting images to be validated by an expert

In [None]:
from PIL import Image  
import PIL
from fastcore.all import *

In [None]:
def buffer_plot_and_get(fig):
    buf = io.BytesIO()
    fig.savefig(buf)
    buf.seek(0)
    return PIL.Image.open(buf)

In [None]:
def get_concat_h_multi_resize(im_list, resample=PIL.Image.BICUBIC):
    min_height = min(im.height for im in im_list)
    im_list_resize = [im.resize((int(im.width * min_height / im.height), min_height),resample=resample)
                      for im in im_list]
    total_width = sum(im.width for im in im_list_resize)
    dst = PIL.Image.new('RGB', (total_width, min_height))
    pos_x = 0
    for im in im_list_resize:
        dst.paste(im, (pos_x, 0))
        pos_x += im.width
    return dst

In [None]:
def get_multiview_at(ds, idx, ylim):
    "TODO: Move this to TensorMotion?"
    fig1 = show_at(ds, idx, return_fig=True, ylim=ylim)
    fig2 = show_at(ds, idx, return_fig=True)
    fig3 = show_at(ds, idx, mode='stacked', return_fig=True)
    pils = [buffer_plot_and_get(fig) for fig in [fig1, fig2, fig3]]
    return get_concat_h_multi_resize(pils)

In [None]:
import pandas as pd

In [None]:
def generate_classification_error_report(learn, folder, ylim):
    "Warning: Call this function with the %%capture magic, otherwise you will be \
    prompted with all the outputs from the calls to `show`. Requires xlrd to read xlsx \
    files."
    probs, targets, preds = learn.get_preds(with_decoded=True)
    error_valid_idxs = torch.where(targets!=preds, tensor(1), tensor(0)).nonzero().squeeze()
    error_ds_idxs = tensor(learn.dls.dataset.splits[1])[error_valid_idxs]
    for valid_idx, ds_idx in zip(error_valid_idxs, error_ds_idxs):
        merged_pil = get_multiview_at(learn.dls.valid.dataset, valid_idx, ylim)
        folder = Path(folder)
        merged_pil.save(f'{folder}/{ds_idx}.png')
    # Generate a csv file with the item indices and the targets
    preds_decoded = [str(ds.valid.tfms[1].decode(x)) for x in preds[error_valid_idxs]]
    targs_decoded = [str(ds.valid.tfms[1].decode(x)) for x in targets[error_valid_idxs]]
    rlbl_df = pd.DataFrame(list(zip(error_ds_idxs.numpy(), preds_decoded, targs_decoded)), 
                       columns=['Item #', 'ML classification', 'FLI-based classification'])
    rlbl_path = Path(f'{folder}/relabelling.xlsx')
    if rlbl_path.exists():
        # Load old relabelling and merge it with the new items
        rlbl_df_old = pd.read_excel(rlbl_path)
        rlbl_df = pd.concat([rlbl_df, rlbl_df_old]).drop_duplicates()
    rlbl_df = rlbl_df.sort_values(by=['Item #'])
    rlbl_df.to_excel(rlbl_path, index=False)