In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

import torch.optim as optim
from tqdm.auto import tqdm

from matplotlib import cm
import matplotlib
import matplotlib.pyplot as plt
import seaborn as sns
sns.set_theme()

jet = cm.get_cmap('jet')

# Optimization for some chosen Y (no l2 normalization)

In [None]:
device = "cpu"

Y = torch.tensor([[1,0,0,0,0],
                  [1,1,1,1,0],
                  [0,0,1,1,0],
                  [0,1,0,1,0],
                  [0,0,0,1,1]])
Y = torch.unique(Y, dim=-1)
Y = Y.repeat_interleave(torch.randint(15, 35, size=(Y.shape[-1],)), dim=1)

u, labels = torch.unique(Y, dim=-1, return_inverse=True)
labels = labels / labels.max()

c, n = Y.shape

X = torch.normal(0, 0.1, size=(8, n), requires_grad=True)
norm = torch.ones(size=(n,), requires_grad=True)
tau = torch.tensor(1.0, requires_grad=True)

alpha = 0.9
beta  = 0.01
niter = 2000

tracker = {'loss' : [], 'x_svals' : []}
bar = tqdm(total=niter, dynamic_ncols=True, desc='Train')

optimizer = optim.SGD([X, norm, tau], lr=0.01, momentum=0.98)
scheduler = optim.lr_scheduler.ExponentialLR(optimizer, 0.9995)
X = X.to(device)
Y = Y.to(device)
norm = norm.to(device)

for i in range(niter):
    optimizer.zero_grad()

    idx = np.random.choice(np.arange(n), 70, replace=False)

    norm_batch = norm[idx]
    y_batch = Y[:,idx].float()
    x_batch = X[:,idx]
    
    Kx = norm_batch.reshape(-1,1) * torch.exp(-(torch.cdist(x_batch.T, x_batch.T, p=2)**2) / tau) * norm_batch.reshape(1,-1)
    #Kx = X[:,idx].T @ X[:,idx]

    _, _, y_batch = torch.linalg.svd(y_batch)
    y_batch = F.normalize(y_batch[:5,:], dim=0, p=2)
    Ky = (y_batch.T @ y_batch).float()

    z_svals = torch.sqrt(F.relu(torch.linalg.eigvalsh(Ky + Kx)))
    x_svals = torch.sqrt(F.relu(torch.linalg.eigvalsh(Kx)))
    z_nuc = z_svals.sum()
    x_nuc = x_svals.sum()

    loss = z_nuc - alpha * x_nuc + beta * x_svals.max() ** 2
    loss.backward()
    tracker['loss'].append(loss.detach().item())
    tracker['x_svals'].append(x_svals.detach().sort()[0])

    bar.set_postfix(loss="{:1.5e}".format(loss.detach().item()),
                    grad="{:1.3e}".format(torch.linalg.norm(X.grad, 'fro')),
                    tau="{:1.3e}".format(tau.detach()),
                    lr="{:1.3e}".format(scheduler.get_last_lr()[0]))
    bar.update()
    optimizer.step()
    scheduler.step()
bar.close()

In [None]:
Y_ = Y.detach().cpu()
X_ = X.detach().cpu()
#K_ = (F.normalize(X_, p=2, dim=0).T @ F.normalize(X_, p=2, dim=0)) ** 2
K_ = torch.exp(-(torch.cdist(X_.T, X_.T, p=2)**2) / 1.0)
svals_ = torch.stack(tracker['x_svals'], dim=0).detach().cpu()
Z_ = torch.cat((Y_,X_), dim=0)

fig = plt.figure(figsize=(30,5))
plt.subplot(1,5,1)
plt.imshow(Y_)
plt.axis("off")
plt.title(f"Y {Y_.shape[0]}x{Y_.shape[1]} ({u.shape[1]} minterms)")
plt.subplot(1,5,2)
for j in range(svals_.shape[1]):
    plt.plot(svals_[::5,j], '-', linewidth=1.0)
plt.title("\sqrt{\lambda_i(K)}")
plt.xlabel("Iteration")
plt.subplot(1,5,3)
sns.heatmap(K_)
plt.colorbar(fraction=0.05, pad=0.04)
plt.axis("off")
plt.title("Kernel")
plt.subplot(1,5,4)
plt.plot(tracker["loss"])
plt.xscale("log")
plt.title("Loss")
plt.show()

In [None]:
from torch.linalg import eigh, norm

idx = torch.where(Y[0,:] == 1)[0]

# Build kernel matrix for representations verifying proposition
subspace_vecs = X[idx]
subspace_kernel = torch.exp(-(torch.cdist(subspace_vecs, subspace_vecs, p=2)**2) / tau)

subspace_lambdas, subspace_evecs = eigh(subspace_kernel)

lambdas_mask = subspace_lambdas > 1e-1
lambdas = subspace_lambdas[lambdas_mask]
subspace_evecs = subspace_evecs[:,lambdas_mask]
subspace_evecs /= norm(subspace_evecs, ord=2, dim=0)
subspace_evecs /= torch.sqrt(lambdas.view(1,-1))

In [None]:
subspace_evecs.shape

In [None]:
kernel_memory_query = torch.exp(-(torch.cdist(subspace_vecs, X, p=2)**2) / tau) 

subspace_projections = torch.einsum("ij,ik->jk", kernel_memory_query, subspace_evecs)
prob = torch.square(subspace_projections).sum(-1)

In [None]:
subspace_evecs.shape

In [None]:
from scipy import optimize

mu = torch.linalg.svdvals(Y)
print(mu)
def curve(x):
    x = x[0]
    return [np.array([x / np.sqrt(m**2 + x**2) for m in mu]).sum() - len(mu)*alpha + 2.0 * beta * x]

def grad(x):
    x = x[0]
    return [np.array([m**2 / np.sqrt(m**2 + x**2)**3 for m in mu]).sum() + 2.0 * beta]

sol = optimize.root(curve, [0.0], jac=grad, method='hybr')
print("Singular values according to theory: {}".format(round(sol.x[0], 5)))

In [None]:
plt.figure()
plt.stem(svals_[-1])