In [1]:
__PRODUCTION__ = 1
__NAME__       = 'iteration'
__WIDTH__      = 5.5  # NeurIPS 2021 text box width
__HEIGHT__     = 3.1

if __PRODUCTION__:
    from mplmagic2 import pgf
else:
    from mplmagic2 import svg

from mplmagic2 import SuperFigure, tex
import matplotlib.pyplot as plt

print('This is how much space the figure will take up on letter paper')
SuperFigure.size_hint(__WIDTH__, __HEIGHT__, margin_left=0.5 * (8.5 - __WIDTH__));

This is how much space the figure will take up on letter paper


In [2]:
import dill

import functools
import numpy as np
from scipy.ndimage import gaussian_filter1d
from scipy.spatial.distance import cdist
import torch
import torch.nn.functional as F
from torch import optim

from symfac.experimental import RBFExpansionPlus

In [3]:
def rbf(X, Y=None):
    return np.exp(-np.subtract.outer(X, Y if Y is not None else X)**2)

n = 64

np.random.seed(128)
u = 5 * np.random.randn(n)
v = 4 * np.random.randn(n)
u = gaussian_filter1d(u, 3., order=0)
v = gaussian_filter1d(v, 6., order=0)

K = rbf(u) - rbf(v)
Kt = torch.tensor(K, dtype=torch.float32)

fig, axs = plt.subplots(1, 2, figsize=(8, 4))

axs[0].imshow(K)
axs[1].bar(np.arange(n), np.linalg.eigvalsh(K))
plt.show()

U, S, V = np.linalg.svd(K)
svd_K = []
for k in range(1, 10):
    truncation = (U[:, :k] * S[None, :k]) @ V[:k, :]
    svd_K.append(round(float(F.mse_loss(torch.tensor(truncation), torch.tensor(K))), 5))
print(svd_K)

[0.13909, 0.06549, 0.0267, 0.00791, 0.0037, 0.00134, 0.00047, 0.0001, 5e-05]


In [4]:
torch.manual_seed(15513512)

K = rbf(u) - rbf(v)

fac = RBFExpansionPlus(
    k=2,
    batch_size=512,
    max_steps=1001,
)

component_history = []

def component_history_callback(step, model):
    if step in [0, 5, 20, 100]:
        component_history.append(dict(
            step=step,
            model=dill.loads(dill.dumps(model))
        ))

fac.fith(
    K.astype(np.float32),
    plugins=[
        dict(
            every=1,
            requires=['step', 'model'],
            callback=component_history_callback
        )
    ],
    u0=0.1 * torch.randn(fac.batch_size, n, fac.k),
)

100%|██████████| 1001/1001 [00:17<00:00, 56.88it/s]


<symfac.experimental.rbf_expansion_plus.RBFExpansionPlus at 0x7f49ade72ee0>

In [5]:
torch.min(fac.report.loss_best)

tensor(2.0720e-15)

In [6]:
best_run = np.argmin(fac.report.loss_best)
learned_data = []
learned_components = []
print(fac.optimum.a[best_run, :])

plt.plot(fac.report.loss_history_ticks, fac.report.loss_history[:, best_run])
plt.xscale('log')
plt.yscale('log')
xticks = 2**np.arange(11)
plt.gca().set_xticks(xticks)
plt.gca().set_xticklabels(xticks)
plt.title('Loss of the best run')
plt.show()

fig, axs = plt.subplots(len(component_history), 2, figsize=(10, 4 * len(component_history)))

for snapshot, axrow in zip(component_history, axs):
    step = snapshot['step']
    model = snapshot['model']
    axrow[0].set_title(f'Step {step}')
    learned_data.append(model(runs=best_run, device='cpu', grad_on=False))
    im = axrow[0].imshow(
        learned_data[-1],
        #model(runs=best_run, device='cpu', grad_on=False),
        vmin=-1, vmax=1
    )
    plt.colorbar(mappable=im, ax=axrow[0])
    learned_components.append(
        [model.u[best_run, :, 0].detach().cpu() * fac.optimum.a[best_run, 0].item(), 
          model.u[best_run, :, 1].detach().cpu() * fac.optimum.a[best_run, 1].item()]
         )
    axrow[1].plot(u, ls=(2, (3, 6)))
    axrow[1].plot(v, ls=(2, (3, 6)))
    axrow[1].plot(model.u[best_run, :, 0].detach().cpu() * fac.optimum.a[best_run, 0].item(), label='comp1')
    axrow[1].plot(model.u[best_run, :, 1].detach().cpu() * fac.optimum.a[best_run, 1].item(), label='comp2')
    #axrow[1].plot(model.u[best_run, :, 2].detach().cpu() * fac.optimum.a[best_run, 2].item(), label='comp3')
    axrow[1].legend()
plt.show()

tensor([ 1., -1.], device='cuda:0')


In [7]:
learned_components[2]

[tensor([-0.2138, -0.1461, -0.1820, -0.0646,  0.0818,  0.4204,  0.3965,  0.5514,
          0.5011,  0.2147,  0.2672,  0.2531,  0.3223,  0.4518,  0.2494,  0.3234,
          0.1192, -0.2392, -0.3888, -0.4789, -0.4926, -0.6058, -0.3498, -0.5117,
         -0.4499, -0.3810, -0.2022, -0.1608,  0.0559,  0.4186,  0.5634,  0.4567,
          0.5871,  0.4941,  0.4830,  0.4010,  0.0561, -0.2151, -0.3129, -0.3933,
         -0.3601, -0.0302,  0.2523,  0.5170,  0.5998,  0.4956,  0.2653, -0.0992,
         -0.4548, -0.4811, -0.5596, -0.5241, -0.5423, -0.4971, -0.5493, -0.4490,
         -0.4268, -0.1846,  0.0414,  0.4256,  0.5369,  0.5411,  0.4771,  0.5263]),
 tensor([-1.2518, -1.2644, -1.2444, -1.2687, -1.2240, -1.1973, -1.3122, -1.2710,
         -0.8301, -0.5069, -0.3696, -0.3962, -0.2750, -0.3836, -0.1594, -0.1780,
         -0.0166,  0.0219,  0.0993,  0.2112,  0.3761,  0.6702,  0.5615,  0.6298,
          0.8364,  0.8987,  0.9856,  1.1518,  1.2133,  1.4123,  1.2394,  1.3114,
          1.2345,  0.9677,

In [20]:
fig = SuperFigure(plt.figure(figsize=(__WIDTH__, __HEIGHT__), dpi=300))
ax_canvas = fig.make_axes(
    left=0, right=1, top=0, bottom=1, zorder=-100,
    style='blank' if __PRODUCTION__ else None
)
ax_canvas.set_xlim([0, 1])
ax_canvas.set_ylim([0, 1])

# draw the axes grid
x0 = 0.035
y0 = 0.11
dx = 0.24
dy = 0.41
w = 0.22
axs = np.array([
    [
        fig.make_axes(
            left=x0 + dx * i,
            width=w,
            top=y0,
            width_to_height=1.0,
            style='modern',
        ) for i in range(4)
    ],
    [
        fig.make_axes(
            left=x0 + dx * i,
            width=w,
            top=y0 + dy,
            width_to_height=1.0,
            style='modern'
        ) for i in range(4)
    ],
])

title_style = dict(
    fontsize=8,
    y=0.96,
    linespacing=1.6,
    va='bottom'
)
tick_style = dict(
    fontsize=7,
)
label_style = dict(
    fontsize=8,
    labelpad=0.75,
)

for ax, snapshot in zip(axs[0], component_history):
    loss = F.mse_loss(snapshot['model'](runs=best_run, device='cpu'), Kt).item()
    ax.set_title(
        ''.join([
            fr'\textbf{{Step {snapshot["step"]}}}',
            '\n',
            f'MSE Loss ${tex.sciform(loss, frac_digits=1)}$'
        ]),
        **title_style
    )

best_run = np.argmin(fac.report.loss_best)

img_vmin, img_vmax = -1, 1
image_style = dict(
    vmin=img_vmin,
    vmax=img_vmax,
    cmap='Spectral'
)
truth_style_base = dict(
    lw=0.75,
    dash_capstyle='round'
)
truth_style1 = dict(
    ls=(2, (1.25, 2.5)),
    **truth_style_base
)
truth_style2 = dict(
    ls=(4, (0.5, 2.5, 3, 2.5)),
    **truth_style_base
)
learned_style_base = dict(
    lw=1.0,
    markersize=3,
    markeredgewidth=0.75,
    markerfacecolor='w',
    markevery=5
)
learned_style1 = dict(
    marker='o',
    **learned_style_base
)
learned_style2 = dict(
    marker='^',
    **learned_style_base
)
u_style = dict(
    color='#FC645A'
)
v_style = dict(
    color='#4A70A0'
)

for ax_img, ax_plot, snapshot in zip(axs[0], axs[1], component_history):
    ax_img.imshow(
        snapshot['model'](runs=best_run, device='cpu', grad_on=False),
        **image_style
    )

    u1, v1 = snapshot['model'].u[best_run, :, :].detach().cpu().numpy().T
    ax_plot.plot(u, **truth_style1, **u_style)
    ax_plot.plot(v, **truth_style2, **v_style)
    ax_plot.plot(u1, **learned_style1, **u_style)
    ax_plot.plot(v1, **learned_style2, **v_style)

# use dummy plots for legends
bf = 'boldsymbol' if __PRODUCTION__ else 'mathbf'
def u_i(i):
    # AMS Euler font
    return fr'$\mathbf{{u}}^{{\!(\!{i}\!)}}$'
    # Same font as ticks, etc.
    # return tex.pow(r'\textbf{u}', f'({i})', mathmode=True)

ax_canvas.plot([-1], [-1], **truth_style1, **u_style, label='True ' + u_i(1))
ax_canvas.plot([-1], [-1], **truth_style2, **v_style, label='True ' + u_i(2))
ax_canvas.plot([-1], [-1], **learned_style1, **u_style, label='Learned ' + u_i(1))
ax_canvas.plot([-1], [-1], **learned_style2, **v_style, label='Learned ' + u_i(2))
ax_canvas.legend(
    loc='center',
    frameon=False,
    ncol=4,
    fontsize=7,
    bbox_to_anchor=(0.5, 0.02)
)

ax_colorbar = fig.make_axes(
    left=axs[0, -1].right + 0.01,
    width=0.01,
    top=axs[0, -1].top + 0.02,
    bottom=axs[0, -1].bottom - 0.02,
    style='modern'
)

ax_colorbar.imshow(
    np.linspace(img_vmin, img_vmax, 256).reshape(-1, 1),
    aspect='auto',
    origin='lower',
    extent=(0, 1, 0, 1),
    **image_style
)

ax_colorbar.set_xlim([0, 1])
ax_colorbar.set_xticks([])
ax_colorbar.set_ylim([0, 1])
ax_colorbar.set_yticks([])
colorbar_tickstyle = dict(
    color='k',
    ha='center',
    **tick_style
)
ax_colorbar.text(
    0.5, 1.02,
    '%.0f' % img_vmax,
    va='bottom',
    **colorbar_tickstyle,
)
ax_colorbar.text(
    0.5, -0.02,
    '%.0f' % img_vmin,
    va='top',
    **colorbar_tickstyle
)
for _, sp in ax_colorbar.spines.items():
    sp.set_visible(True)
    sp.set_color('#606060')
    sp.set_linewidth(0.25)

for i, ax in enumerate(axs[0]):
    ticks = [0, 64]
    ax.xaxis.tick_top()
    ax.set_xbound(lower=0, upper=64)
    ax.set_ybound(lower=0, upper=64)
    ax.set_xticks([])
    # ax.set_xticks(ticks)
    # ax.set_xticklabels(['%d' % x for x in ticks], **tick_style)
    ax.set_yticks(ticks)
    if i == 0:
        ax.set_yticklabels(['%d' % y for y in ticks], **tick_style)
        ax.set_ylabel(
            r'\textbf{Reconstructed Matrix}',
            fontsize=8,
            labelpad=-2,
        )
    else:
        ax.set_yticklabels([])
        
for i, ax in enumerate(axs[1]):
    xticks = [0,64]
    yticks = [-2.5, 2.5]
    ax.set_xbound(lower=0, upper=64)
    ax.set_ybound(lower=-2.7, upper=2.7)
    ax.set_xticks(xticks)
    ax.set_yticks(yticks)
    ax.set_xticklabels(xticks, **tick_style)
    if i == 0:
        ax.set_yticklabels(yticks, **tick_style)
        ax.set_ylabel(
            r'\textbf{Component Vectors}',
            fontsize=8,
            labelpad=-7,
        )
    else:
        ax.set_yticklabels([])
        

if __PRODUCTION__:
    fig.savefig(f'pgf/{__NAME__}.pgf', dpi=300)
else:

    fig.savefig(f'svg/{__NAME__}.svg', dpi=300)

In [21]:
!make -f Makefile.figures fig-"$__NAME__".pdf 2>&1 | tail -n 1

Successfully created fig-iteration.pdf


# Sandbox below

---