In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import plotly.express as px
import plotly.graph_objects as go
from colorsys import rgb_to_hsv
from IPython.display import display_svg

from pyrejection.rejection import is_reject_or_null_mask
from pyrejection.datasets import DATASETS, prepare_dataset
from pyrejection.classifiers import CLASSIFIERS
from pyrejection.experiments import run_experiment
from pyrejection.evaluation import get_all_summaries, get_experiment_base_summary, render_svg_fig, plot_legend_svg

## Experiment-level Analysis

In [None]:
exp_result = run_experiment(metric_name='accuracy',
                            classifier_name='unscaled-logreg',
                            dataset_name='skin-segmentation',
                            random_state=0,
                            cache_dir='results_cache')

## Pixel Comparison

In [None]:
BACKGROUND_COLOUR = {'r': 255, 'g': 255, 'b': 255}
TRUE_CLASS_COLOUR = {'r': 170, 'g': 170, 'b': 170}
FALSE_CLASS_COLOUR = {'r': 0, 'g': 0, 'b': 0}
CORRECT_COLOUR = {'r': 20, 'g': 102, 'b': 20}
INCORRECT_COLOUR = {'r': 255, 'g': 51, 'b': 51}
REJECT_COLOUR = {'r': 38, 'g': 38, 'b': 191}

def colour_to_hex(c):
    return f"#{c['r']:x}{c['g']:x}{c['b']:x}"

def colour_block(rgb, index):
    return pd.DataFrame(rgb, index)

def get_df_pixels(X_df):
    return X_df[['r', 'g', 'b']].to_numpy().astype(int)

def get_class_pixels(y_series):
    mask_df = colour_block(BACKGROUND_COLOUR, y_series.index)
    mask_df = mask_df.mask((y_series == 'c1'), colour_block(TRUE_CLASS_COLOUR, y_series.index))
    mask_df = mask_df.mask((y_series == 'c2'), colour_block(FALSE_CLASS_COLOUR, y_series.index))
    mask_df = mask_df.mask(is_reject_or_null_mask(y_series), colour_block(REJECT_COLOUR, y_series.index))
    return get_df_pixels(mask_df)

def get_correctness_pixels(y_series, preds_series):
    correctness_df = colour_block(BACKGROUND_COLOUR, y_series.index)
    correctness_df = correctness_df.mask((y_series == preds_series),
                                         colour_block(CORRECT_COLOUR, y_series.index))
    correctness_df = correctness_df.mask((y_series != preds_series),
                                         colour_block(INCORRECT_COLOUR, y_series.index))
    correctness_df = correctness_df.mask(is_reject_or_null_mask(preds_series),
                                         colour_block(REJECT_COLOUR, y_series.index))
    return get_df_pixels(correctness_df)
    
rgb_indexes = [
    exp_result['dataset_attributes']['feature_names'].index(f)
    for f in ['r', 'g', 'b']
]
model_extras = get_experiment_base_summary(exp_result)['extras']
coefs = [model_extras['model_coefs'][0][idx] for idx in rgb_indexes]
intercept = model_extras['intercept'][0]

def get_pixel_index(pixels):
    hues = np.apply_along_axis(pixel_ordering, axis=1, arr=pixels)
    return hues.argsort()

def plot_image(pixels, ax=None):
    N = pixels.shape[0]
    H = int(np.ceil(np.sqrt(N)))
    W = int(np.ceil(N / H))
    remainder = (H * W) - N
    filler_vals = np.tile([255, 255, 255], (remainder, 1))
    pixels = np.concatenate([pixels, filler_vals], axis=0)
    img = pixels.reshape([H, W, 3])
    fig = px.imshow(img.astype(np.uint8))
    fig.update_layout(margin=dict(l=0, r=0, b=0, t=0, pad=0), width=600, height=600)
    fig.update_xaxes(showticklabels=False).update_yaxes(showticklabels=False)
    render_svg_fig(fig)

def pixel_ordering(rgb):
    def logreg(x):
        return np.sum(np.array(x) * coefs) - intercept
    return logreg(rgb)

In [None]:
confidence_threshold = 0.1
nl_iteration = 2

In [None]:
dataset_parts = prepare_dataset(
    CLASSIFIERS[exp_result['config']['classifier']],
    DATASETS[exp_result['config']['dataset']],
    random_state=exp_result['config']['random_state'],
    test_size=exp_result['config']['test_size'],
    apply_preprocessing=False)
test_X, test_y = dataset_parts['test_X'], dataset_parts['test_y']
all_summaries = get_all_summaries(exp_result)

In [None]:
X_pixels = get_df_pixels(test_X)
class_pixels = get_class_pixels(test_y)
pixel_index = get_pixel_index(X_pixels)
print('Pixel colours (ordered by logistic regression activation)')
plot_image(X_pixels[pixel_index])

In [None]:
true_class_proportion = test_y.value_counts()['c1'] / test_y.shape[0]
print(f'Class values ({true_class_proportion:.2%} skin)')
plot_image(class_pixels[pixel_index])

In [None]:
base_preds = pd.Series(
    all_summaries['confidence-thresholding']['0.0']['test_preds'],
    index=test_y.index
)
base_pixels = get_class_pixels(base_preds)
base_reject_rate = is_reject_or_null_mask(base_preds).sum() / base_preds.shape[0]
print(f'Base Classifications ({base_reject_rate:.2%} rejected)')
plot_image(base_pixels[pixel_index])

In [None]:
base_conditional_error = (
    # Covered and error
    (~is_reject_or_null_mask(base_preds) & (base_preds != test_y)).sum() /
    # Covered
    (~is_reject_or_null_mask(base_preds)).sum()
)
print(f'Base Correctness ({base_conditional_error:.2%} conditional error)')
plot_image(get_correctness_pixels(test_y, base_preds)[pixel_index])

In [None]:
ct_preds = pd.Series(
    all_summaries['confidence-thresholding'][str(confidence_threshold)]['test_preds'],
    index=test_y.index
)
ct_pixels = get_class_pixels(ct_preds)
ct_reject_rate = is_reject_or_null_mask(ct_preds).sum() / ct_preds.shape[0]
print(f'CT Classifications ({ct_reject_rate:.2%} rejected)')
plot_image(ct_pixels[pixel_index])

In [None]:
ct_conditional_error = (
    # Covered and error
    (~is_reject_or_null_mask(ct_preds) & (ct_preds != test_y)).sum() /
    # Covered
    (~is_reject_or_null_mask(ct_preds)).sum()
)
print(f'CT Correctness ({ct_conditional_error:.2%} conditional error)')
plot_image(get_correctness_pixels(test_y, ct_preds)[pixel_index])

In [None]:
nl_preds = pd.Series(
    all_summaries['null-labeling-nlrm-1-nlrc-0-iteration-{}'.format(nl_iteration)]['0.0']['test_preds'],
    index=test_y.index
)
nl_pixels = get_class_pixels(nl_preds)
nl_reject_rate = is_reject_or_null_mask(nl_preds).sum() / nl_preds.shape[0]
print(f'NL Classifications ({nl_reject_rate:.2%} rejected)')
plot_image(nl_pixels[pixel_index])

In [None]:
nl_conditional_error = (
    # Covered and error
    (~is_reject_or_null_mask(nl_preds) & (nl_preds != test_y)).sum() /
    # Covered
    (~is_reject_or_null_mask(nl_preds)).sum()
)
print(f'NL Correctness ({nl_conditional_error:.2%} conditional error)')
plot_image(get_correctness_pixels(test_y, nl_preds)[pixel_index])

In [None]:
display_svg(plot_legend_svg({
    'Skin': colour_to_hex(TRUE_CLASS_COLOUR),
    'Not skin': colour_to_hex(FALSE_CLASS_COLOUR),
    'Rejected': colour_to_hex(REJECT_COLOUR),
    'Correct': colour_to_hex(CORRECT_COLOUR),
    'Incorrect': colour_to_hex(INCORRECT_COLOUR),
}), raw=True)