In [None]:
from sympy import Rational
import math

def v_p(x, p):
    """
    Calculates p-adic valuation of an integer x.
    """
    if x == 0:
        return float('inf')
    
    v = 0
    x = abs(x)
    while x % p == 0:
        x //= p
        v += 1
    return v

def get_padic_norm(x, p):
    """
    Calculates p-adic norm of a Fraction/Rational x using exact arithmetic.
    (Formerly _get_valuation in your snippet, but returns the Norm).
    Returns a SymPy Rational.
    """
    # Handle SymPy Rational or Python Fraction inputs
    if hasattr(x, 'p'): # SymPy Rational
        a, b = x.p, x.q
    elif hasattr(x, 'numerator'): # Python Fraction
        a, b = x.numerator, x.denominator
    else: # Integer or Float (cast to Rational first to be safe)
        # Note: If passing a float, this converts it to exact Rational
        x_rat = Rational(x)
        a, b = x_rat.p, x_rat.q

    if a == 0:
        return Rational(0)

    # val(a/b) = val(a) - val(b)
    # norm(x) = p ^ (-val(x)) = p ^ (val(b) - val(a))
    exponent = v_p(b, p) - v_p(a, p)

    # Return exact Rational
    return Rational(p) ** exponent



In [None]:
def p_expansion_positive(r, s, p, n, *, report=False):
    """
    Compute the p-adic expansion digits of x = r/s (r,s>0) up to n digits.
    Handles denominators divisible by p. If report=True, prints any valuation shift.
    """
    # factor out powers of p from r and s
    v_r = v_p(r, p)
    v_s = v_p(s, p)
    shift = v_r - v_s            # overall power of p (v_p(x))
    r_ = r // (p ** v_r)
    s_ = s // (p ** v_s)

    # s_ now coprime to p
    x = Fraction(r_, s_)
    digits = []
    s_partial = Fraction(0, 1)
    f_prev = x

    for i in range(n):
        r_i = f_prev.numerator
        s_i = f_prev.denominator

        g, u, _ = Extended_Euclidean_Recursive(s_i % p, p)
        if g != 1:
            raise ValueError("Denominator not invertible mod p.")
        u %= p

        # a_i = ((r_i mod p)*(s_i^{-1} mod p)) mod p
        a_i = ((r_i % p) * u) % p
        digits.append(a_i)

        # update next fractional term
        s_partial = s_partial + Fraction(a_i * (p ** i), 1)
        f_prev = (x - s_partial) / (p ** (i + 1))

    # apply valuation shift and REPORT it
    if shift > 0:
        if report:
            print(f"[shift] v_p(x) = {shift} ⇒ prepending {shift} zero digit(s).")
        digits = ([0] * shift + digits)[:n]
    elif shift < 0:
        if report:
            print(f"[shift] v_p(x) = {shift} ⇒ series starts at power p^{shift} (negative index).")
            
    return digits, shift

# ------------------------------------------------------------
# General p-adic expansion with correct negative handling
# ------------------------------------------------------------
def p_expansion(r, s, p, n):
    if r == 0:
        return [0]*n,0

    neg = (r < 0) ^ (s < 0)
    r_abs, s_abs = abs(r), abs(s)

    pos,shift = p_expansion_positive(r_abs, s_abs, p, n)
    if not neg:
        return pos,shift

    # find first nonzero digit index k
    k = None
    for i, d in enumerate(pos):
        if d % p != 0:
            k = i
            break

    neg_digits = [0]*n
    if k is None:
        # |x| ≡ 0 mod p^n ⇒ -x also ≡ 0 mod p^n
        return neg_digits,shift 

    for i in range(n):
        if i < k:
            neg_digits[i] = 0                         # *** FIX: was p-1
        elif i == k:
            neg_digits[i] = (p - pos[i]) % p
        else:

In [None]:
def vary_padic(x,precision,p,varying_extent,sign):
    r,s=x.numerator,x.denominator
    expansion,shift=p_expansion(r,s,p,precision+1)
    deviation=expansion[precision+1+shift]
    expansion_modified=expansion+sign*varying_extent*p**(precision+1)-(1-sign)*varying_extent*p**(precision+1)
    return expansion_modified

In [None]:
def compute_padic_gradients(loss_function, variables, p, varying_extent=2,precision=10):
    """
    Takes a loss function and a list of p-adic variables.
    Varies the variables in a 'p-adic sense' (by perturbing their Norms)
    and calculates the gradient via central differencing.
    
    Args:
        loss_function: A function that takes a list of NORMS (floats) and returns a loss (float).
        variables: List of actual p-adic numbers (ints, fractions, or objects).
        p: The prime base.
        epsilon: The small perturbation size for the central difference.
        
    Returns:
        gradients: A list of float gradients corresponding to each variable.
    """
    # 1. Convert actual variables to their p-adic norms
    # We work in the continuous "Norm Space" as requested.
    current_norms = [(get_padic_norm(var, p)) for var in variables]
    
    gradients = []
    
    # 2. Iterate through each variable to calculate its partial derivative
    for i in range(len(current_norms)):
        original_norm = current_norms[i]

        normp=vary_padic(x=original_norm,precision=precision,p=p,varying_extent=varying_extent,sign=0)
        current_norms[i] = get_padic_norm()
        loss_plus = loss_function(current_norms)
        
        # B. Vary Down (Decrease Norm slightly)
        # Effectively simulating |x - dx|_p
        normm=vary_padic(x=original_norm,precision=precision,p=p,varying_extent=varying_extent,sign=1)
        current_norms[i] = get_padic_norm(vary_padic(x=original_norm,precision=precision,p=p,varying_extent=varying_extent,sign=1))
        loss_minus = loss_function(current_norms)
        
        # C. Restore original norm for next iteration
        current_norms[i] = original_norm
        
        # D. Calculate Slope (Gradient)
        # dy/dx approx (y2 - y1) / (x2 - x1)
        grad = (loss_plus - loss_minus) / (norm)
        gradients.append(grad)
        
    return gradients