In [1]:
__PRODUCTION__ = 1
__NAME__       = 'efficiency-grpah-kernel'
__WIDTH__      = 5.5  # NeurIPS 2021 text box width
__HEIGHT__     = 1.8

if __PRODUCTION__:
    from mplmagic2 import pgf
else:
    from mplmagic2 import svg

from mplmagic2 import SuperFigure, tex
import matplotlib.pyplot as plt
from matplotlib import patches
import matplotlib.patheffects as path_effects

print('This is how much space the figure will take up on letter paper')
SuperFigure.size_hint(__WIDTH__, __HEIGHT__, margin_left=0.5 * (8.5 - __WIDTH__));

This is how much space the figure will take up on letter paper


In [2]:
import functools
import numpy as np
import colorsys
from scipy.ndimage import gaussian_filter1d
from scipy.spatial.distance import cdist
import torch
import torch.nn.functional as F
from torch import optim
import networkx as nx
import json
import pickle
import matplotlib

from symfac.experimental import RBFExpansionPlus, RBFExpansionMiniBatchPlus

In [3]:
target = torch.tensor(
    np.load('data/paper/QM7-rkhs-distance.maxdet.npy'),
    dtype=torch.float32
)

In [4]:
time_per_step = {}
for key, table in json.loads(open('data/paper/time-per-step.json').read()).items():
    time_per_step[key] = {
        int(k): t for k, t in table.items()
    }
time_per_step

{'GD': {1: 0.010356170999898495,
  2: 0.018769071000065196,
  3: 0.027150676000019303,
  4: 0.032652469499907966,
  5: 0.04127614599997287,
  6: 0.04706726799997796,
  7: 0.055211035000070297,
  8: 0.06141732100002173,
  9: 0.07012742749998324,
  10: 0.07557914349990824,
  16: 0.11964770300005512},
 'SGD.MINIBATCH8': {1: 0.0028968785002234654,
  2: 0.0031054250000579486,
  3: 0.0032788099997560494,
  4: 0.0036005299998578266,
  5: 0.0038789539999015687,
  6: 0.004137956000022314,
  7: 0.004411385000139489,
  8: 0.004742534999877535,
  9: 0.005030798500001765,
  10: 0.005294907000006788,
  16: 0.006994879000103538}}

In [5]:
# OLD result, 500K steps, 0.1 learning rate, 100 batch size
np.array([
    RBFExpansionPlus.from_pickle(f'data/paper/GraphKernel.GD.RANK{i}.pickle').report.loss_best.min().item()
    for i in 1 + np.arange(10)
])

array([0.00614288, 0.00428757, 0.00300888, 0.00247165, 0.00161892,
       0.00120787, 0.00096547, 0.00080021, 0.00074569, 0.00066422])

In [6]:
gd_losses_top10 = np.array([
    RBFExpansionPlus.from_pickle(f'data/paper/GraphKernel.GD.ITER10000.BATCH1024.RANK{i}.pickle').report.loss_best.min().item()
    for i in 1 + np.arange(10)
])
gd_losses_top10

array([0.00620039, 0.00403608, 0.00270088, 0.00202079, 0.00156884,
       0.00135751, 0.00115214, 0.00100329, 0.00092674, 0.00078057])

In [7]:
sgd_losses_top10 = np.array([
    RBFExpansionMiniBatchPlus.from_pickle(f'data/paper/GraphKernel.SGD.ITER100000.BATCH1024.RANK{i}.MINIBATCH8.pickle').report.loss_best.min().item()
    for i in 1 + np.arange(10)
])
sgd_losses_top10

array([0.00623057, 0.0040156 , 0.00303386, 0.00217209, 0.00187505,
       0.00153212, 0.00128608, 0.00104993, 0.0009985 , 0.00089636])

In [8]:
def load(i):
    try:
        return RBFExpansionMiniBatchPlus.from_pickle(f'data/paper/GraphKernel.SGD.ITER100000.BATCH1024.RANK{i}.MINIBATCH8.pickle')
    except:
        try:
            return RBFExpansionMiniBatchPlus.from_pickle(f'data/paper/GraphKernel.SGD.ITER100000.BATCH512.RANK{i}.MINIBATCH8.pickle')
        except:
            try:
                return RBFExpansionMiniBatchPlus.from_pickle(f'data/paper/GraphKernel.SGD.ITER100000.BATCH256.RANK{i}.MINIBATCH8.pickle')
            except:
                try:
                    return RBFExpansionMiniBatchPlus.from_pickle(f'data/paper/GraphKernel.SGD.ITER100000.BATCH128.RANK{i}.MINIBATCH8.pickle')
                except:
                    return RBFExpansionMiniBatchPlus.from_pickle(f'data/paper/GraphKernel.SGD.ITER100000.BATCH64.RANK{i}.MINIBATCH8.pickle')

sgd_losses_pow2 = np.array([
    load(i).report.loss_best.min().item()
    for i in 2**np.arange(8)
])
sgd_losses_pow2

array([6.23057457e-03, 4.01560450e-03, 2.17208732e-03, 1.04993337e-03,
       4.75537119e-04, 1.76278845e-04, 5.44178038e-05, 1.36175804e-05])

In [9]:
def get_svd_losses(X, N):

    U, S, V = np.linalg.svd(X.detach().numpy().astype(np.float64))

    return np.array([np.nan] + [
        F.mse_loss(torch.tensor((U[:, :s] * S[None, :s]) @ V[:s, :]), X).item()
        for s in (np.arange(N) + 1)
    ])

svd_losses = get_svd_losses(target, target.shape[0])
svd_losses

array([           nan, 1.18033795e-02, 7.14218483e-03, 4.27781065e-03,
       3.60614917e-03, 3.09930362e-03, 2.68164223e-03, 2.42946051e-03,
       2.21411275e-03, 2.06660096e-03, 1.94150743e-03, 1.82584854e-03,
       1.72719583e-03, 1.64549299e-03, 1.56959208e-03, 1.50809450e-03,
       1.44984754e-03, 1.40125853e-03, 1.35678719e-03, 1.31362758e-03,
       1.27692494e-03, 1.24346601e-03, 1.21368005e-03, 1.18557889e-03,
       1.15829506e-03, 1.13298660e-03, 1.10858484e-03, 1.08673075e-03,
       1.06574667e-03, 1.04499929e-03, 1.02518367e-03, 1.00649112e-03,
       9.89600418e-04, 9.73061587e-04, 9.57042232e-04, 9.41162894e-04,
       9.25553645e-04, 9.10545430e-04, 8.96088233e-04, 8.81880682e-04,
       8.68346928e-04, 8.55418132e-04, 8.42728803e-04, 8.30444548e-04,
       8.18486454e-04, 8.06712565e-04, 7.95414401e-04, 7.84229644e-04,
       7.73451575e-04, 7.62699292e-04, 7.52359583e-04, 7.42575675e-04,
       7.33074536e-04, 7.23646510e-04, 7.14291093e-04, 7.05294666e-04,
      

In [10]:
def get_nystrom_losses(X, C, random_seed=None):

    with torch.random.fork_rng(devices=None):
        if random_seed is not None:
            torch.random.manual_seed(random_seed)

        if isinstance(C, int):
            C = torch.randperm(len(X))[:C].numpy()
        
        C = np.sort(C)
        D = np.setxor1d(np.arange(len(X)), C)

        Kcc = X[C, :][:, C]
        Kdc = X[D, :][:, C]
        Kdd = X[D, :][:, D]
        
        Kcc2 = Kcc.copy()
        Kcc2[np.diag_indices_from(Kcc2)] += 1e-10
        Kccinv = np.linalg.inv(Kcc2)
        
        Kdd_approx = Kdc @ Kccinv @ Kdc.T
        
        K_approx = torch.tensor(np.block([
            [Kcc, Kdc.T],
            [Kdc, Kdd_approx]
        ]))
        K_original = torch.tensor(np.block([
            [Kcc, Kdc.T],
            [Kdc, Kdd]
        ]))

        return F.mse_loss(K_approx, K_original, reduction='mean')


nystrom_losses = np.array([
    get_nystrom_losses(target.numpy(), k, random_seed=0).item() for k in range(target.shape[0] // 2)
])
nystrom_losses



array([7.02259302e-01, 5.72581631e+19, 4.98886704e-01, 1.22917421e-01,
       6.74562305e-02, 5.16107306e-02, 2.71645486e-02, 1.84814893e-02,
       1.56307481e-02, 1.38288010e-02, 1.23420684e-02, 1.08710239e-02,
       9.39737819e-03, 8.91632028e-03, 8.02531838e-03, 7.56157562e-03,
       7.05981627e-03, 6.36129733e-03, 6.06473349e-03, 5.59320021e-03,
       5.42458054e-03, 4.96871537e-03, 4.74538561e-03, 4.57891170e-03,
       4.48798714e-03, 4.31839377e-03, 4.27790638e-03, 4.03640978e-03,
       3.98738543e-03, 3.80847743e-03, 3.58551880e-03, 3.50836059e-03,
       3.42833647e-03, 3.37694865e-03, 3.28029110e-03, 3.23108328e-03,
       3.15024192e-03, 3.11281998e-03, 3.05419462e-03, 2.98236264e-03,
       2.95610423e-03, 2.78732134e-03, 2.70370580e-03, 2.46871938e-03,
       2.43991613e-03, 2.39262450e-03, 2.36799801e-03, 2.29798816e-03,
       2.25978205e-03, 2.22332496e-03, 2.15417915e-03, 2.09674891e-03,
       2.02476652e-03, 1.97807327e-03, 1.95556181e-03, 1.92967977e-03,
      

In [11]:
time_per_step

{'GD': {1: 0.010356170999898495,
  2: 0.018769071000065196,
  3: 0.027150676000019303,
  4: 0.032652469499907966,
  5: 0.04127614599997287,
  6: 0.04706726799997796,
  7: 0.055211035000070297,
  8: 0.06141732100002173,
  9: 0.07012742749998324,
  10: 0.07557914349990824,
  16: 0.11964770300005512},
 'SGD.MINIBATCH8': {1: 0.0028968785002234654,
  2: 0.0031054250000579486,
  3: 0.0032788099997560494,
  4: 0.0036005299998578266,
  5: 0.0038789539999015687,
  6: 0.004137956000022314,
  7: 0.004411385000139489,
  8: 0.004742534999877535,
  9: 0.005030798500001765,
  10: 0.005294907000006788,
  16: 0.006994879000103538}}

In [12]:
gd_walltime_top10 = np.array([
    time_per_step['GD'][i] * np.max(RBFExpansionPlus.from_pickle(f'data/paper/GraphKernel.GD.ITER10000.BATCH1024.RANK{i}.pickle').report.loss_history_ticks)
    for i in 1 + np.arange(10)
])
gd_walltime_top10

array([103.55135383, 187.67194093, 271.47960932, 326.49204253,
       412.72018385, 470.62561273, 552.05513897, 614.11179268,
       701.20414757, 755.71585586])

In [13]:
sgd_walltime_top10 = np.array([
    time_per_step['SGD.MINIBATCH8'][i] * np.max(RBFExpansionMiniBatchPlus.from_pickle(f'data/paper/GraphKernel.SGD.ITER100000.BATCH1024.RANK{i}.MINIBATCH8.pickle').report.loss_history_ticks)
    for i in 1 + np.arange(10)
])
sgd_walltime_top10

array([289.65888124, 310.51144576, 327.84821188, 360.01699469,
       387.85661045, 413.75422044, 441.09438616, 474.20607464,
       503.02954202, 529.43775093])

In [14]:
fig = SuperFigure(plt.figure(figsize=(__WIDTH__, __HEIGHT__), dpi=300))
ax_canvas = fig.make_axes(
    left=0, right=1, top=0, bottom=1, zorder=-100,
    style='blank' if __PRODUCTION__ else None
)
ax_canvas.set_xlim([0, 1])
ax_canvas.set_ylim([0, 1])

# draw the axes grid
x0 = 0.01
dx = 0.255
dx2 = 0.01
w = 0.205
axs = [
    fig.make_axes(
        #left=x0 + dx * i - dx2*(i//3),
        left=x0 + dx * i + (dx2 if i >= 2 else 0),
        width=w,
        top=0.165,
        width_to_height=1.0,
        style='modern'
    ) for i in range(4)
]

image_style = dict(
    vmin=0,
    vmax=1,
    cmap='Spectral'
)
plot_style = dict(
    lw=0.75,
    marker='o',
    markersize=2.5
)
svd_style = dict(
    color='#404040',
)
rbf_style = dict(
    color='#90306A',
)
srbf_style = dict(
    color='#60D0FF',
)

axs[0].imshow(target, **image_style)

pow2_idx = 2**np.arange(8)
axs[1].plot(pow2_idx, sgd_losses_pow2, **plot_style, **srbf_style)
axs[1].plot(pow2_idx, svd_losses[pow2_idx], **plot_style, **svd_style)

top10_idx = 1 + np.arange(10)
axs[2].plot(top10_idx, svd_losses[top10_idx], **plot_style, **svd_style)
axs[2].plot(top10_idx, gd_losses_top10, **plot_style, **rbf_style)
axs[2].plot(top10_idx, sgd_losses_top10, **plot_style, **srbf_style)

bar_width = 0.25
bar_style = dict(
    width = bar_width
)

# draw GD times
def draw_bars(ax, h, offset, **kwargs):
    ax.bar(1 + np.arange(len(h)) + offset, h, **kwargs)

draw_bars(axs[3], gd_walltime_top10, -0.5 * bar_width, **bar_style, **rbf_style)
draw_bars(axs[3], sgd_walltime_top10, 0.5 * bar_width, **bar_style, **srbf_style)
# # draw_bars(axs[3], [gd_10_times[t]/sgd_10_times[t] for t in range(10)], -0.5 * bar_width, **bar_style, **srbf_style)
# # draw_bars(axs[3], data['times']['rbf'], 0.5 * bar_width, **bar_style, **rbf_style)


title_style = dict(
    fontsize=8,
    y=0.945,
    va='bottom'
)

tick_style = dict(
    fontsize=7,
)
xlabel_style = dict(
    fontsize=8,
    labelpad=1.0,
)
ylabel_style = dict(
    fontsize=8,
    labelpad=0.75,
)

ticks = [0, 250]
axs[0].set_xbound(lower=0, upper=250)
axs[0].set_ybound(lower=0, upper=250)
axs[0].set_xticks(ticks)
axs[0].set_yticks([])
axs[0].set_xticklabels(['%d' % x for x in [0, 250]], **tick_style)
axs[0].set_xlabel('Molecule ID', **xlabel_style)

for ax in axs[1:2]:
    ax.plot(
        [1.05, 8, 8, 1.05, 1.05],
        [1e-3, 1e-3, 2e-2, 2e-2, 1e-3],
        color='k',
        lw=0.5,
        ls=(2, (1, 3)),
        dash_capstyle='round'
    )
    polygon_style = dict(
        edgecolor='none',
        clip_on=False,
    )
    ax.add_patch(
        patches.Rectangle(
            (1.03, 1e-3),
            8 - 1.03, 2e-2 - 1e-3,
            facecolor='#E0E0E0',
            **polygon_style
        )
    )
    ax.add_patch(
        patches.FancyArrow(
            8, 3.5 * 1e-5**0.5, 512, 0,
            width=0.001,
            head_width=0.003,
            head_length=80,
            facecolor='#404040',
            length_includes_head=True,
            **polygon_style
        )
    )

    # print(svd_losses)
    
    # extra_ticks = np.linspace(-1, -3, 3)
    for logr_rbf in [4, 5, 6]:
        ref = sgd_losses_pow2[logr_rbf]
        r_rbf = 2**logr_rbf
        r_svd = np.flatnonzero(np.array(svd_losses) <= ref)[0] + 1
        print(r_rbf, r_svd)

        ax.axhline(
            ref,
            color='#404040',
            lw=0.5,
            ls=(2, (1.7, 3.1)),
            dash_capstyle='round',
            zorder=-30
        )
        x0 = 2**0.25
        ax.text(
            x0, ref * 0.8, f'RBF {r_rbf} / {r_svd} SVD  ',
            ha='left',
            va='top',
            fontsize=7,
            zorder=20,
            path_effects=[
                path_effects.Stroke(linewidth=1.5, foreground='w'),
                path_effects.Normal()
            ]        
        )
#         ax.add_patch(
#             patches.Rectangle(
#                 (1.05, ref * 0.36), 13, 0.64 * ref,
#                 color='w',
#                 alpha=0.8,
#                 zorder=10
#             ),
#         )
            
    ax.set_xscale('log', base=2)
    ax.set_yscale('log', base=10)
    ax.minorticks_off()
    xticks = np.arange(8)
    yticks = np.linspace(-5, -2, 4)
    ax.set_xbound(lower=2**xticks.min(), upper=2**xticks.max())
    ax.set_ybound(lower=0.25 * 10**yticks.min(), upper=2 * 10**yticks.max())
    ax.set_xticks(2**xticks)
    ax.set_yticks(10**yticks)
    ax.set_xticklabels([fr'{2**t:.0f}' for t in xticks], **tick_style)
    ax.set_yticklabels([fr'${tex.pow(10, int(t))}$' for t in yticks], **tick_style)
    ax.set_xlabel('Components', **xlabel_style)
    ax.set_ylabel('MSE Loss', **ylabel_style)

for ax in axs[2:3]:
    ax.set_yscale('log', base=10)
    ax.minorticks_off()
    xticks = np.arange(1, 11)
    yticks = np.linspace(-3, -2, 2) 
    ax.set_xbound(lower=xticks.min(), upper=xticks.max())
    ax.set_ybound(lower=0.5 * 10**yticks.min(), upper=2 * 10**yticks.max())
    ax.set_xticks(xticks)
    ax.set_yticks(10**yticks)
    ax.set_xticklabels(['%d' % x for x in xticks], **tick_style)
    ax.set_yticklabels([fr'${tex.pow(10, int(t))}$' for t in yticks], **tick_style)
    ax.set_xlabel('Components', **xlabel_style)
    ax.set_ylabel('MSE Loss', **ylabel_style)
    
for ax in axs[3:4]:
    xticks = 1 + np.arange(10)
    # yticks = [1,2,3,4]
    yticks = np.linspace(0, 800, 5)
    ax.set_xbound(lower=xticks.min() - 0.5, upper=xticks.max() + 0.5)
    ax.set_ybound(lower=yticks.min(), upper=yticks.max())

    ax.set_xticks(xticks)
    ax.set_yticks(yticks)
    ax.set_xticklabels(['%d' % x for x in xticks], **tick_style)
    ax.set_yticklabels(['%d' % y for y in yticks], **tick_style)
    ax.set_xlabel('Components', **xlabel_style)
    ax.set_ylabel('Time (s)', **ylabel_style)

for ax in axs[1:]:
    ax.tick_params(
        axis='y', which='both', length=1, direction='in', pad=1
    )

axs[0].set_title(r'\textbf{Molecular Distance}', **title_style)
axs[1].set_title(r'\textbf{Accuracy}', **title_style)
axs[2].set_title(r'Accuracy', **title_style)
axs[3].set_title(r'Time to Accuracy', **title_style)

ax_zoomin = fig.make_axes(
    left=axs[2].left,
    right=axs[3].right,
    top=0,
    bottom=1,
    style='blank'
)
ax_zoomin.set_xlim([0, 1])
ax_zoomin.set_ylim([0, 1])
overtitle_style = dict(
    fontsize=8,
    va='top',
    ha='center'
)
ax_zoomin.text(
    0.5, 0.985,
    r'\textbf{Comparison of GD and SGD} (k$\leq$10)',
    **overtitle_style,
)
ax_zoomin.axhline(0.915, lw=0.5, color='k')

ax_canvas.bar([-1], [0], **bar_style, **svd_style, label='SVD')
ax_canvas.bar([-1], [0], **bar_style, **rbf_style, label=fr'RBF, ${tex.pow(10, 4)}$ GD steps')
ax_canvas.bar([-1], [0], **bar_style, **srbf_style, label=fr'RBF, ${tex.pow(10, 5)}$ SGD steps')
ax_canvas.legend(
    loc='upper center',
    bbox_to_anchor=(0.5, 0.10),
    ncol=3,
    frameon=False,
    fontsize=6
)

if __PRODUCTION__:
    fig.savefig(f'pgf/{__NAME__}.pgf', dpi=300)
else:
    fig.savefig(f'svg/{__NAME__}.svg', dpi=300)
plt.show()

16 92
32 176
64 227


In [15]:
!make -f Makefile.figures fig-"$__NAME__".pdf 2>&1 | tail -n 1

Successfully created fig-efficiency-grpah-kernel.pdf


END
---