In [127]:
# Source: Alexandru Tifrea and Fanny Yang, 2022.

# Python Notebook Commands
%reload_ext autoreload
%autoreload 2

from IPython.core.display import HTML
from IPython.display import display
display(HTML("<style>.container { width:100% !important; }</style>"))

from copy import deepcopy
import numpy as np
import time

import plotly
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import plotly.io as pio

import ipywidgets
from ipywidgets import interact

from sklearn.linear_model import LogisticRegression, SGDClassifier
from sklearn import datasets

# General math and plotting modules.
import numpy as np

from sklearn.calibration import calibration_curve

from functools import partial, update_wrapper
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random
key = random.PRNGKey(0)

# Change these values if the images don't fit for your screen.
figure_width = 1600
figure_height = 600

# Utilities for plotting

In [80]:
# Plots a linear function determined by a slope and an intercept term passed as arguments.
def plot_decision_boundary(W, b, name="", color=None, showlegend=True):
  x1 = np.linspace(-10, 10, 100)
  x2 = (-b[0] - W[0] * x1) / W[1]
  return go.Scatter(x=x1,
                    y=x2,
                    line=dict(color=color, width=3),
                    mode="lines",
                    line_width=7,
                    line_color="cyan",
                    hoverinfo="skip",
                    name=name,
                    legendgroup="db", 
                    showlegend=showlegend)

def dir_norm_split(W, b):
  dir = np.array([float(W[0]), float(W[1]), b])
  return dir / np.linalg.norm(dir), np.linalg.norm(dir)

In [88]:
def make_plots(X, y, model1, model2, title1, title2):
  x_min, x_max = -1.5, 1.5
  y_min, y_max = -1.5, 1.5

  fig = make_subplots(rows=1, cols=2, column_widths=[0.5, 0.5], 
                      horizontal_spacing=0.15,
                      subplot_titles=[title1, title2])
  # Plot decision boundary for model 1.
  fig.add_trace(go.Scatter(x=X[y == 0][:, 0], y=X[y == 0][:, 1], mode="markers", 
                           marker_symbol="circle", marker_size=10, marker_color="blue", 
                           name="Class -", 
                           legendgroup="Class -"), row=1, col=1)
  fig.add_trace(go.Scatter(x=X[y == 1][:, 0], y=X[y == 1][:, 1], mode="markers",
                           marker_symbol="cross", marker_size=10, marker_color="red", 
                           name="Class +",
                           legendgroup="Class +"), row=1, col=1)
  fig.add_trace(plot_decision_boundary(model1[0], [model1[1]], name="Decision boundary"), row=1, col=1)

  # Plot decision boundary for model 2.
  fig.add_trace(go.Scatter(x=X[y == 0][:, 0], y=X[y == 0][:, 1], mode="markers", 
                           marker_symbol="circle", marker_size=10, marker_color="blue", 
                           name="Class -",
                           legendgroup="Class -", showlegend=False), row=1, col=2)
  fig.add_trace(go.Scatter(x=X[y == 1][:, 0], y=X[y == 1][:, 1], mode="markers",
                           marker_symbol="cross", marker_size=10, marker_color="red", 
                           name="Class +",
                           legendgroup="Class +", showlegend=False), row=1, col=2)
  fig.add_trace(plot_decision_boundary(model2[0], [model2[1]], name="Decision boundary", showlegend=False), row=1, col=2)

  fig.update_layout(
    width=figure_width,
    height=figure_height,
    xaxis1={
      "range": (x_min, x_max),
      "title": "x1",
    },
    yaxis1={
      "range": (y_min, y_max),
      "title": "x2",
    },
    xaxis2={
      "range": [x_min, x_max],
      "title": "x1",
    },
    yaxis2={
      "range": [y_min, y_max],
      "title": "x2",
    },
  )

  fig.show()

# Training model and loss

In [84]:
# Outputs predictions of an estimator.
@partial(jit, static_argnums=(3,))
def predict(W, b, X, loss_code):
  if loss_code == LOSS_CODE["log"]:
    return sigmoid(jnp.dot(X, W) + b)
  elif loss_code in [LOSS_CODE["squared"], LOSS_CODE["avg_margin"]]:
    return jnp.dot(X, W) + b
  else:
    raise RuntimeError(f"Unknown loss {loss_code}!")

# Training loss.
@partial(jit, static_argnums=(4,))
def loss(W, b, X, y, loss_code):
  preds = predict(W, b, X, loss_code).squeeze()
  if loss_code == LOSS_CODE["log"]:
    label_probs = preds * y + (1 - preds) * (1 - y)
    return -jnp.mean(jnp.log(label_probs))
  elif loss_code == LOSS_CODE["squared"]:
    # For squared loss we need the true labels to be in {-1, 1}.
    label_probs = jnp.power(y - preds, 2)
    return jnp.mean(label_probs)
  elif loss_code == LOSS_CODE["avg_margin"]:
    # For avg margin loss we need the true labels to be in {-1, 1}.
    label_probs = preds * y
    return -jnp.mean(label_probs)
  else:
    raise RuntimeError(f"Unknown loss code {loss_code}!")

def sigmoid(z):
  return 0.5 * (jnp.tanh(z / 2) + 1)

@partial(jit, static_argnums=(4, 5))
def train_step(W, b, X, y, loss_code, eta):
  grads = grad(loss, argnums=[0, 1])(W, b, X, y, loss_code)
  new_W = W - eta * grads[0]
  new_b = b - eta * grads[1]
  return new_W, new_b, loss(new_W, new_b, X, y, loss_code), grads
      

LOSS_CODE = {
    "log": 0,
    "squared": 1, 
    "avg_margin": 2,
}

def train(W, b, X, y, loss_str, n_iter, eta):
  # Get the labels to be in {0, 1} or {-1, 1} as appropriate for each loss function.
  labels = list(np.sort(np.unique(y)))
  if loss_str == "log":
    assert labels == [0, 1], labels
  elif loss_str in ["squared", "avg_margin", "poly_loss"]:
    y = (y - 0.5) * 2
    labels = list(np.sort(np.unique(y)))
    assert labels == [-1, 1], labels
  else:
    raise RuntimeError(f"Unknown loss {loss_str}!")

  loss_code = LOSS_CODE[loss_str]
  for i in range(n_iter):
    lr = eta
    preds = predict(W, b, X, loss_code).squeeze()
    W, b, _, _ = train_step(W, b, X, y, loss_code, lr)
  return W, b

In [128]:
def generate_mixture_data(n_samples=200, ratio_neg=0.1, noise=0.2, mu_norm=1.):
  n_neg = int(n_samples * ratio_neg)
  n_pos = n_samples - n_neg
  return datasets.make_blobs(n_samples=[n_pos, n_neg],
                             centers=np.array([[0, 1], [0, -1]]) * mu_norm,
                             random_state=3,
                             cluster_std=[noise, noise])


def generate_discrimintative_data(n_samples=200):
  X = np.random.randn(n_samples * 2).reshape((-1, 2)) * 0.5
  w_star = np.random.randn(2)
  return X, (X @ w_star) > 0


def compare_losses(data_model, n_samples, ratio_neg=None, noise=None, mu_norm=None):
  if data_model == "gmm":
    assert ratio_neg is not None
    assert noise is not None
    assert mu_norm is not None
    X, y = generate_mixture_data(n_samples, ratio_neg, noise, mu_norm)
  elif data_model == "discriminative":
    X, y = generate_discrimintative_data(n_samples)
  else:
    raise RuntimeError(f"Unknown data model {data_model}")
  
  key = random.PRNGKey(int(time.time()))
  W_init = random.normal(key, (2, 1)) * 1e-2
  b_init = random.normal(key, ()) * 1e-2

  W1, b1 = train(W_init, b_init, X, y, loss_str="log", n_iter=10000, eta=1e-1)
  W2, b2 = train(W_init, b_init, X, y, loss_str="avg_margin", n_iter=10000, eta=1e-1)

  make_plots(X, y, (W1, b1), (W2, b2), title1="Logistic loss (max min margin)", title2="Max average margin")

# Compare different losses
Similar to https://scikit-learn.org/stable/auto_examples/linear_model/plot_sgd_loss_functions.html.

In [126]:
z = np.concatenate((np.linspace(-3, -1e-5, 50), np.linspace(-1e-5, 1e-5, 2), np.linspace(1e-5, 3, 50)))
fig = go.Figure()

fig.add_trace(go.Scatter(
  x=z,
  y=np.log2(1 + np.exp(-z)),
  mode="lines",
#   hoverinfo="skip",
  name="logistic loss"))
fig.add_trace(go.Scatter(
  x=z,
  y=np.exp(-z),
  mode="lines",
#   hoverinfo="skip",
  name="exponential loss"))
fig.add_trace(go.Scatter(
  x=z, 
  y=(1-z)**2,
  mode="lines",
#   hoverinfo="skip",
  name="squared loss"))
fig.add_trace(go.Scatter(
  x=z, 
  y=np.maximum(0., 1-z),
  mode="lines",
#   hoverinfo="skip",
  name="hinge"))
fig.add_trace(go.Scatter(
  x=z, 
  y=-z,
  mode="lines",
#   hoverinfo="skip",
  name="avg margin"))
fig.add_trace(go.Scatter(
  x=z, 
  y=np.where(z < 0, 1, 0),
  mode="lines",
#   hoverinfo="skip",
  name="0-1 loss"))

fig.update_layout(
  height=figure_height,
  width=800,
  yaxis_title="$L(y=1, f(x))$",
  xaxis_title="Decision function f(x)",
  yaxis_range=[-1, 5]
)

# Logistic loss vs average margin for GMM data model

In [90]:
%%time

_ = interact(
    compare_losses,
    data_model=ipywidgets.fixed("gmm"),
    n_samples=ipywidgets.IntSlider(value=100,
                                   min=20,
                                   max=200,
                                   step=10,
                                   readout_format='d',
                                   description='Number of samples:',
                                   style={'description_width': 'initial'},
                                   continuous_update=False),
    ratio_neg=ipywidgets.FloatSlider(value=0.5,
                                     min=0.1,
                                     max=0.9,
                                     step=0.1,
                                     readout_format='.2f',
                                     description='Ratio of negative samples:',
                                     style={'description_width': 'initial'},
                                     continuous_update=False),
    noise=ipywidgets.FloatSlider(value=0.1,
                                 min=0.1,
                                 max=0.9,
                                 step=0.1,
                                 readout_format='.2f',
                                 description='Noise:',
                                 style={'description_width': 'initial'},
                                 continuous_update=False),
    mu_norm=ipywidgets.FloatSlider(value=1.,
                                   min=0.5,
                                   max=1.5,
                                   step=0.1,
                                   readout_format='.2f',
                                   description='$\|\mu\|:$',
                                   style={'description_width': 'initial'},
                                   continuous_update=False),
)

interactive(children=(IntSlider(value=100, continuous_update=False, description='Number of samples:', max=200,…

CPU times: user 858 ms, sys: 11.3 ms, total: 870 ms
Wall time: 873 ms


# Logistic loss vs average margin for discriminative data model

In [91]:
%%time

_ = interact(
    compare_losses,
    data_model=ipywidgets.fixed("discriminative"),   
    ratio_neg=ipywidgets.fixed(None), 
    noise=ipywidgets.fixed(None),
    mu_norm=ipywidgets.fixed(None),
    n_samples=ipywidgets.IntSlider(value=100,
                                   min=20,
                                   max=200,
                                   step=10,
                                   readout_format='d',
                                   description='Number of samples:',
                                   style={'description_width': 'initial'},
                                   continuous_update=False),

)

interactive(children=(IntSlider(value=100, continuous_update=False, description='Number of samples:', max=200,…

CPU times: user 985 ms, sys: 14.1 ms, total: 1e+03 ms
Wall time: 1.01 s
