In [1]:
import pandas as pd
import numpy as np

data = pd.read_csv("small degs large widths.csv")


# Get the substring in quotes
def get_substring_in_brackets(s):
    start = s.find('(') + 1
    end = s.rfind(')')
    return float(s[start:end])

data["val_loss"] = data["val_loss"].map(get_substring_in_brackets)
data["train_loss"] = data["train_loss"].map(get_substring_in_brackets)

print(data.columns)

Index(['deg', 'width', 'func', 'epoch', 'train_loss', 'val_loss', 'batch_size',
       'lr', 'n_samples', 'func_val_test', 'time_elapsed', 'backend',
       'top_eig', 'trace', 'stop_loss', 'ln_eps', 'ln', 'weight_norm', 'l',
       'd', 'f', 'h', 'dropout', 'wd'],
      dtype='object')


In [2]:
filtered_data = data.loc[data.groupby(["deg", "width", "func"])["epoch"].idxmax()]
print(filtered_data.shape)


(100, 24)


In [3]:
import math
from math import sqrt

def L(omega, D_f, T):
    bound = 21 + 16*(D_f+1)*D_f**2 + 8 * (D_f+1) + 4 * math.log(T)**2 * (D_f*omega + (T+1-omega))
    return sqrt(bound)

def G_p(sigma, omega, D_f, T, d):
    bound = 128 * sqrt(2) * sigma * D_f**2 * sqrt(omega) * math.log(T)
    bound += 192 * sigma * sqrt(omega) * D_f**2
    bound += 128 * sqrt(2) * D_f**2 * omega**2 * math.log(T)
    bound += 64 * sqrt(2) * d * D_f**2 * omega**2 * math.log(T)
    bound += 128 * sigma * d * omega**2 * math.log(T) * D_f * sqrt(D_f)
    bound += 1024 * sigma * D_f**3 * sqrt(D_f) * omega * sqrt(omega) * math.log(T)
    bound += 8 * sigma * sqrt(D_f) * sqrt(omega) * math.log(T) * d 
    bound += 256 * sigma * D_f**2 * sqrt(D_f) * omega**2 * d
    bound += 32 * sigma * sqrt(D_f) * omega**2 * d ** 2
    return bound

def H_u(omega, D_f, T, d):
    bound = 8*(D_f+1) + 6*(D_f+1)*math.sqrt(omega)
    bound += 2*math.sqrt(3*T*(D_f+1)) + math.sqrt(48*T*(D_f+1)*D_f)
    bound += math.sqrt(T)*math.sqrt(12*(D_f+1))*D_f
    bound += 2 * math.sqrt(T) * math.sqrt(12*(D_f+1)) * D_f
    bound += 8*math.sqrt(2)*math.sqrt(d*T)*D_f
    bound += 2*math.sqrt(d*T)*D_f
    bound += 8*math.sqrt(2)*d*D_f*math.sqrt(omega)
    bound += 6 * math.sqrt(2)*d*math.sqrt(d)*D_f
    return bound

def H_p(sigma, omega, D_f, T, d):
    bound = 768 * sqrt(2) * sigma * D_f**2 * omega * sqrt(omega) * d * math.log(T)
    bound += 256 * sigma * omega * sqrt(omega) * d**2 * D_f * sqrt(D_f) * math.log(T)
    bound += 1024 * sigma * D_f**3 * sqrt(D_f) * omega * sqrt(omega) * math.log(T)
    bound += 512 * sigma * D_f**2 * sqrt(D_f) * omega * sqrt(omega) * d**2 * math.log(T)
    bound += 1536 * sigma * D_f**3 * sqrt(D_f) * omega**2 * sqrt(omega) * sqrt(d) * math.log(T)**2
    bound += 768 * sigma * D_f**2 * omega**2 * sqrt(omega) * d * math.log(T)**2
    bound += 16 * sigma * omega**2 * d**2 * sqrt(d) * math.log(T)
    return bound

def G_u(omega, D_f):
    bound = 4 + 4*omega*(2 + D_f + 32*D_f**2 + 32*D_f**3)
    return bound

def T_p(sigma, omega, D_f, T, d):
    bound = 128 * sigma * D_f ** 2 * omega * math.log(T)
    bound += sigma * math.sqrt(2 * d) + 32 * math.sqrt(omega) * sigma * D_f ** 2
    return bound

def Theta(D_f, d):
    n = 3*(d+1)**2 + (d+1) + (8*d+6)*(D_f+1)
    return n

def P(sigma, omega, D_f, T, d): 
    pert_bound = 2*G_p(sigma, omega, D_f, T, d)*(2*math.sqrt(G_u(omega, D_f)) + G_p(sigma, omega, D_f, T, d))
    pert_bound += T_p(sigma, omega, D_f, T, d)*(H_u(omega, D_f, T, d) + H_p(sigma, omega, D_f, T, d)) * Theta(D_f, d)
    return pert_bound

def perturbed_sharpness_term(sigma, omega, D_f, T, d):
    trace_bound = G_u(omega, D_f) #+ P(sigma, omega, D_f, T, d)
    return sigma**2 * trace_bound

def parameter_norm_term(Sigma, m, omega, D_f, T, sigma, delta, d):
    param_bound = 2 * sqrt( Sigma**2 / (2*m) * (L(omega, D_f, T)/(2*sigma**2) 
                                                + math.log(1/delta)))
    return param_bound

def theoretical_gen_gap(sigma, omega, D_f, T, Sigma, m, delta, d):
    sharpness_term = perturbed_sharpness_term(sigma, omega, D_f, T, d)
    param_norm_term = parameter_norm_term(Sigma, m, omega, D_f, T, sigma, delta, d)
    return sharpness_term + param_norm_term

actual_gen_gap = filtered_data["actual_gen_gap"] = filtered_data["val_loss"] - filtered_data["train_loss"]


In [4]:
import numpy as np

# Create a mesh over T, D_f, omega
Ts = np.logspace(1, 7, num=10)
D_fs = np.logspace(1, 7, num=10)
omegas = np.logspace(1, 7, num=10) 

Ts, D_fs, omegas = np.meshgrid(Ts, D_fs, omegas, indexing='ij')
mesh_points = np.stack([Ts.ravel(), D_fs.ravel(), omegas.ravel()], axis=-1)

# Filter out points with D_f > T or omega > T
mesh_points = mesh_points[(mesh_points[:, 0] >= mesh_points[:, 1]) & (mesh_points[:, 0] >= mesh_points[:, 2])]

print("Mesh points:", mesh_points.shape)


# Add in points for out empirical construction
emp_Ts = [20]
emp_D_fs = [1, 2, 3, 4, 5]
emp_omegas = [1, 7, 14, 20]
Ts, D_fs, omegas = np.meshgrid(emp_Ts, emp_D_fs, emp_omegas, indexing='ij')

emp_mesh_points = np.empty((20, 3))

i = 0
for T in emp_Ts:
    for D_f in emp_D_fs:
        for omega in emp_omegas:
            emp_mesh_points[i] = [T, D_f, omega]
            i += 1
            
mesh_points = np.concatenate([mesh_points, emp_mesh_points], axis=0)

print("Mesh and Empirical points:", mesh_points.shape)



Mesh points: (385, 3)
Mesh and Empirical points: (405, 3)


In [9]:
from scipy.optimize import minimize
from tqdm import tqdm

delta = 0.2
Sigma = 0.01

records = []
for row in tqdm(mesh_points): 
    for m in np.logspace(3, 15, num=12):
        T, D_f, omega = row
        eps = 0.01
        d = 8*math.log(T)/eps**2
        # Mimimize over sigma
        minimal_bound = minimize(theoretical_gen_gap, x0=0.001, args=(omega, D_f, T, Sigma, int(m), delta, d), bounds=[(-0.0000000001, None)])
        records.append({
            "T": T,
            "D_f": D_f,
            "omega": omega,
            "m": m,
            "d": d,
            "delta": delta,
            "Sigma": Sigma,
            "sigma": minimal_bound.x[0],
            "bound": minimal_bound.fun
        })
        if minimal_bound.fun <= 1:
            print(f"D_f: {D_f}, omega: {omega}, T: {T}, m: {m}, sigma: {minimal_bound.x[0]}, bound: {minimal_bound.fun}")
            break
sigma_data = pd.DataFrame(records)


  param_bound = 2 * sqrt( Sigma**2 / (2*m) * (L(omega, D_f, T)/(2*sigma**2)
  2%|▏         | 7/405 [00:00<00:05, 66.95it/s]

D_f: 10.0, omega: 10.0, T: 10.0, m: 151991.10829529332, sigma: 0.0004763362460485309, bound: 0.9587592144113549
D_f: 10.0, omega: 10.0, T: 46.41588833612777, m: 1873817.422860383, sigma: 0.00032032110651277217, bound: 0.4335685912386013
D_f: 10.0, omega: 46.41588833612777, T: 46.41588833612777, m: 1873817.422860383, sigma: 0.00020117595827219288, bound: 0.7938027822961009
D_f: 46.41588833612777, omega: 10.0, T: 46.41588833612777, m: 284803586.8435793, sigma: 4.333169659575204e-05, bound: 0.7366950802795392
D_f: 46.41588833612777, omega: 46.41588833612777, T: 46.41588833612777, m: 3511191734.2151275, sigma: 1.716914894102221e-05, bound: 0.5370210064806727
D_f: 10.0, omega: 10.0, T: 215.44346900318823, m: 1873817.422860383, sigma: 0.0003401712271763542, bound: 0.48896873723686163
D_f: 10.0, omega: 46.41588833612777, T: 215.44346900318823, m: 1873817.422860383, sigma: 0.00021327844352791607, bound: 0.8921806373596389
D_f: 10.0, omega: 215.44346900318823, T: 215.44346900318823, m: 23101297

  5%|▍         | 20/405 [00:00<00:08, 44.07it/s]

D_f: 10.0, omega: 10.0, T: 1000.0, m: 1873817.422860383, sigma: 0.00038380783358073407, bound: 0.6224584886559742
D_f: 10.0, omega: 46.41588833612777, T: 1000.0, m: 23101297.00083158, sigma: 0.000154493347696872, bound: 0.4681515955485982
D_f: 10.0, omega: 215.44346900318823, T: 1000.0, m: 23101297.00083158, sigma: 9.81497821776094e-05, bound: 0.877026628090685
D_f: 10.0, omega: 1000.0, T: 1000.0, m: 284803586.8435793, sigma: 4.278996641873885e-05, bound: 0.7738099386058479
D_f: 46.41588833612777, omega: 10.0, T: 1000.0, m: 284803586.8435793, sigma: 4.383641551353042e-05, bound: 0.7539604841213103
D_f: 46.41588833612777, omega: 46.41588833612777, T: 1000.0, m: 3511191734.2151275, sigma: 1.7508463232974053e-05, bound: 0.5584594527761766
D_f: 46.41588833612777, omega: 215.44346900318823, T: 1000.0, m: 43287612810.83061, sigma: 7.575503585951149e-06, bound: 0.4398118193794395
D_f: 46.41588833612777, omega: 1000.0, T: 1000.0, m: 43287612810.83061, sigma: 6.271592294493639e-06, bound: 0.950

  6%|▌         | 25/405 [00:00<00:10, 38.00it/s]

D_f: 215.44346900318823, omega: 10.0, T: 1000.0, m: 533669923120.6302, sigma: 6.708712304145987e-06, bound: 0.8087274253653298


  9%|▉         | 36/405 [00:00<00:09, 39.85it/s]

D_f: 10.0, omega: 10.0, T: 4641.588833612777, m: 1873817.422860383, sigma: 0.0004459261343526274, bound: 0.8402504664504173
D_f: 10.0, omega: 46.41588833612777, T: 4641.588833612777, m: 23101297.00083158, sigma: 0.00017685135599584655, bound: 0.6134514769114918
D_f: 10.0, omega: 215.44346900318823, T: 4641.588833612777, m: 284803586.8435793, sigma: 7.127888890542983e-05, bound: 0.4625812902701406
D_f: 10.0, omega: 1000.0, T: 4641.588833612777, m: 284803586.8435793, sigma: 4.538690685231954e-05, bound: 0.8706032742822989
D_f: 10.0, omega: 4641.588833612777, T: 4641.588833612777, m: 3511191734.2151275, sigma: 1.9820146029740654e-05, bound: 0.7705556807156186
D_f: 46.41588833612777, omega: 10.0, T: 4641.588833612777, m: 284803586.8435793, sigma: 4.562170651935817e-05, bound: 0.8166300661422993
D_f: 46.41588833612777, omega: 46.41588833612777, T: 4641.588833612777, m: 3511191734.2151275, sigma: 1.8210331650762195e-05, bound: 0.603805951966233
D_f: 46.41588833612777, omega: 215.443469003188

 11%|█         | 45/405 [00:01<00:10, 33.13it/s]

D_f: 215.44346900318823, omega: 10.0, T: 4641.588833612777, m: 533669923120.6302, sigma: 6.713174713738704e-06, bound: 0.8098185977590427


 16%|█▌        | 63/405 [00:01<00:09, 36.60it/s]

D_f: 10.0, omega: 10.0, T: 21544.346900318822, m: 23101297.00083158, sigma: 0.00034209197301364577, bound: 0.4945060794379633
D_f: 10.0, omega: 46.41588833612777, T: 21544.346900318822, m: 23101297.00083158, sigma: 0.00020533558134308206, bound: 0.8269506291483897
D_f: 10.0, omega: 215.44346900318823, T: 21544.346900318822, m: 284803586.8435793, sigma: 8.143711145825955e-05, bound: 0.603814279167926
D_f: 10.0, omega: 1000.0, T: 21544.346900318822, m: 3511191734.2151275, sigma: 3.282770944769277e-05, bound: 0.4554968622413915
D_f: 10.0, omega: 4641.588833612777, T: 21544.346900318822, m: 3511191734.2151275, sigma: 2.0910399291419964e-05, bound: 0.8578645029813132
D_f: 10.0, omega: 21544.346900318822, T: 21544.346900318822, m: 43287612810.83061, sigma: 9.924253888671685e-06, bound: 0.7650018754615753
D_f: 46.41588833612777, omega: 10.0, T: 21544.346900318822, m: 284803586.8435793, sigma: 5.047808934430021e-05, bound: 0.9997392519832622
D_f: 46.41588833612777, omega: 46.41588833612777, T:

 17%|█▋        | 67/405 [00:01<00:10, 33.55it/s]

D_f: 46.41588833612777, omega: 1000.0, T: 21544.346900318822, m: 533669923120.6302, sigma: 6.843622414857944e-06, bound: 0.7583056118465775
D_f: 215.44346900318823, omega: 10.0, T: 21544.346900318822, m: 533669923120.6302, sigma: 6.737592956039285e-06, bound: 0.8158072865859043


 24%|██▍       | 98/405 [00:02<00:07, 41.11it/s]

D_f: 10.0, omega: 10.0, T: 100000.0, m: 23101297.00083158, sigma: 0.00039799319444141086, bound: 0.6693218364825382
D_f: 10.0, omega: 46.41588833612777, T: 100000.0, m: 284803586.8435793, sigma: 0.00015701649751382083, bound: 0.4835674695787634
D_f: 10.0, omega: 215.44346900318823, T: 100000.0, m: 284803586.8435793, sigma: 9.424821912057726e-05, bound: 0.8086622550539952
D_f: 10.0, omega: 1000.0, T: 100000.0, m: 3511191734.2151275, sigma: 3.738168128712502e-05, bound: 0.5904717551039249
D_f: 10.0, omega: 4641.588833612777, T: 100000.0, m: 43287612810.83061, sigma: 1.5070078191192053e-05, bound: 0.44545948113168465
D_f: 10.0, omega: 21544.346900318822, T: 100000.0, m: 43287612810.83061, sigma: 9.605868029425977e-06, bound: 0.8390508333321631
D_f: 10.0, omega: 100000.0, T: 100000.0, m: 6579332246575.655, sigma: 6.062538596120095e-06, bound: 0.6152592101270308
D_f: 46.41588833612777, omega: 10.0, T: 100000.0, m: 3511191734.2151275, sigma: 3.815462378214068e-05, bound: 0.5710796028518399
D

 26%|██▋       | 107/405 [00:03<00:08, 35.51it/s]

D_f: 46.41588833612777, omega: 1000.0, T: 100000.0, m: 533669923120.6302, sigma: 7.131917538432019e-06, bound: 0.8458983869951011
D_f: 215.44346900318823, omega: 10.0, T: 100000.0, m: 533669923120.6302, sigma: 6.864240654975625e-06, bound: 0.847359555777182


 36%|███▌      | 145/405 [00:04<00:05, 48.46it/s]

D_f: 10.0, omega: 10.0, T: 464158.8833612772, m: 23101297.00083158, sigma: 0.00046180167315417874, bound: 0.9011427249819339
D_f: 10.0, omega: 46.41588833612777, T: 464158.8833612772, m: 284803586.8435793, sigma: 0.00018215191488083415, bound: 0.6507740038463028
D_f: 10.0, omega: 215.44346900318823, T: 464158.8833612772, m: 3511191734.2151275, sigma: 7.186102916613447e-05, bound: 0.4701674946123313
D_f: 10.0, omega: 1000.0, T: 464158.8833612772, m: 3511191734.2151275, sigma: 4.3132568227324656e-05, bound: 0.7862543468054317
D_f: 10.0, omega: 4641.588833612777, T: 464158.8833612772, m: 43287612810.83061, sigma: 1.7104607696266163e-05, bound: 0.5741116618693831
D_f: 10.0, omega: 21544.346900318822, T: 464158.8833612772, m: 533669923120.6302, sigma: 8.097463470338599e-06, bound: 0.4449315916760589
D_f: 10.0, omega: 100000.0, T: 464158.8833612772, m: 533669923120.6302, sigma: 5.947206591743408e-06, bound: 0.9000113635465237
D_f: 46.41588833612777, omega: 10.0, T: 464158.8833612772, m: 3511

 39%|███▊      | 156/405 [00:04<00:06, 38.14it/s]

D_f: 46.41588833612777, omega: 215.44346900318823, T: 464158.8833612772, m: 43287612810.83061, sigma: 1.0456802689086353e-05, bound: 0.9248980193818718
D_f: 46.41588833612777, omega: 1000.0, T: 464158.8833612772, m: 533669923120.6302, sigma: 7.091198710895166e-06, bound: 0.9213067957298493
D_f: 215.44346900318823, omega: 10.0, T: 464158.8833612772, m: 533669923120.6302, sigma: 7.2995925702223024e-06, bound: 0.9625273665687397


 52%|█████▏    | 209/405 [00:05<00:03, 58.32it/s]

D_f: 10.0, omega: 10.0, T: 2154434.6900318824, m: 284803586.8435793, sigma: 0.00035175067640219636, bound: 0.5228238629776824
D_f: 10.0, omega: 46.41588833612777, T: 2154434.6900318824, m: 284803586.8435793, sigma: 0.0002108700627800954, bound: 0.872144228591893
D_f: 10.0, omega: 215.44346900318823, T: 2154434.6900318824, m: 3511191734.2151275, sigma: 8.31732781588123e-05, bound: 0.6298326669720402
D_f: 10.0, omega: 1000.0, T: 2154434.6900318824, m: 43287612810.83061, sigma: 3.281116740356783e-05, bound: 0.4550379933579827
D_f: 10.0, omega: 4641.588833612777, T: 2154434.6900318824, m: 43287612810.83061, sigma: 1.9696768287439546e-05, bound: 0.7609535907216791
D_f: 10.0, omega: 21544.346900318822, T: 2154434.6900318824, m: 533669923120.6302, sigma: 8.083507601852087e-06, bound: 0.5562910499748674
D_f: 10.0, omega: 100000.0, T: 2154434.6900318824, m: 6579332246575.655, sigma: 6.598370175089539e-06, bound: 0.746627557473313


 53%|█████▎    | 215/405 [00:05<00:04, 45.02it/s]

D_f: 46.41588833612777, omega: 10.0, T: 2154434.6900318824, m: 43287612810.83061, sigma: 3.3625672072715565e-05, bound: 0.4436760491029952
D_f: 46.41588833612777, omega: 46.41588833612777, T: 2154434.6900318824, m: 43287612810.83061, sigma: 2.0160050926437194e-05, bound: 0.7401907982108867
D_f: 46.41588833612777, omega: 215.44346900318823, T: 2154434.6900318824, m: 533669923120.6302, sigma: 9.370409455475829e-06, bound: 0.5500244604419031
D_f: 46.41588833612777, omega: 1000.0, T: 2154434.6900318824, m: 533669923120.6302, sigma: 6.333626645213547e-06, bound: 0.9741575685279196


 56%|█████▌    | 226/405 [00:06<00:05, 34.70it/s]

D_f: 215.44346900318823, omega: 10.0, T: 2154434.6900318824, m: 533669923120.6302, sigma: 6.383825919255672e-06, bound: 0.9772552972751823


 70%|███████   | 285/405 [00:07<00:01, 78.32it/s]

D_f: 10.0, omega: 10.0, T: 10000000.0, m: 284803586.8435793, sigma: 0.00040647533397071735, bound: 0.6981554019423766
D_f: 10.0, omega: 46.41588833612777, T: 10000000.0, m: 3511191734.2151275, sigma: 0.00016031975145561832, bound: 0.5041270614267157
D_f: 10.0, omega: 215.44346900318823, T: 10000000.0, m: 3511191734.2151275, sigma: 9.61108417486033e-05, bound: 0.8409558060900111
D_f: 10.0, omega: 1000.0, T: 10000000.0, m: 43287612810.83061, sigma: 3.791131330858046e-05, bound: 0.6073095324005682
D_f: 10.0, omega: 4641.588833612777, T: 10000000.0, m: 533669923120.6302, sigma: 1.4957020774603647e-05, bound: 0.4387656099439767
D_f: 10.0, omega: 21544.346900318822, T: 10000000.0, m: 533669923120.6302, sigma: 9.832533290522541e-06, bound: 0.7400071071818077
D_f: 10.0, omega: 100000.0, T: 10000000.0, m: 6579332246575.655, sigma: 7.035704034132172e-06, bound: 0.8779851309290896


 74%|███████▍  | 301/405 [00:07<00:02, 44.81it/s]

D_f: 46.41588833612777, omega: 10.0, T: 10000000.0, m: 43287612810.83061, sigma: 3.885953992288249e-05, bound: 0.592380173029782
D_f: 46.41588833612777, omega: 46.41588833612777, T: 10000000.0, m: 43287612810.83061, sigma: 2.3291387728345074e-05, bound: 0.9881769506822269
D_f: 46.41588833612777, omega: 215.44346900318823, T: 10000000.0, m: 533669923120.6302, sigma: 9.957090168843822e-06, bound: 0.7183630508939467
D_f: 46.41588833612777, omega: 1000.0, T: 10000000.0, m: 6579332246575.655, sigma: 7.095013139466842e-06, bound: 0.8338789659996768


 76%|███████▌  | 307/405 [00:08<00:02, 36.36it/s]

D_f: 215.44346900318823, omega: 10.0, T: 10000000.0, m: 6579332246575.655, sigma: 7.116077091888926e-06, bound: 0.8267846387135822


100%|██████████| 405/405 [00:09<00:00, 42.40it/s]

D_f: 1.0, omega: 1.0, T: 20.0, m: 1000.0, sigma: 0.014601873067464924, bound: 0.17398608687800318
D_f: 1.0, omega: 7.0, T: 20.0, m: 1000.0, sigma: 0.007665609856237171, bound: 0.3314171613968262
D_f: 1.0, omega: 14.0, T: 20.0, m: 1000.0, sigma: 0.00608635900843935, bound: 0.4174109955923031
D_f: 1.0, omega: 20.0, T: 20.0, m: 1000.0, sigma: 0.005404678421307981, bound: 0.47005804122544925
D_f: 2.0, omega: 1.0, T: 20.0, m: 1000.0, sigma: 0.00831651491269298, bound: 0.3228613975844267
D_f: 2.0, omega: 7.0, T: 20.0, m: 1000.0, sigma: 0.004420323571650549, bound: 0.6370587350110812
D_f: 2.0, omega: 14.0, T: 20.0, m: 1000.0, sigma: 0.0035629073046912645, bound: 0.8276210406559128
D_f: 2.0, omega: 20.0, T: 20.0, m: 1000.0, sigma: 0.003199295710047311, bound: 0.9532554288000556
D_f: 3.0, omega: 1.0, T: 20.0, m: 1000.0, sigma: 0.005951544392213206, bound: 0.4922103647902317
D_f: 3.0, omega: 7.0, T: 20.0, m: 1000.0, sigma: 0.003179983679335681, bound: 0.9829187762917868
D_f: 3.0, omega: 14.0, T:




In [None]:
# Extract out unique rows
sigma_data = sigma_data.drop_duplicates(subset=["T", "D_f", "omega", "m", "d", "delta", "Sigma", "sigma"])
non_vacous = sigma_data[sigma_data["bound"] <= 1 & (sigma_data["sigma"] > 0)]

print(sigma_data.shape)
print(non_vacous.shape)
sigma_data.to_csv("optimized_sigma.csv", index=False)
non_vacous.to_csv(("non_vacous.csv"), index=False)

(4218, 9)
(110, 9)


In [13]:
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots

fig = make_subplots(
    rows=1, cols=2,
    specs=[
        [{"type": "xy"}, {"type": "xy"}]
    ],
    horizontal_spacing=0.10,
    vertical_spacing=0.10,
    subplot_titles=("Theoretical", "Empirical")
)

# Get unique widths
unique_widths = filtered_data["width"].unique()
colors = ["red", "brown", "green", "purple"]

T = 20
delta = 0.2
Sigma = 0.01
d = T + 2
m = 10000
# eps = 0.01
# d = 8*math.log(T)/eps**2


def fetch_optimal_sigma(row):
    # Fetch optimal sigma for each width
    return sigma_data[(sigma_data["T"] == T) & 
                      (sigma_data["omega"] == row["width"]) & 
                      (sigma_data["D_f"] == row["deg"])]["sigma"].values[-1]

# Fetch optimal theo_gen_gap for each width
filtered_data["theo_gen_gap"] = filtered_data.apply(
    lambda row: theoretical_gen_gap(
        sigma=fetch_optimal_sigma(row),
        omega=row["width"],
        D_f=row["deg"],
        T=T,
        Sigma=Sigma,
        m = 10000,
        delta=delta,
        d = T + 2
    ), axis=1)

# Plot theo_gen_gap for each width
for k, width in enumerate(unique_widths):
    subset = filtered_data[filtered_data["width"] == width]
    fig.add_trace(
        go.Scatter(x=subset["deg"], y=subset["theo_gen_gap"], mode='lines+markers', line=dict(color=colors[k]), showlegend=False),
        row=1, col=1
    )

# Plot actual_gen_gap for each width, averaging over func
for k, width in enumerate(unique_widths):
    subset = filtered_data[filtered_data["width"] == width]
    avg_actual_gen_gap = subset.groupby("deg")["actual_gen_gap"].mean()
    fig.add_trace(
        go.Scatter(x=avg_actual_gen_gap.index, y=avg_actual_gen_gap.values, mode='lines+markers', line=dict(color=colors[k]), name=f"Width={width}"),
        row=1, col=2
    )

fig.update_layout(
    title="Comparison of Empirical and Theoretical Generalization Gap by Degree, Width",
    height=750,
    width=1200,
    margin=dict(t=50, b=50, l=50, r=50)
)
fig.update_xaxes(title_text="Degree", row=1, col=3)
fig.update_xaxes(title_text="Degree", row=2, col=1)
fig.update_xaxes(title_text="Degree", row=2, col=2)
fig.update_yaxes(title_text="Generalzation Gap", row=1, col=1)
fig.update_yaxes(title_text="Generalzation Gap", row=1, col=3)
fig.update_yaxes(title_text="Generalzation Gap", row=2, col=1)
fig.show()
fig.write_image("plots/unpretrubed vary over sigma.png", width=1200, height=750)

In [None]:
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots

sigmas = [ 0.00003, 0.0001, 0.005,  0.1]
pos = [(1, 1), (1, 2), (2, 1), (2, 2)]
fig = make_subplots(
    rows=2, cols=3,
    specs=[
        [{"type": "xy"}, {"type": "xy"}, {"rowspan": 2, "colspan": 1}],  # First row
        [{"type": "xy"}, {"type": "xy"}, None]                           # Second row
    ],
    column_widths=[0.25, 0.25, 0.5],  # Left half = 0.25+0.25, Right = 0.5
    row_heights=[0.5, 0.5],
    horizontal_spacing=0.10,
    vertical_spacing=0.10,
    subplot_titles=(f"σ = {sigmas[0]}", f"σ = {sigmas[1]}", "Empirical Generalization Gap", f"σ = {sigmas[2]}", f"σ = {sigmas[3]}")
)

# Get unique widths
unique_widths = filtered_data["width"].unique()
colors = ["red", "brown", "green", "purple"]

T = 20
delta = 0.2
Sigma = 0.01
d = T + 2
# eps = 0.01
# d = 8*math.log(T)/eps**2

# T, D_f, m, omega 

for i, sigma in enumerate(sigmas): 
    # Plot theo_gen_gap for each width
    filtered_data["theo_gen_gap"] = filtered_data.apply(
    lambda row: theoretical_gen_gap(
        sigma=sigma,
        omega=row["width"],
        D_f=row["deg"],
        T=T,
        Sigma=Sigma,
        m=m,
        delta=delta,
        d=d
    ),
    axis=1)
    for k, width in enumerate(unique_widths):
        subset = filtered_data[filtered_data["width"] == width]
        fig.add_trace(
            go.Scatter(x=subset["deg"], y=subset["theo_gen_gap"], mode='lines+markers', line=dict(color=colors[k]), showlegend=False),
            row=pos[i][0], col=pos[i][1]
        )

# Plot actual_gen_gap for each width, averaging over func
for k, width in enumerate(unique_widths):
    subset = filtered_data[filtered_data["width"] == width]
    avg_actual_gen_gap = subset.groupby("deg")["actual_gen_gap"].mean()
    fig.add_trace(
        go.Scatter(x=avg_actual_gen_gap.index, y=avg_actual_gen_gap.values, mode='lines+markers', line=dict(color=colors[k]), name=f"Width={width}"),
        row=1, col=3
    )

fig.update_layout(
    title="Comparison of Empirical and Theoretical Generalization Gap by Degree, Width",
    height=750,
    width=1200,
    margin=dict(t=50, b=50, l=50, r=50)
)
fig.update_xaxes(title_text="Degree", row=1, col=3)
fig.update_xaxes(title_text="Degree", row=2, col=1)
fig.update_xaxes(title_text="Degree", row=2, col=2)
fig.update_yaxes(title_text="Generalzation Gap", row=1, col=1)
fig.update_yaxes(title_text="Generalzation Gap", row=1, col=3)
fig.update_yaxes(title_text="Generalzation Gap", row=2, col=1)
fig.show()
fig.write_image("plots/unpretrubed vary over sigma.png", width=1200, height=750)

