In [None]:
%matplotlib inline

import warnings
warnings.filterwarnings("ignore", message="numpy.dtype size changed")
warnings.filterwarnings("ignore", message="numpy.ufunc size changed")

# user-friendly print
from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = "all"

In [None]:
from itertools import product

class ConfusionMatrixDisplay:
    """Confusion Matrix visualization.
    It is recommend to use :func:`~sklearn.metrics.plot_confusion_matrix` to
    create a :class:`ConfusionMatrixDisplay`. All parameters are stored as
    attributes.
    Read more in the :ref:`User Guide <visualizations>`.
    Parameters
    ----------
    confusion_matrix : ndarray of shape (n_classes, n_classes)
        Confusion matrix.
    display_labels : ndarray of shape (n_classes,)
        Display labels for plot.
    Attributes
    ----------
    im_ : matplotlib AxesImage
        Image representing the confusion matrix.
    text_ : ndarray of shape (n_classes, n_classes), dtype=matplotlib Text, \
            or None
        Array of matplotlib axes. `None` if `include_values` is false.
    ax_ : matplotlib Axes
        Axes with confusion matrix.
    figure_ : matplotlib Figure
        Figure containing the confusion matrix.
    """
    def __init__(self, confusion_matrix, display_labels):
        self.confusion_matrix = confusion_matrix
        self.display_labels = display_labels

    def plot(self, *, include_values=True, cmap='viridis',
             xticks_rotation='horizontal', values_format=None, ax=None, anno_fontsize=14, tick_fontsize=14, label_fontsize=20):
        """Plot visualization.
        Parameters
        ----------
        include_values : bool, default=True
            Includes values in confusion matrix.
        cmap : str or matplotlib Colormap, default='viridis'
            Colormap recognized by matplotlib.
        xticks_rotation : {'vertical', 'horizontal'} or float, \
                         default='vertical'
            Rotation of xtick labels.
        values_format : str, default=None
            Format specification for values in confusion matrix. If `None`,
            the format specification is '.2f' for a normalized matrix, and
            'd' for a unnormalized matrix.
        ax : matplotlib axes, default=None
            Axes object to plot on. If `None`, a new figure and axes is
            created.
        Returns
        -------
        display : :class:`~sklearn.metrics.ConfusionMatrixDisplay`
        """
        import matplotlib.pyplot as plt

        if ax is None:
            fig, ax = plt.subplots()
        else:
            fig = ax.figure

        cm = self.confusion_matrix
        n_classes = cm.shape[0]
        self.im_ = ax.imshow(cm, interpolation='nearest', cmap=cmap)
        self.text_ = None

        cmap_min, cmap_max = self.im_.cmap(0), self.im_.cmap(256)

        if include_values:
            self.text_ = np.empty_like(cm, dtype=object)
            if values_format is None:
                values_format = '.2g'

            # print text with appropriate color depending on background
            thresh = (cm.max() - cm.min()) / 2.
            for i, j in product(range(n_classes), range(n_classes)):
                color = cmap_max if cm[i, j] < thresh else cmap_min
                self.text_[i, j] = ax.text(j, i,
                                           format(cm[i, j], values_format),
                                           ha="center", va="center",
                                           color=color, fontsize=anno_fontsize)

        cbar = fig.colorbar(self.im_, ax=ax)
        cbar.ax.tick_params(labelsize=anno_fontsize) 
        ax.set(xticks=np.arange(n_classes),
               yticks=np.arange(n_classes),
               xticklabels=self.display_labels,
               yticklabels=self.display_labels,
               ylabel="True label",
               xlabel="Predicted label")

        ax.set_ylim((n_classes - 0.5, -0.5))
        plt.setp(ax.get_xticklabels(), rotation=xticks_rotation)

        ax.xaxis.label.set_size(label_fontsize)
        ax.yaxis.label.set_size(label_fontsize)
        ax.tick_params(labelsize=tick_fontsize)
        
        self.figure_ = fig
        self.ax_ = ax
        return self


In [196]:
import numpy as np

import matplotlib.ticker as plticker

from matplotlib import pyplot as plt
from xenonpy.model.utils import regression_metrics

def cv_plot(pred, true, pred_fit=None, true_fit=None, *, unit='', lim=None, n_ticks=5, title='', ax=None, color_style=True, show_label=True, show_legend=True, **kwargs):
    pred, true = pred.flatten(), true.flatten()
    scores = regression_metrics(true, pred)

    if ax is None:
        _, ax = plt.subplots(figsize=(8, 8), dpi=150)

    if color_style:
        if pred_fit is not None and true_fit is not None:
            ax.scatter(pred_fit, true_fit, alpha=0.4, s=13, marker='D', label='Train', **kwargs)
        ax.scatter(pred, true, alpha=1, s=35, ec='w', marker='o', label='Test', **kwargs)
    else:
        if pred_fit is not None and true_fit is not None:
            ax.scatter(pred_fit, true_fit, alpha=0.6, s=15, ec='grey', c='none', marker='D', label='Train', **kwargs)
        ax.scatter(pred, true, alpha=1, s=35, ec='w', c='k', marker='o', label='Test', **kwargs)
        
    # adjust lims
    if lim is None:
        temp_data = np.concatenate([pred, true]) if pred_fit is None or true_fit is None else np.concatenate([pred, true, pred_fit, true_fit])
        lim = (temp_data.min(), temp_data.max())
        shift = (lim[1] - lim[0]) * 0.05
        lim = (lim[0] - shift, lim[1] + shift)
    ax.set_xlim(*lim)
    ax.set_ylim(*lim)

    # plot diagonal
    ax.plot(lim, lim, ls="--", c="0.3", alpha=0.7, lw=1.5)
    
    # align ticks
    base = round((lim[1] - lim[0]) / n_ticks)
    
    if base > 0:
        loc = plticker.MultipleLocator(base=base) # this locator puts ticks at regular intervals
        ax.xaxis.set_major_locator(loc)
        ax.yaxis.set_major_locator(loc)
    
    if unit != '':
        unit = f' ({unit})'
    if show_label:
        ax.set_xlabel(f'Prediction{unit}', fontsize='x-large')
        ax.set_ylabel(f'Observation{unit}', fontsize='x-large')
    if show_legend:
        legend = ax.legend(markerscale=2.5, fontsize='larger', loc=0)
        for lh in legend.legendHandles:
            lh.set_alpha(1.0)
    ax.grid(color='gray', linestyle='--', linewidth=0.5, alpha=0.5)
    ax.set_title(title, fontsize='xx-large')
    shift = (lim[1] - lim[0]) / 30
    ax.text(lim[1] - shift, lim[0] + shift,
             'R2: %.3f\nMAE: %.3f\nCorrelation: %.3f' % (scores['r2'],scores['mae'],scores['pearsonr']),
             horizontalalignment='right', verticalalignment='bottom', fontsize='x-large',
            bbox=dict(boxstyle='square', facecolor='grey', alpha=0.2, ec='black'),
           )
    ax.tick_params(axis='both', which='major', labelsize='larger')
    return ax