In [None]:
_ = """
# GeodePoly AI Quickstart

This Colab-friendly notebook shows how to use the differentiable RootLayer and losses.

What you'll do:
- Install packages (automatically if missing)
- Run a minimal Torch example (CPU or GPU)
- Optionally try the JAX example
"""



In [None]:
# Install geodepoly with AI extras (in Colab)
try:
    import geodepoly  # noqa: F401
except Exception:
    import sys, subprocess
    subprocess.check_call([sys.executable, "-m", "pip", "install", "geodepoly[ai-torch]"])

import torch
from geodepoly.ai import root_solve_torch
from geodepoly.ai.losses import pole_placement_loss

# Minimal Torch demo
B, N = 4, 4
coeffs = torch.randn(B, N+1, dtype=torch.cdouble, requires_grad=True)
roots  = root_solve_torch(coeffs)
loss   = pole_placement_loss(roots, half_plane="left", margin=0.1)
loss.backward()
roots.shape, loss.item()


In [None]:
# Optional: JAX demo (skips if JAX unavailable)
try:
    import jax
    import jax.numpy as jnp
    from geodepoly.ai import root_solve_jax

    B, N = 2, 3
    key = jax.random.PRNGKey(0)
    k1, k2 = jax.random.split(key)
    coeffs = (jax.random.normal(k1, (B, N+1)) + 1j * jax.random.normal(k2, (B, N+1))).astype(jnp.complex128)
    roots = root_solve_jax(coeffs)
    roots.shape
except Exception as e:
    print("JAX not available:", e)
