goal 
- implement convolutional GP
    - interdomain inducing points in patch space
    - rectangles dataset

In [None]:
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))

from typing import Any, Callable, Sequence, Optional, Tuple, Union, List, Iterable
import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ['CUDA_VISIBLE_DEVICES'] = '3'
os.environ['TF_CPP_VMODULE'] = '=bfc_allocator=1'
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '0'
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform"

import jax
import jax.numpy as np
from jax import grad, jit, vmap, device_put, random
from flax import linen as nn

from jax.lib import xla_bridge
print(xla_bridge.get_backend().platform)
print(jax.local_device_count())
print(jax.devices())

import time
import copy

import numpy as onp
onp.set_printoptions(precision=3,suppress=True)
import matplotlib.pyplot as plt

import tensorflow as tf

from setup_convgp import *
from plt_utils import *
from gpax import *

In [None]:
def make_rectangle(arr, x0, y0, x1, y1):
    arr[y0:y1, x0] = 1
    arr[y0:y1, x1] = 1
    arr[y0, x0:x1] = 1
    arr[y1, x0 : x1 + 1] = 1


def make_random_rectangle(arr):
    x0 = onp.random.randint(1, arr.shape[1] - 3)
    y0 = onp.random.randint(1, arr.shape[0] - 3)
    x1 = onp.random.randint(x0 + 2, arr.shape[1] - 1)
    y1 = onp.random.randint(y0 + 2, arr.shape[0] - 1)
    make_rectangle(arr, x0, y0, x1, y1)
    return x0, y0, x1, y1


def make_rectangles_dataset(num, w, h):
    d, Y = onp.zeros((num, h, w)), onp.zeros((num, 1))
    for i, img in enumerate(d):
        for j in range(1000):  # Finite number of tries
            x0, y0, x1, y1 = make_random_rectangle(img)
            rw, rh = y1 - y0, x1 - x0
            if rw == rh:
                img[:, :] = 0
                continue
            Y[i, 0] = rw > rh
            break
    return (
        d.reshape(num, w * h).astype(onp.float32),
        Y.astype(onp.float32),
    )


## Data

onp.random.seed(123)
key = random.PRNGKey(0)

MAXITER = 2 # 100
NUM_TRAIN_DATA = 50 # 100
NUM_TEST_DATA = 100 # 300
H = W = 14  # width and height. In the original paper this is 28
h = w = 14
IMAGE_SHAPE = [H, W]


X, Y = make_rectangles_dataset(NUM_TRAIN_DATA, *IMAGE_SHAPE)
Xt, Yt = make_rectangles_dataset(NUM_TEST_DATA, *IMAGE_SHAPE)
X, Y, Xt, Yt = np.array(X), np.array(Y), np.array(Xt), np.array(Yt)
data = (X,Y); test_data = (Xt,  Yt)


print(X.shape, Y.shape, type(X))
print(Xt.shape, Yt.shape)

fig, axs = plt.subplots(1,4,figsize=(16, 6))
for i in range(4):
    ax = axs[i]
    ax.set_xticks([]); ax.set_yticks([])
    ax.imshow(X[i, :].reshape(*IMAGE_SHAPE), cmap='Greys')
    ax.set_title(Y[i, 0], fontsize=(25))
    

In [None]:


class CNN(nn.Module):
    """A simple CNN model."""
    output_dim: int = 2

    @nn.compact
    def __call__(self, x):
        x = x.reshape(-1, *IMAGE_SHAPE, 1)
        x = nn.Conv(features=32, kernel_size=(3, 3))(x)
        x = nn.relu(x)
        x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
        x = nn.Conv(features=64, kernel_size=(3, 3))(x)
        x = nn.relu(x)
        x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
        x = x.reshape((x.shape[0], -1))
        x = nn.Dense(features=128)(x)
        x = nn.relu(x)
        x = nn.Dense(features=self.output_dim)(x)
        x = nn.log_softmax(x)
        return x
    
class CNNTrunk(nn.Module):

    @nn.compact
    def __call__(self, x):
        x = x.reshape(-1, *IMAGE_SHAPE, 1)
        x = nn.Conv(features=32, kernel_size=(3, 3))(x)
        x = nn.relu(x)
        x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
        x = nn.Conv(features=64, kernel_size=(3, 3))(x)
        x = nn.relu(x)
        x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
        x = x.reshape((x.shape[0], -1))
        x = nn.Dense(features=128)(x)
        return x

def cross_entropy_loss(logits, labels):
    y_onehot = jax.nn.one_hot(labels, num_classes=output_dim).squeeze()
    return -np.mean(np.sum(y_onehot * logits, axis=-1))


def compute_metrics(logits, labels):
    loss = cross_entropy_loss(logits, labels)
    pred = np.argmax(logits, -1).reshape(-1,1)
    accuracy = np.mean(pred == labels)
    metrics = {'loss': loss,
               'accuracy': accuracy}
    return metrics


@jax.jit
def train_step(opt, batch, key):
    key, subkey = random.split(key)
    X, y = batch
    def loss_fn(params):
        logits = CNN(output_dim=output_dim).apply(params, X)
        loss = cross_entropy_loss(logits, y)
        return loss, logits
    grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
    (loss, logits), grad = grad_fn(opt.target)
    opt = opt.apply_gradient(grad)
    metrics = compute_metrics(logits, y)
    log = {'loss': loss,
           'accuracy': metrics['accuracy'],
           'dense0_kernel_gradnorm': linalg.norm(grad['params']['Dense_0']['kernel'])}
    return opt, log, key


@jax.jit
def eval_step(params, X):
    logits = CNN(output_dim=output_dim).apply(params, X)
    return logits


def eval_model(params, data_test, logit_fn=eval_step):
    test_n_batches, test_batches = get_data_stream(
        random.PRNGKey(0), 100, data_test)

    logits = []; labels = []
    for _ in range(test_n_batches):
        batch = next(test_batches)
        X, y = batch
        logit = logit_fn(params, X)
        labels.append(y.reshape(-1, 1))
        logits.append(logit)

    logits = np.vstack(logits)
    labels = np.vstack(labels)
    metrics = compute_metrics(logits, labels)
    metrics = jax.tree_map(lambda x: x.item(), metrics)
    return metrics


In [None]:
model = CNN(output_dim=output_dim)
params = model.init(key, np.ones((1,28,28,1)))
opt = flax_create_optimizer(params, 'Adam', {'learning_rate': .03})

metrics = eval_model(params, test_data)
print(f'[{0:3}] test \t'
      f'Loss={metrics["loss"]:.3f}\t'
      f'accuracy={metrics["accuracy"]:.3f}\t')


from collections import defaultdict
bsz = 64
train_n_batches, train_batches = get_data_stream(key, bsz, data)
n_epochs = 20

for epoch in range(n_epochs):
    logs = defaultdict(list)
    for it in range(train_n_batches):
        step = epoch*train_n_batches+it
        batch = next(train_batches)
        opt, log, key = train_step(opt, batch, key)
        params = opt.target
        for k, v in log.items():
            logs[k].append(v)
        if step%(train_n_batches)==0:
            avg_metrics = {k: np.mean(np.array(v))
                           for k, v in logs.items()}
            print(f'[{epoch:3}|{100*it/train_n_batches:5.2f}%]\t'
                  f'Loss={avg_metrics["loss"]:.3f}\t'
                  f'accuracy={avg_metrics["accuracy"]:.3f}\t')
    
    metrics = eval_model(params, test_data)
    print(f'[{epoch:3}] test \t'
          f'Loss={metrics["loss"]:.3f}\t'
          f'accuracy={metrics["accuracy"]:.3f}\t')
    
    
cnn_save_path = f'./cnn_params_rectangle.pkl'
pytree_save(opt.target, cnn_save_path)
params = pytree_load(CNN(output_dim=output_dim).init(key, np.ones((1,28,28,1))), cnn_save_path)

metrics = eval_model(params, test_data)
print(f'Loss={metrics["loss"]:.3f}\t'
      f'accuracy={metrics["accuracy"]:.3f}\t')

In [None]:
config_base = copy.deepcopy(get_config_base())
config = ml_collections.ConfigDict(config_base)
config.image_shape = (14, 14, 1)
config.patch_shape = (3, 3)
config.patch_encoder = None
config.T_type = 'transl'
config


In [None]:
patch_shape = config.patch_shape 
image_shape = config.image_shape
n_inducing = config.n_inducing
output_dim = config.output_dim

key = random.PRNGKey(0)
model_cls = get_model_cls(key, config, X)
model = model_cls()
params = model.get_init_params(model, key, X_shape=config.image_shape)
print(model)
pytree_keys(params)
    

In [None]:
key = random.PRNGKey(0)
output_dim = 2
lik_type = 'LikMulticlassDirichlet' # LikMulticlassDirichlet, LikMulticlassSoftmax, LikMultipleNormalKron
α_ϵ = 1; α_δ = 10; n_mc_samples = 20
image_shape = (14,14,1)
patch_shape = (10,10)
inducing_patch=True
n_inducing = 20
T_type = 'transl+isot_scal' # '', 'transl', 'isot_scal'
use_loc_kernel = True
patch_encoder = 'CNNMnist' # None, 'CNNMnist'


init_val_m = gamma_to_lognormal(np.array([1.]))[0] \
    if lik_type == 'LikMulticlassDirichlet' else np.array([0.5])
mean_fn_cls = partial(MeanConstant, output_dim=output_dim, init_val_m=init_val_m, flat=False)
if lik_type == 'LikMulticlassDirichlet':
    lik_cls = partial(LikMulticlassDirichlet, output_dim=output_dim, init_val_α_ϵ=α_ϵ, init_val_α_δ=α_δ, n_mc_samples=n_mc_samples)
elif lik_type == 'LikMulticlassSoftmax':
    lik_cls = partial(LikMulticlassSoftmax, output_dim=output_dim, n_mc_samples=n_mc_samples)
else:
    lik_cls = partial(LikMultipleNormalKron, output_dim=output_dim)


kl_cls = CovSE if use_loc_kernel else partial(CovConstant, output_scaling=False)
available_encoders = CovPatchEncoder.get_available_encoders()

if patch_encoder not in available_encoders:
    g_cls = LayerIdentity
    kg_cls = partial(CovPatch, image_shape=image_shape, patch_shape=patch_shape, kp_cls=CovSE, kl_cls=kl_cls)
else:
    g_cls = available_encoders[patch_encoder]
    XL_init_fn = partial(g_cls.get_XL, image_shape=image_shape)
    kg_cls = partial(CovPatchEncoder, encoder=patch_encoder, XL_init_fn=XL_init_fn, kp_cls=CovSE, kl_cls=kl_cls)

kx_cls = partial(CovConvolutional, kg_cls=kg_cls, inducing_patch=inducing_patch)
k_cls = partial(CovMultipleOutputIndependent, output_dim=output_dim, k_cls=kx_cls, g_cls=g_cls)
if T_type == '':
    transform_cls = LayerIdentity
else:
    scal = np.array(patch_shape)/np.array(image_shape[:2])
    A_init_val = trans2x3_from_scal_transl(scal,(0,0))
    T_init_fn = lambda k, s: np.tile(np.array([BijSigmoid([np.mean(scal)*.5,np.mean(scal)*1.5]).reverse(np.mean(scal)*1.4), 0, 0.]),
                                     (n_inducing, 1)) # (s, tx, ty)
    bound_init_fn = partial(spatial_transform_bound_init_fn, in_shape=image_shape, out_shape=patch_shape)
    transform_cls = partial(SpatialTransform, shape=patch_shape, n_transforms=n_inducing, 
                            T_type=T_type, T_init_fn=T_init_fn, A_init_val=A_init_val, output_transform=use_loc_kernel,
                            bound_init_fn=bound_init_fn)

if inducing_patch:
#     Xu_initial = get_init_patches(key, X, n_inducing, image_shape, patch_shape)
    Xu_initial = np.take(X, random.randint(key, (n_inducing,), 0, len(X)), axis=0)
    Xu_initial = Xu_initial.reshape((-1, *image_shape))
    inducing_loc_cls = partial(InducingLocations,
                               shape=Xu_initial.shape,
                               init_fn=lambda k,s: Xu_initial,
                               transform_cls=transform_cls)
else:
    Xu_initial = X[np.linspace(0,len(X)-1,n_inducing).astype(np.int32)].copy()
    inducing_loc_cls = partial(InducingLocations,
                               shape=Xu_initial.shape,
                               init_fn=lambda k,s: Xu_initial)

print('Xu:', Xu_initial.shape)

model = SVGP(mean_fn_cls=mean_fn_cls,
             k_cls=k_cls,
             lik_cls=lik_cls,
             inducing_loc_cls=inducing_loc_cls,
             n_data=len(X),
             output_dim=output_dim)


params = model.get_init_params(model, key, X_shape=image_shape)
print(model)
pytree_keys(params)

In [None]:
# load pretrained weights & set initial values
# if g_cls != LayerIdentity:
#     cnn_save_path = f'./rectangle_cnn_params.pkl'
#     encoder_params = pytree_load({'params': params['params']['k']['g']}, cnn_save_path)
#     encoder_params_kvs = pytree_get_kvs(encoder_params)
#     params = pytree_mutate(params, {f'params/k/g/{k}': v for k,v in encoder_params_kvs.items()})

kwd_notrain =  ['mean_fn', 'Xu/X']
kwd_notrain += [f'params/k/ks_{i}/kl/σ2' for i in range(config.output_dim)]
kwd_notrain += [f'params/k/ks_{i}/kg/kl/σ2' for i in range(config.output_dim)]
kwd_trainslow = [] #  # 'Xu/transform'
opt = flax_create_multioptimizer(
    params, 'Adam',
    [{'learning_rate': 0.}, {'learning_rate': .03}, {'learning_rate': .1}],
    [lambda p, v: pytree_path_contains_keywords(p, kwd_notrain),
     lambda p, v: pytree_path_contains_keywords(p, kwd_trainslow),
     lambda p, v: not pytree_path_contains_keywords(p, kwd_notrain+kwd_trainslow)])

flax_check_multiopt(params, opt)
# params

In [None]:



if isinstance(transform_cls(), SpatialTransform):
    ind = np.arange(n_inducing)
    
    m = model.bind(params)
    A = m.Xu.transform.T
    S = pytree_leaf(params, 'params/Xu/X')
    A = A[ind]; S = S[ind]

    fn = vmap(spatial_transform_details, (0, 0, None), 0)
    T, Gs = fn(A, S, patch_shape)
    fig, axs = plt.subplots(2, len(A), figsize=(3*len(A),3*2))
    for i in range(len(T)):
        plt_spatial_transform(axs[:,i], Gs[i], S[i], T[i])
    fig.tight_layout()
    plt.show()


In [None]:
######################################################
import time
        
@jax.jit
def eval_model(params, data):
    Xt, Yt = data
    Xt = Xt.reshape((len(Xt), *image_shape))
    Ey, Vy = model.apply(params, Xt, method=model.pred_y, rngs={'lik_mc_samples': key})
    pred = np.argmax(Ey, -1).reshape(-1, 1)
    acc = np.mean(pred == Yt)
    return acc

@jax.jit
def train_step2(step, opt, batch, key):
    key, subkey = random.split(key)
    Xb, yb = batch
    Xb = Xb.reshape((len(Xb), *image_shape))
    y_onehot = jax.nn.one_hot(yb.squeeze(), num_classes=output_dim).reshape((-1,output_dim))
    def loss_fn(params):
        fx = model.apply(params,
                         (Xb, y_onehot),
                         method=model.mll,
                         rngs={'lik_mc_samples': subkey})
        return -fx, {}
    grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
    params = opt.target
    (loss, aux), grad = grad_fn(params)
    opt = opt.apply_gradient(flax_params2model(model, grad))
    log = {'loss': loss}
    return opt, log, key


bsz = 5
train_n_batches, train_batches = get_data_stream(key, bsz, data)

n_epochs = 100

for epoch in range(n_epochs):
    start = time.time()
    for it in range(train_n_batches):
        step = epoch*train_n_batches+it
        batch = next(train_batches)
        opt, log, key = train_step2(step, opt, batch, key)
        params = opt.target
        if step%(train_n_batches*n_epochs//30)==0:
            insert_str = '/kg' if isinstance(model.k_cls().k_cls(), CovConvolutional) else ''
            log.update({
               'k.ls': jax.nn.softplus(pytree_leaf(params, f'params/k/kx/ls')
                                       if isinstance(model.k_cls(), CovICM) else 
                                       np.hstack([pytree_leaf(params, f'params/k/ks_{i}{insert_str}/kp/ls')
                                                  for i in range(output_dim)])),
               'k.σ2': jax.nn.softplus(pytree_leaf(params, f'params/k/kx/σ2')
                                       if isinstance(model.k_cls(), CovICM) else 
                                       np.hstack([pytree_leaf(params, f'params/k/ks_{i}{insert_str}/kp/σ2')
                                                  for i in range(output_dim)])),
               'kl.ℓ': jax.nn.softplus(np.hstack([pytree_leaf(params, f'params/k/ks_{i}/kg/kl/ls')
                                                  for i in range(output_dim)])
                                       if  isinstance(model.k_cls(), CovMultipleOutputIndependent) and \
                                           not isinstance(model.k_cls().k_cls().kg_cls().kl_cls(), CovConstant) else np.array([np.nan]))})
            acc = eval_model(params, data)
            acc_test = eval_model(params, test_data)
            print(f'[{epoch:3}|{100*it/train_n_batches:5.2f}%]\t'
                  f'Time={time.time()-start:.3f}\t'
                  f'Loss={log["loss"]:.3f}\t'
                  f'k.ls={log["k.ls"][:3]}\t'
                  f'k.σ2={log["k.σ2"][:3]}\t'
                  f'kl.ℓ = {log["kl.ℓ"][:3]}\t'
                  f'acc={acc:.3f}|{acc_test:.3f}\t')
            start = time.time()


params = opt.target


# N=100
# CovSE(g_cls=CNNTrunk(pretrained=True)): .96 (bsz=5, n_inducing=50%)
# CovSE: .65
# CovConvolutional(g_cls=CovSE, patch_inducing_loc=False): .89 (bsz=5, n_inducing=60%)
# CovConvolutional(g_cls=CovSE, patch_inducing_loc=True):  .90 (bsz=5, n_inducing=45 all unique patches)

# N=20
# CovConvolutional(g_cls=CovSE, patch_inducing_loc=False): .85/.65 (bsz=5, n_inducing=100%)
# CovConvolutional(g_cls=CovSE, patch_inducing_loc=True):  1.0/.99 (bsz=5, n_inducing=100%, lr=.1)

# if use location kernel=CovSE helps with bad perf of optimiznig for Xu/T
# N=50
# CovConvolutional(g_cls=CovSE, patch_inducing_loc=False): .9/.68     (bsz=5,  n_inducing=20)
# CovConvolutional(g_cls=CovSE, patch_inducing_loc=True): 0.960|0.810 (bsz=5,  n_inducing=20, Xu/X,Xu/T fixed)
# CovConvolutional(g_cls=CovSE, patch_inducing_loc=True): 1.000|0.930 (bsz=5,  n_inducing=20, Xu/X fixed, Xu/T opt)
#     - 0.740|0.640 another run ...
# CovConvolutional(g_cls=CovSE, patch_inducing_loc=True, kl_cls=CovSE): 0.980|0.920 (bsz=5,  n_inducing=20, Xu/X fixed, Xu/T opt)


# patch_shape (10, 10)
# patch_encoder=LayerIdentity    acc=0.800|0.660


 

In [None]:

if isinstance(transform_cls(), SpatialTransform):
    ind = np.arange(n_inducing)
    
    m = model.bind(params)
    A = m.Xu.transform.T
    S = pytree_leaf(params, 'params/Xu/X')
    A = A[ind]; S = S[ind]

    fn = vmap(spatial_transform_details, (0, 0, None), 0)
    T, Gs = fn(A, S, patch_shape)
    fig, axs = plt.subplots(2, len(A), figsize=(3*len(A),3*2))
    for i in range(len(T)):
        plt_spatial_transform(axs[:,i], Gs[i], S[i], T[i])
    fig.tight_layout()
    plt.show()


In [None]:
# Plot inducing locations
digits = ['-','|']

if config.inducing_patch:
    Xu, transl = inducing_loc_cls().apply({'params': pytree_leaf(params, 'params/Xu')})
else:
    Xu = inducing_loc_cls().apply({'params': pytree_leaf(params, 'params/Xu')})


qm = pytree_leaf(params, 'params/q/μ')
M = len(Xu)
    

gridspec_kw = {'width_ratios': np.ones((output_dim,)), 'height_ratios': [3,1]}
fig, axs = plt.subplots(2,output_dim,figsize=(output_dim*8,8))
n_top = config.n_inducing
c = 0; ind = np.argsort(qm[c,:])
ylim = (np.min(qm)-.5, np.max(qm)+.5)

for c in range(config.output_dim):
    ind = np.argsort(qm[c,:])[::-1]
    # variational μ
    ax = axs[0,c]
    for co in range(config.output_dim):
        ls = '--' if co != c else '-'
        ax.plot(np.arange(n_inducing), qm[co,ind], ls, label=f'{digits[co]}')
    ax.set_title(f'qμ ({digits[c]})', fontsize=35)
    ax.grid()
    ax.set_ylim(ylim)
    ax.legend(fontsize=20)
    
    # top weighted patches
    ax = axs[1,c]
    ims = Xu[ind].reshape((-1,*patch_shape))[:n_top]
    grid = make_im_grid(ims, im_per_row=min(len(ims), 10))
    ax.set_xticks([]); ax.set_yticks([])
    ax.imshow(grid, cmap='Greys', vmin=0, vmax=1)
    
fig.tight_layout()

