This notebook runs evaluation on the validation set using a saved checkpoint. Predictions and ground truth annotation data can be saved to files and loaded later for in-depth analysis and visualization.

In [None]:
import os
from pathlib import Path

import brambox as bb
import matplotlib.pyplot as plt
import pandas as pd
import torch
from torchvision.utils import draw_bounding_boxes

from wheat.config import load_config
from wheat.data_module import WheatDataModule
from wheat.scripts import evaluate
from wheat import visualization as vis

pd.options.plotting.backend = 'plotly'

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
# move from 'notebooks' directory to top level directory
os.chdir('..')

## Running inference and evaluation
PyTorch Lightning makes it easy to run evaluation/validation on the val set using saved weights. We're going to be sneaky here and save off the predicted detections and the ground truth data so that we can do custom analysis later. (The PyTorch Lightning validation process doesn't have an obvious way to return the data, thus the sneakiness using EVs below.)

If you have already run evaluation and saved the results previously, you can just run the first cell and skip the rest of this section to load the previous results.

In [None]:
output_dir = Path('lightning_logs/kaggle_version_3')
checkpoint_path = output_dir/'epoch=9-step=6849.ckpt'

# load a configuration file
config = load_config('wheat/config/config.ini')

In [None]:
# set pytorch lightning flags here
pl_args_dict = dict()

In [None]:
# if this environment variable is set, detections and ground truth annotations
# will be saved to .csv files for easy loading and analysis later on
os.environ['CMD_WHEAT_OUTPUT_DIR'] = str(output_dir)
evaluate.evaluate(config, pl_args_dict, checkpoint_path)

## Loading saved detections and annotations
Since the evaluation outputs have been saved to disk in the output directory, we can load the detections and ground truth annotations at any time later without having to rerun inference. Annotation and detection data are saved in a format that is compatible with the brambox Python package. Annotation and detection dataframes have a similar format, except that detection data includes a 'confidence' column.

In [None]:
det_df = pd.read_csv(output_dir/'det.csv', index_col=0)
det_df.head()

In [None]:
anno_df = pd.read_csv(output_dir/'anno.csv', index_col=0)
anno_df.head()

In [None]:
def plot_pr_curve(det_df, anno_df, iou_threshold):
    """Plot a precision-recall curve using the specified IOU threshold."""
    df_pr = bb.stat.pr(det_df, anno_df, threshold=iou_threshold)
    df_pr = df_pr.append({'precision': 0, 'recall': df_pr['recall'].max(), 'confidence': 0}, ignore_index=True)
    ap = bb.stat.ap(df_pr)
    fig = df_pr.plot('recall', 'precision', title=f'AP at IOU {iou_threshold}: {ap:.3f}')
    fig.update_xaxes(range=[0, 1])
    fig.update_yaxes(range=[0, 1])
    return fig

In [None]:
# we can plot a pr curve for the entire validation dataset
# later, we plot a pr curve for a single image
plot_pr_curve(det_df, anno_df, iou_threshold=0.5)

In [None]:
# this function calculates the ap for each invidiual image
def get_per_image_ap_values(det_df, anno_df, iou_thresholds):
    images = det_df.image.unique()
    data_dict = {'image': images}
    for iou_threshold in iou_thresholds:
        image_ap_vals = []
        for image in images:
            pr_image = bb.stat.pr(
                det_df[det_df.image == image],
                anno_df[anno_df.image == image],
                threshold=iou_threshold)
            ap = bb.stat.ap(pr_image)
            image_ap_vals.append(ap)
        ap_str = 'ap' + str(round(100 * iou_threshold))
        data_dict[ap_str] = image_ap_vals
    return pd.DataFrame(data_dict)

In [None]:
image_ap_df = get_per_image_ap_values(det_df, anno_df, iou_thresholds=[0.5, 0.75])

In [None]:
# this code adds the number of ground truth annotations for each image as a new column
image_ap_df = image_ap_df.merge(
    anno_df['image'].value_counts().rename('num_annos'),
    how='left', left_on='image', right_index=True,
)
image_ap_df['num_annos'] = image_ap_df['num_annos'].fillna(0).astype(int)

In [None]:
# sort by ap75 and reset the index
image_ap_df = image_ap_df.sort_values('ap75').reset_index(drop=True)
image_ap_df.head()

In [None]:
image_ap_df.plot.scatter(x=image_ap_df.index, y=['ap50', 'ap75'], hover_data=['image', 'num_annos'])

We can use the information on which images had the best or worst AP values to plot the images with their ground truth bounding boxes and predicted detections.

In [None]:
def display_image_with_detections(dataset, image_index, det_df=None):
    image, labels = dataset[image_index]
    # plot ground truth bounding boxes in blue
    result = draw_bounding_boxes(
        vis.image_float_to_int_transform(image), labels['boxes'], colors='blue', width=5)
    # plot predicted bounding boxes in yellow
    if det_df is not None:
        det_df_filtered = det_df[det_df.image == image_index]
        boxes = det_df_filtered[['x_top_left', 'y_top_left', 'width', 'height']].values
        boxes[:, 2:] += boxes[:, :2]
        scores = det_df_filtered['confidence'].round(2).astype(str).values.tolist()
        result = draw_bounding_boxes(
            result, torch.tensor(boxes), labels=scores, colors='yellow', width=5, 
            font='DejaVuSans.ttf', font_size=20)
    vis.show(result)

In [None]:
# initialize the dataset
wheat_data_module = WheatDataModule(config)
wheat_data_module.setup(stage='validate')
val_dataset = wheat_data_module.val_dataset

In [None]:
# here's an image with AP50 and AP75 at zero
plt.rcParams['figure.figsize'] = [10, 10]
display_image_with_detections(val_dataset, 213, det_df[det_df.confidence>0])

In [None]:
# here's an image with very low AP75 but pretty decent AP50
display_image_with_detections(val_dataset, 64, det_df[det_df.confidence>0.5])

In [None]:
plot_pr_curve(det_df[det_df.image==64], anno_df[anno_df.image==64], iou_threshold=0.75)

In [None]:
plot_pr_curve(det_df[det_df.image==64], anno_df[anno_df.image==64], iou_threshold=0.5)