# Setup

In [38]:
%reload_ext autoreload

In [39]:
%cd /home/kpmurphy/github/shifty
#from shifty.skax.logreg_flax import *
from shifty.skax.skax import *

print(logprior_fn) # check that one of the symbols is defiend

/home/kpmurphy/github/shifty
<function loglikelihood_fn at 0x7f5cc85f2950>


In [3]:
# Silence WARNING:root:The use of `check_types` is deprecated and does not have any effect.
# https://github.com/tensorflow/probability/issues/1523
import logging

logger = logging.getLogger()


class CheckTypesFilter(logging.Filter):
    def filter(self, record):
        return "check_types" not in record.getMessage()


logger.addFilter(CheckTypesFilter())

In [85]:

import matplotlib.pyplot as plt
import numpy as np
np.set_printoptions(precision=3)
import scipy.stats
import einops
from functools import partial

from functools import partial
from collections import namedtuple
import itertools
from itertools import repeat
from time import time
import chex
import typing
from typing import Any, Callable, Sequence

import jax
import jax.random as jr
import jax.numpy as jnp
from jax import vmap, grad, jit
from jax import lax, random, numpy as jnp
import jax.scipy as jsp

from flax.core import freeze, unfreeze
from flax import linen as nn
import flax

import jaxopt
import optax
import tensorflow as tf
import tensorflow_datasets as tfds

import torch
from torch.utils.data import TensorDataset, DataLoader

import distrax
from tensorflow_probability.substrates import jax as tfp
tfd = tfp.distributions
from tensorflow_probability.substrates.jax.distributions import MultivariateNormalFullCovariance as MVN

#jax.config.update("jax_enable_x64", False)


import sklearn.datasets
from sklearn.model_selection import train_test_split
import sklearn
from sklearn.preprocessing import PolynomialFeatures, StandardScaler
from sklearn.pipeline import make_pipeline, Pipeline
from sklearn.linear_model import LogisticRegression

In [7]:
import os 
cpu_count = os.cpu_count()
print(cpu_count)

# Run jax on multiple CPU cores
# https://github.com/google/jax/issues/5506
# https://stackoverflow.com/questions/72328521/jax-pmap-with-multi-core-cpu
import os 
#os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=90'

import jax
print(jax.devices())

96
[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0), TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1), TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0), TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1), TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0), TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1), TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0), TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]


# Code

In [None]:
from shifty.skax.logreg_flax import *

print(loglikelihood_fn) # check that one of the symbols is defiend

In [104]:
@chex.dataclass
class GenParams:
    nclasses: int
    nfeatures: int
    prior: chex.Array
    mus: chex.Array # (C,D)
    Sigmas: chex.Array #(C,D,D)

def make_params(key, nclasses, nfeatures, scale_factor=1):
    mus = jr.normal(key, (nclasses, nfeatures)) # (C,D)
    # shared covariance -> linearly separable
    #Sigma = scale_factor * jnp.eye(nfeatures)
    #Sigmas = jnp.array([Sigma for _ in range(nclasses)]) # (C,D,D)
    # diagonal covariance -> nonlinear decision boundaries
    sigmas = jr.uniform(key, shape=(nclasses, nfeatures), minval=0.5, maxval=5)
    Sigmas = jnp.array([scale_factor*jnp.diag(sigmas[y]) for y in range(nclasses)])
    prior = jnp.ones(nclasses)/nclasses
    return GenParams(nclasses=nclasses, nfeatures=nfeatures, prior=prior, mus=mus, Sigmas=Sigmas)

def sample_data(key, params, nsamples):
    y = jr.categorical(key, logits=jnp.log(params.prior), shape=(nsamples,))
    X = jr.multivariate_normal(key, params.mus[y], params.Sigmas[y])
    return X, y

def predict_bayes(X, params):
    def lik_fn(y):
        return   jsp.stats.multivariate_normal.pdf(X, params.mus[y], params.Sigmas[y])
    liks = vmap(lik_fn)(jnp.arange(params.nclasses)) # liks(k,n)=p(X(n,:) | y=k)
    joint = jnp.einsum('kn,k -> nk', liks, params.prior) # joint(n,k) = liks(k,n) * prior(k)
    norm = joint.sum(axis=1) # norm(n)  = sum_k joint(n,k) = p(X(n,:)
    post = joint / jnp.expand_dims(norm, axis=1) # post(n,k) = p(y = k | xn)
    return post



In [105]:
key = jr.PRNGKey(0)
key, subkey = jr.split(key)
params = make_params(subkey, nclasses=4, nfeatures=10, scale_factor=5)
key, subkey = jr.split(key)
Xtrain, ytrain = sample_data(subkey, params, nsamples=1000)
key, subkey = jr.split(key)
Xtest, ytest = sample_data(subkey, params, nsamples=1000)

yprobs_train_bayes = predict_bayes(Xtrain, params)
yprobs_test_bayes = predict_bayes(Xtest, params)

ypred_train_bayes = jnp.argmax(yprobs_train_bayes, axis=1)
error_rate_train_bayes = jnp.sum(ypred_train_bayes != ytrain) / len(ytrain)
print(error_rate_train_bayes)

ypred_test_bayes = jnp.argmax(yprobs_test_bayes, axis=1)
error_rate_test_bayes = jnp.sum(ypred_test_bayes != ytest) / len(ytest)
print(error_rate_test_bayes)

0.26700002
0.303


In [106]:

network = MLPNetwork((nclasses,)) # no hidden layers == logistic regression
mlp = NeuralNetClassifier(network, key, nclasses, l2reg=1e-5, optimizer = "adam+warmup", 
        batch_size=32, num_epochs=20, print_every=10)  
mlp.fit(Xtrain, ytrain)

yprobs_train_mlp = np.array(mlp.predict(Xtrain))
yprobs_test_mlp = np.array(mlp.predict(Xtest))

ypred_train_mlp = jnp.argmax(yprobs_train_mlp, axis=1)
error_rate_train_mlp = jnp.sum(ypred_train_mlp != ytrain) / len(ytrain)
print(error_rate_train_mlp)

ypred_test_mlp = jnp.argmax(yprobs_test_mlp, axis=1)
error_rate_test_mlp = jnp.sum(ypred_test_mlp != ytest) / len(ytest)
print(error_rate_test_mlp)

  return x.astype(jnp.float_)


epoch 0, train loss 47.790
epoch 10, train loss 43.175
0.45600003
0.453


In [112]:
network = MLPNetwork((10, 10, nclasses,)) # num hidden units per layer
mlp = NeuralNetClassifier(network, key, nclasses, l2reg=1e-5, optimizer = "adam+warmup", 
        batch_size=32, num_epochs=50, print_every=10)  
mlp.fit(Xtrain, ytrain)

yprobs_train_mlp = np.array(mlp.predict(Xtrain))
yprobs_test_mlp = np.array(mlp.predict(Xtest))

ypred_train_mlp = jnp.argmax(yprobs_train_mlp, axis=1)
error_rate_train_mlp = jnp.sum(ypred_train_mlp != ytrain) / len(ytrain)
print(error_rate_train_mlp)

ypred_test_mlp = jnp.argmax(yprobs_test_mlp, axis=1)
error_rate_test_mlp = jnp.sum(ypred_test_mlp != ytest) / len(ytest)
print(error_rate_test_mlp)

epoch 0, train loss 96.741
epoch 10, train loss 84.876
epoch 20, train loss 82.873
epoch 30, train loss 79.111
epoch 40, train loss 75.705
0.26200002
0.42200002


In [111]:
network = MLPNetwork((10, 10, nclasses,)) # num hidden units per layer
mlp = NeuralNetClassifier(network, key, nclasses, l2reg=1e-5, optimizer =  optax.adam(1e-3), 
        batch_size=32, num_epochs=50, print_every=10)  
mlp.fit(Xtrain, ytrain)

yprobs_train_mlp = np.array(mlp.predict(Xtrain))
yprobs_test_mlp = np.array(mlp.predict(Xtest))

ypred_train_mlp = jnp.argmax(yprobs_train_mlp, axis=1)
error_rate_train_mlp = jnp.sum(ypred_train_mlp != ytrain) / len(ytrain)
print(error_rate_train_mlp)

ypred_test_mlp = jnp.argmax(yprobs_test_mlp, axis=1)
error_rate_test_mlp = jnp.sum(ypred_test_mlp != ytest) / len(ytest)
print(error_rate_test_mlp)

epoch 0, train loss 98.647
epoch 10, train loss 89.701
epoch 20, train loss 85.274
epoch 30, train loss 82.865
epoch 40, train loss 81.472
0.36
0.41900003


In [107]:
network = MLPNetwork((5, nclasses,)) # 5 hidden units in first layer
#network = MLPNetwork((nclasses,)) # no hidden layers == logistic regression
mlp = NeuralNetClassifier(network, key, nclasses, l2reg=1e-5, optimizer = "adam+warmup", 
        batch_size=32, num_epochs=20, print_every=10)  
mlp.fit(Xtrain, ytrain)

yprobs_train_mlp = np.array(mlp.predict(Xtrain))
yprobs_test_mlp = np.array(mlp.predict(Xtest))

ypred_train_mlp = jnp.argmax(yprobs_train_mlp, axis=1)
error_rate_train_mlp = jnp.sum(ypred_train_mlp != ytrain) / len(ytrain)
print(error_rate_train_mlp)

ypred_test_mlp = jnp.argmax(yprobs_test_mlp, axis=1)
error_rate_test_mlp = jnp.sum(ypred_test_mlp != ytest) / len(ytest)
print(error_rate_test_mlp)

epoch 0, train loss 59.498
epoch 10, train loss 43.852
0.37500003
0.409


In [108]:
network = MLPNetwork((5, nclasses,)) # 5 hidden units in first layer
mlp = NeuralNetClassifier(network, key, nclasses, l2reg=1e-5, optimizer = optax.adam(1e-3), 
        batch_size=32, num_epochs=50, print_every=10)  
mlp.fit(Xtrain, ytrain)

yprobs_train_mlp = np.array(mlp.predict(Xtrain))
yprobs_test_mlp = np.array(mlp.predict(Xtest))

ypred_train_mlp = jnp.argmax(yprobs_train_mlp, axis=1)
error_rate_train_mlp = jnp.sum(ypred_train_mlp != ytrain) / len(ytrain)
print(error_rate_train_mlp)

ypred_test_mlp = jnp.argmax(yprobs_test_mlp, axis=1)
error_rate_test_mlp = jnp.sum(ypred_test_mlp != ytest) / len(ytest)
print(error_rate_test_mlp)

epoch 0, train loss 64.890
epoch 10, train loss 56.047
epoch 20, train loss 52.202
epoch 30, train loss 49.440
epoch 40, train loss 47.341
0.42700002
0.44200003
