One reason we may not see feature learning is due to the self consistency assumption that $||w||^2 \sim O(1)$ due to the fact that there are no $O(d)$ corrections assumed. This approximation may be too justified, as it ignores the contributions of the weight splitting in layer one. I want to go back into the equations, use the correct numbers for the splitting, and we will expect to see a greater discrepancy.

In [2]:
import sympy
from sympy import E, sqrt, pi

import math



# 1. Define all symbols
# Variables to solve for
lambda_V, lambda_J, lambda_H, lambda_K = sympy.symbols('lambda_V lambda_J lambda_H lambda_K')

# Parameters in the equations
gammaYh2, gammagh, N, chi, kappa, P, d = sympy.symbols('gammaYh2 gammagh N chi kappa P d')


N = 50.0
chi =  1.0
kappa =  1.0
P = 50.0
d = 50.0

c_w =  1 / (d + 4/(3 * math.pi) * 1/N * lambda_V)

ECyh2 = 4/(3*math.pi) 
EChh = (2/math.pi) * math.asin(2/3) 
gmYh2 = 4/math.pi * 1/(1 + 2 * EChh)

gmgh = 2/math.sqrt(math.pi) * 1/math.sqrt(1 + 2 * EChh)

gammaYh2 = gmYh2
gammagh= gmgh

# Note: There is a lambda_star_K. I'll treat it as a parameter.
lambda_star_K = sympy.symbols('lambda_star_K')

# 2. Define the system of equations
# It's best to write them in residual form (i.e., expr = 0)
# Equation 1 for lambda_V
tau = - chi**2 / (kappa/P + lambda_K)**2
eq1 = lambda_V - (gammaYh2 * tau / (N * chi))  / (lambda_J + gammaYh2 * tau / (N * chi))

eq2 = lambda_J - 4 / (3 * math.pi) * c_w
# Equation 3 for lambda_H
# Be careful with parentheses in the denominator
denominator_H = (1 / lambda_J + gammaYh2 * tau / ( N * chi) )
eq3 = lambda_H -  1 / denominator_H

# Equation 4 for lambda_K
eq4 = lambda_K - (4 / (3 * math.pi)) * sqrt(gammaYh2) * gammagh * c_w 


# Create a list of the equations with numerical parameters substituted
# This makes them easier to work with for the numerical solver
system_of_eqs = [
    eq1,
    eq2,
    eq3,
    eq4,
]

# 4. Provide an initial guess for the variables
# The quality of the guess can matter. [1, 1, 1, 1] is a safe, non-zero start.
initial_guess = [1.0, 1.0, 1.0, 0.1]

# 5. Solve the system numerically using sympy.nsolve
# This is the key step. It finds the numerical solution for the given parameters.
try:
    solution = sympy.nsolve(
        system_of_eqs,
        (lambda_V, lambda_J, lambda_H, lambda_K),
        initial_guess
    )
    
    # The solution is a Sympy Matrix object, which can be unpacked
    solved_V, solved_J, solved_H, solved_K = solution.evalf()
    cw_final = c_w.subs({lambda_V: solved_V})
    kperp = ((4 / (3 * math.pi)) * sqrt(gammaYh2) * gammagh * 1/d)
    print("--- Sympy Numerical Solution (nsolve) ---")
    print(f"λ_V = {solved_V}")
    print(f"λ_J = {solved_J}")
    print(f"λ_H = {solved_H}")
    print(f"λ_K = {solved_K}")
    print(f"λ_K⊥= {kperp}")

except (ValueError, ZeroDivisionError) as e:
    print(f"Solver failed. This could be due to a bad initial guess or parameter values that lead to a singularity (e.g., division by zero).")
    print(f"Error: {e}")

Solver failed. This could be due to a bad initial guess or parameter values that lead to a singularity (e.g., division by zero).
Error: Could not find root within given tolerance. (143.528790817555960074 > 2.16840434497100886801e-19)
Try another starting point or tweak arguments.
