# driftax — Conditional inverse problem on a ring (bimodal posterior)

We train a conditional generator **x ~ qθ(x | y)** where:
- **x** is a 2D point on a noisy ring
- **y** is a 1D measurement: the x-coordinate (with heteroscedastic noise)

Conditioning on y yields a **bimodal** posterior (upper/lower arc).

This notebook:
1) plots the training dataset (x, y)
2) trains DiT1D (sequence length 2) with drifting loss
3) visualizes conditional samples for a few y values


In [None]:
import os, sys

# Driftax notebook backend policy:
# - macOS: FORCE CPU (avoid Metal instability)
# - other OS: let JAX auto-select (GPU if available; else CPU)
#
# NOTE: must run BEFORE importing jax.
for k in ("JAX_PLATFORMS", "JAX_PLATFORM_NAME"):
    os.environ.pop(k, None)

if sys.platform == "darwin":
    os.environ["JAX_PLATFORMS"] = "cpu"


In [None]:
import numpy as np
import matplotlib.pyplot as plt

import jax
import jax.numpy as jnp
import optax

from driftax.datasets import inverse_ring_toy
from driftax.dit1d import DiT1D, DiT1DConfig
from driftax.conditioning import CondMLP
from driftax.drift import drifting_loss_features

print("JAX devices:", jax.devices())


In [None]:
# Plot dataset snapshot
key = jax.random.PRNGKey(0)
x, y = inverse_ring_toy(key, 20000)
x = np.array(x)
y = np.array(y).squeeze(-1)

plt.figure(figsize=(5,5))
plt.scatter(x[:,0], x[:,1], s=2, alpha=0.25)
plt.axis("equal")
plt.title("Training dataset: noisy ring (x)")
plt.tight_layout()
plt.show()

plt.figure(figsize=(6,3))
plt.scatter(y, x[:,1], s=2, alpha=0.15)
plt.xlabel("y (measured x0)")
plt.ylabel("x1 (vertical)")
plt.title("Bimodal conditional structure: x1 | y")
plt.tight_layout()
plt.show()


In [None]:
# Train DiT1D on ring conditional
cfg = DiT1DConfig(length=2, patch=1, dim=256, depth=6, heads=4, cond_dim=256, drop=0.0)
model = DiT1D(cfg)
cond_net = CondMLP(in_dim=1, out_dim=cfg.cond_dim, hidden=256)

batch = 512
steps = 2000
plot_every = 500
temps = (0.02, 0.05, 0.2)
lr = 2e-4

key = jax.random.PRNGKey(1)
key, k1, k2 = jax.random.split(key, 3)

dummy_noise = jnp.zeros((1, 2), dtype=jnp.float32)
dummy_y = jnp.zeros((1, 1), dtype=jnp.float32)

cond_params = cond_net.init(k1, dummy_y)
cond0 = cond_net.apply(cond_params, dummy_y)
model_params = model.init(k2, dummy_noise, cond0, train=True)

params = {"cond": cond_params, "model": model_params}
opt = optax.adamw(lr)
opt_state = opt.init(params)

def loss_fn(params, key):
    key, kdata, kz = jax.random.split(key, 3)
    x_true, y = inverse_ring_toy(kdata, batch)
    c = cond_net.apply(params["cond"], y)
    z = jax.random.normal(kz, (batch, 2), dtype=jnp.float32)
    x_gen = model.apply(params["model"], z, c, train=True)
    return drifting_loss_features(
        x_feat=x_gen,
        pos_feat=x_true,
        temps=temps,
        neg_feat=x_gen,
        feature_normalize=True,
        drift_normalize=True,
    )

@jax.jit
def step_fn(params, opt_state, key):
    loss, grads = jax.value_and_grad(lambda p: loss_fn(p, key))(params)
    updates, opt_state2 = opt.update(grads, opt_state, params)
    params2 = optax.apply_updates(params, updates)
    return params2, opt_state2, key, loss

def conditional_plot(params, y0_list=(-0.6, 0.0, 0.6), n_gen=6000, tol=0.03):
    # approximate true conditional by filtering a large sample set
    k = jax.random.PRNGKey(123)
    x_true, y_true = inverse_ring_toy(k, 80000)
    xt = np.array(x_true)
    yt = np.array(y_true).squeeze(-1)

    fig, axes = plt.subplots(1, len(y0_list), figsize=(4*len(y0_list), 4), sharex=False, sharey=False)
    if len(y0_list) == 1:
        axes = [axes]

    for ax, y0 in zip(axes, y0_list):
        mask = np.abs(yt - y0) < tol
        ax.scatter(xt[mask,0], xt[mask,1], s=2, alpha=0.15, label="true ~p(x|y)")

        y_in = jnp.full((n_gen, 1), y0, dtype=jnp.float32)
        c = cond_net.apply(params["cond"], y_in)
        kgen = jax.random.PRNGKey(int((y0 + 1.5) * 1000) % (2**31 - 1))
        z = jax.random.normal(kgen, (n_gen, 2), dtype=jnp.float32)
        xg = model.apply(params["model"], z, c, train=False)
        xg = np.array(xg)

        ax.scatter(xg[:,0], xg[:,1], s=2, alpha=0.15, label="gen ~q(x|y)")
        ax.set_title(f"y={y0:+.2f}")
        ax.set_aspect("equal", adjustable="box")
        ax.set_xlim(-1.4, 1.4)
        ax.set_ylim(-1.4, 1.4)

    axes[0].legend(loc="lower left", markerscale=3)
    fig.tight_layout()
    plt.show()

loss_hist = []
for s in range(1, steps + 1):
    params, opt_state, key, loss = step_fn(params, opt_state, key)
    loss_hist.append(float(loss))

    if s == 1 or (plot_every and s % plot_every == 0) or s == steps:
        print(f"step {s} loss {float(loss):.3e}")
        conditional_plot(params)

loss_hist = np.asarray(loss_hist, dtype=np.float32)
plt.figure(figsize=(6,3))
plt.plot(loss_hist, alpha=0.8)
plt.yscale("log")
plt.grid(True, alpha=0.3)
plt.xlabel("Step"); plt.ylabel("Loss")
plt.title("driftax ring conditional drifting loss")
fig.tight_layout()
plt.show()
