goal
- [done] SVGP + dirichlet training on mnist
    - recreate evidential DL example ... 
- [done] variational learning of supporting image patches !
    - [done] impl STN ... 
    - observations
        - perf okay if allows finetune encoder network 
        - not so much as evidence that is localized ... 
            - perhaps due to fact shared inducing locations ... \
                would want to retain all info (not localized) and \
                use variational mean to modulate evidence for class
            - so to get localized info ... might want to do per-class inducing variables
            - also might want to put STN to kernel hyperparam ...\
                and put product kernel over both image and affine trans matrix 
            - might also try just using one STN for entire thing 

In [None]:
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))

import time
import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ['CUDA_VISIBLE_DEVICES'] = '2'
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform"

from collections import defaultdict

import scipy
import numpy as onp
onp.set_printoptions(precision=3,suppress=True)

import jax
import jax.numpy as np
from jax import grad, jit, vmap, device_put, random
from flax import linen as nn
from jax.scipy.stats import dirichlet

from jax.lib import xla_bridge
print('jax/jaxlib: ', jax.__version__, jax.lib.version)
print(xla_bridge.get_backend().platform)
print(jax.local_device_count())
print(jax.devices())

import matplotlib.pyplot as plt
import matplotlib as mpl
import matplotlib.tri as tri
# https://matplotlib.org/3.1.1/gallery/style_sheets/style_sheets_reference.html
mpl.rcParams['lines.linewidth'] = 3
mpl.rcParams['font.size'] = 25
mpl.rcParams['font.family'] = 'Times New Roman' 
cmap = plt.cm.get_cmap('bwr')

from tabulate import tabulate
from functools import partial

from plt_utils import *
from gpax import *

In [None]:
import torch
import torchvision

def dataset_subset(dataset, y):
    # Convert labels to to [0,...,len(y)]
    #     e.g. y = [6, 8] y_train/y_test will only have {0,1}
    #
    ind = np.any(np.stack([dataset.targets.numpy()==i for i in y]), axis=0)
    ind = torch.where(torch.tensor(onp.array(ind))==True)[0]
    F = torch.zeros((10,), dtype=torch.float32)
    F[y] = torch.arange(len(y), dtype=torch.float32)
    dataset.targets = F[dataset.targets[ind].to(torch.int64)]
    dataset.data = dataset.data[ind]
    return dataset

key = random.PRNGKey(1)

# https://stackoverflow.com/questions/66577151/http-error-when-trying-to-download-mnist-data
new_mirror = 'https://ossci-datasets.s3.amazonaws.com/mnist'
torchvision.datasets.MNIST.resources = [
   ('/'.join([new_mirror, url.split('/')[-1]]), md5)
   for url, md5 in torchvision.datasets.MNIST.resources
]

transforms = torchvision.transforms.Compose([
    lambda x: np.asarray(x)[...,np.newaxis] / 255.
])

train_dataset = torchvision.datasets.MNIST('./data', train=True, transform=transforms, download=True)
test_dataset = torchvision.datasets.MNIST('./data', train=False, transform=transforms, download=True)
digits = [1,2,5]
train_dataset = dataset_subset(train_dataset, digits)
test_dataset = dataset_subset(test_dataset, digits)
X_train = jax_to_gpu(np.asarray(train_dataset.data[...,np.newaxis]) / 255.)
y_train = jax_to_gpu(np.asarray(train_dataset.targets[...,np.newaxis]))
X_test = jax_to_gpu(np.asarray(test_dataset.data[...,np.newaxis]) / 255.)
y_test = jax_to_gpu(np.asarray(test_dataset.targets[...,np.newaxis]))
data_train = (X_train, y_train)
data_test = (X_test, y_test)

In [None]:
print(np.where(y_test==0)[0][:10])

ind = 1
n_ims = 20

fig, axs = plt.subplots(1,n_ims,figsize=(2*n_ims,2))
[axi.set_xticks([]) for axi in axs.ravel()]; [axi.set_yticks([]) for axi in axs.ravel()]

x, y = X_test[ind], y_test[ind]

x_rot = rotated_ims(x, n_ims=n_ims)
for i in range(n_ims):
    ax = axs[i]
    ax.imshow(x_rot[i], cmap='Greys')



In [None]:
output_dim = len(digits)


class CNN(nn.Module):
    """A simple CNN model."""
    output_dim: int = 2

    @nn.compact
    def __call__(self, x):
        x = nn.Conv(features=32, kernel_size=(3, 3))(x)
        x = nn.relu(x)
        x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
        x = nn.Conv(features=64, kernel_size=(3, 3))(x)
        x = nn.relu(x)
        x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
        x = x.reshape((x.shape[0], -1))
        x = nn.Dense(features=128)(x)
        x = nn.relu(x)
        x = nn.Dense(features=self.output_dim)(x)
        x = nn.log_softmax(x)
        return x
    

class CNNTrunk(nn.Module):

    @nn.compact
    def __call__(self, x):
        x = nn.Conv(features=32, kernel_size=(3, 3))(x)
        x = nn.relu(x)
        x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
        x = nn.Conv(features=64, kernel_size=(3, 3))(x)
        x = nn.relu(x)
        x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
        x = x.reshape((x.shape[0], -1))
        x = nn.Dense(features=128)(x)
        return x


def cross_entropy_loss(logits, labels):
    y_onehot = jax.nn.one_hot(labels, num_classes=output_dim).squeeze()
    return -np.mean(np.sum(y_onehot * logits, axis=-1))


def compute_metrics(logits, labels):
    loss = cross_entropy_loss(logits, labels)
    pred = np.argmax(logits, -1).reshape(-1,1)
    accuracy = np.mean(pred == labels)
    metrics = {'loss': loss,
               'accuracy': accuracy}
    return metrics


@jax.jit
def train_step(opt, batch, key):
    key, subkey = random.split(key)
    X, y = batch
    def loss_fn(params):
        logits = CNN(output_dim=output_dim).apply(params, X)
        loss = cross_entropy_loss(logits, y)
        return loss, logits
    grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
    (loss, logits), grad = grad_fn(opt.target)
    opt = opt.apply_gradient(grad)
    metrics = compute_metrics(logits, y)
    log = {'loss': loss,
           'accuracy': metrics['accuracy'],
           'dense0_kernel_gradnorm': linalg.norm(grad['params']['Dense_0']['kernel'])}
    return opt, log, key


@jax.jit
def eval_step(params, X):
    logits = CNN(output_dim=output_dim).apply(params, X)
    return logits


def eval_model(params, data_test, logit_fn=eval_step):
    test_n_batches, test_batches = get_data_stream(
        random.PRNGKey(0), 100, data_test)

    logits = []; labels = []
    for _ in range(test_n_batches):
        batch = next(test_batches)
        X, y = batch
        logit = logit_fn(params, X)
        labels.append(y.reshape(-1, 1))
        logits.append(logit)

    logits = np.vstack(logits)
    labels = np.vstack(labels)
    metrics = compute_metrics(logits, labels)
    metrics = jax.tree_map(lambda x: x.item(), metrics)
    return metrics



In [None]:
model = CNN(output_dim=output_dim)
params = model.init(key, np.ones((1,28,28,1)))
opt = flax_create_optimizer(params, 'Adam', {'learning_rate': .03})

In [None]:

metrics = eval_model(params, data_test)
print(f'[{0:3}] test \t'
      f'Loss={metrics["loss"]:.3f}\t'
      f'accuracy={metrics["accuracy"]:.3f}\t')


In [None]:

bsz = 64
train_n_batches, train_batches = get_data_stream(key, bsz, data_train)
n_epochs = 10


for epoch in range(n_epochs):
    logs = defaultdict(list)
    for it in range(train_n_batches):
        step = epoch*train_n_batches+it
        batch = next(train_batches)
        opt, log, key = train_step(opt, batch, key)
        params = opt.target
        for k, v in log.items():
            logs[k].append(v)
        if step%(train_n_batches//10)==0:
            avg_metrics = {k: np.mean(np.array(v))
                           for k, v in logs.items()}
            print(f'[{epoch:3}|{100*it/train_n_batches:5.2f}%]\t'
                  f'Loss={avg_metrics["loss"]:.3f}\t'
                  f'accuracy={avg_metrics["accuracy"]:.3f}\t'
                  f'norm(Dense0.k)={avg_metrics["dense0_kernel_gradnorm"]:.3f}')
    
    metrics = eval_model(params, data_test)
    print(f'[{epoch:3}] test \t'
          f'Loss={metrics["loss"]:.3f}\t'
          f'accuracy={metrics["accuracy"]:.3f}\t')
    

In [None]:
cnn_save_path = f'./cnn_params_{",".join([str(x) for x in digits])}.pkl'
pytree_save(opt.target, cnn_save_path)
params = pyhttp://localhost:8895/notebooks/misc_impl/gp/note_svgp_mnist.ipynb#tree_load(CNN(output_dim=output_dim).init(key, np.ones((1,28,28,1))), cnn_save_path)

metrics = eval_model(params, data_test)
print(f'Loss={metrics["loss"]:.3f}\t'
      f'accuracy={metrics["accuracy"]:.3f}\t')

In [None]:
# Variations of inducing points
#     1. Application of differentiable transformation, defined via `transform_cls`
#     2. Whether inducing points wrt patches or images.
# 


def reinitialize_inducing(params, key, transform_cls, X_train, L=10):
    """Make full use of the set of inducing points.
        - find L inducing points with smallest average μ magnitude
        - re-initialize transformation parameters randomly
        - update corresponding images in `Xu/X`
    """

    ## find L inducing points with smallest average μ magnitude
    im = pytree_leaf(params, 'params/Xu/X')
    qμ = pytree_leaf(params, 'params/q/μ')
    μconst = pytree_leaf(params, 'params/mean_fn/c')
    qμ_mag = np.mean(np.abs(qμ - μconst[...,np.newaxis]), axis=0)
    ind = np.argsort(qμ_mag)[:L]

    ## re-initialize transformation parameters randomly

    def reinitialize_T(T):
        trans = transform_cls()
        trans_params = {'params': pytree_leaf(params, 'params/Xu/transform')}
        default_T_init_fn = trans.apply(trans_params, method=trans.default_T_init)[1]
        return jax.ops.index_update(T, ind, default_T_init_fn(key, (L,)))

    params = pytree_mutate_with_fn(params, 'params/Xu/transform/T', reinitialize_T)

    ## re-initialize random images `Xu/X`

    key, k2 = random.split(key)
    def reinitialize_X(X):
        Xind = np.take(X_train, random.randint(k2, (L,), 0, len(X_train)), axis=0)
        return jax.ops.index_update(X, ind, Xind)
    params = pytree_mutate_with_fn(params, 'params/Xu/X', reinitialize_X)
    
    return params, key

params = opt.target
params, key = reinitialize_inducing(params, key, transform_cls, X_train)




S = pytree_leaf(params, 'params/Xu/X')[ind]
fn = vmap(transform_to_matrix, (0, None, None), 0)
A = fn(pytree_leaf(params, 'params/Xu/transform/T'),
       transform_cls().T_type,
       transform_cls().A_init_val)[ind]

fn = vmap(spatial_transform_details, (0, 0, None), 0)
T, Gs = fn(A, S, patch_shape)
fig, axs = plt.subplots(2, len(A), figsize=(3*len(A),3*2))
for i in range(len(T)):
    plt_spatial_transform(axs[:,i], Gs[i], S[i], T[i])
fig.tight_layout()
plt.show()

In [None]:
key = random.PRNGKey(0)

lik_type = 'LikMulticlassDirichlet' # LikMulticlassDirichlet, LikMulticlassSoftmax, LikMultipleNormalKron
α_ϵ = 1; α_δ = 10; n_mc_samples = 20
image_shape = (28,28,1)
patch_shape = (7,7)
n_inducing = 50
T_type = 'transl' # 'transl', etc.
use_loc_kernel = True


init_val_m = gamma_to_lognormal(np.array([1.]))[0] \
    if lik_type == 'LikMulticlassDirichlet' else np.array([0.5])
mean_fn_cls = partial(MeanConstant, output_dim=output_dim, init_val_m=init_val_m, flat=False)
if lik_type == 'LikMulticlassDirichlet':
    lik_cls = partial(LikMulticlassDirichlet, output_dim=output_dim, init_val_α_ϵ=α_ϵ, init_val_α_δ=α_δ, n_mc_samples=n_mc_samples)
elif lik_type == 'LikMulticlassSoftmax':
    lik_cls = partial(LikMulticlassSoftmax, output_dim=output_dim, n_mc_samples=n_mc_samples)
else:
    lik_cls = partial(LikMultipleNormalKron, output_dim=output_dim)

g_cls = LayerIdentity # CNNTrunk
# kx_cls = partial(CovSE, output_scaling=True)

kl_cls = CovSE if use_loc_kernel else partial(CovConstant, train_σ2=False)
kx_cls = partial(CovConvolutional, image_shape=image_shape, patch_shape=patch_shape,
                 kg_cls=CovSE, patch_inducing_loc=True, kl_cls=kl_cls)
k_cls = partial(CovMultipleOutputIndependent, k_cls=kx_cls, output_dim=output_dim, g_cls=g_cls)
if T_type == '':
    transform_cls = LayerIdentity
    Xu_initial = get_init_patches(key, X_train, n_inducing, image_shape, patch_shape)
else:
    # transform_type = 'transl+isot_scal'; T_init_fn = lambda k,s: np.tile(np.array([.25, 0, 0]), (s[0], 1)); A_init_val = np.array([[1.,0,0],[0,1.,0]])
    scal = np.array(patch_shape)/np.array(image_shape[:2])
    A_init_val = trans2x3_from_scal_transl(scal,(0,0))
    T_init_fn = None
    transform_cls = partial(SpatialTransform, shape=patch_shape, n_transforms=n_inducing, 
                            T_type=T_type, A_init_val=A_init_val, output_transform=use_loc_kernel)
    Xu_initial = np.take(X_train, random.randint(key, (n_inducing,), 0, len(X_train)), axis=0)

inducing_loc_cls = partial(InducingLocations,
                           shape=Xu_initial.shape,
                           init_fn=lambda k,s: Xu_initial,
                           transform_cls=transform_cls)

print('Xu:', Xu_initial.shape)

model = SVGP(mean_fn_cls=mean_fn_cls,
             k_cls=k_cls,
             lik_cls=lik_cls,
             inducing_loc_cls=inducing_loc_cls,
             n_data=len(X_train),
             output_dim=output_dim)


params = model.get_init_params(model, key, X_shape=image_shape)
print(model)
pytree_keys(params)

In [None]:
# extract_patches_2d_scal_transl(image_shape, patch_shape)


class BijBox(object):
    lh: np.ndarray

    @staticmethod
    def forward(x):
        """ x -> exp(x) \in \R+ """
        return np.exp(jax.nn.sigmoid())

    @staticmethod
    def reverse(y):
        return np.log(y)

    

In [None]:
# if isinstance(inducing_loc_cls(), InducingLocations):
#     first_n = min(40, n_inducing)
#     Xu = inducing_loc_cls().apply({'params': pytree_leaf(params, 'params/Xu')})
#     fig, axs = plt.subplots(1,first_n,figsize=(3*first_n, 3))
#     for i in range(first_n):
#         axs[i].imshow(Xu[i,...].reshape(*patch_shape), cmap='Greys', vmin=0, vmax=1)

# # if isinstance(inducing_loc_cls(), InducingLocationsSpatialTransform):
# #     first_n = min(40, n_inducing)
# #     Xu = inducing_loc_cls().apply({'params': pytree_leaf(params, 'params/Xu')})
# #     fig, axs = plt.subplots(1,first_n,figsize=(3*first_n, 3))
# #     for i in range(first_n):
# #         axs[i].imshow(Xu[i,...].reshape(*patch_shape), cmap='Greys', vmin=0, vmax=1)
    
if isinstance(transform_cls(), SpatialTransform):
    ind = np.arange(n_inducing)
    
    m = model.bind(params)
    fn = vmap(transform_to_matrix, (0, None, None), 0)
    A = fn(pytree_leaf(params, 'params/Xu/transform/T'), m.Xu.transform.T_type, m.Xu.transform.A_init_val)
    S = pytree_leaf(params, 'params/Xu/X')
    A = A[ind]; S = S[ind]

    fn = vmap(spatial_transform_details, (0, 0, None), 0)
    T, Gs = fn(A, S, patch_shape)
    fig, axs = plt.subplots(2, len(A), figsize=(3*len(A),3*2))
    for i in range(len(T)):
        plt_spatial_transform(axs[:,i], Gs[i], S[i], T[i])
    fig.tight_layout()
    plt.show()


In [None]:
# load pretrained weights & set initial values
if g_cls != LayerIdentity:
    g_path = 'params/k/kx/g' if isinstance(k_cls(), CovICM) else 'params/k/g'
    encoder_params = pytree_load({'params': pytree_leaf(params, g_path)}, cnn_save_path)
    encoder_params_kvs = pytree_get_kvs(encoder_params)
    params = pytree_mutate(params, {f'{g_path}/{k}': v for k,v in encoder_params_kvs.items()})
# set initial lengthscales
if isinstance(k_cls(), CovICM):
    params = pytree_mutate(params, {'params/k/kx/ls': softplus_inv(np.array([2.]))})
else:
    params = pytree_mutate(params, {f'params/k/ks_{i}/ls': softplus_inv(np.array([2.]))
                                    for i in range(output_dim)})

kwd_notrain = ['mean_fn', 'Xu/X'] + [f'params/k/ks_{i}/kl/σ2' for i in range(output_dim)]
kwd_trainslow = ['Xu/transform'] # 'Xu/transform'
opt = flax_create_multioptimizer(
    params, 'Adam',
    [{'learning_rate': 0.}, {'learning_rate': .01}, {'learning_rate': .03}],
    [lambda p, v: pytree_path_contains_keywords(p, kwd_notrain),
     lambda p, v: pytree_path_contains_keywords(p, kwd_trainslow),
     lambda p, v: not pytree_path_contains_keywords(p, kwd_notrain+kwd_trainslow)])

flax_check_multiopt(params, opt)
# params

In [None]:
######################################################
        
@jax.jit
def eval_step2(params, X):
    Ey, Vy = model.apply(params, X, method=model.pred_y, rngs={'lik_mc_samples': key})
    return Ey

@jax.jit
def train_step2(step, opt, batch, key):
    key, subkey = random.split(key)
    X, y = batch
    y_onehot = jax.nn.one_hot(y.squeeze(), num_classes=output_dim)
    def loss_fn(params):
        fx = model.apply(params,
                         (X, y_onehot),
                         method=model.mll,
                         rngs={'lik_mc_samples': subkey})
        return -fx, {}
    grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
    params = opt.target
    (loss, aux), grad = grad_fn(params)
    opt = opt.apply_gradient(grad)
    log = {'loss': loss}
    return opt, log, key


bsz = 64
train_n_batches, train_batches = get_data_stream(
    key, bsz, data_train)

n_epochs = 5

for epoch in range(n_epochs):
    start = time.time()
    for it in range(train_n_batches):
        step = epoch*train_n_batches+it
        batch = next(train_batches)
        opt, log, key = train_step2(step, opt, batch, key)
        params = opt.target
        if step%(train_n_batches//5)==0:
            insert_str = '/kg' if isinstance(model.k_cls().k_cls(), CovConvolutional) else ''
            log.update({'k.ls': jax.nn.softplus(pytree_leaf(params, f'params/k/kx/ls')
                                           if isinstance(model.k_cls(), CovICM) else 
                                           np.hstack([pytree_leaf(params, f'params/k/ks_{i}{insert_str}/ls')
                                                      for i in range(output_dim)])),
                   'k.σ2': jax.nn.softplus(pytree_leaf(params, f'params/k/kx/σ2')
                                           if isinstance(model.k_cls(), CovICM) else 
                                           np.hstack([pytree_leaf(params, f'params/k/ks_{i}{insert_str}/σ2')
                                                      for i in range(output_dim)])),
                   'kl.ℓ': jax.nn.softplus(np.hstack([pytree_leaf(params, f'params/k/ks_{i}/kl/ls')
                                                      for i in range(output_dim)])
                                           if  isinstance(model.k_cls(), CovMultipleOutputIndependent) and \
                                               not isinstance(model.k_cls().k_cls().kl_cls(), CovConstant) else np.array([np.nan]))})
            print(f'[{epoch:3}|{100*it/train_n_batches:5.2f}%]\t'
                  f'Loss={log["loss"]:.3f}\t'
                  f'Time={time.time()-start:.3f}\t'
                  f'k.ls={log["k.ls"][:3]}\t'
                  f'kl.ℓ = {log["kl.ℓ"][:3]}\t'
                  f'k.σ2={log["k.σ2"][:3]}\t')
            start = time.time()


    metrics = eval_model(params, data_test, logit_fn=eval_step2)
    print(f'[{epoch:3}] test \t'
          f'Loss={metrics["loss"]:.3f}\t'
          f'accuracy={metrics["accuracy"]:.3f}\t')

parmas = opt.target
m = model.bind(params, rngs={'lik_mc_samples': key})

# 
# n_inducing=50, LikMulticlassDirichlet
# CovConvolutional(patch_inducing_loc=False) acc=.7
# CovConvolutional(patch_inducing_loc=True, patch_shape=(3,3)) acc=.83
# CovConvolutional(patch_inducing_loc=True, patch_shape=(7,7)) acc=.95

# n_inducing=50, LikMulticlassDirichlet, patch_shape=(7, 7), patch_inducing_loc=True
# compare using location kernel over `transl`
# use_loc_kernel=False: center initialization of patches, acc=.93
# use_loc_kernel=True: center initialization of patches, acc=.93

# [  0| 0.00%]	Time=26.826	Loss=60202.617	k.ls=[1. 1. 1.]	kl.ℓ = [nan]	k.σ2=[1. 1. 1.]	
# [  0|19.72%]	Time=50.971	Loss=23863.398	k.ls=[1.601 1.362 1.369]	kl.ℓ = [nan]	k.σ2=[0.426 0.397 0.403]	
# [  0|39.44%]	Time=2.870	Loss=2301.921	k.ls=[1.715 1.88  2.032]	kl.ℓ = [nan]	k.σ2=[0.361 0.567 0.594]	
# [  0|59.15%]	Time=2.881	Loss=-2140.403	k.ls=[1.759 2.083 2.174]	kl.ℓ = [nan]	k.σ2=[0.241 1.098 1.122]


In [None]:
# Rank inducing patches based on variational μ for different classes

m = model.bind(params, rngs={'lik_mc_samples': key})
Xu = m.Xu()[0] if use_loc_kernel else m.Xu()
qm = m.q.μ - m.mean_fn.c[...,np.newaxis]

M = len(Xu)
    

gridspec_kw = {'width_ratios': np.ones((output_dim,)), 'height_ratios': [3,1]}
fig, axs = plt.subplots(2,3,figsize=(25,10))
n_top = n_inducing
c = 0; ind = np.argsort(qm[c,:])
ylim = (np.min(qm), np.max(qm))

for c in range(output_dim):
    ind = np.argsort(qm[c,:])[::-1]
    # variational μ
    ax = axs[0,c]
    for co in range(output_dim):
        ls = '--' if co != c else '-'
        ax.plot(np.arange(n_inducing), qm[co,ind], ls, label=f'{digits[co]}')
    ax.set_title(f'qμ ({digits[c]})', fontsize=35)
    ax.grid()
    ax.set_ylim(ylim)
    ax.legend(fontsize=20)
    
    # top weighted patches
    ax = axs[1,c]
    ims = Xu[ind].reshape((-1,*patch_shape))[:n_top]
    grid = make_im_grid(ims, im_per_row=min(len(ims), 10))
    ax.set_xticks([]); ax.set_yticks([])
    ax.imshow(grid, cmap='Greys', vmin=0, vmax=1)
    
fig.tight_layout()



In [None]:
# Rank inducing patches based on average variational μ
#     goal is to see if there are patches not used as evidence for any class 
#     sometime its patches where all classes have ... so cannot be used as discriminating features 
#

Xu = m.Xu()[0] if use_loc_kernel else m.Xu()
qm = pytree_leaf(params, 'params/q/μ')
μconst = pytree_leaf(params, 'params/mean_fn/c')[...,np.newaxis]
qm_mag = np.mean(np.abs(qm - μconst), axis=0)
ind = np.argsort(qm_mag)[::-1]

fig, axs = plt.subplots(2,1,figsize=(8,10))

ax = axs[0]
ax.plot(np.arange(n_inducing), qm_mag[ind], '-')
ax.set_title(f'qμ magnitude', fontsize=35)
ax.grid()
ax.set_ylim((-.5, np.max(qm_mag)+.5))
fig.tight_layout()

ax = axs[1]
ims = Xu[ind].reshape((-1,*patch_shape))
grid = make_im_grid(ims, im_per_row=min(len(ims), 10))
ax.set_xticks([]); ax.set_yticks([])
ax.imshow(grid, cmap='Greys', vmin=0, vmax=1)


## plot spatial transform as well

fn = vmap(transform_to_matrix, (0, None, None), 0)
A = fn(pytree_leaf(params, 'params/Xu/transform/T'),
       transform_cls().T_type,
       transform_cls().A_init_val)
S = pytree_leaf(params, 'params/Xu/X')
A = A[ind]; S = S[ind]

fn = vmap(spatial_transform_details, (0, 0, None), 0)
T, Gs = fn(A, S, patch_shape)
fig, axs = plt.subplots(2, len(A), figsize=(3*len(A),3*2))
for i in range(len(T)):
    plt_spatial_transform(axs[:,i], Gs[i], S[i], T[i])
fig.tight_layout()
plt.show()

In [None]:
# Rank inducing patches based on variational μ for different classes

m = model.bind(params)
Xu = m.Xu()
qm = m.q.μ - m.mean_fn.c[...,np.newaxis]

M = len(Xu)
    

gridspec_kw = {'width_ratios': np.ones((output_dim,)), 'height_ratios': [3,1]}
fig, axs = plt.subplots(2,3,figsize=(25,10))
n_top = n_inducing
c = 0; ind = np.argsort(qm[c,:])
ylim = (np.min(qm), np.max(qm))

for c in range(output_dim):
    ind = np.argsort(qm[c,:])[::-1]
    # variational μ
    ax = axs[0,c]
    for co in range(output_dim):
        ls = '--' if co != c else '-'
        ax.plot(np.arange(n_inducing), qm[co,ind], ls, label=f'{digits[co]}')
    ax.set_title(f'qμ ({digits[c]})', fontsize=35)
    ax.grid()
    ax.set_ylim(ylim)
    ax.legend(fontsize=20)
    
    # top weighted patches
    ax = axs[1,c]
    ims = Xu[ind].reshape((-1,*patch_shape))[:n_top]
    grid = make_im_grid(ims, im_per_row=min(len(ims), 10))
    ax.set_xticks([]); ax.set_yticks([])
    ax.imshow(grid, cmap='Greys', vmin=0, vmax=1)
    
fig.tight_layout()



In [None]:
# Rank inducing patches based on average variational μ
#     goal is to see if there are patches not used as evidence for any class 
#     sometime its patches where all classes have ... so cannot be used as discriminating features 
#

Xu = inducing_loc_cls().apply({'params': pytree_leaf(params, 'params/Xu')})
qm = pytree_leaf(params, 'params/q/μ')
μconst = pytree_leaf(params, 'params/mean_fn/c')[...,np.newaxis]
qm_mag = np.mean(np.abs(qm - μconst), axis=0)
ind = np.argsort(qm_mag)[::-1]

fig, axs = plt.subplots(2,1,figsize=(8,10))

ax = axs[0]
ax.plot(np.arange(n_inducing), qm_mag[ind], '-')
ax.set_title(f'qμ magnitude', fontsize=35)
ax.grid()
ax.set_ylim((-.5, np.max(qm_mag)+.5))
fig.tight_layout()

ax = axs[1]
ims = Xu[ind].reshape((-1,*patch_shape))
grid = make_im_grid(ims, im_per_row=min(len(ims), 10))
ax.set_xticks([]); ax.set_yticks([])
ax.imshow(grid, cmap='Greys', vmin=0, vmax=1)


## plot spatial transform as well

fn = vmap(transform_to_matrix, (0, None, None), 0)
A = fn(pytree_leaf(params, 'params/Xu/transform/T'),
       transform_cls().T_type,
       transform_cls().A_init_val)
S = pytree_leaf(params, 'params/Xu/X')
A = A[ind]; S = S[ind]

fn = vmap(spatial_transform_details, (0, 0, None), 0)
T, Gs = fn(A, S, patch_shape)
fig, axs = plt.subplots(2, len(A), figsize=(3*len(A),3*2))
for i in range(len(T)):
    plt_spatial_transform(axs[:,i], Gs[i], S[i], T[i])
fig.tight_layout()
plt.show()

In [None]:
## 
print(np.where(y_test==0)[0][:10])

ind_digit0 = np.where(y_test==0)[0]
ind_digit1 = np.where(y_test==1)[0]

ind = 1
n_ims = 20

x_rot = rotated_ims(X_test[ind], n_ims=n_ims)
# X = np.stack(X_test[ind_digit0[:10]])
# X = np.stack(X_test[np.hstack((ind_digit0[:n_ims//2], ind_digit1[:n_ims//2]))])
X = x_rot
μf, σ2f = model.apply(params, X, full_cov=False, method=model.pred_f); μf = μf.squeeze()
y_pred = np.array(digits)[np.argmax(μf,axis=-1)]
α = gamma_to_lognormal_inv(μf, σ2f, approx_type='kl')

if isinstance(model.lik_cls(), LikMulticlassSoftmax):
    lik_test = LikMulticlassSoftmax(output_dim=output_dim, n_mc_samples=5000)
    p_mc, Vp_mc = lik_test.apply({}, μf, σ2f, rngs={'lik_mc_samples': key}, method=lik_test.predictive_dist)
elif isinstance(model.lik_cls(), LikMulticlassDirichlet):
    lik_test = LikMulticlassDirichlet(output_dim=output_dim, n_mc_samples=5000)
    p_mc, Vp_mc = lik_test.apply({}, μf, σ2f, rngs={'lik_mc_samples': key}, method=lik_test.predictive_dist)
    α = gamma_to_lognormal_inv(μf, σ2f, approx_type='kl')
    α0 = np.sum(α, axis=-1, keepdims=True)
    p = α / α0
    Vp = p*(1-p)

gridspec_kw = {'width_ratios': [1],
               'height_ratios': [4, 4, 4, 1]}
fig, axs = plt.subplots(4, 1, gridspec_kw=gridspec_kw, figsize=(15, 15))
cmap = plt.cm.get_cmap('Set1')
colors = [cmap(0), cmap(1), cmap(2)]


ax = axs[0]
ax.set_xticks([])
for i, d in enumerate(digits):
    c = colors[i]
    μ, std = μf[:,i], np.sqrt(σ2f[:,i])
    ax.plot(np.arange(len(X)), μf[:,i], lw=2, color=c, label=d)
    ax.plot(np.arange(len(X)), μ + 2*std, '--', c=c)
    ax.plot(np.arange(len(X)), μ - 2*std, '--', c=c)
ax.legend()
ax.grid()
ax.set_title('predictive posterior p(f*|X)')


ax = axs[1]
ax.set_xticks([])
for i, d in enumerate(digits):
    c = colors[i]
    ax.plot(np.arange(len(X)), α[:,i], lw=2, color=c, label=d)
ax.grid()
ax.set_yticks(np.linspace(0, np.floor(np.max(α)*1.1), 5))
ax.set_title('α')


ax = axs[2]
ax.set_xticks([])
for i, d in enumerate(digits):
    c = colors[i]
    if isinstance(model.lik_cls(), LikMulticlassSoftmax) or \
        isinstance(model.lik_cls(), LikMulticlassDirichlet):
        ax.plot(np.arange(len(X)), p[:,i], '-', c=c, lw=2)
        ax.plot(np.arange(len(X)), p[:,i] + 2*np.sqrt(Vp[:,i]), '--', c=c)
        ax.plot(np.arange(len(X)), p[:,i] - 2*np.sqrt(Vp[:,i]), '--', c=c)
    else:
        ax.plot(np.arange(len(X)), μf[:,i], c=c, lw=2)
ax.set_title('E[p(y|f*)]')
ax.grid()

    
ax = axs[3]
ax.set_xticks([]); ax.set_yticks([])
ax.imshow(np.hstack([x for x in X]), cmap='Greys')

fig.tight_layout()

In [None]:
def mvn_marginal_variational_details(Kff, Kuf, mf,
                             Luu, mu, μq, Lq, full_cov=False):
    α = solve_triangular(Luu, Kuf, lower=True)
    β = solve_triangular(Luu.T, α, lower=False)
    γ = Lq.T@β
    if full_cov:
        Σf = Kff - α.T@α + γ.T@γ
    else:
        Σf = Kff - \
            np.sum(np.square(α), axis=0) + \
            np.sum(np.square(γ), axis=0)
    # for multiple-output
    μq = μq.reshape(-1, 1)
    mu = mu.reshape(-1, 1)
    mf = mf.reshape(-1, 1)
    A = β.T; δ = (μq-mu)
    μf = mf + A@δ
    return μf, Σf, A, δ, mf

def pred_f_details(self, Xs, full_cov=True):
    k = self.k
    Xu = self.Xu()               # (M,...)
    μq, Lq = self.q.μ, self.q.L  # (P, M) & (P, M, M)
    if μq.shape[0] == 1:
        μq, Lq = μq.squeeze(0), Lq.squeeze(0)

    ms = self.mean_fn(Xs)
    mu = self.mean_fn(Xu)

    Kss = k.Kff(Xs, full_cov=full_cov)
    Kus = k.Kuf(Xu, Xs)
    Kuu = k.Kuu(Xu)
    Luu = cholesky_jitter_vmap(Kuu, jitter=5e-5)  # (P, M, M)

    if isinstance(k, CovMultipleOutputIndependent):
        mvn_marginal_variational_fn = vmap(
            mvn_marginal_variational_details, (0, 0, 1, 0, 1, 0, 0, None), -1)  # along P-dim
    else:
        mvn_marginal_variational_fn = mvn_marginal_variational_details

    μf, Σf, A, δ, mf = mvn_marginal_variational_fn(Kss, Kus, ms,
                                         Luu, mu, μq, Lq, full_cov)
    # (N, D), (N, D), (N, M, D), (M, D), (N, D)
    N, D = Σf.shape; M = len(Xu)
    if not full_cov:
        μf = μf.reshape((N,D))
        A = A.reshape((N,M,D))
        δ = δ.reshape((M,D))
        mf = mf.reshape((N,D))
    return μf, Σf, A, δ, mf

x = X_test[:3]
μ, σ2, A, δ, mf = model.apply(params, x, full_cov=False, method=pred_f_details)
print(A.shape)

In [None]:

cmap = plt.cm.get_cmap('Set1')
colors = [cmap(0), cmap(1), cmap(2)]

In [None]:
N, M, D = A.shape  # (#test, #inducing, #classes)
fig, axs = plt.subplots(2,N,figsize=(30,15), sharey=True)
[axi.set_xticks([]) for axi in axs.ravel()];

for n in range(N):
    ax = axs[0,n]
    for i, d in enumerate(digits):
        ax.plot(A[n,:,i], np.arange(M), '--', c=colors[i], label=f'{d}')
        

for n in range(N):
    ax = axs[1,n]
    for i, d in enumerate(digits):
        ax.plot(A[n,:,i]*δ[:,i], np.arange(M), '--', c=colors[i], label=f'{d}')
        
fig.tight_layout()

In [None]:
n = 10
Aδ = (A*δ)

ind = np.argsort(np.abs(Aδ[n,:,0]))[::-1][:5]
print(list(zip(ind, Aδ[n,ind,0]))[:5])

Xu = model.inducing_loc_cls().apply({'params': params['params']['Xu']}, rngs=rngs)

fig, axs = plt.subplots(len(digits),1+len(ind),figsize=(5*(len(ind)+1),10), sharey=True)
[axi.set_xticks([]) for axi in axs.ravel()]; [axi.set_yticks([]) for axi in axs.ravel()]
    
for di in range(len(digits)):
    # original image
    ax = axs[di,0]
    ax.imshow(X[n], cmap='Greys', vmin=0, vmax=1)
    ax.set_xlabel(f'{np.sum(Aδ[n,:,di]):.2f}', fontsize=40, color='r')
    ax.set_ylabel(f'C={digits[di]}', fontsize=40)
    
    # take top evidence
    ind = np.argsort(np.abs(Aδ[n,:,di]))[::-1][:5]
    for i in range(len(ind)):
        ax = axs[di,i+1]
        ax.imshow(Xu[ind[i]], cmap='Greys', vmin=0, vmax=1)
        ax.set_xlabel(f'{Aδ[n,ind[i],di]:.2f}', fontsize=40)

fig.tight_layout()