<a href="https://colab.research.google.com/github/rsonthal/Low-Rank-Gradient/blob/main/Eigenvector_Alignment.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import math
import matplotlib.pyplot as plt # For optional plotting later
import numpy as np
from tqdm import tqdm
import copy

In [None]:
torch.set_default_device("cuda")

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
seed = 42

# --- Device Setup ---
if torch.cuda.is_available():
    device = torch.device("cuda")
    print("Using GPU:", torch.cuda.get_device_name(0))
else:
    device = torch.device("cpu")
    print("Using CPU")

torch.manual_seed(seed)
np.random.seed(seed)
if device == torch.device("cuda"):
    torch.cuda.manual_seed_all(seed)

# --- Activation Function ---
relu_act = torch.relu
sigmoid_act = torch.sigmoid
tanh_act = torch.tanh

def relative_projection_norm(W, X, y, q):
  """
  Computes ||P_S(W)||_F / ||W||_F, where S = span(X^T*y, q).

  Args:
    W (torch.Tensor): Matrix to project (m, d).
    X (torch.Tensor): Data matrix (n, d).
    y (torch.Tensor): Target vector (n,) or (n, 1).
    q (torch.Tensor): Second spanning vector (d,) or (d, 1).

  Returns:
    float: Relative Frobenius norm of the projection. Returns 0.0 if ||W||_F is zero.
  """
  if y.ndim == 1: y = y.unsqueeze(1)
  if q.ndim == 1: q = q.unsqueeze(1)
  v1 = X.T @ y # Shape (d, 1)

  # Create matrix of basis vectors (d, k), handle zero vectors implicitly via QR
  # Concatenate non-zero vectors robustly
  vec_list = [v for v in [v1, q] if torch.linalg.norm(v) > 1e-9]
  if not vec_list: return 0.0
  basis_mat = torch.cat(vec_list, dim=1)

  # Orthonormal basis Q (d, k) for the span (k=rank, <=2)
  Q, _ = torch.linalg.qr(basis_mat, mode='reduced')

  # Norm of projection || P_S(W) ||_F = || W @ Q ||_F
  norm_proj = torch.linalg.norm(W @ Q)
  norm_W = torch.linalg.norm(W)

  return (norm_proj / norm_W).item() if norm_W > 1e-9 else 0.0

Using GPU: NVIDIA A100-SXM4-40GB


In [None]:
def operator_norm(A):
  return largest_singular_value(A)

def largest_singular_value(A, num_iterations=200):
    """
    Approximates the largest singular value of a matrix using power iteration.

    Args:
        A: The input matrix (torch.Tensor).
        num_iterations: The number of iterations to perform.

    Returns:
        The estimated largest singular value.
    """

    v = torch.randn(A.shape[1], device=A.device)  # Random initialization
    v = v / torch.norm(v)

    for _ in range(num_iterations):
        u = A @ v
        v = A.T @ u
        v = v / torch.norm(v)

    return torch.norm(u)  # Largest singular value approximation

In [None]:
import torch

def principal_angles(A: torch.Tensor,
                     B: torch.Tensor,
                     *,
                     rtol=1e-12,
                     atol=1e-14):
    """
    Principal angles (radians) between the column spaces of A and B.
    Returns the angles in *ascending* order.

    The routine
      • works with real or complex tensors
      • keeps only the numerically significant singular vectors
      • uses SVD (more stable than unpivoted QR)
    """

    if A.shape[0] != B.shape[0]:
        raise ValueError("A and B must have the same number of rows")

    # promote to float64 unless the user already did so
    if A.dtype == torch.float32:
        A = A.double()
        B = B.double()

    # --- orthonormal bases --------------------------------------------------
    #
    #   A = U_A Σ_A V_A^H   (thin SVD)
    #   keep columns of U_A corresponding to Σ_A > τ
    #
    def orthonormal_basis(X):
        U, S, _ = torch.linalg.svd(X, full_matrices=False)
        τ = atol + rtol * S.max()
        r = (S > τ).sum().item()          # numerical rank
        return U[:, :r]                   # m × r  (possibly r = 0)

    UA, UB = map(orthonormal_basis, (A, B))
    if UA.shape[1] == 0 or UB.shape[1] == 0:
        # One of the spaces is {0}; caller can decide what to do
        return torch.empty(0, dtype=A.dtype, device=A.device)

    # --- cosines of principal angles ---------------------------------------
    C = UA.conj().T @ UB                 # r_A × r_B
    σ = torch.linalg.svdvals(C)          # singular values, descending
    σ = torch.clamp(σ, 0.0, 1.0)         # safety against roundoff

    angles = torch.acos(σ)
    angles, _ = torch.sort(angles)       # explicit ascending order
    return angles


In [None]:
def activation(x, act = "relu"):
  if act == "relu":
    # print("Activation ReLU")
    return torch.relu(x)
  elif act == "sigmoid":
    # print("Activation Sigmoid")
    return torch.sigmoid(x)
  elif act == "tanh":
    # print("Activation Tanh")
    return torch.tanh(x)

def activation_derivative(x, act = "relu"):
  if act == "relu":
    # print("Activation ReLU Derivative")
    return (x > 0).float()
  elif act == "sigmoid":
    # print("Activation Sigmoid Derivative")
    z = activation(x, act = act)
    return z * (1-z)
  elif act == "tanh":
    # print("Activation Tanh Derivative")
    z = activation(x, act = act)
    return 1 - z**2

def gen_data(n, d, zeta, q, alpha = 0, return_all = False):
  """
    - q should be a d x 1 vector
  """
  Sigma = torch.diag(torch.arange(1,d+1) ** (-1*alpha/2))
  X_B = torch.randn(n,d) @ Sigma
  z = torch.randn(n,1)
  if not return_all:
    return X_B + zeta * z @ q.T
  else:
    return X_B + zeta * z @ q.T, z, X_B, zeta * z @ q.T

def get_y(X, f, tau = 1, loss_type = "MSE"):
  if loss_type == "MSE":
    y = torch.tensor([f(X[[i],:].T) for i in range(X.shape[0])])
    return y + torch.randn(y.shape) * tau
  elif loss_type == "bce":
    return torch.tensor([f(X[[i],:].T) for i in range(X.shape[0])])
  elif loss_type == "hinge":
    y = torch.tensor([f(X[[i],:].T) for i in range(X.shape[0])])
    return (y-0.5).sign()
  y = torch.tensor([f(X[[i],:].T) for i in range(X.shape[0])])
  return y + torch.randn(y.shape) * tau

def loss(X, W, a, y, gamma, act = "relu", loss_type = "MSE"):
  F = gamma * a.T @ activation(W @ X.T, act = act)
  if loss_type == "MSE":
    return torch.mean((F - y)**2)/2
  elif loss_type == "bce":
    z = activation(F, act = "sigmoid")
    return torch.mean(-y * torch.log(z) - (1-y) * torch.log(1-z))
  elif loss_type == "hinge":
    return torch.mean(torch.relu(1 - y * F))

def loss_derivative(X, W, a, y, gamma, act = "relu", loss_type = "MSE"):
  F = gamma * a.T @ activation(W @ X.T, act = act)
  if loss_type == "MSE":
    # print("Here")
    # print(F.shape, y.shape)
    return (F - y)
  elif loss_type == "bce":
    z = activation(F, act = "sigmoid")
    return z - y
  elif loss_type == "hinge":
    return torch.where(y * F < 1, -y, 0)

def get_mu(zeta, q, W, alpha, gamma, act = "sigmoid", loss_type = "MSE", n = 10000):
  d = q.shape[0]
  X = gen_data(n, d, zeta, q, alpha)
  Z = activation_derivative(W @ X.T, act = act)
  return Z.mean(dim = 1)

def activation_derivative_perp(X, W, mu, act = "relu"):
  return activation_derivative(X @ W.T, act = act) - mu.view(1,-1)

def activation_derivative_S2(X, W, mu, act = "relu"):
  return activation_derivative(X @ W.T, act = act)

def get_S1_exp(zeta, q, f, a, mu, alpha, gamma,  act = "relu", loss_type = "MSE", n = 10000):
  d = q.shape[0]
  X = gen_data(n, d, zeta, q, alpha)
  y = get_y(X, f)
  r = loss_derivative(X, W, a, y, gamma, act = act, loss_type = loss_type).view(-1,1)
  return (X.T @ r @ (a * mu).T) / n

def get_S1(X, r, a, mu):
  # print(mu.shape, a.shape)
  return X.T @ r @ (a * mu).T / X.shape[0]

def get_S1_perp(X, W, mu, a, r, act = "relu", loss_type = "MSE"):
  sigma_prime_perp = activation_derivative_perp(X, W, mu, act = act)
  return X.T @ ((r @ a.T) * sigma_prime_perp) / X.shape[0]

def get_G(X, W, y, r, a, act = "relu", loss_type = "MSE"):
  sigma_prime = activation_derivative(X @ W.T, act = act)

  G = X.T @ ((r @ a.T) * sigma_prime)
  return G / X.shape[0]

def get_S2(X, X_S, mu, r, W, a, y, alpha, act = "relu", loss_type = "MSE"):
  d = q.shape[0]
  sigma_prime_perp = activation_derivative_S2(X, W, mu, act = act)

  return X_S.T @ ((r @ a.T) * sigma_prime_perp) / X_S.shape[0]

def get_S2_small(X, X_S, mu, r, W, a, y, alpha, act = "relu", loss_type = "MSE"):
  d = q.shape[0]
  sigma_prime_perp = activation_derivative_perp(X, W, mu, act = act)

  return X_S.T @ ((r @ a.T) * sigma_prime_perp) / X_S.shape[0]

def get_E(X, X_B, W, mu, a, r, y, act = "relu", loss_type = "MSE"):
  sigma_prime_perp = activation_derivative_perp(X, W, mu, act = act)

  return X_B.T @ ((r @ a.T) * sigma_prime_perp) / X_B.shape[0]

def operator_norm(A):
  return largest_singular_value(A)

def largest_singular_value(A, num_iterations=200):
    """
    Approximates the largest singular value of a matrix using power iteration.

    Args:
        A: The input matrix (torch.Tensor).
        num_iterations: The number of iterations to perform.

    Returns:
        The estimated largest singular value.
    """

    v = torch.randn(A.shape[1], device=A.device)  # Random initialization
    v = v / torch.norm(v)

    for _ in range(num_iterations):
        u = A @ v
        v = A.T @ u
        v = v / torch.norm(v)

    return torch.norm(u)  # Largest singular value approximation


In [None]:
# --- Define the Network Function f(X) ---
def network_forward(X_input, W_layer, a_vec, activation_fn, gamma):
    """
    Computes the forward pass of the network y = a^T * sigma(W @ X^T).

    Args:
        X_input (torch.Tensor): Input data shape (N, d).
        W_layer (torch.Tensor): Inner weight matrix shape (k, d).
        a_vec (torch.Tensor): Outer fixed weight vector shape (k, 1).
        activation_fn (callable): Activation function (e.g., torch.relu).

    Returns:
        torch.Tensor: Network output prediction shape (N, 1).
    """
    # Z = W @ X.T  -> shape (k, N)
    Z = W_layer @ X_input.T
    # H = sigma(Z) -> shape (k, N)
    H = activation_fn(Z)
    # y_pred_transposed = a.T @ H -> shape (1, N)
    y_pred_transposed = gamma * a_vec.T @ H
    # y_pred -> shape (N, 1)
    y_pred = y_pred_transposed.T
    return gamma*y_pred

In [None]:
def train_track(X,y,Xtst,ytst,W,a,act,gamma,lr,epochs):
  losses = []
  operator_norms = []
  spectrum = []
  leading_right_svs = [] # List to store the vectors themselves
  Ws = []
  test_losses = []
  energy = []

  print(f"Starting training for {epochs} epochs with lr={lr}")
  for epoch in range(epochs):

      # --- Forward Pass using the defined function f(X) ---
      y_pred = network_forward(X, W, a, act, gamma = gamma)
      loss = torch.mean((y_pred - y)**2)

      if W.grad is not None:
          W.grad.zero_()

      loss.backward() # Computes dL/dW

      with torch.no_grad(): # Context manager to disable gradient tracking for the update
          # Gradient Descent step
          W -= lr/gamma * W.grad

          # Renormalize rows of W to have unit norm
          W_norms = torch.norm(W, p=2, dim=1, keepdim=True)
          W /= (W_norms + 1e-8)

          U, S, Vh = torch.linalg.svd(W.grad, full_matrices=False)

          # Operator norm is the largest singular value
          op_norm = operator_norm(W) #S[0].item() # S is sorted in descending order
          operator_norms.append(op_norm.detach().cpu())

          spectrum.append(S.clone().detach().cpu())

          Ws.append(W.clone().detach().cpu())

          # Leading right singular vector is the first row of Vh
          lead_right_sv = Vh[0, :].clone() # Store a copy!
          leading_right_svs.append(lead_right_sv.detach().cpu())

          test_loss = torch.mean((network_forward(Xtst, W, a, act, gamma = gamma) - ytst)**2)
          test_losses.append(test_loss.detach().cpu().item())


      # Store loss for plotting/analysis
      losses.append(loss.detach().cpu().item())
      energy.append(relative_projection_norm(Ws[-1].to('cuda'),X_B,y,q))

      # --- Clear Gradients ---
      W.grad.zero_()

      # Print progress
      if (epoch + 1) % 50 == 0:
          print(f"Epoch [{epoch+1}/{epochs}], Loss: {loss.item():.6f}, Test Loss: {test_loss.item():.6f}, Operator Norm: {operator_norms[-1]:.4f}")

  print("Training finished.")

  a = torch.linalg.pinv(act(X @ W.T)) @ y

  y_test = network_forward(Xtst, W, a, act, gamma = gamma)
  test_loss = torch.mean((y_test - ytst)**2)
  print(f"Final Test Loss: {test_loss.item():.6f}")

  return losses, operator_norms, spectrum, leading_right_svs, Ws, test_losses, energy

In [None]:
def plot_train_test_loss(losses, test_losses, label):
  fig, axs = plt.subplots(1, 2, figsize=(12, 5))

  # Plot Loss
  axs[0].plot(range(epochs), losses)
  axs[0].set_xlabel("Epoch", fontsize = 16)
  axs[0].set_ylabel("Train MSE Loss", fontsize = 16)
  axs[0].set_title(label, fontsize = 16)
  axs[0].grid(True)
  # axs[0].set_yscale('log') # Log scale often helpful for loss

  # Plot Operator Norm
  axs[1].plot(range(epochs), test_losses)
  axs[1].set_xlabel("Epoch", fontsize = 16)
  axs[1].set_ylabel("Test MSE Loss", fontsize = 16)
  axs[1].set_title(label, fontsize = 16)
  axs[1].grid(True)

  plt.tight_layout()
  plt.show()

  plt.savefig(path+f"Loss-{label}-gamma-{gamma}-lr-{lr}.png", dpi = 100)


  plt.close()

def plot_norms(spectrum, operator_norms, label):
  fig, axs = plt.subplots(1, 2, figsize=(12, 5))

  S = torch.zeros(len(spectrum))
  for i in range(len(spectrum)):
    S[i] = spectrum[i].max().item()

  # Plot Loss
  axs[0].plot(range(epochs), S.cpu())
  axs[0].set_xlabel("Epoch", fontsize = 16)
  axs[0].set_ylabel("Gradient G Operator Norm", fontsize = 16)
  axs[0].set_title(label, fontsize = 16)
  axs[0].grid(True)
  # axs[0].set_yscale('log') # Log scale often helpful for loss

  # Plot Operator Norm
  axs[1].plot(range(epochs), operator_norms)
  axs[1].set_xlabel("Epoch", fontsize = 16)
  axs[1].set_ylabel("Weight W operator norm", fontsize = 16)
  axs[1].set_title(label, fontsize = 16)
  axs[1].grid(True)

  plt.tight_layout()
  plt.show()

  plt.savefig(path+f"Norms-{label}-gamma-{gamma}-lr-{lr}.png", dpi = 100)


  plt.close()

In [None]:
def alignment_plot(leading_right_svs, label, epochs, q = None, Ws = None, a = None, act = None, gamma = None):
  # You can further analyze the evolution of the leading_right_svs list
  # For example, calculate the cosine similarity between consecutive vectors
  similarities = []
  for i in range(len(leading_right_svs)):
      # Ensure vectors are on CPU for numpy if needed, or use torch dot product
      vec1 = leading_right_svs[i]
      vec1 /= vec1.norm()
      if q is not None:
        vec2 = q.flatten()
      else:
        vec2 = (X_B.T.cpu() @ (y.cpu() - network_forward(X.cpu(), Ws[i].cpu(), a.cpu(), act, gamma = gamma)) ).flatten() #
      vec2 /= vec2.norm() #leading_right_svs[i+1]
      # Cosine similarity = dot(v1, v2) / (norm(v1) * norm(v2))
      # Since norms should be 1, it's just the dot product
      sim = torch.dot(vec1.cpu(), vec2.cpu()).abs().item()
      similarities.append(sim)

  plt.figure(figsize=(10, 6))
  plt.plot(range(epochs), similarities)
  plt.xlabel("Epoch", fontsize = 16)
  plt.ylabel(r"Dot Product", fontsize = 16)
  plt.title(label, fontsize = 16)
  plt.grid(True)
  plt.xscale('log')
  # plt.show()

  plt.savefig(path+f"Alignment-{label}-gamma-{gamma}-lr-{lr}.png", dpi = 100)

  plt.close()

def angles_plot(Ws1, Ws2, label, gamma = None):
  angles_deg = torch.zeros(epochs)

  for i in tqdm(range(epochs)):
      angles_rad = principal_angles(Ws1[i], Ws2[i])
      angles_deg[i] = torch.rad2deg(angles_rad).mean()
  plt.figure(figsize=(10, 6))
  plt.scatter(range(epochs), angles_deg.cpu())
  plt.xlabel("Epoch", fontsize = 16)
  plt.ylabel(r"Mean Prinicipal Angle", fontsize = 16)
  plt.title(label, fontsize = 16)
  plt.grid(True)
  # plt.xscale('log')
  # plt.show()

  plt.savefig(path+f"Angle-{label}-gamma-{gamma}-lr-{lr}.png", dpi = 100)


  plt.close()

def energy_plot(energy, label, gamma = None):
  plt.figure(figsize=(10, 6))
  plt.scatter(range(epochs), energy)
  plt.xlabel("Epoch", fontsize = 16)
  plt.ylabel(r"Proportion in Rank 2 Subspace", fontsize = 16)
  plt.title(label, fontsize = 16)
  plt.grid(True)
  # plt.xscale('log')
  # plt.show()

  plt.savefig(path+f"Energy-{label}-gamma-{gamma}-lr-{lr}.png", dpi = 100)


  plt.close()

# ReLU vs Sigmoid Experiment

In [None]:
def run_exp_and_plot(lr,gamma,W,a):

  W_relu = copy.deepcopy(W)
  a_relu = copy.deepcopy(a)

  W_sigmoid = copy.deepcopy(W)
  a_sigmoid = copy.deepcopy(a)
  losses_relu, operator_norms_relu, spectrum_relu, leading_right_svs_relu, Ws_relu, test_losses_relu, energy_relu = train_track(X,y,Xtst,ytst,W_relu,a_relu,relu_act,gamma,lr,epochs)
  losses_sigmoid, operator_norms_sigmoid, spectrum_sigmoid, leading_right_svs_sigmoid, Ws_sigmoid, test_losses_sigmoid, energy_sigmoid = train_track(X,y,Xtst,ytst,W_sigmoid,a_sigmoid,sigmoid_act,gamma,lr,epochs)
  plot_train_test_loss(losses_relu, test_losses_relu, "ReLU")
  plot_norms(spectrum_relu, operator_norms_relu, "ReLU")
  plot_train_test_loss(losses_sigmoid, test_losses_sigmoid, "Sigmoid")
  plot_norms(spectrum_sigmoid, operator_norms_sigmoid, "Sigmoid")
  alignment_plot(leading_right_svs_relu, "ReLU Residue Alignment", epochs, Ws = Ws_relu, a = a_relu, act = relu_act, gamma = gamma)
  alignment_plot(leading_right_svs_relu, "ReLU Data Alignment", epochs, q = q)
  alignment_plot(leading_right_svs_sigmoid, "Sigmoid Residue Alignment", epochs, Ws = Ws_sigmoid, a = a_sigmoid, act = sigmoid_act, gamma = gamma)
  alignment_plot(leading_right_svs_sigmoid, "Sigmoid Data Alignment", epochs, q = q)

  energy_plot(energy_relu, "ReLU")
  energy_plot(energy_sigmoid, "Sigmoid")

  # angles_plot(Ws_relu, Ws_sigmoid, "ReLU vs Sigmoid")

In [None]:
def run_exp_and_plot_regression(lr,gamma,W,a):

  W_relu = copy.deepcopy(W)
  a_relu = copy.deepcopy(a)

  W_sigmoid = copy.deepcopy(W)
  a_sigmoid = copy.deepcopy(a)

  if gamma > 1/2:
    print("NTK Regime \n")
  else:
    print("MF Regime \n")


  print("ReLU")
  losses_relu, operator_norms_relu, spectrum_relu, leading_right_svs_relu, Ws_relu, test_losses_relu = train_track(X,y,Xtst,ytst,W_relu,a_relu,relu_act,gamma,lr,epochs)

  r = (X_B.T.cpu() @ (y.cpu()- network_forward(X.cpu(), Ws_relu[-1].cpu(), a.cpu(), relu_act, gamma = gamma) ) )

  print("Residue Alignment", leading_right_svs_relu[-1].view(1,-1) @ (r) / (r.norm()))
  print("Data Alignment", leading_right_svs_relu[-1].view(1,-1) @ (q), "\n")

  print("Sigmoid")

  losses_sigmoid, operator_norms_sigmoid, spectrum_sigmoid, leading_right_svs_sigmoid, Ws_sigmoid, test_losses_sigmoid = train_track(X,y,Xtst,ytst,W_sigmoid,a_sigmoid,sigmoid_act,gamma,lr,epochs)

  r = (X_B.T.cpu() @ (y.cpu()- network_forward(X.cpu(), Ws_sigmoid[-1].cpu(), a.cpu(), sigmoid_act, gamma = gamma) ) )

  print("Residue Alignment", leading_right_svs_sigmoid[-1].view(1,-1) @ (r) / (r.norm()))
  print("Data Alignment", leading_right_svs_sigmoid[-1].view(1,-1) @ (q), "\n")

  # plot_train_test_loss(losses_relu, test_losses_relu, "ReLU")
  # plot_norms(spectrum_relu, operator_norms_relu, "ReLU")
  # plot_train_test_loss(losses_sigmoid, test_losses_sigmoid, "Sigmoid")
  # plot_norms(spectrum_sigmoid, operator_norms_sigmoid, "Sigmoid")
  # alignment_plot(leading_right_svs_relu, "ReLU Residue Alignment", epochs, Ws = Ws_relu, a = a_relu, act = relu_act, gamma = gamma)
  # alignment_plot(leading_right_svs_relu, "ReLU Data Alignment", epochs, q = q)
  # alignment_plot(leading_right_svs_sigmoid, "Sigmoid Residue Alignment", epochs, Ws = Ws_sigmoid, a = a_sigmoid, act = sigmoid_act, gamma = gamma)
  # alignment_plot(leading_right_svs_sigmoid, "Sigmoid Data Alignment", epochs, q = q)
  # angles_plot(Ws_relu, Ws_sigmoid, "ReLU vs Sigmoid")

In [None]:
nu = 1/8
alpha = 0

# --- Configuration ---
seed = 42          # For reproducibility
epochs = 100       # Number of training iterations
d = 1000            # Input dimension
k = 1250           # Number of hidden neurons (rows of W)
N = 750            # Number of data points (full batch)

q = torch.randn(d,1)
q = q/torch.norm(q)

In [None]:
# --- Model Initialization ---
print(f"Initializing model: k={k}, d={d}")
# Initialize W with rows uniformly on the unit sphere
W = torch.randn(k, d, device=device)
W = W/ torch.norm(W, p=2, dim=1, keepdim=True)
W.requires_grad_(True) # We want to compute gradients for W

# Initialize fixed vector 'a' with Gaussian entries (scaled)
# Variance 1/k ensures E[||a||^2] = 1
a = (torch.randint(0, 2, (k,1), device = device).float() - 1)/np.sqrt(k)
a.requires_grad_(False) # 'a' is fixed
print("Model initialized.")

Initializing model: k=1250, d=1000
Model initialized.


In [None]:
beta = torch.randn(d,1)
beta = beta/torch.norm(beta)

def f(x):
  return activation(x.T @ beta, act = "sigmoid")

In [None]:
X, z, X_B, X_S = gen_data(N, d, N**nu, q, alpha, return_all=True)
y = get_y(X, f).reshape(-1,1)

In [None]:
X = X.to('cuda')
y = y.to('cuda')

In [None]:
Xtst = gen_data(N, d, N**nu, q, alpha)
ytst = get_y(Xtst, f)

Xtst = Xtst.to('cuda')
ytst = ytst.to('cuda')

In [None]:
lr = 2*np.sqrt(k)             # Fixed learning rate
gamma = 1/np.sqrt(k)
path = ""
run_exp_and_plot(lr,gamma,copy.deepcopy(W),copy.deepcopy(a))

In [None]:
lr = 2*np.sqrt(k)             # Fixed learning rate
gamma = 1
path = "drive/MyDrive/Spikes/ReLU VS Sigmoid/"
run_exp_and_plot(lr,gamma,copy.deepcopy(W),copy.deepcopy(a))

In [None]:
lr = 2         # Fixed learning rate
gamma = 1/np.sqrt(k)
path = "drive/MyDrive/Spikes/ReLU VS Sigmoid/Exp9/"
run_exp_and_plot(lr,gamma,copy.deepcopy(W),copy.deepcopy(a))

In [None]:
lr = 2 * np.sqrt(k)            # Fixed learning rate
epochs = 1000
gamma = 1/k
path = "drive/MyDrive/Spikes/ReLU VS Sigmoid/Exp11/"
run_exp_and_plot(lr,gamma,copy.deepcopy(W),copy.deepcopy(a))

In [None]:
beta = torch.randn(d,1)
beta = beta/torch.norm(beta)

def f(x):
  return activation(x.T @ beta, act = "relu")

In [None]:
X, z, X_B, X_S = gen_data(N, d, N**nu, q, alpha, return_all=True)
y = get_y(X, f).reshape(-1,1)

In [None]:
X = X.to('cuda')
y = y.to('cuda')

In [None]:
Xtst = gen_data(N, d, N**nu, q, alpha)
ytst = get_y(Xtst, f)

Xtst = Xtst.to('cuda')
ytst = ytst.to('cuda')

In [None]:
lr = 2             # Fixed learning rate
gamma = 1/np.sqrt(k)
path = "drive/MyDrive/Spikes/ReLU VS Sigmoid/Exp2/"
run_exp_and_plot(lr,gamma,copy.deepcopy(W),copy.deepcopy(a))

In [None]:
lr = 2             # Fixed learning rate
gamma = 1
path = "drive/MyDrive/Spikes/ReLU VS Sigmoid/Exp4/"
run_exp_and_plot(lr,gamma,copy.deepcopy(W),copy.deepcopy(a))

In [None]:
lr = 2*np.sqrt(k)             # Fixed learning rate
gamma = 1/np.sqrt(k)
path = "drive/MyDrive/Spikes/ReLU VS Sigmoid/Exp10/"
run_exp_and_plot(lr,gamma,copy.deepcopy(W),copy.deepcopy(a))

In [None]:
lr = 2*np.sqrt(k)             # Fixed learning rate
gamma = 1
path = "drive/MyDrive/Spikes/ReLU VS Sigmoid/Exp12/"
run_exp_and_plot(lr,gamma,copy.deepcopy(W),copy.deepcopy(a))

# NTK vs Meanfield

In [None]:
def alignment_plot_initial(X, y, a, leading_right_svs, sigma, label, epochs,gamma = None):
  # You can further analyze the evolution of the leading_right_svs list
  # For example, calculate the cosine similarity between consecutive vectors
  similarities = []
  X = X.cpu()
  a = a.cpu()
  y = y.cpu()
  for i in range(len(leading_right_svs)):
      # Ensure vectors are on CPU for numpy if needed, or use torch dot product
      vec1 = leading_right_svs[0]
      vec1 /= vec1.norm() #leading_right_svs[i]
      vec2 = leading_right_svs[i]
      vec2 /= vec2.norm()

      # Cosine similarity = dot(v1, v2) / (norm(v1) * norm(v2))
      # Since norms should be 1, it's just the dot product
      sim = torch.dot(vec1.cpu(), vec2.cpu()).abs().item()
      similarities.append(sim)

  plt.figure(figsize=(10, 6))
  plt.scatter(range(epochs), similarities)
  plt.xlabel("Epoch", fontsize = 16)
  plt.ylabel(r"Dot Product", fontsize = 16)
  # plt.title(label, fontsize = 16)
  plt.grid(True)
  # plt.xscale('log')
  # plt.show()

  plt.savefig(path+f"Alignment-initial-{label}-gamma-{gamma}-lr-{lr}.png", dpi = 100)

  plt.close()

def alignment_plot_subsequent(X, y, a, Ws, sigma, label, epochs,gamma = None):
  # You can further analyze the evolution of the leading_right_svs list
  # For example, calculate the cosine similarity between consecutive vectors
  similarities = []
  X = X.cpu()
  a = a.cpu()
  y = y.cpu()
  for i in range(len(Ws)-1):
      # Ensure vectors are on CPU for numpy if needed, or use torch dot product
      vec1 = (X.T @ (y - network_forward(X, Ws[i+1], a, sigma, gamma = 1)) ).flatten()
      vec1 /= vec1.norm() #leading_right_svs[i]
      vec2 = (X.T @ (y - network_forward(X, Ws[i], a, sigma, gamma = 1)) ).flatten() #
      vec2 /= vec2.norm()

      # Cosine similarity = dot(v1, v2) / (norm(v1) * norm(v2))
      # Since norms should be 1, it's just the dot product
      sim = torch.dot(vec1.cpu(), vec2.cpu()).abs().item()
      similarities.append(sim)

  plt.figure(figsize=(10, 6))
  plt.scatter(range(epochs-1), similarities)
  plt.xlabel("Epoch", fontsize = 16)
  plt.ylabel(r"Dot Product", fontsize = 16)
  plt.title(label, fontsize = 16)
  plt.grid(True)
  plt.xscale('log')
  # plt.show()

  plt.savefig(path+f"Alignment-subsequent-{label}-gamma-{gamma}-lr-{lr}.png", dpi = 100)

  plt.close()

In [None]:
def run_exp_and_plot_ntk(lr,gamma,W,a):

  W_NTK = copy.deepcopy(W)
  a_NTK = copy.deepcopy(a)

  W_MF = copy.deepcopy(W)
  a_MF = copy.deepcopy(a)

  losses_NTK, operator_norms_NTK, spectrum_NTK, leading_right_svs_NTK, Ws_NTK, test_losses_NTK, energy_NTK = train_track(X,y,Xtst,ytst,W_NTK,a_NTK,sigmoid_act,1,lr,epochs)
  losses_MF, operator_norms_MF, spectrum_MF, leading_right_svs_MF, Ws_MF, test_losses_MF, energy_MF = train_track(X,y,Xtst,ytst,W_MF,a_MF,sigmoid_act,1/np.sqrt(k),lr,epochs)
  # plot_train_test_loss(losses_NTK, test_losses_NTK, "NTK")
  # plot_norms(spectrum_NTK, operator_norms_NTK, "NTK")
  # plot_train_test_loss(losses_MF, test_losses_MF, "MF")
  # plot_norms(spectrum_MF, operator_norms_MF, "MF")
  # alignment_plot(leading_right_svs_NTK, "NTK Data Alignment", epochs, q = q, gamma = 1)
  # alignment_plot(leading_right_svs_NTK, "NTK Residue Alignment", epochs, Ws = Ws_NTK, a = a_NTK, act = sigmoid_act, gamma = 1)
  # alignment_plot(leading_right_svs_MF, "MF Data Alignment", epochs, q = q)
  # alignment_plot(leading_right_svs_MF, "MF Residue Alignment", epochs, Ws = Ws_MF, a = a_MF, act = sigmoid_act, gamma = 1)
  # alignment_plot_subsequent(X, y, a, Ws_NTK, sigmoid_act, "NTK Subsequent Residue Alignment", epochs, gamma = "NTK")
  # alignment_plot_subsequent(X, y, a, Ws_MF, sigmoid_act, "MF Subsequent Residue Alignment", epochs, gamma = "MF")
  alignment_plot_initial(X, y, a, leading_right_svs_NTK, sigmoid_act, "NTK Initial Residue Alignment", epochs, gamma = "NTK")
  alignment_plot_initial(X, y, a, leading_right_svs_MF, sigmoid_act, "MF Initial Residue Alignment", epochs, gamma = "MF")
  angles_plot(Ws_NTK, Ws_MF, "NTK vs MF", gamma = "Both")
  energy_plot(energy_NTK, "NTK")
  energy_plot(energy_MF, "MF")

In [None]:
nu = 0
alpha = 0

# --- Configuration ---
seed = 42          # For reproducibility
epochs = 50       # Number of training iterations
d = 1000            # Input dimension
k = 1250           # Number of hidden neurons (rows of W)
N = 750            # Number of data points (full batch)

q = torch.randn(d,1)
q = q/torch.norm(q)

In [None]:
# --- Model Initialization ---
print(f"Initializing model: k={k}, d={d}")
# Initialize W with rows uniformly on the unit sphere
W = torch.randn(k, d, device=device)
W = W/ torch.norm(W, p=2, dim=1, keepdim=True)
W.requires_grad_(True) # We want to compute gradients for W

# Initialize fixed vector 'a' with Gaussian entries (scaled)
# Variance 1/k ensures E[||a||^2] = 1
a = torch.randn(k, 1, device=device) / math.sqrt(k)
a.requires_grad_(False) # 'a' is fixed
print("Model initialized.")

Initializing model: k=1250, d=1000
Model initialized.


In [None]:
beta = torch.randn(d,1)
beta = beta/torch.norm(beta)

def f(x):
  return activation(x.T @ beta, act = "sigmoid")

In [None]:
X, z, X_B, X_S = gen_data(N, d, N**nu, q, alpha, return_all=True)
y = get_y(X, f).reshape(-1,1)

In [None]:
X = X.to('cuda')
y = y.to('cuda')

In [None]:
Xtst = gen_data(N, d, N**nu, q, alpha)
ytst = get_y(Xtst, f)

Xtst = Xtst.to('cuda')
ytst = ytst.to('cuda')

In [None]:
lr = np.sqrt(k)/2          # Fixed learning rate
path = ""
run_exp_and_plot_ntk(lr,None,W,a)

Starting training for 50 epochs with lr=17.67766952966369
Epoch [50/50], Loss: 14.358500, Test Loss: 10.037251, Operator Norm: 14.3799
Training finished.
Final Test Loss: 3.362156
Starting training for 50 epochs with lr=17.67766952966369
Epoch [50/50], Loss: 0.925410, Test Loss: 1.269230, Operator Norm: 33.7112
Training finished.
Final Test Loss: 1.222997


100%|██████████| 50/50 [00:25<00:00,  1.99it/s]
