In [1]:
"""
SageMath verification code for BOTH G_a and H_a families

This code verifies the main theorems from the paper:
- G_a family (Li-Kaleyski Type 1 generalization)
- H_a family (Li-Kaleyski Type 2 generalization)

For each family, we check:
1. G_a/H_a is permutation <-> Q_a(T) has no roots
2. G_a/H_a is APN <-> Q_a(T) has no roots  
3. G_a/H_a is permutation <-> G_a/H_a is APN
"""

from sage.all import *
import time

# ==========================================
# HELPER FUNCTIONS
# ==========================================

def check_roots_Qa(F, q, a):
    """
    Check if Q_a(T) = T^{q^2+q+1} + aT + 1 has roots in F.
    This is the KEY polynomial from Proposition 3.3 of the paper.
    
    Returns: (has_roots, list_of_roots)
    """
    exponent = q^2 + q + 1
    roots = []
    for t in F:
        # Q_a(T) = T^{q^2+q+1} + aT + 1
        if t^exponent + a*t + 1 == 0:
            roots.append(t)
    return (len(roots) > 0, roots)


def check_roots_Ra(F, q, a):
    """
    Check if R_a(T) = T^{q^2+q+1} + (aT + 1)^{q+1} has roots in F.
    This is the polynomial for H_a from Theorem 6.1 of the paper.
    
    By Proposition 3.3, this is equivalent to Q_a having roots.
    
    Returns: (has_roots, list_of_roots)
    """
    exponent = q^2 + q + 1
    roots = []
    for t in F:
        # R_a(T) = T^{q^2+q+1} + (aT + 1)^{q+1}
        if t^exponent + (a*t + 1)^(q+1) == 0:
            roots.append(t)
    return (len(roots) > 0, roots)


# ==========================================
# G_a FAMILY (Li-Kaleyski Type 1)
# ==========================================

def G_a(x, y, z, a, q):
    """
    G_a(x,y,z) = (x^{q+1} + ax^qz + yz^q, x^qz + y^{q+1}, xy^q + ay^qz + z^{q+1})
    
    This is the generalized Li-Kaleyski Type 1 family from Section 3 of the paper.
    When a=1, this is exactly F_1 from Li-Kaleyski [LK24].
    """
    comp1 = x^(q+1) + a * x^q * z + y * z^q
    comp2 = x^q * z + y^(q+1)
    comp3 = x * y^q + a * y^q * z + z^(q+1)
    return (comp1, comp2, comp3)


def test_permutation_G(m, i, a):
    """
    Test if G_a is a permutation over F_{2^m}^3.
    
    Returns: (is_permutation, image_size, domain_size)
    """
    F = GF(2^m, 'w')
    q = 2^i
    
    image_set = set()
    
    for x in F:
        for y in F:
            for z in F:
                image = G_a(x, y, z, a, q)
                image_set.add(image)
    
    domain_size = F.cardinality()^3
    image_size = len(image_set)
    
    return (image_size == domain_size, image_size, domain_size)


def compute_differential_uniformity_G(m, i, a):
    """
    Compute the differential uniformity of G_a.
    
    Returns: (delta, is_APN) where delta is the differential uniformity
             and is_APN is True if delta == 2
    """
    F = GF(2^m, 'w')
    q = 2^i
    
    max_solutions = 0
    
    # Iterate over all nonzero directions
    for alpha in F:
        for beta in F:
            for gamma in F:
                if alpha == 0 and beta == 0 and gamma == 0:
                    continue
                
                # For each direction, iterate over all targets
                for u in F:
                    for v in F:
                        for w_target in F:
                            solution_count = 0
                            
                            # Count solutions to the differential equation
                            for x in F:
                                for y in F:
                                    for z in F:
                                        # Compute G_a(x,y,z)
                                        g = G_a(x, y, z, a, q)
                                        
                                        # Compute G_a(x+alpha, y+beta, z+gamma)
                                        g_shifted = G_a(x + alpha, y + beta, z + gamma, a, q)
                                        
                                        # Check if difference equals target
                                        if (g_shifted[0] + g[0] == u and 
                                            g_shifted[1] + g[1] == v and 
                                            g_shifted[2] + g[2] == w_target):
                                            solution_count += 1
                            
                            max_solutions = max(max_solutions, solution_count)
    
    return (max_solutions, max_solutions == 2)


# ==========================================
# H_a FAMILY (Li-Kaleyski Type 2)
# ==========================================

def H_a(x, y, z, a, q):
    """
    H_a(x,y,z) = (x^{q+1} + axy^q + yz^q, xy^q + z^{q+1}, x^qz + y^{q+1} + ay^qz)
    
    This is the generalized Li-Kaleyski Type 2 family from Section 6 of the paper.
    When a=1, this is exactly F_2 from Li-Kaleyski [LK24].
    """
    comp1 = x^(q+1) + a * x * y^q + y * z^q
    comp2 = x * y^q + z^(q+1)
    comp3 = x^q * z + y^(q+1) + a * y^q * z
    return (comp1, comp2, comp3)


def test_permutation_H(m, i, a):
    """
    Test if H_a is a permutation over F_{2^m}^3.
    
    Returns: (is_permutation, image_size, domain_size)
    """
    F = GF(2^m, 'w')
    q = 2^i
    
    image_set = set()
    
    for x in F:
        for y in F:
            for z in F:
                image = H_a(x, y, z, a, q)
                image_set.add(image)
    
    domain_size = F.cardinality()^3
    image_size = len(image_set)
    
    return (image_size == domain_size, image_size, domain_size)


def compute_differential_uniformity_H(m, i, a):
    """
    Compute the differential uniformity of H_a.
    
    Returns: (delta, is_APN) where delta is the differential uniformity
             and is_APN is True if delta == 2
    """
    F = GF(2^m, 'w')
    q = 2^i
    
    max_solutions = 0
    
    # Iterate over all nonzero directions
    for alpha in F:
        for beta in F:
            for gamma in F:
                if alpha == 0 and beta == 0 and gamma == 0:
                    continue
                
                # For each direction, iterate over all targets
                for u in F:
                    for v in F:
                        for w_target in F:
                            solution_count = 0
                            
                            # Count solutions to the differential equation
                            for x in F:
                                for y in F:
                                    for z in F:
                                        # Compute H_a(x,y,z)
                                        h = H_a(x, y, z, a, q)
                                        
                                        # Compute H_a(x+alpha, y+beta, z+gamma)
                                        h_shifted = H_a(x + alpha, y + beta, z + gamma, a, q)
                                        
                                        # Check if difference equals target
                                        if (h_shifted[0] + h[0] == u and 
                                            h_shifted[1] + h[1] == v and 
                                            h_shifted[2] + h[2] == w_target):
                                            solution_count += 1
                            
                            max_solutions = max(max_solutions, solution_count)
    
    return (max_solutions, max_solutions == 2)


# ==========================================
# COMPREHENSIVE VERIFICATION
# ==========================================

def verify_family(family_name, m, i, test_perm_func, test_APN_func, poly_check_func):
    """
    Verify a family (G_a or H_a) for given (m, i).
    
    Args:
        family_name: "G_a" or "H_a"
        m: field extension degree
        i: Frobenius parameter
        test_perm_func: function to test permutation property
        test_APN_func: function to test APN property
        poly_check_func: function to check polynomial roots
    """
    if gcd(i, m) != 1:
        print(f"Warning: gcd({i}, {m}) != 1")
        return None
    
    F = GF(2^m, 'w')
    q = 2^i
    
    print(f"\n{'='*80}")
    print(f"VERIFICATION FOR {family_name}: m={m}, i={i}, q={q}")
    print(f"Field: F_{{2^{m}}} with {F.cardinality()} elements")
    print(f"Domain: F_{{2^{m}}}^3 with {F.cardinality()^3} elements")
    if family_name == "G_a":
        print(f"Polynomial: Q_a(T) = T^{q^2+q+1} + aT + 1")
    else:
        print(f"Polynomial: R_a(T) = T^{q^2+q+1} + (aT + 1)^(q+1)")
    print(f"{'='*80}")
    
    results = {
        'permutations': [],
        'APN_functions': [],
        'APN_permutations': [],
        'correlation_matches': 0,
        'correlation_failures': []
    }
    
    total_tested = 0
    
    for a in F:
        if a == 0:
            continue
        
        total_tested += 1
        print(f"\n--- Testing a = {a} ({total_tested}/{F.cardinality()-1}) ---")
        
        # Step 1: Check polynomial condition
        has_roots, roots = poly_check_func(F, q, a)
        print(f"  Polynomial has roots: {has_roots}")
        if has_roots:
            print(f"    Roots: {roots}")
        
        # Step 2: Check permutation property
        start = time.time()
        is_perm, img_size, dom_size = test_perm_func(m, i, a)
        perm_time = time.time() - start
        print(f"  Is permutation: {is_perm} ({perm_time:.2f}s)")
        if not is_perm:
            print(f"    Image size: {img_size}/{dom_size}")
        
        # Step 3: Check APN property
        print(f"  Computing differential uniformity...")
        start = time.time()
        delta, is_APN = test_APN_func(m, i, a)
        APN_time = time.time() - start
        print(f"  Differential uniformity δ = {delta} ({APN_time:.2f}s)")
        print(f"  Is APN (δ=2): {is_APN}")
        
        # Step 4: Check correlations
        predicted_perm = not has_roots
        predicted_APN = not has_roots
        
        perm_match = (predicted_perm == is_perm)
        APN_match = (predicted_APN == is_APN)
        
        if perm_match and APN_match:
            results['correlation_matches'] += 1
            print(f"    CORRELATION VERIFIED")
        else:
            results['correlation_failures'].append({
                'a': a,
                'has_roots': has_roots,
                'predicted_perm': predicted_perm,
                'actual_perm': is_perm,
                'predicted_APN': predicted_APN,
                'actual_APN': is_APN,
                'delta': delta
            })
            print(f"    CORRELATION FAILED")
        
        # Record results
        if is_perm:
            results['permutations'].append({
                'a': a,
                'poly_has_roots': has_roots,
                'delta': delta,
                'is_APN': is_APN
            })
        
        if is_APN:
            results['APN_functions'].append({
                'a': a,
                'is_perm': is_perm,
                'poly_has_roots': has_roots
            })
        
        if is_perm and is_APN:
            results['APN_permutations'].append(a)
    
    # Print summary
    print(f"\n{'='*80}")
    print(f"RESULTS SUMMARY FOR {family_name}")
    print(f"{'='*80}")
    print(f"Total non-zero a values tested: {total_tested}")
    print(f"Permutations found: {len(results['permutations'])}")
    print(f"APN functions found: {len(results['APN_functions'])}")
    print(f"APN permutations found: {len(results['APN_permutations'])}")
    print(f"Correlation matches: {results['correlation_matches']}/{total_tested}")
    
    correlation_rate = 100.0 * results['correlation_matches'] / total_tested if total_tested > 0 else 0
    print(f"Correlation rate: {correlation_rate:.1f}%")
    
    if results['correlation_matches'] == total_tested:
        print(f"\n  THEOREMS VERIFIED FOR {family_name}:")
        print(f"  • {family_name} is permutation <-> Polynomial has no roots")
        print(f"  • {family_name} is APN <-> Polynomial has no roots")
        print(f"  • {family_name} is permutation <-> {family_name} is APN")
        
        # List APN permutations
        if len(results['APN_permutations']) > 0:
            print(f"\n  APN permutations for a ∈ {{")
            for a_val in results['APN_permutations']:
                print(f"    {a_val},")
            print(f"  }}")
    else:
        print(f"\n  CORRELATION FAILURES DETECTED FOR {family_name}:")
        for failure in results['correlation_failures']:
            print(f"  a={failure['a']}:")
            print(f"    Predicted perm={failure['predicted_perm']}, actual={failure['actual_perm']}")
            print(f"    Predicted APN={failure['predicted_APN']}, actual={failure['actual_APN']}")
            print(f"    δ={failure['delta']}")
    
    return results


def run_all_verifications():
    """
    Run verifications for both G_a and H_a families.
    """
    test_cases = [
        (3, 1),  # m=3, i=1, q=2
        (3, 2),  # m=3, i=2, q=4
    ]
    
    print("="*80)
    print("COMPREHENSIVE VERIFICATION OF BOTH G_a AND H_a FAMILIES")
    print("Paper: APN permutations in a general class of trivariate functions")
    print("="*80)
    
    all_results = {}
    
    for m, i in test_cases:
        print(f"\n\n{'#'*80}")
        print(f"# TEST CASE: m={m}, i={i}")
        print(f"{'#'*80}")
        
        # Verify G_a family
        try:
            print("\n" + "="*80)
            print(f"FAMILY 1: G_a (Li-Kaleyski Type 1 Generalization)")
            print("="*80)
            results_G = verify_family(
                "G_a", m, i,
                test_permutation_G,
                compute_differential_uniformity_G,
                check_roots_Qa
            )
            all_results[('G_a', m, i)] = results_G
        except KeyboardInterrupt:
            print(f"\nSkipping G_a for (m={m}, i={i}) - computation interrupted")
        except Exception as e:
            print(f"\nError for G_a (m={m}, i={i}): {e}")
            import traceback
            traceback.print_exc()
        
        # Verify H_a family
        try:
            print("\n" + "="*80)
            print(f"FAMILY 2: H_a (Li-Kaleyski Type 2 Generalization)")
            print("="*80)
            results_H = verify_family(
                "H_a", m, i,
                test_permutation_H,
                compute_differential_uniformity_H,
                check_roots_Ra
            )
            all_results[('H_a', m, i)] = results_H
        except KeyboardInterrupt:
            print(f"\nSkipping H_a for (m={m}, i={i}) - computation interrupted")
        except Exception as e:
            print(f"\nError for H_a (m={m}, i={i}): {e}")
            import traceback
            traceback.print_exc()
    
    # Final summary
    print("\n\n" + "="*80)
    print("FINAL SUMMARY ACROSS ALL FAMILIES AND TEST CASES")
    print("="*80)
    
    for (family, m, i), results in all_results.items():
        if results is None:
            continue
        total = results['correlation_matches'] + len(results['correlation_failures'])
        if total > 0:
            rate = 100.0 * results['correlation_matches'] / total
            status = "  VERIFIED" if rate == 100.0 else "  FAILED"
            print(f"{family} (m={m}, i={i}): {results['correlation_matches']}/{total} "
                  f"({rate:.1f}%) {status}")
            print(f"  APN permutations: {len(results['APN_permutations'])}")
    
    # Check consistency between families
    print(f"\n{'='*80}")
    print("CROSS-FAMILY COMPARISON")
    print(f"{'='*80}")
    for m, i in test_cases:
        key_G = ('G_a', m, i)
        key_H = ('H_a', m, i)
        if key_G in all_results and key_H in all_results:
            num_G = len(all_results[key_G]['APN_permutations'])
            num_H = len(all_results[key_H]['APN_permutations'])
            print(f"m={m}, i={i}:")
            print(f"  G_a: {num_G} APN permutations")
            print(f"  H_a: {num_H} APN permutations")
            if num_G == num_H:
                print(f"  → SAME COUNT (expected by Proposition 3.3)")
            else:
                print(f"  → DIFFERENT COUNT (unexpected!)")
    
    return all_results


# ==========================================
# EXECUTION
# ==========================================

if __name__ == "__main__":
    print("\nStarting comprehensive verification...")
    print("This will test BOTH G_a and H_a families for m=3 with i=1,2")
    print("Expected time: ~5-10 minutes for m=3\n")
    
    results = run_all_verifications()
    
    print("\n" + "="*80)
    print("VERIFICATION COMPLETE")
    print("="*80)


Starting comprehensive verification...
This will test BOTH G_a and H_a families for m=3 with i=1,2
Expected time: ~5-10 minutes for m=3

COMPREHENSIVE VERIFICATION OF BOTH G_a AND H_a FAMILIES
Paper: APN permutations in a general class of trivariate functions


################################################################################
# TEST CASE: m=3, i=1
################################################################################

FAMILY 1: G_a (Li-Kaleyski Type 1 Generalization)

VERIFICATION FOR G_a: m=3, i=1, q=2
Field: F_{2^3} with 8 elements
Domain: F_{2^3}^3 with 512 elements
Polynomial: Q_a(T) = T^7 + aT + 1

--- Testing a = w (1/7) ---
  Polynomial has roots: False
  Is permutation: True (0.00s)
  Computing differential uniformity...
  Differential uniformity δ = 2 (726.97s)
  Is APN (δ=2): True
    CORRELATION VERIFIED

--- Testing a = w^2 (2/7) ---
  Polynomial has roots: False
  Is permutation: True (0.00s)
  Computing differential uniformity...
  Differential u