# Plot MLP results

In [1]:
%%capture
!pip install -U kaleido
!pip install plotly==5.24.1

In [2]:
import math
import numpy as np
import plotly.graph_objs as go
import plotly.colors as pc

In [3]:
def plot_accuracies_over_ts_and_dts(accuracies, max_ts, save_path):
    n_dts = len(accuracies.keys())
    n_ts = len(accuracies[0.5])
    xtickvals = [t for t in range(n_ts+1)]

    fig = go.Figure()
    colors = pc.sample_colorscale("Oranges", n_dts+2)[::-1]
    for i, (dt, max_avg_acc) in enumerate(accuracies.items()):
        fig.add_traces(
            go.Scatter(
                y=max_avg_acc,
                name=f"$dt = {{{dt}}}$",
                mode="lines+markers",
                line=dict(width=2, color=colors[i])
            )
        )

    fig.update_layout(
        height=350,
        width=550,
        xaxis=dict(
            title="Max T",
            tickvals=xtickvals,
            ticktext=max_ts,
        ),
        yaxis=dict(
            title="Max mean accuracy (%)",
            nticks=5
        ),
        font=dict(size=16),
        margin=dict(r=120)
    )
    fig.write_image(save_path)


def plot_accuracies_over_optims(accuracies, save_path, test_every=100):
    n_train_iters = len(accuracies["Heun"])
    train_iters = [t+1 for t in range(n_train_iters)]

    fig = go.Figure()
    for i, (optim_id, accuracy) in enumerate(accuracies.items()):
        if optim_id == "Euler":
            color = "#636EAF"
        elif optim_id == "Heun":
            color = "#EF553B"
            
        means = accuracy.mean(axis=-1)
        stds = accuracy.std(axis=-1)
        y_upper, y_lower = means + stds, means - stds
        
        fig.add_traces(
            go.Scatter(
                x=list(train_iters) + list(train_iters[::-1]),
                y=list(y_upper) + list(y_lower[::-1]),
                fill="toself",
                fillcolor=color,
                line=dict(color="rgba(255,255,255,0)"),
                hoverinfo="skip",
                showlegend=False,
                opacity=0.3
            )
        )
        fig.add_trace(
            go.Scatter(
                x=train_iters,
                y=means,
                mode="lines+markers",
                name=optim_id,
                line=dict(width=2, color=color),
                legendrank=1 if optim_id == "Heun" else 0
            )
        )

    tickvals = [1, int(train_iters[-1]/2)+1, train_iters[-1]]
    fig.update_layout(
        height=300,
        width=400,
        xaxis=dict(
            title="Training iteration",
            tickvals=tickvals,
            ticktext=[tickval*test_every for tickval in tickvals]
        ),
        yaxis=dict(title="Test accuracy (%)"),
        font=dict(size=16)
    )
    fig.write_image(save_path)


def plot_runtimes_over_optims(runtimes, save_path):
    n_train_iters = len(runtimes["Heun"])
    train_iters = [t+1 for t in range(n_train_iters)]

    fig = go.Figure()
    for i, (optim_id, runtime) in enumerate(runtimes.items()):
        if optim_id == "Euler":
            color = "#636EAF"
        elif optim_id == "Heun":
            color = "#EF553B"
            
        means = runtime.mean(axis=-1)
        stds = runtime.std(axis=-1)
        y_upper, y_lower = means + stds, means - stds
        
        fig.add_traces(
            go.Scatter(
                x=list(train_iters) + list(train_iters[::-1]),
                y=list(y_upper) + list(y_lower[::-1]),
                fill="toself",
                fillcolor=color,
                line=dict(color="rgba(255,255,255,0)"),
                hoverinfo="skip",
                showlegend=False,
                opacity=0.3
            )
        )
        fig.add_trace(
            go.Scatter(
                x=train_iters,
                y=means,
                mode="lines",
                name=optim_id,
                line=dict(width=2, color=color),
                legendrank=1 if optim_id == "Heun" else 0
            )
        )

    tickvals = [1, int(train_iters[-1]/2), train_iters[-1]]
    fig.update_layout(
        height=300,
        width=400,
        xaxis=dict(
            title="Training iteration",
            tickvals=tickvals,
            ticktext=[tickval+1 for tickval in tickvals]
        ),
        yaxis=dict(title="Runtime (ms)"),
        font=dict(size=16)
    )
    fig.write_image(save_path)


In [4]:
### plot max avg test acc as a function of T and dt, for each optim ###
DATASETS = ["MNIST", "Fashion-MNIST", "CIFAR10"]
N_HIDDENS = [3, 5, 10]
ACTIVITY_OPTIMS_ID = ["Euler", "Heun"]  # "SGD" - alternative (slower) Euler implementation

MAX_T1S = [5, 10, 20, 50, 100, 200, 500]
ACTIVITY_LRS = [5e-1, 1e-1, 5e-2]

N_SEEDS = 3

for dataset in DATASETS:
    for n_hidden in N_HIDDENS:
        for activity_optim_id in ACTIVITY_OPTIMS_ID:
            max_test_accs = {}
            max_t1s = MAX_T1S[:-1] if activity_optim_id == "Euler" else MAX_T1S
            
            for activity_lr in ACTIVITY_LRS:
                max_test_accs[activity_lr] = []
                for max_t1 in max_t1s:
                    avg_test_acc = 0.
                    for seed in range(N_SEEDS):
                        test_acc = np.load(
                            f"mlp_results/{dataset}/width_300/{n_hidden}_n_hidden/tanh/max_t1_{max_t1}/activity_lr_{activity_lr}/param_lr_0.001/{activity_optim_id}/{seed}/test_accs.npy"
                        )
                        avg_test_acc += test_acc
        
                    avg_test_acc /= N_SEEDS
                    max_test_accs[activity_lr].append(max(avg_test_acc))
            
            plot_accuracies_over_ts_and_dts(
                accuracies=max_test_accs,
                max_ts=max_t1s,
                save_path=f"mlp_results/{dataset}/width_300/{n_hidden}_n_hidden/tanh/max_avg_test_acc_{activity_optim_id}.pdf"
            )


In [6]:
### plot inference runtimes for best accuracies of different optimisers ###
BATCH_SIZE = 64
TEST_EVERY = 100
n_train_mnist = 60000
n_train_cifar = 50000

for dataset in DATASETS:
    print(f"\n{dataset}")
    for n_hidden in N_HIDDENS:
        print(f"\nH = {n_hidden}")
        test_accs, inference_runtimes = {}, {}
        for activity_optim_id in ACTIVITY_OPTIMS_ID:
            
            if n_hidden == 3:
                best_t = 20
                if dataset == "MNIST":
                    best_lr = 0.05 if activity_optim_id == "Heun" else 0.5
                elif dataset == "Fashion-MNIST":
                    best_lr = 0.1 if activity_optim_id == "Heun" else 0.5
                else:
                    best_t = 50
                    best_lr = 0.05
            
            if n_hidden == 5:
                if dataset == "MNIST":
                    best_t = 50
                    best_lr = 0.05 if activity_optim_id == "Heun" else 0.5
                elif dataset == "Fashion-MNIST":
                    best_t = 200
                    best_lr = 0.5
                else:
                    best_t = 200 if activity_optim_id == "Euler" else 500
                    best_lr = 0.05 if activity_optim_id == "Euler" else 0.5

            if n_hidden == 10:
                best_t = 200
                if dataset == "MNIST":
                    best_lr = 0.05
                elif dataset == "Fashion-MNIST":
                    best_lr = 0.05
                else:
                    best_t = 200 if activity_optim_id == "Euler" else 500
                    best_lr = 0.1

            n_train_iters = math.floor((n_train_cifar if dataset == "CIFAR10" else n_train_mnist)/BATCH_SIZE)-1
            n_tests = math.floor(n_train_iters/TEST_EVERY)
            
            test_accs[activity_optim_id] = np.zeros((n_tests, N_SEEDS))
            inference_runtimes[activity_optim_id] = np.zeros((n_train_iters, N_SEEDS))
            for seed in range(N_SEEDS):
                test_acc = np.load(
                    f"mlp_results/{dataset}/width_300/{n_hidden}_n_hidden/tanh/max_t1_{best_t}/activity_lr_{best_lr}/param_lr_0.001/{activity_optim_id}/{seed}/test_accs.npy"
                )
                inference_runtime = np.load(
                    f"mlp_results/{dataset}/width_300/{n_hidden}_n_hidden/tanh/max_t1_{best_t}/activity_lr_{best_lr}/param_lr_0.001/{activity_optim_id}/{seed}/inference_runtimes.npy"
                )
                test_accs[activity_optim_id][:, seed] = test_acc
                # skip first point for jit compilation
                inference_runtimes[activity_optim_id][:, seed] = inference_runtime[1:]

            best_avg_runtime = inference_runtimes[activity_optim_id].mean()
            print(f"\t{activity_optim_id}, best avg runtime (ms): {best_avg_runtime}")

        plot_accuracies_over_optims(
            test_accs, 
            f"mlp_results/{dataset}/width_300/{n_hidden}_n_hidden/tanh/best_mean_test_accs_Euler_vs_Heun.pdf"
        )
        plot_runtimes_over_optims(
            inference_runtimes, 
            f"mlp_results/{dataset}/width_300/{n_hidden}_n_hidden/tanh/best_mean_infer_runtimes_Euler_vs_Heun.pdf"
        )


MNIST

H = 3
	Euler, best avg runtime (ms): 4.516731215338422
	Heun, best avg runtime (ms): 2.5090177853902182

H = 5
	Euler, best avg runtime (ms): 12.866245460646105
	Heun, best avg runtime (ms): 4.310250027566894

H = 10
	Euler, best avg runtime (ms): 710.075849226737
	Heun, best avg runtime (ms): 49.049995478741465

Fashion-MNIST

H = 3
	Euler, best avg runtime (ms): 4.40254748037398
	Heun, best avg runtime (ms): 2.660179579699481

H = 5
	Euler, best avg runtime (ms): 39.98467971456696
	Heun, best avg runtime (ms): 16.335261003923552

H = 10
	Euler, best avg runtime (ms): 707.1279722061591
	Heun, best avg runtime (ms): 58.709163611431066

CIFAR10

H = 3
	Euler, best avg runtime (ms): 78.07821198406383
	Heun, best avg runtime (ms): 3.3708108796013727

H = 5
	Euler, best avg runtime (ms): 430.4416865365118
	Heun, best avg runtime (ms): 25.713603313152607

H = 10
	Euler, best avg runtime (ms): 365.9962234334049
	Heun, best avg runtime (ms): 195.4227184637999
