In [1]:
import numpy as np
import pdb
import matplotlib.pyplot as plt
%matplotlib inline
plt.switch_backend('agg')
import itertools

In [2]:
def confusion_matrix(y_pred, y_real, normalize=None):
    """Compute confusion matrix.

    Args:
        y_pred (list[int] | np.ndarray[int]): Prediction labels.
        y_real (list[int] | np.ndarray[int]): Ground truth labels.
        normalize (str | None): Normalizes confusion matrix over the true
            (rows), predicted (columns) conditions or all the population.
            If None, confusion matrix will not be normalized. Options are
            "true", "pred", "all", None. Default: None.

    Returns:
        np.ndarray: Confusion matrix.
    """
    if normalize not in ['true', 'pred', 'all', None]:
        raise ValueError("normalize must be one of {'true', 'pred', "
                         "'all', None}")

    if isinstance(y_pred, list):
        y_pred = np.array(y_pred)
    if not isinstance(y_pred, np.ndarray):
        raise TypeError(
            f'y_pred must be list or np.ndarray, but got {type(y_pred)}')
    if not y_pred.dtype == np.int64:
        raise TypeError(
            f'y_pred dtype must be np.int64, but got {y_pred.dtype}')

    if isinstance(y_real, list):
        y_real = np.array(y_real)
    if not isinstance(y_real, np.ndarray):
        raise TypeError(
            f'y_real must be list or np.ndarray, but got {type(y_real)}')
    if not y_real.dtype == np.int64:
        raise TypeError(
            f'y_real dtype must be np.int64, but got {y_real.dtype}')

    label_set = np.unique(np.concatenate((y_pred, y_real)))
    num_labels = len(label_set)
    label_map = {label: i for i, label in enumerate(label_set)}
    confusion_mat = np.zeros((num_labels, num_labels), dtype=np.int64)
    for rlabel, plabel in zip(y_real, y_pred):
        index_real = label_map[rlabel]
        index_pred = label_map[plabel]
        confusion_mat[index_real][index_pred] += 1

    with np.errstate(all='ignore'):
        if normalize == 'true':
            confusion_mat = (
                confusion_mat / confusion_mat.sum(axis=1, keepdims=True))
        elif normalize == 'pred':
            confusion_mat = (
                confusion_mat / confusion_mat.sum(axis=0, keepdims=True))
        elif normalize == 'all':
            confusion_mat = (confusion_mat / confusion_mat.sum())
        confusion_mat = np.nan_to_num(confusion_mat)

    return confusion_mat

In [3]:
def plot_confusion_matrix(cm, class_names, figsize=(10,10), save_path=None, normalize=False, colorbar=False):
    """
    Returns a matplotlib figure containing the plotted confusion matrix.

    Args:
        cm (array, shape = [n, n]): a confusion matrix of integer classes
        class_names (array, shape = [n]): String names of the integer classes
        figsize: figure pannel size
        save_path: the path to save the confusion matrix as an image
        normalize: normalized or not
    Returns:
        figure: matplotlib figure of the confusion matrix
    """
    figure = plt.figure(figsize=figsize)
    plt.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
#     plt.title("Confusion matrix")
    if colorbar:
        plt.colorbar()
    tick_marks = np.arange(len(class_names))
    plt.xticks(tick_marks, class_names, rotation=90)
    plt.yticks(tick_marks, class_names)

    # Use white text if squares are dark; otherwise black.
    threshold = cm.max() / 2.
    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
        color = "white" if cm[i, j] > threshold else "black"
        if normalize:
            if cm[i,j] == 0:
                number = '0'
            else:
                number = '{:0.1f}'.format(100*cm[i, j])
        else:
            number = cm[i,j]        
        plt.text(j, i, number, horizontalalignment="center", color=color)
    
    plt.ylabel('True label')
    plt.xlabel('Predicted label')
    plt.tight_layout()
    
    plt.show()
    
    if save_path:
        plt.savefig(save_path, format='jpg')

In [13]:
# set up path for thh reference (baseline) model prediction and its ground truth
# The reason why we have two different ground truth is that the orders could be different across different machines
ref_prediction_path = '../../work_dirs/r2plus1d_r34_video_3d_8x8x1_900e_ucf101_rgb_20percent_vidssl/results.pkl'
ref_gt_labels_path = '../../work_dirs/r2plus1d_r34_video_8x8x1_180e_ucf101_rgb_fixmatch_20percent_vidssl/gt_labels.npy'

# class names 
classnames_path = '../../work_dirs/r2plus1d_r34_video_8x8x1_180e_ucf101_rgb_fixmatch_20percent_vidssl/classnames.npy'

# set up the path for the compared (proposed) model predictions, and its ground truth
comp_prediction_path = '../../work_dirs/r2plus1d_r34_video_8x8x1_360e_ucf101_rgb_all_20percent_vidssl/results.pkl' # final model
comp_gt_labels_path = '../../work_dirs/r2plus1d_r34_video_3d_8x8x1_900e_ucf101_rgb_20percent_vidssl/gt_labels_new.npy'
# comp_prediction_path = '../../work_dirs/r2plus1d_r34_video_8x8x1_180e_ucf101_rgb_taugment_20percent_vidssl/results.pkl' # temp aug all only
# comp_gt_labels_path = '../../work_dirs/r2plus1d_r34_video_8x8x1_180e_ucf101_rgb_fixmatch_20percent_vidssl/gt_labels.npy'
# comp_prediction_path = '../../work_dirs/r2plus1d_r34_video_8x8x1_180e_ucf101_rgb_actorcutmix_20percent_vidssl/results.pkl' # ActorCutMix only
# comp_gt_labels_path = '../../work_dirs/r2plus1d_r34_video_3d_8x8x1_900e_ucf101_rgb_20percent_vidssl/gt_labels_new.npy'

# set up the output bar chart path for both worse classes and better classes
ref_output_bar_chart_path = '../../work_dirs/confmat_baseline.jpg'
comp_output_bar_chart_path = '../../work_dirs/confmat_final.jpg'
# comp_output_bar_chart_path = '../../work_dirs/confmat_tempaug_only.jpg'
# comp_output_bar_chart_path = '../../work_dirs/confmat_actorcutmix_only.jpg'

# canvas size
figsize = (50,50)

# normalize or not
# normalize = 'true'
normalize = None

In [14]:
# read the data
ref_preds = np.load(ref_prediction_path, allow_pickle=True)
ref_preds = [entry.argmax() for entry in ref_preds]
comp_preds = np.load(comp_prediction_path, allow_pickle=True)
comp_preds = [entry.argmax() for entry in comp_preds]
ref_gt_labels = np.load(ref_gt_labels_path)
comp_gt_labels = np.array(np.load(comp_gt_labels_path, allow_pickle=True))
classnames = np.load(classnames_path)

In [15]:
# compute ref model confmat, top-1 accuracy
confmat_ref = confusion_matrix(ref_preds, ref_gt_labels, normalize=normalize)
if normalize == 'true':
    ref_overall_acc = np.sum(np.diag(confmat_ref))
    print('Ref pred top-1 accuracy: {:.2f}%'.format(ref_overall_acc))
else:
    ref_overall_acc = np.sum(np.diag(confmat_ref))/float(len(ref_gt_labels))
    print('Ref pred top-1 accuracy: {:.2f}%'.format(100*ref_overall_acc))

# compute compared model confmat, top-1 accuracy
confmat_comp = confusion_matrix(comp_preds, comp_gt_labels, normalize=normalize)
if normalize == 'true':
    comp_overall_acc = np.sum(np.diag(confmat_comp))
    print('Compared pred accuracy: {:.2f}%'.format(comp_overall_acc))
else:
    comp_overall_acc = np.sum(np.diag(confmat_comp))/float(len(comp_gt_labels))
    print('Compared pred accuracy: {:.2f}%'.format(100*comp_overall_acc))

Ref pred top-1 accuracy: 38.91%
Compared pred accuracy: 56.73%


In [7]:
plot_confusion_matrix(confmat_ref, classnames, figsize, ref_output_bar_chart_path, normalize=True)

In [8]:
plot_confusion_matrix(confmat_comp, classnames, figsize, comp_output_bar_chart_path, normalize=True)

### Permute the confusion matrices by the ascending order of the baseline class accuracies

In [9]:
baseline_acc_ascending_order = np.argsort(np.diag(confmat_ref))
confmat_ref_perm = confmat_ref[baseline_acc_ascending_order]
confmat_ref_perm = confmat_ref_perm[:,baseline_acc_ascending_order]

In [10]:
plot_confusion_matrix(confmat_ref_perm, classnames[baseline_acc_ascending_order], figsize, ref_output_bar_chart_path, normalize=True)

In [11]:
confmat_comp_perm = confmat_comp[baseline_acc_ascending_order]
confmat_comp_perm = confmat_comp_perm[:,baseline_acc_ascending_order]

In [12]:
plot_confusion_matrix(confmat_comp_perm, classnames[baseline_acc_ascending_order], figsize, comp_output_bar_chart_path, normalize=True)