goal
- [done] SVGP + dirichlet training on mnist
    - recreate evidential DL example ... 
- [done] variational learning of supporting image patches !
    - [done] impl STN ... 
    - observations
        - perf okay if allows finetune encoder network 
        - not so much as evidence that is localized ... 
            - perhaps due to fact shared inducing locations ... \
                would want to retain all info (not localized) and \
                use variational mean to modulate evidence for class
            - so to get localized info ... might want to do per-class inducing variables

In [None]:
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))

import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
os.environ['TF_CPP_VMODULE'] = '=bfc_allocator=1'
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '0'
os.environ['XLA_FLAGS'] = '--xla_gpu_cuda_data_dir=/usr/local/cuda'
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform"
os.environ['LD_LIBRARY_PATH'] = '${LD_LIBRARY_PATH}:/usr/local/cuda/lib64'

from collections import defaultdict

import scipy
import numpy as onp
onp.set_printoptions(precision=3,suppress=True)

import jax
import jax.numpy as np
from jax import grad, jit, vmap, device_put, random
from flax import linen as nn
from jax.scipy.stats import dirichlet

from jax.lib import xla_bridge
print(xla_bridge.get_backend().platform)
print(jax.local_device_count())
print(jax.devices())

import matplotlib.pyplot as plt
import matplotlib as mpl
import matplotlib.tri as tri
# https://matplotlib.org/3.1.1/gallery/style_sheets/style_sheets_reference.html
mpl.rcParams['lines.linewidth'] = 3
mpl.rcParams['font.size'] = 25
mpl.rcParams['font.family'] = 'Times New Roman' 
cmap = plt.cm.get_cmap('bwr')

from tabulate import tabulate
from functools import partial

from plt_utils import *
from gpax import *

In [None]:
import torch
import torchvision

def dataset_subset(dataset, y):
    # Convert labels to to [0,...,len(y)]
    #     e.g. y = [6, 8] y_train/y_test will only have {0,1}
    #
    ind = np.any(np.stack([dataset.targets.numpy()==i for i in y]), axis=0)
    ind = torch.where(torch.tensor(onp.array(ind))==True)[0]
    F = torch.zeros((10,), dtype=torch.float32)
    F[y] = torch.arange(len(y), dtype=torch.float32)
    dataset.targets = F[dataset.targets[ind].to(torch.int64)]
    dataset.data = dataset.data[ind]
    return dataset

key = random.PRNGKey(1)

# https://stackoverflow.com/questions/66577151/http-error-when-trying-to-download-mnist-data
new_mirror = 'https://ossci-datasets.s3.amazonaws.com/mnist'
torchvision.datasets.MNIST.resources = [
   ('/'.join([new_mirror, url.split('/')[-1]]), md5)
   for url, md5 in torchvision.datasets.MNIST.resources
]

transforms = torchvision.transforms.Compose([
    lambda x: np.asarray(x)[...,np.newaxis] / 255.
])

train_dataset = torchvision.datasets.MNIST('./data', train=True, transform=transforms, download=True)
test_dataset = torchvision.datasets.MNIST('./data', train=False, transform=transforms, download=True)
digits = [6, 8]
train_dataset = dataset_subset(train_dataset, digits)
test_dataset = dataset_subset(test_dataset, digits)
X_train = jax_to_gpu(np.asarray(train_dataset.data[...,np.newaxis]) / 255.)
y_train = jax_to_gpu(np.asarray(train_dataset.targets[...,np.newaxis]))
X_test = jax_to_gpu(np.asarray(test_dataset.data[...,np.newaxis]) / 255.)
y_test = jax_to_gpu(np.asarray(test_dataset.targets[...,np.newaxis]))
data_train = (X_train, y_train)
data_test = (X_test, y_test)

In [None]:
print(np.where(y_test==1)[0][:10])

ind = 31
n_ims = 20

fig, axs = plt.subplots(1,n_ims,figsize=(2*n_ims,2))
[axi.set_xticks([]) for axi in axs.ravel()]; [axi.set_yticks([]) for axi in axs.ravel()]

x, y = X_test[ind], y_test[ind]

x_rot = rotated_ims(x, n_ims=n_ims)
for i in range(n_ims):
    ax = axs[i]
    ax.imshow(x_rot[i], cmap='Greys')



In [None]:
params = pytree_load(CNN().init(key, np.ones((1,28,28,1))), './cnn_params_binary.pkl')
logits = CNN(output_dim=2).apply(params, x_rot)
prob = np.exp(logits)

fig, ax = plt.subplots(1, 1, figsize=(15, 10))
for c in [1,3,5]:
    ax.plot(np.arange(n_ims), prob[:,c], label=f'{c}')
    

ax.legend()
ax.grid()

In [None]:
num_classes = 2


class CNN(nn.Module):
    """A simple CNN model."""
    output_dim: int = 2

    @nn.compact
    def __call__(self, x):
        x = nn.Conv(features=32, kernel_size=(3, 3))(x)
        x = nn.relu(x)
        x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
        x = nn.Conv(features=64, kernel_size=(3, 3))(x)
        x = nn.relu(x)
        x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
        x = x.reshape((x.shape[0], -1))
        x = nn.Dense(features=128)(x)
        x = nn.relu(x)
        x = nn.Dense(features=self.output_dim)(x)
        x = nn.log_softmax(x)
        return x
    

class CNNTrunk(nn.Module):

    @nn.compact
    def __call__(self, x):
        x = nn.Conv(features=32, kernel_size=(3, 3))(x)
        x = nn.relu(x)
        x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
        x = nn.Conv(features=64, kernel_size=(3, 3))(x)
        x = nn.relu(x)
        x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
        x = x.reshape((x.shape[0], -1))
        x = nn.Dense(features=128)(x)
        return x


def cross_entropy_loss(logits, labels):
    y_onehot = jax.nn.one_hot(labels, num_classes=num_classes).squeeze()
    return -np.mean(np.sum(y_onehot * logits, axis=-1))


def compute_metrics(logits, labels):
    loss = cross_entropy_loss(logits, labels)
    pred = np.argmax(logits, -1).reshape(-1,1)
    accuracy = np.mean(pred == labels)
    metrics = {'loss': loss,
               'accuracy': accuracy}
    return metrics


@jax.jit
def train_step(opt, batch, key):
    key, subkey = random.split(key)
    X, y = batch
    def loss_fn(params):
        logits = CNN().apply(params, X)
        loss = cross_entropy_loss(logits, y)
        return loss, logits
    grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
    (loss, logits), grad = grad_fn(opt.target)
    opt = opt.apply_gradient(grad)
    metrics = compute_metrics(logits, y)
    log = {'loss': loss,
           'accuracy': metrics['accuracy'],
           'dense0_kernel_gradnorm': linalg.norm(grad['params']['Dense_0']['kernel'])}
    return opt, log, key


@jax.jit
def eval_step(params, X):
    logits = CNN().apply(params, X)
    return logits


def eval_model(params, data_test, logit_fn=eval_step):
    test_n_batches, test_batches = get_data_stream(
        random.PRNGKey(0), 100, data_test)

    logits = []; labels = []
    for _ in range(test_n_batches):
        batch = next(test_batches)
        X, y = batch
        logit = logit_fn(params, X)
        labels.append(y.reshape(-1, 1))
        logits.append(logit)

    logits = np.vstack(logits)
    labels = np.vstack(labels)
    metrics = compute_metrics(logits, labels)
    metrics = jax.tree_map(lambda x: x.item(), metrics)
    return metrics



In [None]:
model = CNN()
params = model.init(key, np.ones((1,28,28,1)))
opt = flax_create_optimizer(params, 'Adam', {'learning_rate': .03})

In [None]:

metrics = eval_model(params, data_test)
print(f'[{0:3}] test \t'
      f'Loss={metrics["loss"]:.3f}\t'
      f'accuracy={metrics["accuracy"]:.3f}\t')


In [None]:

bsz = 64
train_n_batches, train_batches = get_data_stream(key, bsz, data_train)
n_epochs = 10


for epoch in range(n_epochs):
    logs = defaultdict(list)
    for it in range(train_n_batches):
        step = epoch*train_n_batches+it
        batch = next(train_batches)
        opt, log, key = train_step(opt, batch, key)
        params = opt.target
        for k, v in log.items():
            logs[k].append(v)
        if step%(train_n_batches//10)==0:
            avg_metrics = {k: np.mean(np.array(v))
                           for k, v in logs.items()}
            print(f'[{epoch:3}|{100*it/train_n_batches:5.2f}%]\t'
                  f'Loss={avg_metrics["loss"]:.3f}\t'
                  f'accuracy={avg_metrics["accuracy"]:.3f}\t'
                  f'norm(Dense0.k)={avg_metrics["dense0_kernel_gradnorm"]:.3f}')
    
    metrics = eval_model(params, data_test)
    print(f'[{epoch:3}] test \t'
          f'Loss={metrics["loss"]:.3f}\t'
          f'accuracy={metrics["accuracy"]:.3f}\t')
    

In [None]:
cnn_save_path = f'./cnn_params_{",".join([str(x) for x in digits])}.pkl'
pytree_save(opt.target, cnn_save_path)
params = pytree_load(CNN().init(key, np.ones((1,28,28,1))), cnn_save_path)

metrics = eval_model(params, data_test)
print(f'Loss={metrics["loss"]:.3f}\t'
      f'accuracy={metrics["accuracy"]:.3f}\t')

In [None]:
T = np.arange(6).reshape(2,3)
θ = np.arange(6)+10
print(θ)
T = jax.ops.index_update(
    T, jax.ops.index[:], θ.reshape(2,3))
T

In [None]:
output_dim = 2
α_ϵ = 1; α_δ = 10; n_mc_samples = 20; n_inducing = 200

lik_type = 'LikMulticlassDirichlet' # LikMulticlassDirichlet, LikMulticlassSoftmax, LikMultipleNormalKron
init_val_m = gamma_to_lognormal(np.array([1.]))[0] \
    if lik_type == 'LikMulticlassDirichlet' else np.array([.5])
mean_fn_cls = partial(MeanConstant,
                      output_dim=output_dim,
                      init_val_m=init_val_m,
                      flat=False)
if lik_type == 'LikMulticlassDirichlet':
    lik_cls = partial(LikMulticlassDirichlet,
                      output_dim=output_dim,
                      init_val_α_ϵ=α_ϵ,
                      init_val_α_δ=α_δ,
                      n_mc_samples=n_mc_samples)
elif lik_type == 'LikMulticlassSoftmax':
    lik_cls = partial(LikMulticlassSoftmax,
                      output_dim=output_dim,
                      n_mc_samples=n_mc_samples)
else:
    lik_cls = partial(LikMultipleNormalKron,
                      output_dim=output_dim)
    
# previous 
# kx_cls = partial(CovSEwithEncoder, g_cls=CNNTrunk)
# kt_cls = partial(CovIndex, output_dim=output_dim, rank=1)
# k_cls = partial(CovICM, kx_cls=kx_cls, kt_cls=kt_cls)
# more efficient
kx_cls = partial(CovSE, output_scaling=True)
k_cls = partial(CovMultipleOutputIndependent,
                k_cls=kx_cls,
                output_dim=output_dim,
                g_cls=CNNTrunk)
Xu_initial = X_train[np.linspace(0,len(X_train)-1,n_inducing).astype(np.int32)].copy()
# inducing_loc_cls = partial(InducingLocations,
#                            shape=(n_inducing, 28, 28, 1),
#                            init_fn_inducing=lambda k,s: Xu_initial)
inducing_loc_cls = partial(InducingLocationsSpatialTransform,
                           shape=(n_inducing, 28, 28, 1),
                           init_fn_inducing=lambda k,s: Xu_initial,
                           trans_type='3')



model = SVGP(mean_fn_cls=mean_fn_cls,
             k_cls=k_cls,
             lik_cls=lik_cls,
             inducing_loc_cls=inducing_loc_cls,
             n_data=len(X_train),
             output_dim=output_dim)


params = model.get_init_params(model, key)
params

In [None]:
# load pretrained weights & set initial values
cnn_save_path = 'cnn_params_6,8.pkl'

if isinstance(k_cls(), CovICM):
    encoder_params = pytree_load({'params': params['params']['k']['kx']['g']}, cnn_save_path)
    encoder_params_kvs = pytree_get_kvs(encoder_params)
    params = pytree_mutate(params, {f'params/k/kx/g/{k}': v for k,v in encoder_params_kvs.items()})
    params = pytree_mutate(params, {'params/k/kx/ls': softplus_inv(np.array([2.]))})
else:
    encoder_params = pytree_load({'params': params['params']['k']['g']}, cnn_save_path)
    encoder_params_kvs = pytree_get_kvs(encoder_params)
    params = pytree_mutate(params, {f'params/k/g/{k}': v for k,v in encoder_params_kvs.items()})
    params = pytree_mutate(params, {f'params/k/ks_{i}/ls': softplus_inv(np.array([2.]))
                                    for i in range(output_dim)})

# create optimizer
# opt = flax_create_multioptimizer_2focus(params, 'Adam',
#                                         [{'learning_rate': 0.}, {'learning_rate': 0.03}],
#                                         ['g', 'mean_fn', 'Xu/X']) #  'k/ks_0/σ2', 'k/ks_1/σ2'

kwd_notrain = ['mean_fn', 'Xu/X']
kwd_trainslow = ['Xu/T', 'g'] # Xu/T
opt = flax_create_multioptimizer(
    params, 'Adam',
    [{'learning_rate': 0.}, {'learning_rate': .01}, {'learning_rate': .03}],
    [lambda p, v: pytree_path_contains_keywords(p, kwd_notrain),
     lambda p, v: pytree_path_contains_keywords(p, kwd_trainslow),
     lambda p, v: not pytree_path_contains_keywords(p, kwd_notrain+kwd_trainslow)])

flax_check_multiopt(params, opt)
# params

In [None]:
######################################################
import time

@jax.jit
def eval_step2(params, X):
    Ey, Vy = model.apply(params, X, method=model.pred_y, rngs={'lik_mc_samples': key})
    return Ey

@jax.jit
def train_step2(step, opt, batch, key):
    key, subkey = random.split(key)
    X, y = batch
    y_onehot = jax.nn.one_hot(y.squeeze(), num_classes=output_dim)
    def loss_fn(params):
        fx = model.apply(params,
                         (X, y_onehot),
                         method=model.mll,
                         rngs={'lik_mc_samples': subkey})
        return -fx, {}
    grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
    (loss, aux), grad = grad_fn(opt.target)
    opt = opt.apply_gradient(grad)
    log = {'loss': loss,
           'lik.α_ϵ': pytree_leaf(opt.target, 'params/lik/α_ϵ'),
           'k.ls': jax.nn.softplus(pytree_leaf(opt.target, f'params/k/kx/ls')
                                   if isinstance(model.k_cls(), CovICM) else 
                                   np.hstack([pytree_leaf(opt.target, f'params/k/ks_{i}/ls')
                                              for i in range(output_dim)])),
           'k.σ2': jax.nn.softplus(pytree_leaf(opt.target, f'params/k/kx/σ2')
                                   if isinstance(model.k_cls(), CovICM) else 
                                   np.hstack([pytree_leaf(opt.target, f'params/k/ks_{i}/σ2')
                                              for i in range(output_dim)]))}
    return opt, log, key


bsz = 64
train_n_batches, train_batches = get_data_stream(key, bsz, data_train)

n_epochs = 5

for epoch in range(n_epochs):
    start = time.time()
    for it in range(train_n_batches):
        step = epoch*train_n_batches+it
        batch = next(train_batches)
        opt, log, key = train_step2(step, opt, batch, key)
        params = opt.target
        if step%(train_n_batches//5)==0:
            print(f'[{epoch:3}|{100*it/train_n_batches:5.2f}%]\t'
                  f'Time={time.time()-start:.3f}\t'
                  f'Loss={log["loss"]:.3f}\t'
                  f'lik.α_ϵ={log["lik.α_ϵ"]}\t'
                  f'k.ls={log["k.ls"][:3]}\t'
                  f'k.σ2={log["k.σ2"][:3]}\t')
            start = time.time()


    metrics = eval_model(opt.target, data_test, logit_fn=eval_step2)
    print(f'[{epoch:3}] test \t'
          f'Loss={metrics["loss"]:.3f}\t'
          f'accuracy={metrics["accuracy"]:.3f}\t')

params = opt.target
m = model.bind(params, rngs={'lik_mc_samples': key})

In [None]:
m = model.bind(opt.target, rngs={'lik_mc_samples': key})
Xu = m.Xu()

ind = np.arange(10)+30

fig, axs = plt.subplots(2,len(ind),figsize=(5*len(ind),10))
[axi.set_xticks([]) for axi in axs.ravel()]; [axi.set_yticks([]) for axi in axs.ravel()]

for i in range(len(ind)):
    ax = axs[0,i]
    ax.imshow(Xu_initial[ind[i]], cmap='Greys', vmin=0, vmax=1)
    ax = axs[1,i]
    ax.imshow(Xu[ind[i]], cmap='Greys', vmin=0, vmax=1)

    
    

In [None]:
## 
print(np.where(y_test==0)[0][:10])

ind_digit0 = np.where(y_test==0)[0]
ind_digit1 = np.where(y_test==1)[0]

ind = 6
n_ims = 20

x_rot = rotated_ims(X_test[ind], n_ims=n_ims)
# X = np.stack(X_test[ind_digit0[:10]])
X = np.stack(X_test[np.hstack((ind_digit0[:n_ims//2], ind_digit1[:n_ims//2]))])
μf, σ2f = m.pred_f(X, full_cov=False); μf = μf.squeeze()
y_pred = np.array(digits)[np.argmax(μf,axis=-1)]

if isinstance(m.lik, LikMulticlassSoftmax):
    lik_test = LikMulticlassSoftmax(output_dim=output_dim, n_mc_samples=5000)
    p_mc, Vp_mc = lik_test.apply({}, μf, σ2f, rngs={'lik_mc_samples': key}, method=lik_test.predictive_dist)
elif isinstance(m.lik, LikMulticlassDirichlet):
    lik_test = LikMulticlassDirichlet(output_dim=output_dim, n_mc_samples=5000)
    p_mc, Vp_mc = lik_test.apply({}, μf, σ2f, rngs={'lik_mc_samples': key}, method=lik_test.predictive_dist)
    α = gamma_to_lognormal_inv(μf, σ2f, approx_type='kl')
    α0 = np.sum(α, axis=-1, keepdims=True)
    p = α / α0
    Vp = p*(1-p)

gridspec_kw = {'width_ratios': [1],
               'height_ratios': [4, 4, 1]}
fig, axs = plt.subplots(3, 1, gridspec_kw=gridspec_kw, figsize=(15, 10))


ax = axs[0]
ax.set_xticks([])

colors = [cmap(.1), cmap(.9)]
for i, d in enumerate(digits):
    c = colors[i]
    μ, std = μf[:,i], np.sqrt(σ2f[:,i])
    ax.plot(np.arange(len(X)), μf[:,i], lw=2, color=c, label=d)
    ax.plot(np.arange(len(X)), μ + 2*std, '--', c=c)
    ax.plot(np.arange(len(X)), μ - 2*std, '--', c=c)
ax.legend()
ax.set_title('predictive posterior p(f*|X)')


ax = axs[1]
for i, d in enumerate(digits):
    c = colors[i]
    if isinstance(m.lik, LikMulticlassSoftmax) or \
        isinstance(m.lik, LikMulticlassDirichlet):
        ax.plot(np.arange(len(X)), p[:,i], '-', c=c, lw=2)
        ax.plot(np.arange(len(X)), p[:,i] + 2*np.sqrt(Vp[:,i]), '--', c=c)
        ax.plot(np.arange(len(X)), p[:,i] - 2*np.sqrt(Vp[:,i]), '--', c=c)
    else:
        ax.plot(np.arange(len(X)), μf[:,i], c=c, lw=2)
ax.set_title('E[p(y|f*)]')

    
ax = axs[2]
ax.set_xticks([]); ax.set_yticks([])
ax.imshow(np.hstack([x for x in X]), cmap='Greys')

fig.tight_layout()

In [None]:
def mvn_marginal_variational_details(Kff, Kuf, mf,
                             Luu, mu, μq, Lq, full_cov=False):
    α = solve_triangular(Luu, Kuf, lower=True)
    β = solve_triangular(Luu.T, α, lower=False)
    γ = Lq.T@β
    if full_cov:
        Σf = Kff - α.T@α + γ.T@γ
    else:
        Σf = Kff - \
            np.sum(np.square(α), axis=0) + \
            np.sum(np.square(γ), axis=0)
    # for multiple-output
    μq = μq.reshape(-1, 1)
    mu = mu.reshape(-1, 1)
    mf = mf.reshape(-1, 1)
    A = β.T; δ = (μq-mu)
    μf = mf + A@δ
    return μf, Σf, A, δ, mf

def pred_f_details(self, Xs, full_cov=True):
    k = self.k
    Xu = self.Xu()               # (M,...)
    μq, Lq = self.q.μ, self.q.L  # (P, M) & (P, M, M)
    if μq.shape[0] == 1:
        μq, Lq = μq.squeeze(0), Lq.squeeze(0)

    ms = self.mean_fn(Xs)
    mu = self.mean_fn(Xu)

    Kss = k(Xs, full_cov=full_cov)
    Kus = k(Xu, Xs)
    Kuu = k(Xu)
    Luu = cholesky_jitter_vmap(Kuu, jitter=5e-5)  # (P, M, M)

    if isinstance(k, CovMultipleOutputIndependent):
        mvn_marginal_variational_fn = vmap(
            mvn_marginal_variational_details, (0, 0, 1, 0, 1, 0, 0, None), -1)  # along P-dim
    else:
        mvn_marginal_variational_fn = mvn_marginal_variational_details

    μf, Σf, A, δ, mf = mvn_marginal_variational_fn(Kss, Kus, ms,
                                         Luu, mu, μq, Lq, full_cov)
    # (N, D), (N, D), (N, M, D), (M, D), (N, D)
    N, D = Σf.shape; M = len(Xu)
    if not full_cov:
        μf = μf.reshape((N,D))
        A = A.reshape((N,M,D))
        δ = δ.reshape((M,D))
        mf = mf.reshape((N,D))
    return μf, Σf, A, δ, mf

μ, σ2, A, δ, mf = pred_f_details(m, X, full_cov=False)
print(A.shape)

In [None]:
N, M, D = A.shape  # (#test, #inducing, #classes)
fig, axs = plt.subplots(2,N,figsize=(30,15), sharey=True)
[axi.set_xticks([]) for axi in axs.ravel()];

for n in range(N):
    ax = axs[0,n]
    for i, d in enumerate(digits):
        ax.plot(A[n,:,i], np.arange(M), '--', c=colors[i], label=f'{d}')
        

for n in range(N):
    ax = axs[1,n]
    for i, d in enumerate(digits):
        ax.plot(A[n,:,i]*δ[:,i], np.arange(M), '--', c=colors[i], label=f'{d}')
        
fig.tight_layout()

In [None]:
n = 11
Aδ = (A*δ)

ind = np.argsort(np.abs(Aδ[n,:,0]))[::-1][:5]
print(list(zip(ind, Aδ[n,ind,0]))[:5])

Xu = m.Xu()

fig, axs = plt.subplots(len(digits),1+len(ind),figsize=(5*(len(ind)+1),10), sharey=True)
[axi.set_xticks([]) for axi in axs.ravel()]; [axi.set_yticks([]) for axi in axs.ravel()]
    
for di in range(len(digits)):
    # original image
    ax = axs[di,0]
    ax.imshow(X[n], cmap='Greys', vmin=0, vmax=1)
    ax.set_xlabel(f'{np.sum(Aδ[n,:,di]):.2f}', fontsize=40, color='r')
    ax.set_ylabel(f'C={digits[di]}', fontsize=40)
    
    # take top evidence
    ind = np.argsort(np.abs(Aδ[n,:,di]))[::-1][:5]
    for i in range(len(ind)):
        ax = axs[di,i+1]
        ax.imshow(Xu[ind[i]], cmap='Greys', vmin=0, vmax=1)
        ax.set_xlabel(f'{Aδ[n,ind[i],di]:.2f}', fontsize=40)

fig.tight_layout()