# Online Multiclass Logistic Regression using CMGF

##0. Imports

In [None]:
# Silence WARNING:root:The use of `check_types` is deprecated and does not have any effect.
# https://github.com/tensorflow/probability/issues/1523
import logging

logger = logging.getLogger()


class CheckTypesFilter(logging.Filter):
    def filter(self, record):
        return "check_types" not in record.getMessage()


logger.addFilter(CheckTypesFilter())

In [None]:
try:
    from ssm_jax.cond_moments_gaussian_filter.inference import *
    from ssm_jax.cond_moments_gaussian_filter.containers import *
except ModuleNotFoundError:
    print('installing ssm_jax')
    %pip install -qq git+https://github.com/probml/ssm-jax.git
    from ssm_jax.cond_moments_gaussian_filter.inference import *
    from ssm_jax.cond_moments_gaussian_filter.containers import *

In [None]:
import matplotlib.pyplot as plt

import jax
import jax.numpy as jnp
import jax.random as jr
from sklearn.datasets import make_classification
from sklearn.model_selection import cross_val_score
from sklearn.model_selection import RepeatedStratifiedKFold
from sklearn.linear_model import LogisticRegression
from sklearn import preprocessing
from sklearn.base import BaseEstimator, ClassifierMixin
from sklearn.preprocessing import OneHotEncoder

## 1. CMGF Online Multiclass Logistic Regression

First, we generate and standardize random dataset with 10 features and 4 classes.

In [None]:
num_points, num_features, num_classes = 10000, 5, 4
input, output = make_classification(n_samples=num_points, n_features=num_features, 
                                    n_informative=num_features, n_redundant=0, n_classes=num_classes, random_state=2)
scaler = preprocessing.StandardScaler().fit(input)
input, output = jnp.array(scaler.transform(input)), jnp.array(output)
input_with_bias = jnp.concatenate([jnp.ones((num_points, 1)), input], axis=1)

Note that the moments of a (one-hot-encoded) categorical distribution with $K$ possible classes are as follows:

$$\mathbb{E}[\vec{y}|\vec{x}, \textbf{W}] =  \begin{pmatrix} \sigma_2(\textbf{W}^T\vec{x}) \\ \sigma_3(\textbf{W}^T\vec{x}) \\ \vdots \\ \sigma_K(\textbf{W}^T\vec{x}) \end{pmatrix}
$$
$$Cov[\vec{y}|\vec{x}, \textbf{W}] = \begin{pmatrix} p_2 (1 - p_2) & -p_2 p_3 & \dots & -p_2 p_K \\
-p_2 p_3 & p_3 (1 - p_3) & \dots  & -p_3 p_K \\
\vdots & \vdots & \ddots & \vdots \\
-p_2 p_K & -p_3 p_K & \dots & p_K (1 - p_K)
 \end{pmatrix}$$
where $\vec{\sigma}(\cdot)$ is the softmax function.

Note that in order to prevent the "Dummy Variable Trap," we drop the first column.

Thus, we can build a generic multiclass CMGF classifier that works with the `scikit-learn` cross validation tool as follows.



In [None]:
def fill_diagonal(A, elts):
    # Taken from https://github.com/google/jax/issues/2680
    elts = jnp.ravel(elts)
    i, j = jnp.diag_indices(min(A.shape[-2:]))
    return A.at[..., i, j].set(elts)

In [None]:
class CMGFEstimator(BaseEstimator, ClassifierMixin):
    def __init__(self, params, mean=None, cov=None):
        self.params = params
        self.mean = mean
        self.cov = cov

    def fit(self, X, y):
        X_bias = jnp.concatenate([jnp.ones((len(X), 1)), X], axis=1)
        # Encode output as one-hot-encoded vectors with first column dropped,
        # i.e., [0, ..., 0] correspondes to 1st class
        # This is done to prevent the "Dummy Variable Trap".
        enc = OneHotEncoder(drop='first')
        y_oh = jnp.array(enc.fit_transform(y.reshape(-1, 1)).toarray())
        input_dim = X_bias.shape[-1]
        num_classes = y_oh.shape[-1] + 1
        weight_dim = input_dim * num_classes
        
        initial_mean, initial_covariance = jnp.zeros(weight_dim), jnp.eye(weight_dim)
        dynamics_function = lambda w, x: w
        dynamics_covariance = jnp.zeros((weight_dim, weight_dim))
        emission_mean_function = lambda w, x: jax.nn.softmax(x @ w.reshape(input_dim, -1))[1:]
        def emission_var_function(w, x):
            ps = jnp.atleast_2d(emission_mean_function(w, x))
            return fill_diagonal(ps.T @ -ps, ps * (1-ps))
        cmgf_params = self.params(
            initial_mean = initial_mean,
            initial_covariance = initial_covariance,
            dynamics_function = dynamics_function,
            dynamics_covariance = dynamics_covariance,
            emission_mean_function = emission_mean_function,
            emission_var_function = emission_var_function
        )
        post = conditional_moments_gaussian_filter(cmgf_params, y_oh, inputs = X_bias)
        post_means, post_covs = post.filtered_means, post.filtered_covariances
        self.mean, self.cov = post_means[-1], post_covs[-1]
        return self
    
    def predict(self, X, y=None):
        X_bias = jnp.concatenate([jnp.ones((len(X), 1)), X], axis=1)
        return jnp.argmax(jax.nn.softmax(X_bias @ self.mean.reshape(X_bias.shape[-1], -1)), axis=1)

In [None]:
# Helper function to compute accuracy measure
def compute_accuracy(model, input, output, fit=True):
    if fit:
        model = model.fit(input, output)
    return jnp.count_nonzero(model.predict(input) - output == 0) / len(output)

# Print training accuracy
cmgf_model, sgd_model = CMGFEstimator(EKFParams), LogisticRegression(multi_class='multinomial', solver='sag', max_iter=1)
print(f'CMGF training accuracy: {compute_accuracy(cmgf_model, input, output)}')
print(f'SGD training accuracy: {compute_accuracy(sgd_model, input, output)}')

CMGF training accuracy: 0.5953999757766724
SGD training accuracy: 0.4936999976634979




We test the accuracy using 3 repeated trials of 10-fold cross validation.

In [None]:
cv = RepeatedStratifiedKFold(n_splits=10, n_repeats=3, random_state=1)

cmgf_est = CMGFEstimator(EKFParams)
n_scores = cross_val_score(cmgf_est, input, output, scoring='accuracy', cv=cv, n_jobs=-1, error_score='raise')
print(f'{num_points} data points, {num_features} features, {num_classes} classes.')
print(f'EKF-CMGF estimate average accuracy = {n_scores.mean()}')

10000 data points, 5 features, 4 classes.
EKF-CMGF estimate average accuracy = 0.5935999999999999


In [None]:
model = LogisticRegression(multi_class='multinomial', solver='sag')
n_scores = cross_val_score(model, input, output, scoring='accuracy', cv=cv, n_jobs=-1, error_score='raise')
print(f'{num_points} data points, {num_features} features, {num_classes} classes.')
print(f'sag estimate average accuracy = {n_scores.mean()}')

10000 data points, 5 features, 4 classes.
sag estimate average accuracy = 0.5938


In [None]:
# def repeated_kfold_cv(model, X, y, n_splits=10, n_repeats=3, key=1):
#     if isinstance(key, int):
#         key = jr.PRNGKey(key)
#     key, subkey = jr.split(key, 2)
#     num_points = len(y)
#     accuracy_score = 0
#     for _ in range(n_repeats):
#         idx = jr.permutation(key, jnp.arange(num_points))
#         kfolds = jnp.array_split(idx, n_splits)
#         for i in range(n_splits):
#             test_idx = kfolds[i]
#             train_idx = jnp.concatenate(kfolds[:i] + kfolds[i+1:])
#             X_train, y_train = X[train_idx], y[train_idx]
#             X_test, y_test = X[test_idx], y[test_idx]
#             model = model.fit(X_train, y_train)
#             y_predict = model.predict(X_test)
#             accuracy_score += 1 - (jnp.abs(y_predict - y_test).sum() / len(test_idx))
#     return accuracy_score / (n_splits * n_repeats)