In [1]:
# Source: Alexandru Tifrea and Fanny Yang, 2021.

# Python Notebook Commands
%reload_ext autoreload
%load_ext autoreload
%autoreload 2

from IPython.core.display import display, HTML

display(HTML("<style>.container { width:100% !important; }</style>"))

from copy import deepcopy
import numpy as np
import time

import plotly
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import plotly.io as pio

import ipywidgets
from ipywidgets import interact

from sklearn.linear_model import LogisticRegression
from sklearn import datasets

# General math and plotting modules.
import numpy as np

from sklearn.calibration import calibration_curve

# Change these values if the images don't fit for your screen.
figure_width = 1600
figure_height = 600

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


# Temperature scaling and calibration

This notebook shows the effect of temperature scaling on the probability outputs of a logistic regression model and its calibration.

A logistic regression outputs the estimated probability of an input being labeled as 1. This probability can be tuned using a scalar parameter called temperature, as follows:

$$p = \frac{1}{1 + \exp(-zT)}, \text{ with } z = Wx$$

The value $T=1$ corresponds to the unaltered output of the logistic regression model. As we show below, the original output of logistic regression is sometimes not well calibrated. That can alleviated by tuning the temperature on a holdout calibration set.

In [2]:
def get_platt_scaler(logits, T):
  x = logits * T
  output = 1 / (1 + np.exp(-x))
  return output

  
class PlattCalibrator:
  def __init__(self, log_reg_model, T=None):
    self.log_reg_model = log_reg_model
    self.T = T
    self._platt = lambda logits: get_platt_scaler(logits=logits, T=self.T)

  def calibrate(self, xs):
    zs = xs @ self.log_reg_model.coef_[0] + self.log_reg_model.intercept_
    return self._platt(zs)
      
# Plots a linear function determined by a slope and an intercept term passed as arguments.
def plot_decision_boundary(W, b, name="", color=None):
  x1 = np.linspace(-10, 10, 100)
  x2 = (-b[0] - W[0] * x1) / W[1]
  return go.Scatter(x=x1,
                    y=x2,
                    line=dict(color=color, width=3),
                    mode="lines",
                    line_width=7,
                    line_color="cyan",
                    hoverinfo="skip",
                    name=name)

In [4]:
x_min, x_max = -5, 5
y_min, y_max = -5, 5
noise = 0.5
n_samples = 30


def calibration_plots(T=1, n_bins=3):
  
  def gen_data():
    return datasets.make_blobs(n_samples=n_samples,
                               centers=np.array([[0, 1], [0, -1]]) * 3,
                               random_state=3,
                               cluster_std=[2 * noise, 4 * noise])

  X, Y = gen_data()
  X_test, Y_test = gen_data()

  clf = LogisticRegression(max_iter=10000, tol=1e-10, solver="liblinear").fit(X, Y)
  calibrator = PlattCalibrator(clf, T=T)

  fig = make_subplots(rows=1, cols=2, column_widths=[0.3, 0.5], 
                      horizontal_spacing=0.15,
                      subplot_titles=["Output probability field", "Calibration curves"])
  fig.add_trace(go.Scatter(x=X[Y == 0][:, 0], y=X[Y == 0][:, 1], mode="markers", 
                           marker_symbol="cross", marker_size=10, marker_color="blue", name="Training samples - Class 0"), row=1, col=1)
  fig.add_trace(go.Scatter(x=X[Y == 1][:, 0], y=X[Y == 1][:, 1], mode="markers",
                           marker_symbol="circle", marker_size=10, marker_color="red", name="Training samples - Class 1"), row=1, col=1)
  fig.add_trace(plot_decision_boundary(clf.coef_[0], clf.intercept_, name="Decision boundary"), row=1, col=1)

  # Compute recalibrated probabilities for the grid.
  grid_size = 100
  xs = np.linspace(x_min, x_max, grid_size)
  ys = np.linspace(y_min, y_max, grid_size)
  xx, yy = np.meshgrid(xs, ys)
  xy_coords = np.concatenate((np.expand_dims(xx.flatten(), axis=-1), np.expand_dims(yy.flatten(), axis=-1)), axis=1)
  grid_probs = calibrator.calibrate(xy_coords)
  grid_probs = grid_probs.reshape(grid_size, grid_size)

  fig.add_trace(
    go.Contour(
      x=xs, y=ys,
      z=grid_probs,
      colorscale='sunset',
      contours_coloring='heatmap',
      colorbar_x=0.33,
      zmin=0,
      zmax=1,
    ),
    row=1, col=1
  )

  # Plot calibration curves.
  test_probs = clf.predict_proba(X_test)[:, 1]
  test_calibrated_probs = calibrator.calibrate(X_test)

#   n_bins = 3
  fraction_of_positives, mean_predicted_value = calibration_curve(Y_test, test_probs, n_bins=n_bins)
  fig.add_trace(go.Scatter(x=mean_predicted_value, y=fraction_of_positives, 
                           line_dash="dash", marker_symbol="x", line_color="forestgreen", name="Original outputs (T=1)"),
                row=1, col=2)
  fraction_of_positives, mean_predicted_value = calibration_curve(Y_test, test_calibrated_probs, n_bins=n_bins)
  fig.add_trace(go.Scatter(x=mean_predicted_value, y=fraction_of_positives, 
                           line_color="magenta", name=f"Temperature scaling T={T}"),
                row=1, col=2)


  fig.update_layout(
    width=figure_width,
    height=figure_height,
    xaxis1={
      "range": (x_min, x_max),
      "title": "x1",
    },
    yaxis1={
      "range": (y_min, y_max),
      "title": "x2",
    },
    xaxis2_title="Confidence",
    yaxis2_title="Accuracy",
  )

  fig.show()
  
  
_ = interact(
    calibration_plots,
    T=ipywidgets.FloatSlider(value=3.0,
                               min=0.2,
                               max=10,
                               step=0.2,
                               readout_format='.1f',
                               description='Temperature:',
                               style={'description_width': 'initial'},
                               continuous_update=False),
    n_bins=ipywidgets.IntSlider(value=3,
                                min=2,
                                max=10,
                                step=1,
                                readout_format='d',
                                description='Num. bins:',
                                style={'description_width': 'initial'},
                                continuous_update=False),
)

interactive(children=(FloatSlider(value=3.0, continuous_update=False, description='Temperature:', max=10.0, mi…