In [1]:
__PRODUCTION__ = 1
__NAME__       = 'optimization'
__WIDTH__      = 5.5  # NeurIPS 2021 text box width
__HEIGHT__     = 2.15

if __PRODUCTION__:
    from mplmagic2 import pgf
else:
    from mplmagic2 import svg

from collections import OrderedDict
from mplmagic2 import SuperFigure, tex
import matplotlib.pyplot as plt

print('This is how much space the figure will take up on letter paper')
SuperFigure.size_hint(__WIDTH__, __HEIGHT__, margin_left=0.5 * (8.5 - __WIDTH__));

This is how much space the figure will take up on letter paper


In [2]:
import functools
import numpy as np
import colorsys
from scipy.ndimage import gaussian_filter1d
from scipy.spatial.distance import cdist
import torch
import torch.nn.functional as F
from torch import optim
import networkx as nx
import json
import pickle

from symfac.experimental import RBFExpansionV2

In [3]:
algs_original = [
    'SGD',
    'ASGD',
    'RMSprop',
    'Adam',
    'AdamW',
    'Adadelta',
    'Adagrad',
]

algs = [
    'Vanilla',
    'RMSprop',
    'Adam',
    'AdamW',
    'Adadelta',
    'Adagrad',
]

In [4]:
over_parameterized = {
    alg: pickle.loads(open(f'data/paper/over-optimization-methods-{alg}.pickle', 'rb').read())
    for alg in algs_original
}

over_parameterized.pop('ASGD');
over_parameterized = {
    key if key != 'SGD' else 'Vanilla': value for key, value in over_parameterized.items()
}

In [5]:
even_parameterized = {
    alg: pickle.loads(open(f'data/paper/optimization-methods-{alg}.pickle', 'rb').read())
    for alg in algs_original
}

even_parameterized.pop('ASGD');
even_parameterized = {
    key if key != 'SGD' else 'Vanilla': value for key, value in even_parameterized.items()
}

In [6]:
hist = {
    '2-component': {
        '1.0': pickle.loads(open('data/paper/hist_starts_2.pickle', 'rb').read())[0],
        '0.1': pickle.loads(open('data/paper/hist_starts_2.pickle', 'rb').read())[1],
    },
    '4-component': {
        '1.0': pickle.loads(open('data/paper/hist_starts_4.pickle', 'rb').read())[0],
        '0.1': pickle.loads(open('data/paper/hist_starts_4.pickle', 'rb').read())[1],
    }
}

In [7]:
fig = SuperFigure(figsize=(__WIDTH__, __HEIGHT__), dpi=300)
ax_canvas = fig.make_axes(
    left=0, right=1, top=0, bottom=1, zorder=-100,
    style='blank' if __PRODUCTION__ else None
)
ax_canvas.set_xlim([0, 1])
ax_canvas.set_ylim([0, 1])

# draw the axes grid
x0 = 0.05
dx = 0.2275
dx2 = 0.04
w = 0.215
axs = [
    fig.make_axes(
        left=x0 + dx * i + dx2 * (i // 2),
        width=w,
        top=0.19,
        width_to_height=1.0,
        style='modern'
    ) for i in range(4)
]

legend_style = dict(
    loc='upper center',
    fontsize=6,
    frameon=False,
    ncol=4,
    bbox_to_anchor=(0.5, 0.135),
    # markerscale=1.0 if __PRODUCTION__ else 10.0,
    markerscale=1.0,
    numpoints=2,
    handlelength=2.5,
    labelspacing=0.75,
    columnspacing=1.5
)
tick_style = dict(
    fontsize=7,
)
xlabel_style = dict(
    fontsize=8,
    labelpad=4.5,
)
ylabel_style = dict(
    fontsize=8,
    labelpad=-1.5,
)
solid_style = dict(
    lw=0.75,
)
dash_style = dict(
    lw=1.0,
    ls=(2, (0.75, 2)),
    dash_capstyle='round',
)
marker_basestyle = dict(
    # s=0.25 * (10 if __PRODUCTION__ else 1)
    s=4.0,
    facecolor='w',
    linewidths=0.5,
)
plot_style = OrderedDict(
    Vanilla =dict(color='#7458AF', zorder=9, **solid_style),
    RMSprop =dict(color='#FEB326', zorder=7, **solid_style),
    Adam    =dict(color='#E84D8A', zorder=6, **solid_style),
    AdamW   =dict(color='#E84D8A', zorder=5, **dash_style),
    Adadelta=dict(color='#00aaec', zorder=4, **dash_style),
    Adagrad =dict(color='#00aaec', zorder=3, **solid_style),
)
scatter_style = OrderedDict(
    Vanilla =dict(marker='D', zorder=19, **marker_basestyle),
    RMSprop =dict(marker='^', zorder=17, **marker_basestyle),
    Adam    =dict(marker='v', zorder=16, **marker_basestyle),
    AdamW   =dict(marker='P', zorder=15, **marker_basestyle),
    Adadelta=dict(marker='s', zorder=14, **marker_basestyle),
    Adagrad =dict(marker='o', zorder=13, **marker_basestyle),
)

def lcg(i):
    return (i * 3 + 5) % 7

for i, (ax, data) in enumerate(zip(axs[:2], [even_parameterized, over_parameterized])):
    for j, (alg, (x, y, _)) in enumerate(data.items()):
        x = np.array(x)
        best_y = np.minimum.accumulate(y)
        idx = (
            np.logspace(1, np.log10(len(x)), 10) * 0.9**lcg(j)
        ).astype(np.int32)
        ax.plot(
            x, best_y,
            **plot_style[alg],
            label=alg
        )
        ax.scatter(
            x[idx], best_y[idx],
            color=plot_style[alg]['color'],
            **scatter_style[alg],
            label=alg
        )

    ax.set_xscale('log') 
    ax.set_yscale('log') 

    yticks = np.linspace(-15, 0, 4) # generates -14, -7, 0
    xticks = np.linspace(0, 4, 5)
    ax.set_xbound(lower=1, upper=10**xticks.max())
    ax.set_ybound(lower=10**yticks.min(), upper=10**yticks.max())    
    ax.set_yticks(10**yticks)
    ax.set_xlabel('Iterations', **xlabel_style)
    if i == 0:
        ax.set_xticks(10**xticks[:-1])
        ax.set_xticklabels([fr'${tex.pow(10, int(t))}$' for t in xticks[:-1]], **tick_style)
        ax.set_ylabel('Loss', **ylabel_style)
        ax.set_yticklabels(
            [fr'${tex.pow(10, t)}$' for t in yticks.astype(np.int64)],
            **tick_style
        )
    else:
        ax.set_xticks(10**xticks)
        ax.set_xticklabels([fr'${tex.pow(10, int(t))}$' for t in xticks], **tick_style)
        ax.set_yticklabels([])
        
ax_alg = fig.make_axes(
    left=axs[0].left,
    right=axs[1].right,
    top=0,
    bottom=1,
    style='blank'
)

for alg, _ in data.items():
    style = dict(**plot_style[alg])
    style.update(**scatter_style[alg])
    style['markersize'] = style.pop('s')
    style['markerfacecolor'] = style.pop('facecolor')
    style['markeredgewidth'] = style.pop('linewidths')
    ax_alg.plot(
        [-1], [-1],
        **style,
        label=alg
    )

ax_alg.legend(
    **legend_style
)

'''
Histogram
'''

bar_style = dict(
    alpha=0.8,
)
hist_style = {
    '1.0': dict(color='#FF9070', **bar_style),
    '0.1': dict(color='#4B6480', **bar_style),
}
bins = np.linspace(0, 0.15, 50)
nmax = 0
for i, (ax, data) in enumerate(zip(axs[2:], hist.values())):
    for init_multiplier, best_losses in data.items():
        n, _, _ = ax.hist(best_losses.numpy(), bins, **hist_style[init_multiplier])
        nmax = max(nmax, n.max())

for i, ax in enumerate(axs[2:]):
    xticks = np.linspace(0, 0.15, 4)
    yticks = np.linspace(0, nmax, 4)
    ax.set_xbound(lower=0, upper=xticks.max())
    ax.set_ybound(lower=0, upper=yticks.max())
    ax.set_yticks(yticks)
    ax.set_xlabel('Final Loss', **xlabel_style)
    if i == 0:
        ax.set_xticks(xticks[:-1])
        ax.set_xticklabels([f'{x:g}' for x in xticks[:-1]], **tick_style)
        ax.set_yticklabels([f'{y:.0f}' for y in yticks], **tick_style)
        ax.set_ylabel('Count', **ylabel_style)
    else:
        ax.set_xticks(xticks)
        ax.set_xticklabels([f'{x:g}' for x in xticks], **tick_style)
        ax.set_yticklabels([])

ax_hist = fig.make_axes(
    left=axs[2].left,
    right=axs[3].right,
    top=0,
    bottom=1,
    style='blank'
)

for init_multiplier, best_losses in data.items():
    ax_hist.bar(
        [-1], [0],
        **hist_style[init_multiplier],
        label=fr'$\mathcal{{N}}$(0, {init_multiplier})'
    )

ax_hist.legend(
    **legend_style
)
    
# titles, etc.

title_style = dict(
     fontsize=8,
     y=0.975,
     va='bottom'
)

axs[0].set_title(r'\textbf{Exactly Parameterized}', **title_style)
axs[2].set_title(r'\textbf{Exactly Parameterized}', **title_style)
axs[1].set_title(r'\textbf{Over-Parameterized}', **title_style)
axs[3].set_title(r'\textbf{Over-Parameterized}', **title_style)
panel_number_style = dict(
    fontsize=8,
    va='top',
    ha='right'
)
for i, ax in enumerate(axs):
    ax.text(
        0.975,
        0.975,
        s=fr'\textbf{{({chr(65 + i)})}}',
        **panel_number_style,
        transform=ax.transAxes,
    )

uppertitle_style = dict(
    fontsize=8,
    va='bottom',
    ha='center'
)
ax_alg.text(
    0.5, 0.94,
    r'\textbf{Optimization Algorithm}',
    **uppertitle_style
)
ax_hist.text(
    0.5, 0.94,
    r'\textbf{Effect of Initialization Distribution}',
    **uppertitle_style
)
for ax in [ax_alg, ax_hist]:
    ax.axhline(
        0.91,
        lw=0.5,
        color='k'
    )

if __PRODUCTION__:
   fig.savefig(f'pgf/{__NAME__}.pgf', dpi=300)
else:
   fig.savefig(f'svg/{__NAME__}.svg', dpi=300)
# plt.show()

In [8]:
!make -f Makefile.figures fig-"$__NAME__".pdf 2>&1 | tail -n 1

Successfully created fig-optimization.pdf


END
---