In [46]:
def compare_with_exact(n_spins=3, alpha=2, h_field=2.0):
    """
    Compare NQS results with exact diagonalization.
    
    Args:
        n_spins: Number of spins (keep small for exact comparison)
        alpha: Hidden layer density
        h_field: Strength of transverse field
    """
    # Train NQS and get energy history
    print("Training Neural Quantum State...")
    nqs_energies = train_nqs(n_spins, alpha, h_field)
    
    # Calculate exact ground state energy
    print("\nCalculating exact ground state...")
    # Generate basis states
    basis = list(product([-1,1], repeat=n_spins))
    H = np.zeros((2**n_spins, 2**n_spins))
    
    # Construct Hamiltonian matrix
    for H_i in range(2**n_spins):
        for H_j in range(2**n_spins):
            H_sum = 0
            # Interaction terms
            for i in range(n_spins):
                if H_i == H_j:
                    if i == n_spins-1:
                        H_sum -= basis[H_j][i]*basis[H_j][0]
                    else:
                        H_sum -= basis[H_j][i]*basis[H_j][i+1]
            
            # Transverse field terms
            for i in range(n_spins):
                sj = list(basis[H_j])
                sj[i] *= -1
                if H_i == basis.index(tuple(sj)):
                    H_sum -= h_field
            
            H[H_i,H_j] = H_sum
    
    # Find minimum eigenvalue (ground state energy)
    exact_energy = np.min(np.linalg.eigvals(H))/n_spins
    print(f'Exact ground state energy per spin: {np.real(exact_energy):.6f}')
    
    # Plot comparison
    plt.figure(figsize=(10, 6))
    plt.plot(nqs_energies, 'b-', label='NQS Energy')
    plt.axhline(y=np.real(exact_energy), color='r', linestyle='--', 
                label='Exact Energy')
    plt.xlabel('Training Block')
    plt.ylabel('Energy per Spin')
    plt.title('Neural Quantum State Training vs Exact Ground State')
    plt.legend()
    plt.grid(True)
    plt.show()
    
    # Print final comparison
    final_nqs_energy = nqs_energies[-1]
    error = abs(final_nqs_energy - np.real(exact_energy))
    print(f'\nFinal NQS energy: {final_nqs_energy:.6f}')
    print(f'Absolute error: {error:.6f}')
    print(f'Relative error: {100*error/abs(np.real(exact_energy)):.2f}%')

# Import needed for exact solution