# Harmonic Field: Optimized Cognitive Kernel
Refined with JAX-native optimizations and class-based architecture.

In [None]:
import jax
import jax.numpy as jnp
from jax import jit, grad
import matplotlib.pyplot as plt
import numpy as np

class HarmonicAgent:
    def __init__(self, dim=2048, lr=0.05, threshold=0.08, steps=200, seed=42):
        self.dim, self.lr, self.threshold, self.steps = dim, lr, threshold, steps
        self.master_key = jax.random.PRNGKey(seed)

    @staticmethod
    @jit
    def bind(a, b): return jnp.fft.ifft(jnp.fft.fft(a) * jnp.fft.fft(b))

    @staticmethod
    @jit
    def unbind(s, r): return jnp.fft.ifft(jnp.fft.fft(s) * jnp.conj(jnp.fft.fft(r)))

    @staticmethod
    @jit
    def energy(z, t):
        z_n = z / (jnp.linalg.norm(z) + 1e-9)
        t_n = t / (jnp.linalg.norm(t) + 1e-9)
        return 1.0 - jnp.real(jnp.vdot(z_n, t_n))

    def create_concept(self, key=None):
        if key is None: self.master_key, key = jax.random.split(self.master_key)
        return jnp.exp(1j * jax.random.uniform(key, (self.dim,), minval=0, maxval=2*jnp.pi))

    def think(self, initial, target):
        thought = initial / jnp.linalg.norm(initial) * jnp.sqrt(self.dim)
        history = []
        g_fn = jit(grad(self.energy))
        for _ in range(self.steps):
            e = self.energy(thought, target)
            history.append(float(e))
            if e < self.threshold: return thought, history, True
            thought = thought - self.lr * g_fn(thought, target)
            thought = thought / jnp.linalg.norm(thought) * jnp.sqrt(self.dim)
        return thought, history, False


In [None]:
def run_demo(title, noise=0.0, hunch=0.0, from_scratch=False):
    agent = HarmonicAgent(threshold=0.01 if from_scratch else 0.08)
    k1, k2, k3 = jax.random.split(agent.master_key, 3)
    target = agent.bind(agent.create_concept(k1), agent.create_concept(k2))
    initial = agent.create_concept(k3) if from_scratch else (hunch * target + noise * agent.create_concept(k3))
    _, history, success = agent.think(initial, target)
    print(f"[{title}] {'SUCCESS' if success else 'FAILURE'} | Energy: {history[-1]:.4f}")
    plt.figure(figsize=(6, 3)); plt.plot(history); plt.axhline(y=agent.threshold, color='r', ls='--'); plt.title(title); plt.show()

run_demo("95% Noise Stress Test", noise=0.95, hunch=0.05)
run_demo("From Scratch", from_scratch=True)
run_demo("10% Hunch", noise=0.90, hunch=0.10)

In [None]:
!mkdir -p backend
%%writefile backend/harmonic_field.py
import jax
import jax.numpy as jnp
from jax import jit, grad

class HarmonicAgent:
    def __init__(self, dim=2048, lr=0.05, threshold=0.08, steps=200):
        self.dim, self.lr, self.threshold, self.steps = dim, lr, threshold, steps
    @staticmethod
    @jit
    def bind(a, b): return jnp.fft.ifft(jnp.fft.fft(a) * jnp.fft.fft(b))
    @staticmethod
    @jit
    def unbind(s, r): return jnp.fft.ifft(jnp.fft.fft(s) * jnp.conj(jnp.fft.fft(r)))
    @staticmethod
    @jit
    def energy(z, t):
        z_n = z / (jnp.linalg.norm(z) + 1e-9)
        t_n = t / (jnp.linalg.norm(t) + 1e-9)
        return 1.0 - jnp.real(jnp.vdot(z_n, t_n))
    def think(self, initial, target):
        thought = initial / jnp.linalg.norm(initial) * jnp.sqrt(self.dim)
        g_fn = jit(grad(self.energy))
        for _ in range(self.steps):
            if self.energy(thought, target) < self.threshold: return thought, "IGNITION"
            thought = thought - self.lr * g_fn(thought, target)
            thought = thought / jnp.linalg.norm(thought) * jnp.sqrt(self.dim)
        return thought, "FAILURE"