goal
- SVGP + dirichlet training on mnist
    - recreate evidential DL example ... 
- variational learning of supporting image patches !
    - impl STN ... 

In [None]:
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))

import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
os.environ['TF_CPP_VMODULE'] = '=bfc_allocator=1'
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '0'
os.environ['XLA_FLAGS'] = '--xla_gpu_cuda_data_dir=/usr/local/cuda'
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform"
os.environ['LD_LIBRARY_PATH'] = '${LD_LIBRARY_PATH}:/usr/local/cuda/lib64'

from collections import defaultdict

import scipy
import numpy as onp
onp.set_printoptions(precision=3,suppress=True)

import jax
import jax.numpy as np
from jax import grad, jit, vmap, device_put, random
from flax import linen as nn
from jax.scipy.stats import dirichlet

from jax.lib import xla_bridge
print(xla_bridge.get_backend().platform)
print(jax.local_device_count())
print(jax.devices())

import matplotlib.pyplot as plt
import matplotlib as mpl
import matplotlib.tri as tri
# https://matplotlib.org/3.1.1/gallery/style_sheets/style_sheets_reference.html
mpl.rcParams['lines.linewidth'] = 3
mpl.rcParams['font.size'] = 25
mpl.rcParams['font.family'] = 'Times New Roman' 
cmap = plt.cm.get_cmap('bwr')

from tabulate import tabulate
from functools import partial

from plt_utils import *
from gpax import *

In [None]:
import torch
import torchvision

def dataset_subset(dataset, y):
    # Convert labels to to [0,...,len(y)]
    #     e.g. y = [6, 8] y_train/y_test will only have {0,1}
    #
    ind = np.any(np.stack([dataset.targets.numpy()==i for i in y]), axis=0)
    ind = torch.where(torch.tensor(onp.array(ind))==True)[0]
    F = torch.zeros((10,), dtype=torch.float32)
    F[y] = torch.arange(len(y), dtype=torch.float32)
    dataset.targets = F[dataset.targets[ind].to(torch.int64)]
    dataset.data = dataset.data[ind]
    return dataset

key = random.PRNGKey(1)

# https://stackoverflow.com/questions/66577151/http-error-when-trying-to-download-mnist-data
new_mirror = 'https://ossci-datasets.s3.amazonaws.com/mnist'
torchvision.datasets.MNIST.resources = [
   ('/'.join([new_mirror, url.split('/')[-1]]), md5)
   for url, md5 in torchvision.datasets.MNIST.resources
]

transforms = torchvision.transforms.Compose([
    lambda x: np.asarray(x)[...,np.newaxis] / 255.
])

train_dataset = torchvision.datasets.MNIST('./data', train=True, transform=transforms, download=True)
test_dataset = torchvision.datasets.MNIST('./data', train=False, transform=transforms, download=True)
digits = [6, 8]
train_dataset = dataset_subset(train_dataset, digits)
test_dataset = dataset_subset(test_dataset, digits)
X_train = jax_to_gpu(np.asarray(train_dataset.data[...,np.newaxis]) / 255.)
y_train = jax_to_gpu(np.asarray(train_dataset.targets[...,np.newaxis]))
X_test = jax_to_gpu(np.asarray(test_dataset.data[...,np.newaxis]) / 255.)
y_test = jax_to_gpu(np.asarray(test_dataset.targets[...,np.newaxis]))
data_train = (X_train, y_train)
data_test = (X_test, y_test)

In [None]:
print(np.where(y_test==1)[0][:10])

ind = 31
n_ims = 20

fig, axs = plt.subplots(1,n_ims,figsize=(2*n_ims,2))
[axi.set_xticks([]) for axi in axs.ravel()]; [axi.set_yticks([]) for axi in axs.ravel()]

x, y = X_test[ind], y_test[ind]

x_rot = rotated_ims(x, n_ims=n_ims)
for i in range(n_ims):
    ax = axs[i]
    ax.imshow(x_rot[i], cmap='Greys')



In [None]:
params = pytree_load(CNN().init(key, np.ones((1,28,28,1))), './cnn_params_binary.pkl')
logits = CNN(output_dim=2).apply(params, x_rot)
prob = np.exp(logits)

fig, ax = plt.subplots(1, 1, figsize=(15, 10))
for c in [1,3,5]:
    ax.plot(np.arange(n_ims), prob[:,c], label=f'{c}')
    

ax.legend()
ax.grid()

In [None]:
num_classes = 2


class CNN(nn.Module):
    """A simple CNN model."""
    output_dim: int = 2

    @nn.compact
    def __call__(self, x):
        x = nn.Conv(features=32, kernel_size=(3, 3))(x)
        x = nn.relu(x)
        x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
        x = nn.Conv(features=64, kernel_size=(3, 3))(x)
        x = nn.relu(x)
        x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
        x = x.reshape((x.shape[0], -1))
        x = nn.Dense(features=128)(x)
        x = nn.relu(x)
        x = nn.Dense(features=self.output_dim)(x)
        x = nn.log_softmax(x)
        return x
    

class CNNTrunk(nn.Module):

    @nn.compact
    def __call__(self, x):
        x = nn.Conv(features=32, kernel_size=(3, 3))(x)
        x = nn.relu(x)
        x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
        x = nn.Conv(features=64, kernel_size=(3, 3))(x)
        x = nn.relu(x)
        x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
        x = x.reshape((x.shape[0], -1))
        x = nn.Dense(features=128)(x)
        return x


def cross_entropy_loss(logits, labels):
    y_onehot = jax.nn.one_hot(labels, num_classes=num_classes).squeeze()
    return -np.mean(np.sum(y_onehot * logits, axis=-1))


def compute_metrics(logits, labels):
    loss = cross_entropy_loss(logits, labels)
    pred = np.argmax(logits, -1).reshape(-1,1)
    accuracy = np.mean(pred == labels)
    metrics = {'loss': loss,
               'accuracy': accuracy}
    return metrics


@jax.jit
def train_step(opt, batch, key):
    key, subkey = random.split(key)
    X, y = batch
    def loss_fn(params):
        logits = CNN().apply(params, X)
        loss = cross_entropy_loss(logits, y)
        return loss, logits
    grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
    (loss, logits), grad = grad_fn(opt.target)
    opt = opt.apply_gradient(grad)
    metrics = compute_metrics(logits, y)
    log = {'loss': loss,
           'accuracy': metrics['accuracy'],
           'dense0_kernel_gradnorm': linalg.norm(grad['params']['Dense_0']['kernel'])}
    return opt, log, key


@jax.jit
def eval_step(params, X):
    logits = CNN().apply(params, X)
    return logits


def eval_model(params, data_test, logit_fn=eval_step):
    test_n_batches, test_batches = get_data_stream(
        random.PRNGKey(0), 100, data_test)

    logits = []; labels = []
    for _ in range(test_n_batches):
        batch = next(test_batches)
        X, y = batch
        logit = logit_fn(params, X)
        labels.append(y.reshape(-1, 1))
        logits.append(logit)

    logits = np.vstack(logits)
    labels = np.vstack(labels)
    metrics = compute_metrics(logits, labels)
    metrics = jax.tree_map(lambda x: x.item(), metrics)
    return metrics



In [None]:
model = CNN()
params = model.init(key, np.ones((1,28,28,1)))
opt = flax_create_optimizer(params, 'Adam', {'learning_rate': .03})

In [None]:

metrics = eval_model(params, data_test)
print(f'[{0:3}] test \t'
      f'Loss={metrics["loss"]:.3f}\t'
      f'accuracy={metrics["accuracy"]:.3f}\t')


In [None]:

n_epochs = 10

for epoch in range(n_epochs):
    logs = defaultdict(list)
    for it in range(train_n_batches):
        step = epoch*train_n_batches+it
        batch = next(train_batches)
        opt, log, key = train_step(opt, batch, key)
        params = opt.target
        for k, v in log.items():
            logs[k].append(v)
        if step%(train_n_batches//10)==0:
            avg_metrics = {k: np.mean(np.array(v))
                           for k, v in logs.items()}
            print(f'[{epoch:3}|{100*it/train_n_batches:5.2f}%]\t'
                  f'Loss={avg_metrics["loss"]:.3f}\t'
                  f'accuracy={avg_metrics["accuracy"]:.3f}\t'
                  f'norm(Dense0.k)={avg_metrics["dense0_kernel_gradnorm"]:.3f}')
    
    metrics = eval_model(params, data_test)
    print(f'[{epoch:3}] test \t'
          f'Loss={metrics["loss"]:.3f}\t'
          f'accuracy={metrics["accuracy"]:.3f}\t')
    
# pytree_save(opt.target, './cnn_params_binary.pkl')

In [None]:
cnn_save_path = f'./cnn_params_{",".join([str(x) for x in digits])}.pkl'
pytree_save(opt.target, cnn_save_path)
params = pytree_load(CNN().init(key, np.ones((1,28,28,1))), cnn_save_path)

metrics = eval_model(params, data_test)
print(f'Loss={metrics["loss"]:.3f}\t'
      f'accuracy={metrics["accuracy"]:.3f}\t')

In [None]:
output_dim = 2
α_ϵ = 1; α_δ = 10; n_mc_samples = 20; n_inducing = 200

lik_type = 'LikMulticlassDirichlet' # LikMulticlassDirichlet, LikMulticlassSoftmax, LikMultipleNormalKron
init_val_m = gamma_to_lognormal(np.array([1.]))[0] \
    if lik_type == 'LikMulticlassDirichlet' else np.array([.5])
mean_fn_cls = partial(MeanConstant,
                      output_dim=output_dim,
                      init_val_m=init_val_m,
                      flat=False)
if lik_type == 'LikMulticlassDirichlet':
    lik_cls = partial(LikMulticlassDirichlet,
                      output_dim=output_dim,
                      init_val_α_ϵ=α_ϵ,
                      init_val_α_δ=α_δ,
                      n_mc_samples=n_mc_samples)
elif lik_type == 'LikMulticlassSoftmax':
    lik_cls = partial(LikMulticlassSoftmax,
                      output_dim=output_dim,
                      n_mc_samples=n_mc_samples)
else:
    lik_cls = partial(LikMultipleNormalKron,
                      output_dim=output_dim)
    
# previous 
# kx_cls = partial(CovSEwithEncoder, g_cls=CNNTrunk)
# kt_cls = partial(CovIndex, output_dim=output_dim, rank=1)
# k_cls = partial(CovICM, kx_cls=kx_cls, kt_cls=kt_cls)
# more efficient
kx_cls = partial(CovSE, output_scaling=True)
k_cls = partial(CovMultipleOutputIndependent,
                k_cls=kx_cls,
                output_dim=output_dim,
                g_cls=CNNTrunk)
Xu_initial = X_train[np.linspace(0,len(X_train)-1,n_inducing).astype(np.int32)].copy()
# inducing_loc_cls = partial(InducingLocations,
#                            shape=(n_inducing, 28, 28, 1),
#                            init_fn_inducing=lambda k,s: Xu_initial)
inducing_loc_cls = partial(InducingLocationsST,
                           shape=(n_inducing, 28, 28, 1),
                           init_fn_inducing=lambda k,s: Xu_initial,
                           trans_type='3')



model = SVGP(mean_fn_cls=mean_fn_cls,
             k_cls=k_cls,
             lik_cls=lik_cls,
             inducing_loc_cls=inducing_loc_cls,
             n_data=len(X_train),
             output_dim=output_dim)


params = model.get_init_params(model, key)
params

In [None]:
# load pretrained weights & set initial values
cnn_save_path = 'cnn_params_6,8.pkl'

if isinstance(k_cls(), CovICM):
    encoder_params = pytree_load({'params': params['params']['k']['kx']['g']}, cnn_save_path)
    encoder_params_kvs = pytree_get_kvs(encoder_params)
    params = pytree_mutate(params, {f'params/k/kx/g/{k}': v for k,v in encoder_params_kvs.items()})
    params = pytree_mutate(params, {'params/k/kx/ls': softplus_inv(np.array([2.]))})
else:
    encoder_params = pytree_load({'params': params['params']['k']['g']}, cnn_save_path)
    encoder_params_kvs = pytree_get_kvs(encoder_params)
    params = pytree_mutate(params, {f'params/k/g/{k}': v for k,v in encoder_params_kvs.items()})
    params = pytree_mutate(params, {f'params/k/ks_{i}/ls': softplus_inv(np.array([2.]))
                                    for i in range(output_dim)})

# create optimizer
# opt = flax_create_multioptimizer_2focus(params, 'Adam',
#                                         [{'learning_rate': 0.}, {'learning_rate': 0.03}],
#                                         ['g', 'mean_fn', 'Xu/X']) #  'k/ks_0/σ2', 'k/ks_1/σ2'

kwd_notrain = ['g', 'mean_fn', 'Xu/X']
kwd_trainslow = ['Xu/T']
opt = flax_create_multioptimizer(
    params, 'Adam',
    [{'learning_rate': 0.}, {'learning_rate': .01}, {'learning_rate': .03}],
    [lambda p, v: pytree_path_contains_keywords(p, kwd_notrain),
     lambda p, v: pytree_path_contains_keywords(p, kwd_trainslow),
     lambda p, v: not pytree_path_contains_keywords(p, kwd_notrain+kwd_trainslow)])

flax_check_multiopt(params, opt)
# params

In [None]:
######################################################
import time

@jax.jit
def eval_step2(params, X):
    Ey, Vy = model.apply(params, X, method=model.pred_y, rngs={'lik_mc_samples': key})
    return Ey

@jax.jit
def train_step2(step, opt, batch, key):
    key, subkey = random.split(key)
    X, y = batch
    y_onehot = jax.nn.one_hot(y.squeeze(), num_classes=output_dim)
    def loss_fn(params):
        fx = model.apply(params,
                         (X, y_onehot),
                         method=model.mll,
                         rngs={'lik_mc_samples': subkey})
        return -fx, {}
    grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
    (loss, aux), grad = grad_fn(opt.target)
    opt = opt.apply_gradient(grad)
    log = {'loss': loss,
           'lik.α_ϵ': pytree_leaf(opt.target, 'params/lik/α_ϵ'),
           'k.ls': jax.nn.softplus(pytree_leaf(opt.target, f'params/k/kx/ls')
                                   if isinstance(model.k_cls(), CovICM) else 
                                   np.hstack([pytree_leaf(opt.target, f'params/k/ks_{i}/ls')
                                              for i in range(output_dim)])),
           'k.σ2': jax.nn.softplus(pytree_leaf(opt.target, f'params/k/kx/σ2')
                                   if isinstance(model.k_cls(), CovICM) else 
                                   np.hstack([pytree_leaf(opt.target, f'params/k/ks_{i}/σ2')
                                              for i in range(output_dim)]))}
    return opt, log, key


bsz = 64
train_n_batches, train_batches = get_data_stream(key, bsz, data_train)

n_epochs = 5

for epoch in range(n_epochs):
    start = time.time()
    for it in range(train_n_batches):
        step = epoch*train_n_batches+it
        batch = next(train_batches)
        opt, log, key = train_step2(step, opt, batch, key)
        params = opt.target
        if step%(train_n_batches//20)==0:
            print(f'[{epoch:3}|{100*it/train_n_batches:5.2f}%]\t'
                  f'Time={time.time()-start:.3f}\t'
                  f'Loss={log["loss"]:.3f}\t'
                  f'lik.α_ϵ={log["lik.α_ϵ"]}\t'
                  f'k.ls={log["k.ls"][:3]}\t'
                  f'k.σ2={log["k.σ2"][:3]}\t')
            start = time.time()


    metrics = eval_model(opt.target, data_test, logit_fn=eval_step2)
    print(f'[{epoch:3}] test \t'
          f'Loss={metrics["loss"]:.3f}\t'
          f'accuracy={metrics["accuracy"]:.3f}\t')

params = opt.target
m = model.bind(params, rngs={'lik_mc_samples': key})

In [None]:
m = model.bind(opt.target, rngs={'lik_mc_samples': key})
Xu = m.Xu()

ind = np.arange(10)+30

fig, axs = plt.subplots(2,len(ind),figsize=(5*len(ind),10))
[axi.set_xticks([]) for axi in axs.ravel()]; [axi.set_yticks([]) for axi in axs.ravel()]

for i in range(len(ind)):
    ax = axs[0,i]
    ax.imshow(Xu_initial[ind[i]], cmap='Greys', vmin=0, vmax=1)
    ax = axs[1,i]
    ax.imshow(Xu[ind[i]], cmap='Greys', vmin=0, vmax=1)

    
    

In [None]:
 
@jax.jit
def eval_step2(params, X):
    Ey, Vy = model.apply(params, X, method=model.pred_y, full_cov=False, rngs={'lik_mc_samples': key})
    Ey = Ey.reshape(-1,output_dim); Vy = Vy.reshape(-1,output_dim)
    return Ey


metrics = eval_model(opt.target, data_test, logit_fn=eval_step2)
print(f'[{epoch:3}] test \t'
      f'Loss={metrics["loss"]:.3f}\t'
      f'accuracy={metrics["accuracy"]:.3f}\t')