# Plot MLP results

In [1]:
%%capture
!pip install -U kaleido
!pip install plotly==5.24.1

In [2]:
import math
import numpy as np
import plotly.graph_objs as go
import plotly.colors as pc

In [3]:
def plot_accuracies_over_ts_and_lrs(accuracies, max_ts, save_path):
    n_lrs = len(accuracies.keys())
    n_ts = len(accuracies[0.5])
    xtickvals = [t for t in range(n_ts+1)]

    fig = go.Figure()
    colors = pc.sample_colorscale("Oranges", n_lrs+2)[::-1]
    for i, (lr, max_avg_acc) in enumerate(accuracies.items()):
        fig.add_traces(
            go.Scatter(
                y=max_avg_acc,
                name=f"$lr = {{{lr}}}$",
                mode="lines+markers",
                line=dict(width=2, color=colors[i])
            )
        )

    fig.update_layout(
        height=350,
        width=550,
        xaxis=dict(
            title="Max T",
            tickvals=xtickvals,
            ticktext=max_ts,
        ),
        yaxis=dict(
            title="Max mean accuracy (%)",
            nticks=5
        ),
        font=dict(size=16),
        margin=dict(r=120)
    )
    fig.write_image(save_path)


def plot_accuracies_over_optims(accuracies, save_path, test_every=100):
    n_train_iters = len(accuracies["Euler"])
    train_iters = [t+1 for t in range(n_train_iters)]

    colors = ["#636EAF", "#EF553B", "#00CC96"]
    fig = go.Figure()
    for i, (optim_id, accuracy) in enumerate(accuracies.items()):
        means = accuracy.mean(axis=-1)
        stds = accuracy.std(axis=-1)
        y_upper, y_lower = means + stds, means - stds
        
        fig.add_traces(
            go.Scatter(
                x=list(train_iters) + list(train_iters[::-1]),
                y=list(y_upper) + list(y_lower[::-1]),
                fill="toself",
                fillcolor=colors[i],
                line=dict(color="rgba(255,255,255,0)"),
                hoverinfo="skip",
                showlegend=False,
                opacity=0.3
            )
        )
        fig.add_trace(
            go.Scatter(
                x=train_iters,
                y=means,
                mode="lines+markers",
                name=optim_id if optim_id != "SGD" else "GD",
                line=dict(width=2, color=colors[i])
            )
        )

    tickvals = [1, int(train_iters[-1]/2)+1, train_iters[-1]]
    fig.update_layout(
        height=300,
        width=400,
        xaxis=dict(
            title="Training iteration",
            tickvals=tickvals,
            ticktext=[tickval*test_every for tickval in tickvals]
        ),
        yaxis=dict(title="Test accuracy (%)"),
        font=dict(size=16)
    )
    fig.write_image(save_path)


def plot_runtimes_over_optims(runtimes, save_path):
    n_train_iters = len(runtimes["Euler"])
    train_iters = [t+1 for t in range(n_train_iters)]

    colors = ["#636EAF", "#EF553B", "#00CC96"]
    fig = go.Figure()
    for i, (optim_id, runtime) in enumerate(runtimes.items()):
        means = runtime.mean(axis=-1)
        stds = runtime.std(axis=-1)
        y_upper, y_lower = means + stds, means - stds
        
        fig.add_traces(
            go.Scatter(
                x=list(train_iters) + list(train_iters[::-1]),
                y=list(y_upper) + list(y_lower[::-1]),
                fill="toself",
                fillcolor=colors[i],
                line=dict(color="rgba(255,255,255,0)"),
                hoverinfo="skip",
                showlegend=False,
                opacity=0.3
            )
        )
        fig.add_trace(
            go.Scatter(
                x=train_iters,
                y=means,
                mode="lines",
                name=optim_id if optim_id != "SGD" else "GD",
                line=dict(width=2, color=colors[i])
            )
        )

    tickvals = [1, int(train_iters[-1]/2), train_iters[-1]]
    fig.update_layout(
        height=300,
        width=400,
        xaxis=dict(
            title="Training iteration",
            tickvals=tickvals,
            ticktext=[tickval+1 for tickval in tickvals]
        ),
        yaxis=dict(title="Runtime (ms)"),
        font=dict(size=16)
    )
    fig.write_image(save_path)


In [None]:
### plot max avg test acc as a function of T and lr, for each optim ###
DATASETS = ["MNIST", "Fashion-MNIST"]
N_HIDDENS = [3, 5]
ACTIVITY_OPTIMS_ID = ["Euler", "Heun", "SGD"]

MAX_T1S = [5, 10, 20, 50, 100, 200, 500]
ACTIVITY_LRS = [5e-1, 1e-1, 5e-2]

N_SEEDS = 3

for dataset in DATASETS:
    for n_hidden in N_HIDDENS:
        for activity_optim_id in ACTIVITY_OPTIMS_ID:
            max_test_accs = {}
            max_t1s = MAX_T1S[:-1] if activity_optim_id == "Euler" else MAX_T1S
            
            for activity_lr in ACTIVITY_LRS:
                max_test_accs[activity_lr] = []
                for max_t1 in max_t1s:
                    avg_test_acc = 0.
                    for seed in range(N_SEEDS):
                        test_acc = np.load(
                            f"mlp_results/{dataset}/width_300/{n_hidden}_n_hidden/tanh/max_t1_{max_t1}/activity_lr_{activity_lr}/param_lr_0.001/{activity_optim_id}/{seed}/test_accs.npy"
                        )
                        avg_test_acc += test_acc
        
                    avg_test_acc /= N_SEEDS
                    max_test_accs[activity_lr].append(max(avg_test_acc))
            
        
            plot_accuracies_over_ts_and_lrs(
                accuracies=max_test_accs,
                max_ts=max_t1s,
                save_path=f"mlp_results/{dataset}/width_300/{n_hidden}_n_hidden/tanh/max_avg_test_acc_{activity_optim_id}.pdf"
            )


In [45]:
### plot inference runtimes for best accuracies of different optimisers ###
BATCH_SIZE = 64
TEST_EVERY = 100

n_train_iters = math.floor(60000/BATCH_SIZE)-1
n_tests = math.floor(n_train_iters/TEST_EVERY)

for dataset in DATASETS:
    print(f"\n{dataset}")
    for n_hidden in N_HIDDENS:
        print(f"\nH = {n_hidden}")
        test_accs, inference_runtimes = {}, {}
        for activity_optim_id in ACTIVITY_OPTIMS_ID:
            
            if n_hidden == 3:
                best_t = 10 if activity_optim_id == "SGD" else 20
                if dataset == "MNIST":
                    best_lr = 0.05 if activity_optim_id == "Heun" else 0.5
                else:
                    best_lr = 0.1 if activity_optim_id == "Heun" else 0.5
            
            if n_hidden == 5:
                if dataset == "MNIST":
                    best_t = 50
                    best_lr = 0.05 if activity_optim_id == "Heun" else 0.5
                else:
                    best_t = 200
                    best_lr = 0.1 if activity_optim_id == "SGD" else 0.5
        
            test_accs[activity_optim_id] = np.zeros((n_tests, N_SEEDS))
            inference_runtimes[activity_optim_id] = np.zeros((n_train_iters, N_SEEDS))
            for seed in range(N_SEEDS):
                test_acc = np.load(
                    f"mlp_results/{dataset}/width_300/{n_hidden}_n_hidden/tanh/max_t1_{best_t}/activity_lr_{best_lr}/param_lr_0.001/{activity_optim_id}/{seed}/test_accs.npy"
                )
                inference_runtime = np.load(
                    f"mlp_results/{dataset}/width_300/{n_hidden}_n_hidden/tanh/max_t1_{best_t}/activity_lr_{best_lr}/param_lr_0.001/{activity_optim_id}/{seed}/inference_runtimes.npy"
                )
                test_accs[activity_optim_id][:, seed] = test_acc
                # skip first point for jit compilation
                inference_runtimes[activity_optim_id][:, seed] = inference_runtime[1:]

            best_avg_runtime = inference_runtimes[activity_optim_id].mean()
            print(f"\t{activity_optim_id}, best avg runtime (ms): {best_avg_runtime}")
    
        plot_accuracies_over_optims(test_accs, f"mlp_results/{dataset}/width_300/{n_hidden}_n_hidden/tanh/best_mean_test_accs.pdf")
        plot_runtimes_over_optims(inference_runtimes, f"mlp_results/{dataset}/width_300/{n_hidden}_n_hidden/tanh/best_mean_infer_runtimes.pdf")


MNIST

H = 3
	Euler, best avg runtime (ms): 4.516731215338422
	Heun, best avg runtime (ms): 2.5090177853902182
	SGD, best avg runtime (ms): 7.38429997721289

H = 5
	Euler, best avg runtime (ms): 12.866245460646105
	Heun, best avg runtime (ms): 4.310250027566894
	SGD, best avg runtime (ms): 38.22091179355937

Fashion-MNIST

H = 3
	Euler, best avg runtime (ms): 4.40254748037398
	Heun, best avg runtime (ms): 2.660179579699481
	SGD, best avg runtime (ms): 7.030999983138169

H = 5
	Euler, best avg runtime (ms): 39.98467971456696
	Heun, best avg runtime (ms): 16.335261003923552
	SGD, best avg runtime (ms): 154.62818852177372
