# *Architecture independent generalization bounds for overparametrized deep ReLU networks*

Code used to generate data for [*Architecture independent generalization bounds for overparametrized deep ReLU networks*](https://arxiv.org/abs/2504.05695), by Anandatheertha Bapu, Thomas Chen, Chun-Kai Kevin Chien, Patrícia Muñoz Ewald, and Andrew G. Moore. 

## Imports and definitions

In [None]:
import torch
from torch import nn
import torchvision
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import numpy as np
import torch.nn.functional as F
import math
import pandas as pd
import os
import re
from datetime import datetime

In [None]:
mnist = datasets.MNIST(root=".", train=True, download=True, transform=transforms.ToTensor())

### Sample generation

In [None]:
def train_imports(n, seed):
    torch.manual_seed(seed)
    training_indices = torch.randint(0,60000, (n,))
    train_samples = [mnist[i] for i in training_indices]
    X0_train = torch.stack([s[0].view(-1) for s in train_samples], dim=1)  
    labels_train = torch.tensor([s[1] for s in train_samples]).T
    Y_train = torch.eye(10)[labels_train].T 
    
    return X0_train, Y_train, labels_train

In [None]:
def test_imports(m, seed):
    torch.manual_seed(seed)
    test_indices = torch.randint(0,60000, (m,))
    test_samples = [mnist[i] for i in test_indices]
    X0_test = torch.stack([s[0].view(-1) for s in test_samples], dim=1)
    labels_test = torch.tensor([s[1] for s in test_samples]).T
    Y_test = torch.eye(10)[labels_test].T
    
    return X0_test, Y_test, labels_test

In [None]:
def has_all_mnist_digits(labels):
    labels = np.array(labels)
    unique_digits = np.unique(labels)
    all_digits = set(range(10))
    missing_digits = sorted(all_digits - set(unique_digits))
    
    return len(missing_digits) == 0

### Chamfer Distance 

In [None]:
def _to_numpy(x):
    # Accepts either np.array or torch.Tensor, returns np.array on CPU
    if isinstance(x, torch.Tensor):
        return x.detach().cpu().numpy()
    return np.asarray(x, dtype=float)

In [None]:
def chamfer_distance_labeled(X1, Y1, X2, Y2): 
    """
    Asymmetric Chamfer distance from (X1, Y1) to (X2, Y2) using
        d((x1,y1),(x2,y2)) = (||x1 - x2||_2 + ||y1 - y2||_2)^2.
    
    This version is meant for metrics/logging:
    - If inputs are torch.Tensors, they are detached and moved to CPU.
    - No gradients will flow through this function.
    """
   
    X1 = _to_numpy(X1)
    Y1 = _to_numpy(Y1)
    X2 = _to_numpy(X2)
    Y2 = _to_numpy(Y2)
   
    assert X2.shape[0] == Y2.shape[0], "X2 and Y2 must have same number of points"
    assert X1.shape[1] == X2.shape[1], "X1 and X2 must have same feature dimension"
    assert Y1.shape[1] == Y2.shape[1], "Y1 and Y2 must have same label dimension"

    # Pairwise distances in feature space
    diff_x = X1[:, None, :] - X2[None, :, :]   # (N1, N2, d)
    dx = np.linalg.norm(diff_x, axis=-1)       # (N1, N2)

    # Pairwise distances in label space
    diff_y = Y1[:, None, :] - Y2[None, :, :]   # (N1, N2, Q)
    dy = np.linalg.norm(diff_y, axis=-1)       # (N1, N2)

    # Total pairwise distances
    d = (dx + dy) ** 2                         

    # For each point in cloud 1, find nearest neighbor in cloud 2
    min_d = d.min(axis=1)

    # Asymmetric Chamfer: average nearest-neighbor distance
    cd = min_d.mean()

    return cd

### Network classes

In [None]:
class ReLUNetwork(nn.Module):
    def __init__(self, layers, bias = False):
        super().__init__()
        self.layers = nn.ModuleList()
        for i in range(len(layers)-1):
            self.layers.append(nn.Linear(layers[i], layers[i+1], bias=bias))
    
    def forward(self, x):
        for i in range(len(self.layers)):
            x = self.layers[i](x)
            if i < len(self.layers) - 1:
                x = F.relu(x)
        return x

In [None]:
def trained_model_error(X, Y, model):
    sum = 0
    for i in range(len(X.T)):
        error = model(X.T[i]) - Y.T[i]
    
    for j in error:
        sum += j**2
    
    return sum/len(X.T)

In [None]:
class zeroLossNetwork():
    def __init__(self, M, X0, Y):
        M0 = X0.size(0)
        Q, n = Y.size()
        self.L = len(M)
        
        
        # Exception handling
        if M0 < n:
            raise Exception('M0 must be >= n')
        
        if M[self.L-1] < Q:
            raise Exception('All elements of M must be >= Q')
        
        for i in range(self.L-1):
            if M[i] < M[i+1]:
                raise Exception('M must be nonincreasing')
        
        # alpha and output bias
        alpha = (-1*Y).max() # = 0 in this case for unit vectors in Y
        
        # First layer W1^*
        X0T = X0.T
        X0_plus = torch.linalg.pinv(X0)   # M0 x n pseudoinverse
        top = (Y) @ X0_plus       # Q x M0
        bottom = torch.zeros(M[0] - Q, M0)
        self.W1 = torch.cat([top, bottom], dim=0)
        
        
        # Hidden layer weights W_l^*, l=2,...,L. W_hidden[i](below) <-> W_(i+2)^*(in the paper)
        self.W_hidden = []
        for l in range(1, self.L):  # l=1,...,L
            left = torch.eye(M[l])
            right = torch.zeros(M[l], M[l-1] - M[l])
            self.W_hidden.append(torch.cat([left, right], dim=1))
            
        
        # Output weight W_(L+1)^*
        self.WLp1 = torch.cat([torch.eye(Q), torch.zeros(Q, M[self.L-1] - Q)], dim=1)  


  # Forward method
    def forward(self, X0):
        x = F.relu(self.W1 @ X0)
        for l in range(self.L-1):
            x = self.W_hidden[l] @ x
            x = F.relu(x) # Apply ReLU activation
        out = self.WLp1 @ x
        return out

In [None]:
def initializedNetworkError(X, Y, model):
    sum = 0
    for i in range(len(X.T)):
        error = model.forward(X.T[i]) - Y.T[i]
    for j in torch.squeeze(error):
        sum += j**2
    return sum/len(X[0])

###  Basic code for experiments

`zeroLossNetworkExperiment` initializes a zeroLossNetwork() interpolating `n` MNIST samples using a specific `train_seed`. It then generates a test set with `m` elements from seed `test_seed`, and outputs the following:
- Train error
- Test error
- Certain bounds
- $||W_1 - W_1 X_0 (X_0)^+||$
- Operator norms of the weight matrices
- `train_seed, test_seed`

`trainedNetwork` initializes a ReLUNetwork() with layer dimensions given by an array `M`, possibly with bias (if `bias=True`), either randomly (if `from_random=True`) or using the parameters from a zeroLossNetwork() (if `from_random=False`), and then trains it to interpolate `n` MNIST samples generatem from seed `train_seed`. The training parameters are:
- `l` learning rate, or `lr_from_constructors` if `from_random=False`,
- `wd` weight decay,
- `batch_size`,
- `epochs`,
- `init_seed` seed for random initialization
Returns: model, seed, seed

`trainedNetworkExperiment` then takes a `model` trained with `train_seed, init_seed, n, M` as above, generates a test set with `m` elements from seed `test_seed`, and outputs the following:
- Train error
- Test error
- Certain bounds
- $||W_1 - W_1 X_0 (X_0)^+||$
- Operator norms of the weight matrices
- `train_seed, test_seed, init_seed`


___
**List of parameters**

Required Parameters 
* `n` = int, training set size
* `m` = int, test set size
* `M` = list, architecture of network. Enter all numbers from M0 to Q (inclusive) as a list.

Optional parameters:

* `train_seed` seed for generating training set, `default=30`
* `test_seed` seed for generating test set, `default=10`
* `l` = float, learning rate, `default=0.1`
* `wd` = float, weight decay, `default=0.0`
* `from_random` = boolean, True if training from random, False if prefilled with zero loss minimizers, `default=True`
* `batch_size` = int, `default=1`
* `epochs` = int, `default=100`
* `lr_from_constructors` = float, learning rate if prefilled with zero loss minimizers, `default=0.1`
* `bias` = boolean, `default=False`
* `print_loss` = boolean, prints loss every 10 epochs if True, `default=False`

Note: If the training set does not contain a digit, the program throws it out and iterates through seeds.


In [None]:
def zeroLossNetworkExperiment(n, m, M, train_seed = 30, test_seed = 10):
    
    has_all_digits = False
    while not has_all_digits:
        X0, Y0, L0 = train_imports(n, train_seed)
        X, Y, L = test_imports(m, test_seed)
        train_seed +=1
        has_all_digits = has_all_mnist_digits(L0)
    
    model = zeroLossNetwork(M, X0, Y0)
    train_out = model.forward(X0)
    test_out = model.forward(X)
    train_error = initializedNetworkError(X0, Y0, model)
    test_error = initializedNetworkError(X, Y, model)
    
    W1 = model.W1

    op_norms = []
    op_norms.append(torch.linalg.norm(model.W1, ord=2).item())
    for i in range(len(model.W_hidden)):
        op_norms.append(torch.linalg.norm(model.W_hidden[i], ord=2).item())
    op_norms.append(torch.linalg.norm(model.WLp1, ord=2).item())

    prod_norms = math.prod(op_norms)
    
    # Computes the product of weight matrices 
    prod_weights = model.W1            # For all matrices
    prod_weights_g = torch.eye(M[0])   # Excluding W1
    for i in range(len(model.W_hidden)):
        prod_weights = model.W_hidden[i] @ prod_weights
        prod_weights_g = model.W_hidden[i] @ prod_weights_g
    prod_weights = model.WLp1 @ prod_weights
    prod_weights_g = model.WLp1 @ prod_weights_g

    # Lipf = torch.linalg.norm(prod_weights, ord = 2)
    Lipf = prod_norms
    Lipg = torch.linalg.norm(prod_weights_g, ord = 2)
    
    bound0 = (max(1, Lipf))**2 * chamfer_distance_labeled(X.T, Y.T, X0.T, Y0.T)
    bound1 = (max(1, Lipg))**2 * chamfer_distance_labeled((W1 @ X).T, Y.T, (W1 @ X0).T, Y0.T)
    bound2 = (max(1, Lipg))**2 * chamfer_distance_labeled(F.relu((W1 @ X).T), Y.T, F.relu((W1 @ X0).T), Y0.T)
    
    dist = torch.linalg.norm(W1 - (W1 @ X0 @ torch.linalg.pinv(X0)))
    
    return (
      np.array(train_error), 
      np.array(test_error), 
      np.array(bound0), 
      np.array(bound1), 
      np.array(bound2),
      np.array(dist), 
      np.array(op_norms), 
      np.array(train_seed)-1, # Because check for all numbers does always does an extra +1 
      np.array(test_seed), 
      None
    )


In [None]:
def trainedNetwork(n, M, bias = False, from_random=True,
                   l=0.1, lr_from_contructors = 0.1, wd=0.0,  batch_size=1, epochs=100, print_loss=False, 
                   train_seed = 30, init_seed=None):
    '''
    Returns
    Trained ReLUNetwork(), train_seed used, seed used for parameter initialization
    '''
    
    has_all_digits = False
    while not has_all_digits:
        X0, Y0, L0 = train_imports(n, train_seed)
        train_seed +=1
        has_all_digits = has_all_mnist_digits(L0)
    
    if init_seed == None:
        torch.seed()
    else:
        torch.manual_seed(init_seed)
    model0 = ReLUNetwork(M, bias)
    seed = torch.random.initial_seed()
    #print(f"The seed used is: {seed}")
    
    
    if not from_random:
        model = zeroLossNetwork(M[1:len(M)-1], X0, Y0)
        with torch.no_grad():
            model0.layers[0].weight.copy_(model.W1)
            i = 1
            for w in model.W_hidden:
                model0.layers[i].weight.copy_(w)
                i += 1
                model0.layers[i].weight.copy_(model.WLp1)
        l = lr_from_contructors
    
    loss_fn = nn.MSELoss()
    optimizer = torch.optim.SGD(model0.parameters(), lr = l, weight_decay = wd)
    
    dataset = torch.utils.data.TensorDataset(X0.T, Y0.T)
    loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)
    
    for e in range(epochs):
        for xb, yb in loader:
            optimizer.zero_grad()
            output = model0(xb)
            loss = loss_fn(output, yb)
            loss.backward()
            optimizer.step()
        if print_loss and e%10==0:
            print(loss_fn(model0(X0.T), Y0.T)*10)
    
    return model0, train_seed-1, seed # Because check for all numbers does always does an extra +1  

In [None]:
def trainedNetworkExperiment(model, train_seed, init_seed, 
                             n, m, M, 
                             test_seed = 10):
    X, Y, L = test_imports(m, test_seed)
    X0, Y0, L0 = train_imports(n, train_seed)
    
    model0 = model
    train_error = trained_model_error(X0, Y0, model0)
    test_error = trained_model_error(X, Y, model0)
    
    W1 = model0.layers[0].weight
    
    # Computes the product of the operator norms of the weight matrices
    op_norms = []
    for i in range(len(model0.layers)):
        op_norms.append(torch.linalg.norm(model0.layers[i].weight, ord=2).item())
    prod_norms = math.prod(op_norms)
    
    # Computes the product of weight matrices
    prod_weights = torch.eye(M[0])
    for i in range(len(model0.layers)):
        prod_weights = model0.layers[i].weight @ prod_weights
    
    # Computes the product of weight matrices, excluding W1
    prod_weights_g = torch.eye(M[1])
    for i in range(1, len(model0.layers)):
        prod_weights_g = model0.layers[i].weight @ prod_weights_g
    
    
    # Lipf = torch.linalg.norm(prod_weights, ord = 2)
    Lipf = prod_norms
    Lipg = torch.linalg.norm(prod_weights_g, ord = 2)
    
    bound0 = (max(1, Lipf))**2 * chamfer_distance_labeled(X.T, Y.T, X0.T, Y0.T)
    bound1 = (max(1, Lipg))**2 * chamfer_distance_labeled((W1 @ X).T, Y.T, (W1 @ X0).T, Y0.T)
    bound2 = (max(1, Lipg))**2 * chamfer_distance_labeled(F.relu((W1 @ X).T), Y.T, F.relu((W1 @ X0).T), Y0.T)
    
    dist = torch.linalg.norm(W1 - (W1 @ X0 @ torch.linalg.pinv(X0)))

    return (
      _to_numpy(train_error), 
      _to_numpy(test_error), 
      _to_numpy(bound0), 
      _to_numpy(bound1), 
      _to_numpy(bound2),
      _to_numpy(dist), 
      _to_numpy(op_norms), 
      _to_numpy(train_seed),   
      _to_numpy(test_seed), 
      _to_numpy(init_seed)
    )

___
### Generating data frames

Shorthands:
* ZL   = zeroLossNetworkExperiment
* RI   = trainedNetworkExperiment with from_random=True
* TFZL = trainedNetworkExperiment with from_random=False

The function `train_and_run_all_experiments` runs experiments and returns a data frame

In [None]:
def train_and_run_all_experiments(n_list, M,
                        num_RI_runs=10,
                        tol=1e-3,
                        TFZL=False, random_init=False,
                        train_seed=30, test_seed=10,
                        epochs = 100,
                        scenarios=["fixed_m", "m_eq_n"],
                        save_models=True):
    """
    Run RI, ZL, TFZL for:
      - scenario 'fixed_m': m = min(n_list)
      - scenario 'm_eq_n':  m = n

    Returns
    df : pandas.DataFrame
        One row per run, including all outputs and parameters.
    """
    results = []
    m_fixed = min(n_list)

    # To keep track of network training progress
    nets_to_train = len(scenarios)*len(n_list)*(num_RI_runs*int(random_init) + int(TFZL))
    trained = 0

    for scenario in scenarios:
        for n in n_list:
            if scenario == "fixed_m":
                m = m_fixed
            else:
                m = n

            # ZL: zeroLossNetworkExperiment (single run per (n,m))
            Etrain, Etest, b0, b1, b2, dist, op_norms, train_seed, test_seed, seed = zeroLossNetworkExperiment(
                n, m, M, train_seed=train_seed, test_seed=test_seed
            )
            results.append(
                dict(
                    scenario=scenario,
                    exp_type="ZL",
                    n=n,
                    m=m,
                    M=M,
                    run_idx=0,
                    train_seed=train_seed,
                    test_seed=test_seed,
                    init_seed=seed,
                    Etrain=float(Etrain),
                    Etest=float(Etest),
                    b0=float(b0),
                    b1=float(b1),
                    b2=float(b2),
                    dist=dist,
                    op_norm_W1=op_norms[0],
                    op_norms=op_norms,
                )
            )

            if TFZL==True:
                # TFZL: trainedNetwork with from_random=False (single run per (n, m))
                model, train_seed, init_seed = trainedNetwork(n, M, epochs=epochs, from_random=False, train_seed=train_seed)
                Etrain, Etest, b0, b1, b2, dist, op_norms, train_seed, test_seed, seed = trainedNetworkExperiment(
                    model, train_seed, init_seed,
                    n, m, M,  test_seed
                )
                results.append(
                    dict(
                        scenario=scenario,
                        exp_type="TFZL",
                        n=n,
                        m=m,
                        M=M,
                        run_idx=0,
                        train_seed=train_seed,
                        test_seed=test_seed,
                        init_seed=seed,
                        Etrain=float(Etrain),
                        Etest=float(Etest),
                        b0=float(b0),
                        b1=float(b1),
                        b2=float(b2),
                        dist=dist,
                        op_norm_W1=op_norms[0],
                        op_norms=op_norms
                    )
                )
                trained += 1
                print(f"{trained} out of {nets_to_train} networks trained; {datetime.now()}")
                if save_models==True:
                    torch.save(model.state_dict(), f"params_TFZL_{M}_{n}.pt")

            if random_init==True:
                # RI: trainedNetwork with from_random=True, multiple seeds
                for run_idx in range(num_RI_runs):
                    model, train_seed, init_seed = trainedNetwork(n, M, epochs=epochs, from_random=True, train_seed=train_seed)
                    Etrain, Etest, b0, b1, b2, dist, op_norms, train_seed, test_seed, seed = trainedNetworkExperiment(
                        model, train_seed, init_seed,
                        n, m, M,  test_seed
                    )
                    results.append(
                        dict(
                            scenario=scenario,
                            exp_type="RI",
                            n=n,
                            m=m,
                            M=M,
                            run_idx=run_idx,
                            train_seed=train_seed,
                            test_seed=test_seed,
                            init_seed=seed,
                            Etrain=float(Etrain),
                            Etest=float(Etest),
                            b0=float(b0),
                            b1=float(b1),
                            b2=float(b2),
                            dist=dist,
                            op_norm_W1=op_norms[0],
                            op_norms=op_norms,
                        )
                    )
                    trained += 1
                    print(f"{trained} out of {nets_to_train} networks trained; {datetime.now()}")
                    if save_models==True:
                        torch.save(model.state_dict(), f"params_RI-{run_idx}_{M}_{n}.pt")
    
    df = pd.DataFrame(results)
    return df

In [None]:
def train_nets(n_list, M, train_seed=30, 
               TFZL=False, 
               random_init=False, num_RI_runs=10,
               epochs = 100):
    
    nets_to_train = len(n_list)*(num_RI_runs*int(random_init) + int(TFZL))
    trained = 0
    
    for n in n_list:
        if TFZL:
            model, _, _ = trainedNetwork(n, M, from_random=False, train_seed=train_seed, epochs=epochs)
            os.makedirs("models", exist_ok=True)
            torch.save(model.state_dict(), f"models/params_TFZL_{M}_n{n}_e{epochs}.pt")
            
            trained += 1
            print(f"{trained} out of {nets_to_train} networks trained; {datetime.now()}")

        if random_init:
            for run_idx in range(num_RI_runs):
                model, _, seed = trainedNetwork(n, M, from_random=True, train_seed=train_seed, epochs=epochs)
                os.makedirs(f"models/{M}/{n}", exist_ok=True)
                torch.save(model.state_dict(), f"models/{M}/{n}/params_RI_{M}_n{n}_e{epochs}_s{seed}.pt")
                
                trained += 1
                print(f"{trained} out of {nets_to_train} networks trained; {datetime.now()}")
                #

In [None]:
def networks_to_df(n_list,
                        M, bias=False,
                        num_RI_runs=10,
                        TFZL=True, random_init=True,
                        train_seed=30, test_seed=10,
                        epochs = 100,
                        scenarios=["fixed_m", "m_eq_n"]):
    """
    Run RI, ZL, TFZL for:
      - scenario 'fixed_m': m = min(n_list)
      - scenario 'm_eq_n':  m = n

    Returns
    -------
    df : pandas.DataFrame
        One row per run, including all outputs and parameters.
    """
    results = []
    m_fixed = min(n_list)

    for scenario in scenarios:
        for n in n_list:
            if scenario == "fixed_m":
                m = m_fixed
            else:
                m = n

            # ZL: zeroLossNetworkExperiment (single run per (n,m))
            Etrain, Etest, b0, b1, b2, dist, op_norms, train_seed, test_seed, seed = zeroLossNetworkExperiment(
                n, m, M, train_seed=train_seed, test_seed=test_seed
            )
            results.append(
                dict(
                    scenario=scenario,
                    exp_type="ZL",
                    n=n,
                    m=m,
                    M=M,
                    run_idx=0,
                    train_seed=train_seed,
                    test_seed=test_seed,
                    init_seed=seed,
                    Etrain=float(Etrain),
                    Etest=float(Etest),
                    b0=float(b0),
                    b1=float(b1),
                    b2=float(b2),
                    dist=dist,
                    op_norm_W1=op_norms[0],
                    op_norms=op_norms,
                )
            )

            if TFZL==True:
                # TFZL: trainedNetworkExperiment with from_random=False (single run)
                model = ReLUNetwork(M, bias)
                model.load_state_dict(torch.load(f"models/{M}_trainseed{int(train_seed)}/TFZL/params_TFZL_{M}_n{n}_e{epochs}.pt", weights_only=True))
                Etrain, Etest, b0, b1, b2, dist, op_norms, train_seed, test_seed, seed = trainedNetworkExperiment(
                    model, train_seed, None, n, m, M, test_seed
                )
                results.append(
                    dict(
                        scenario=scenario,
                        exp_type="TFZL",
                        n=n,
                        m=m,
                        M=M,
                        run_idx=0,
                        train_seed=train_seed,
                        test_seed=test_seed,
                        init_seed=seed,
                        Etrain=float(Etrain),
                        Etest=float(Etest),
                        b0=float(b0),
                        b1=float(b1),
                        b2=float(b2),
                        dist=dist,
                        op_norm_W1=op_norms[0],
                        op_norms=op_norms
                    )
                )

            if random_init==True:
                # RI: trainedNetworkExperiment with from_random=True, multiple seeds
                directory_path = f"models/{M}_trainseed{int(train_seed)}/{n}"
                run_idx = 0
                for file in os.listdir(directory_path):
                    full_path = os.path.join(directory_path, file)
                    model = ReLUNetwork(M, bias)
                    model.load_state_dict(torch.load(full_path, weights_only=True))
                    seed = re.search(r"_s(\d+)\.pt$", file)
                    if not m:
                        raise ValueError(f"No '_s<digits>.pt' seed found in: {file}")
                    
                    Etrain, Etest, b0, b1, b2, dist, op_norms, train_seed, test_seed, seed = trainedNetworkExperiment(
                        model, train_seed, None, n, m, M, test_seed
                    )
                    results.append(
                        dict(
                            scenario=scenario,
                            exp_type="RI",
                            n=n,
                            m=m,
                            M=M,
                            run_idx=run_idx,
                            train_seed=train_seed,
                            test_seed=test_seed,
                            init_seed=seed,
                            Etrain=float(Etrain),
                            Etest=float(Etest),
                            b0=float(b0),
                            b1=float(b1),
                            b2=float(b2),
                            dist=dist,
                            op_norm_W1=op_norms[0],
                            op_norms=op_norms,
                        )
                    )
                    run_idx += 1
                    if run_idx == num_RI_runs: break
    
    df = pd.DataFrame(results)
    return df

In [None]:
def _aggregate_by_n(df, n_list, cols):
    """
    Helper: for a DataFrame already filtered by scenario and exp_type,
    return ns (subset of n_list that exists in df) and mean/std for given cols.
    """
    if df.empty:
        return [], {c: np.array([]) for c in cols}, {c: np.array([]) for c in cols}

    g = df.groupby("n")
    ns = [n for n in n_list if n in g.indices]

    means = {}
    stds = {}
    for col in cols:
        means[col] = np.array([g.get_group(n)[col].mean() for n in ns])
        stds[col]  = np.array([g.get_group(n)[col].std(ddof=0) for n in ns])
    return ns, means, stds

### Plotting functions

In [None]:
plt.rcParams.update({'font.size': 12})  # Sets font size for all plots

Produces 3 plots of test error and bound(s) computed, one for each of (ZL, TFZL, RI):

In [None]:
def plot_Etest_and_bounds(df, n_list, tol=1e-3, scenario="fixed_m", save=False):
    """
    1) Single plot of Etest and bound(s) vs n for RI (good seeds: Etrain < tol)
    2) Single plot of Etest and bound(s) vs n for ZL
    3) Single plot of Etest and bound(s) vs n for TFZL

    scenario: 'fixed_m' or 'm_eq_n' (default 'fixed_m')
    """
    cols = ["Etrain", "Etest", "b1"]

    # ----- RI (good seeds) -----
    df_RI = df[(df["scenario"] == scenario) &
               (df["exp_type"] == "RI") &
               (df["Etrain"] < tol)]
    ns_RI, means_RI, stds_RI = _aggregate_by_n(df_RI, n_list, cols)

    if ns_RI:
        fig, ax = plt.subplots(figsize=(7, 5))
        ax.errorbar(ns_RI, means_RI["Etest"], yerr=stds_RI["Etest"],
                    fmt="o-", label=r"$\mathcal{E}^{test}$", capsize=3)
        # ax.errorbar(ns_RI, means_RI["b2"],   yerr=stds_RI["b2"],
        #             fmt="x--", label="Bound 2", capsize=3)
        ax.errorbar(ns_RI, means_RI["b1"],   yerr=stds_RI["b1"],
                    fmt="d-", label="Bound", capsize=3)
        ax.set_xlabel("n")
        # ax.set_ylabel("Value")
        if scenario=="fixed_m":
            ax.set_title(f"Random initialization ($\mathcal{{E}}^{{train}} \leq {tol}$, m={min(n_list)})")
        if scenario=="m_eq_n":
            ax.set_title(f"Random initialization ($\mathcal{{E}}^{{train}} \leq {tol}$, m=n)")
        ax.grid(True)
        ax.legend()
        plt.tight_layout()

        if save==True:
            os.makedirs("plots", exist_ok=True)
            fname = f"plots/Etest_b1_b2_RI_tol{tol:g}_{scenario}.png"
            fig.savefig(fname, dpi=200)

        plt.show()
    else:
        print(f"No RI runs with Etrain < {tol} for scenario '{scenario}'.")

    # ----- ZL -----
    df_ZL = df[(df["scenario"] == scenario) &
               (df["exp_type"] == "ZL")]
    ns_ZL, means_ZL, stds_ZL = _aggregate_by_n(df_ZL, n_list, cols)

    if ns_ZL:
        fig, ax = plt.subplots(figsize=(7, 5))
        ax.plot(ns_ZL, means_ZL["Etest"], 
                    "o-", label=r"$\mathcal{E}^{test}$")
        # ax.plot(ns_ZL, means_ZL["Etrain"],  ## Remove
        #             "s-", label=r"$\mathcal{E}^{train}$")
        # ax.plot(ns_ZL, means_ZL["b2"],   
        #             "x--", label="b2")
        ax.plot(ns_ZL, means_ZL["b1"],   
                    "d-", label="Bound")
        ax.set_xlabel("n")
        # ax.set_ylabel("Value")
        if scenario=="fixed_m":
            ax.set_title(f"Zero loss network (m={min(n_list)})")
        if scenario=="m_eq_n":
            ax.set_title(f"Zero loss network (m=n)")
        ax.grid(True)
        ax.legend()
        plt.tight_layout()

        if save==True:
            os.makedirs("plots", exist_ok=True)
            fname = f"plots/Etest_b1_b2_ZL_{scenario}.png"
            fig.savefig(fname, dpi=200)

        plt.show()
    else:
        print(f"No ZL runs for scenario '{scenario}'.")

    # ----- TFZL -----
    df_TFZL = df[(df["scenario"] == scenario) &
                 (df["exp_type"] == "TFZL")]
    ns_TFZL, means_TFZL, stds_TFZL = _aggregate_by_n(df_TFZL, n_list, cols)

    if ns_TFZL:
        fig, ax = plt.subplots(figsize=(7, 5))
        ax.plot(ns_TFZL, means_TFZL["Etest"], 
                    "o-", label=r"$\mathcal{E}^{test}$")
        # ax.plot(ns_TFZL, means_TFZL["b2"],   
        #             "x--", label="b2")
        ax.plot(ns_TFZL, means_TFZL["b1"],   
                    "d-", label="Bound")
        ax.set_xlabel("n")
        # ax.set_ylabel("Value")
        if scenario=="fixed_m":
            ax.set_title(f"Trained from zero loss network (m={min(n_list)})")
        if scenario=="m_eq_n":
            ax.set_title(f"Trained from zero loss network (m=n)")
        ax.grid(True)
        ax.legend()
        plt.tight_layout()

        if save==True:
            os.makedirs("plots", exist_ok=True)
            fname = f"plots/Etest_b1_b2_TFZL_{scenario}.png"
            fig.savefig(fname, dpi=200)

        plt.show()
    else:
        print(f"No TFZL runs for scenario '{scenario}'.")

Single plot of Etest for ZL, TFZL and RI:

In [None]:
def plot_compare_Etest(df, n_list, scenario="fixed_m", tol=1e-3, save=False):
    """
    Single plot: Etest vs n for
      - ZL
      - TFZL
      - RI (all runs)
      - RI (good seeds: Etrain < tol)

    Uses mean ± std across runs for each (exp_type, n).

    scenario: 'fixed_m' or 'm_eq_n'
    """
    cols = ["Etest"]

    # ZL
    df_ZL = df[(df["scenario"] == scenario) &
               (df["exp_type"] == "ZL")]
    ns_ZL, means_ZL, stds_ZL = _aggregate_by_n(df_ZL, n_list, cols)

    # TFZL
    df_TFZL = df[(df["scenario"] == scenario) &
                 (df["exp_type"] == "TFZL")]
    ns_TFZL, means_TFZL, stds_TFZL = _aggregate_by_n(df_TFZL, n_list, cols)

    # RI (all runs)
    df_RI_all = df[(df["scenario"] == scenario) &
                   (df["exp_type"] == "RI")]
    ns_RI_all, means_RI_all, stds_RI_all = _aggregate_by_n(df_RI_all, n_list, cols)

    # RI (good seeds)
    df_RI_good = df[(df["scenario"] == scenario) &
                    (df["exp_type"] == "RI") &
                    (df["Etrain"] < tol)]
    ns_RI_good, means_RI_good, stds_RI_good = _aggregate_by_n(df_RI_good, n_list, cols)

    if not (ns_ZL or ns_TFZL or ns_RI_all or ns_RI_good):
        print(f"No ZL/TFZL/RI data for scenario '{scenario}'.")
        return

    fig, ax = plt.subplots(figsize=(7, 5))

    if ns_ZL:
        ax.plot(
            ns_ZL,
            means_ZL["Etest"],
            "o-",
            label="ZL",
        )
    if ns_TFZL:
        ax.plot(
            ns_TFZL,
            means_TFZL["Etest"],
            "x--",
            label="TFZL",
        )
        
    # if ns_RI_all:
    #     ax.errorbar(
    #         ns_RI_all,
    #         means_RI_all["Etest"],
    #         yerr=stds_RI_all["Etest"],
    #         fmt="d-",
    #         label="RI (all)", capsize=3,
    #     )
    if ns_RI_good:
        ax.errorbar(
            ns_RI_good,
            means_RI_good["Etest"],
            yerr=stds_RI_good["Etest"],
            fmt="x--",
            label=f"RI ($\mathcal{{E}}^{{train}}$ < {tol})",capsize=3,
        )

    ax.set_xlabel("n")
    ax.set_ylabel(f"$\mathcal{{E}}^{{test}}$")
    if scenario=="fixed_m":
        ax.set_title(f"Comparison of test error $\mathcal{{E}}^{{test}}$ (m={min(n_list)})")
    if scenario=="m_eq_n":
        ax.set_title(f"Comparison of test error $\mathcal{{E}}^{{test}}$ (m=n)")
    ax.grid(True)
    ax.legend()
    plt.tight_layout()

    if save==True:
        os.makedirs("plots", exist_ok=True)
        fname = f"plots/compare_Etest_all_{scenario}.png"
        fig.savefig(fname, dpi=200)

    plt.show()


Produces a single plot for chosen bounds for RI and ZL:

In [None]:
def plot_all_bounds(df, n_list, scenario="fixed_m", tol=1e-3, save=False, ZLb0=True):
    """
    Single plot for (b2, b1, b0) for RI and (b2, b1, b0) for ZL vs n.
    
    - RI uses good seeds only: Etrain < tol.

    scenario: 'fixed_m' or 'm_eq_n'
    """
    cols = ["b1"]

    # RI (good seeds)
    df_RI = df[(df["scenario"] == scenario) &
               (df["exp_type"] == "RI") &
               (df["Etrain"] < tol)]
    ns_RI, means_RI, stds_RI = _aggregate_by_n(df_RI, n_list, cols)

    # ZL
    df_ZL = df[(df["scenario"] == scenario) &
               (df["exp_type"] == "ZL")]
    ns_ZL, means_ZL, _ = _aggregate_by_n(df_ZL, n_list, cols)

    if not ns_RI and not ns_ZL:
        print(f"No RI/ZL data for scenario '{scenario}'.")
        return

    fig, ax = plt.subplots(figsize=(8, 5))

    if ns_RI:
        # ax.errorbar(ns_RI, means_RI["b2"], yerr=stds_RI["b2"], fmt="o-", label="RI b2", capsize=3)
        ax.errorbar(ns_RI, means_RI["b1"], yerr=stds_RI["b1"], fmt="s-", label="RI bound", capsize=3)
        # ax.errorbar(ns_RI, means_RI["b0"], yerr=stds_RI["b0"], fmt="x--", label="RI b0", capsize=3)

    if ns_ZL:
        # ax.plot(ns_ZL, means_ZL["b2"], "o--", label="ZL b2")
        ax.plot(ns_ZL, means_ZL["b1"], "s--", label="ZL bound")
        # if ZLb0:
        #     ax.plot(ns_ZL, means_ZL["b0"], "d--", label="ZL b0")

    ax.set_xlabel("n")
    ax.set_ylabel("Value")
    if scenario=="fixed_m":
        ax.set_title(f"All bounds (m={min(n_list)})")
    if scenario=="m_eq_n":
        ax.set_title(f"All bounds (m=n)")
    ax.grid(True)
    ax.legend()
    plt.tight_layout()

    if save==True:
        os.makedirs("plots", exist_ok=True)
        fname = f"plots/all_bounds_RI_ZL_{scenario}_ZLb0-{ZLb0}.png"
        fig.savefig(fname, dpi=200)

    plt.show()

Takes in saved data frames for architectures `M` in `M_list` above and plots the chosen column (say, Etest or b1) vs `n_list` for each `M`:

In [None]:
def plot_compare_archits(col, M_list, n_list, scenario="fixed_m", tol=1e-3, save=False):
    '''
    Takes in saved data frames for architectures M in M_list 
    and plots the chosen column col vs n_list for each M
    '''
    cols = [col]
    fig, ax = plt.subplots(figsize=(7, 4))

    # RI (good seeds)
    for M in M_list:
        df = pd.read_csv(f'./experiments/exp17/df_exp17_{M}.csv')   # Input file
        df_RI = df[(df["scenario"] == scenario) &
                   (df["exp_type"] == "RI") &
                   (df["Etrain"] < tol)]
        ns_RI, means_RI, stds_RI = _aggregate_by_n(df_RI, n_list, cols)
        ax.errorbar(ns_RI, means_RI[col], yerr=stds_RI[col], label=f"{M}", capsize=3)

    ax.set_xlabel("n")
    # ax.set_ylabel("Value")
    if scenario=="fixed_m":
        ax.set_title(f"Bound for RI ($\mathcal{{E}}^{{train}}$ < {tol}, m={min(n_list)})")
    if scenario=="m_eq_n":
        ax.set_title(f"Bound for different depths (m=n)")
    ax.grid(True)
    ax.legend()
    plt.tight_layout()

    if save==True:
        os.makedirs("plots", exist_ok=True)
        fname = f"plots/bound_RI_different_depths_{scenario}.png"
        fig.savefig(fname, dpi=200)

    plt.show()

___
## Running experiments

### Experiment parameters 

Parameters (necessarily need to be defined in *Experimental setup* below):
* `M` = list, architecture of network. Enter all numbers from M0 to Q (inclusive) as a list
* `bias` = boolean, `default=False`
* `n_list` = list of int, training set sizes to be considered
* `train_seed` seed for generating training set, `default=30`
* `test_seed` seed for generating test set, `default=10`
* `scenarios` = list, possible elements are "fixed_m" (m = min(n_list)) and "m_eq_n" (m=n)
* `TFZL` = boolean, if considering TFZL in this experiment
* `random_init` = boolean, if considering RI in this experiment
* `num_RI_runs`= int, how many random initializations for RI
* `epochs` = int, used 70 in paper
* `tol` = float, training error tolerance for randomly initialized networks
* `M_list` = list of `M`s, for comparing architectures

In [None]:
# Experiment setup

# For training
M = [784, 784, 10]        # architecture 
bias = False
n_list = [20, 30]        # [20, 60, 100, 200, 300, 400, 475, 500, 525, 550]  used for manuscript plots 
train_seed = 30
TFZL = True
random_init = True
num_RI_runs = 1          # number of runs for random init network, between 2 and 10 for published experiments
epochs = 10              # 70 in paper

# For computing quantities of interest
test_seed=10
scenarios=["fixed_m"]
tol = 1e-3                # Etrain < tol, for randomly initialized trained net

# For comparing architectures 
# M_list = ['[784, 600, 10]', '[784, 600, 100, 10]', '[784, 600, 300, 100, 10]', '[784, 600, 300, 200, 100, 10]', '[784, 600, 300, 200, 100, 50, 10]']
M_list = ['[784, 784, 10]', '[784, 784, 784, 10]', '[784, 784, 784, 784, 10]']

### If running new experiments

Train networks (saves model parameters if `save_models=True`), run all experiments and save  to data frame:

In [None]:
df = train_and_run_all_experiments(
    n_list=n_list,
    M=M,
    scenarios=scenarios,
    train_seed=train_seed, test_seed=test_seed,
    random_init=random_init, TFZL=TFZL,
    num_RI_runs=num_RI_runs,
    epochs=epochs, 
    tol=tol, 
    save_models=False
)

Train networks and save models as .pt files:

In [None]:
# train_nets(n_list, M, train_seed=train_seed, 
#                TFZL=False, 
#                random_init=random_init, num_RI_runs=num_RI_runs,
#                epochs = epochs)

Search for saved model parameters (.pt), compute quantities of interest and save to data frame:

In [None]:
# df = networks_to_df(n_list, M, bias, num_RI_runs, TFZL, random_init, train_seed, test_seed, epochs, scenarios)

Save df to a spreadsheet:

In [None]:
# df.to_csv(f'df.csv')

### If loading previously run experiment data

If reading a previously generated data frame, load it here:

In [None]:
# df = pd.read_csv('df.csv')

Concatenating existing saved data frame (df2) with newly generated one (df1):

In [None]:
# df2 = pd.read_csv('other-file.csv')
# new_df = pd.concat([df1, df2], ignore_index=True)

# df = new_df.loc[new_df.astype(str).drop_duplicates().index]

## Save new df to csv
# df.to_csv('new-df.csv')

### Plotting results 

In [None]:
save=False # If plots will be saved to .png

for scenario in scenarios:
# Etest + bound(s) plots for RI, ZL, TFZL:
    plot_Etest_and_bounds(df, n_list, tol=tol, scenario=scenario, save=save)

# Compare Etest for RI, ZL and TFZL:
    plot_compare_Etest(df, n_list, tol=tol, scenario=scenario, save=save)

The following function takes as input data frames saved as `./experiments/exp17/df_exp17_{M}.csv` for architectures `M` in `M_list` above and plots the chosen column (say, Etest or b1) vs $n \in$ `n_list` for each `M`.

In [None]:
save = False
# plot_compare_archits('b1', M_list, n_list, scenario=scenario, tol=tol, save=save)

___
## Other computations

In [None]:
# for seed in [30, 40]:
#     print(f"Training seed {seed}")
#     for n in n_list:
#         X, _, _ = train_imports(n, seed)
#         print(f"Rank for n={n} = {torch.linalg.matrix_rank(X)}; ratio = {torch.linalg.matrix_rank(X)/n:.2f}")
#     print()