In [1]:
import pandas as pd
import numpy as np

data = pd.read_csv("small degs large widths.csv")


# Get the substring in quotes
def get_substring_in_brackets(s):
    start = s.find('(') + 1
    end = s.rfind(')')
    return float(s[start:end])

data["val_loss"] = data["val_loss"].map(get_substring_in_brackets)
data["train_loss"] = data["train_loss"].map(get_substring_in_brackets)

print(data.columns)

Index(['deg', 'width', 'func', 'epoch', 'train_loss', 'val_loss', 'batch_size',
       'lr', 'n_samples', 'func_val_test', 'time_elapsed', 'backend',
       'top_eig', 'trace', 'stop_loss', 'ln_eps', 'ln', 'weight_norm', 'l',
       'd', 'f', 'h', 'dropout', 'wd'],
      dtype='object')


In [2]:
filtered_data = data.loc[data.groupby(["deg", "width", "func"])["epoch"].idxmax()]
print(filtered_data.shape)


(100, 24)


In [None]:
import math
from math import sqrt

def L(omega, D_f, T):
    bound = 21 + 16*(D_f+1)*D_f**2 + 8 * (D_f+1) + 4 * math.log(T)**2 * (D_f*omega + (T+1-omega))
    return sqrt(bound)

def G_p(sigma, omega, D_f, T, d):
    bound = 128 * sqrt(2) * sigma * D_f**2 * sqrt(omega) * math.log(T)
    bound += 192 * sigma * sqrt(omega) * D_f**2
    bound += 128 * sqrt(2) * D_f**2 * omega**2 * math.log(T)
    bound += 64 * sqrt(2) * d * D_f**2 * omega**2 * math.log(T)
    bound += 128 * sigma * d * omega**2 * math.log(T) * D_f * sqrt(D_f)
    bound += 1024 * sigma * D_f**3 * sqrt(D_f) * omega * sqrt(omega) * math.log(T)
    bound += 8 * sigma * sqrt(D_f) * sqrt(omega) * math.log(T) * d 
    bound += 256 * sigma * D_f**2 * sqrt(D_f) * omega**2 * d
    bound += 32 * sigma * sqrt(D_f) * omega**2 * d ** 2
    return bound

def H_u(omega, D_f, T, d):
    bound = 8*(D_f+1) + 6*(D_f+1)*math.sqrt(omega)
    bound += 2*math.sqrt(3*T*(D_f+1)) + math.sqrt(48*T*(D_f+1)*D_f)
    bound += math.sqrt(T)*math.sqrt(12*(D_f+1))*D_f
    bound += 2 * math.sqrt(T) * math.sqrt(12*(D_f+1)) * D_f
    bound += 8*math.sqrt(2)*math.sqrt(d*T)*D_f
    bound += 2*math.sqrt(d*T)*D_f
    bound += 8*math.sqrt(2)*d*D_f*math.sqrt(omega)
    bound += 6 * math.sqrt(2)*d*math.sqrt(d)*D_f
    return bound

def H_p(sigma, omega, D_f, T, d):
    bound = 768 * sqrt(2) * sigma * D_f**2 * omega * sqrt(omega) * d * math.log(T)
    bound += 256 * sigma * omega * sqrt(2) * d**2 * D_f * sqrt(D_f) * math.log(T)
    bound += 1024 * sigma * D_f**3 * sqrt(D_f) * omega * sqrt(omega) * math.log(T)
    bound += 512 * sigma * D_f**2 * sqrt(D_f) * omega * sqrt(omega) * d**2 * math.log(T)
    bound += 1536 * sigma * D_f**3 * sqrt(D_f) * omega**2 * sqrt(omega) * sqrt(d) * math.log(T)**2
    bound += 768 * sigma * D_f**2 * omega**2 * sqrt(omega) * d * math.log(T)**2
    bound += 16 * sigma * omega**2 * d**2 * sqrt(d) * math.log(T)
    return bound

def G_u(omega, D_f, d):
    bound = 4 + 4*omega*(2 + D_f + 32*D_f**2 + 32*D_f**3)
    return bound

def T_p(sigma, omega, D_f, T, d):
    bound = 128 * sigma * D_f ** 2 * omega * math.log(T)
    bound += sigma * math.sqrt(2 * d) + 32 * math.sqrt(omega) * sigma * D_f ** 2
    return bound

def Theta(D_f, d):
    n = 3*(d+1)**2 + (d+1) + (8*d+6)*(D_f+1)
    return n**2

def P(sigma, omega, D_f, T, d): 
    pert_bound = 2*G_p(sigma, omega, D_f, T, d)*(2*math.sqrt(G_u(omega, D_f, d)) + G_p(sigma, omega, D_f, T, d))
    pert_bound += T_p(sigma, omega, D_f, T, d)*(H_u(omega, D_f, T, d) + H_p(sigma, omega, D_f, T, d)) * Theta(D_f, d)
    return pert_bound

def perturbed_sharpness_term(sigma, omega, D_f, T, d):
    trace_bound = G_u(omega, D_f, d) + P(sigma, omega, D_f, T, d)
    return sigma**2 * trace_bound

def parameter_norm_term(Sigma, m, omega, D_f, T, sigma, delta, d):
    param_bound = 2 * sqrt( Sigma**2 / (2*m) * (L(omega, D_f, T)/(2*sigma**2) 
                                                + math.log(1/delta)))
    return param_bound

def theoretical_gen_gap(sigma, omega, D_f, T, Sigma, m, delta, d):
    sharpness_term = perturbed_sharpness_term(sigma, omega, D_f, T, d)
    param_norm_term = parameter_norm_term(Sigma, m, omega, D_f, T, sigma, delta, d)
    return sharpness_term + param_norm_term


actual_gen_gap = filtered_data["actual_gen_gap"] = filtered_data["val_loss"] - filtered_data["train_loss"]

# theo_gen_gap = filtered_data["theo_gen_gap"] = theoretical_gen_gap(sigma, omega, D_f, T, Sigma, m, delta)


In [13]:
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots

# fig = make_subplots(rows=1, cols=2, subplot_titles=("Theoretical Generalization Gap", "Actual Generalization Gap"))
m = 10000
T = 20
Sigma = 0.001
delta = 0.2


sigmas = [ 0.005, 0.01, 0.05,  0.1]
pos = [(1, 1), (1, 2), (2, 1), (2, 2)]
fig = make_subplots(
    rows=2, cols=3,
    specs=[
        [{"type": "xy"}, {"type": "xy"}, {"rowspan": 2, "colspan": 1}],  # First row
        [{"type": "xy"}, {"type": "xy"}, None]                           # Second row
    ],
    column_widths=[0.25, 0.25, 0.5],  # Left half = 0.25+0.25, Right = 0.5
    row_heights=[0.5, 0.5],
    horizontal_spacing=0.10,
    vertical_spacing=0.10,
    subplot_titles=(f"σ = {sigmas[0]}", f"σ = {sigmas[1]}", "Empirical Generalization Gap", f"σ = {sigmas[2]}", f"σ = {sigmas[3]}")
)


# Get unique widths
unique_widths = filtered_data["width"].unique()
colors = ["red", "brown", "green", "purple"]

T = 20
D_f = T + 2
m = 10000
delta = 0.2
Sigma = 0.001
eps = 0.01
d = 8*math.log(T)/eps**2

for i, sigma in enumerate(sigmas): 
    # Plot theo_gen_gap for each width
    filtered_data["theo_gen_gap"] = filtered_data.apply(
    lambda row: theoretical_gen_gap(
        sigma=sigma,
        omega=row["width"],
        D_f=row["deg"],
        T=T,
        Sigma=Sigma,
        m=m,
        delta=delta,
        d=d
    ),
    axis=1)
    for k, width in enumerate(unique_widths):
        subset = filtered_data[filtered_data["width"] == width]
        fig.add_trace(
            go.Scatter(x=subset["deg"], y=subset["theo_gen_gap"], mode='lines+markers', line=dict(color=colors[k]), showlegend=False),
            row=pos[i][0], col=pos[i][1]
        )

# Plot actual_gen_gap for each width, averaging over func
for k, width in enumerate(unique_widths):
    subset = filtered_data[filtered_data["width"] == width]
    avg_actual_gen_gap = subset.groupby("deg")["actual_gen_gap"].mean()
    fig.add_trace(
        go.Scatter(x=avg_actual_gen_gap.index, y=avg_actual_gen_gap.values, mode='lines+markers', line=dict(color=colors[k]), name=f"Width={width}"),
        row=1, col=3
    )

fig.update_layout(
    title="Comparison of Empirical and Theoretical Generalization Gap by Degree, Width",
    height=750,
    width=1200,
    margin=dict(t=50, b=50, l=50, r=50)
)
fig.update_xaxes(title_text="Degree", row=1, col=3)
fig.update_xaxes(title_text="Degree", row=2, col=1)
fig.update_xaxes(title_text="Degree", row=2, col=2)
fig.update_yaxes(title_text="Generalzation Gap", row=1, col=1)
fig.update_yaxes(title_text="Generalzation Gap", row=1, col=3)
fig.update_yaxes(title_text="Generalzation Gap", row=2, col=1)
fig.show()
fig.write_image("plots/unpretrubed vary over sigma.png", width=1200, height=750)

