# Bayesian Logistic Regression

In binary logistic regression, we model the probability of a binary outcome $y \in \{0,1\}$ given an input feature vector $x \in \mathbb{R}^d$ as

$$
p(y = 1 \mid x, w) = \sigma(w^\top x),
$$

where $\sigma(z) = \frac{1}{1 + e^{-z}}$ is the sigmoid function, and $w$ denotes the weight vector.  
For a dataset $\mathcal{D} = \{(x_i, y_i)\}_{i=1}^N$, the likelihood of the data under the model is

$$
p(y_{1:N} \mid X, w)
= \prod_{i=1}^N \sigma(w^\top x_i)^{y_i} [1 - \sigma(w^\top x_i)]^{1 - y_i}.
$$

---

## Bayesian Formulation

In the Bayesian setting, the model parameters $w$ are treated as random variables with a prior distribution $p(w)$.  
Given observed data $(X, y)$, we infer a posterior distribution via Bayes’ theorem:

$$
p(w \mid X, y) = \frac{p(y \mid X, w) \, p(w)}{p(y \mid X)},
$$

where the denominator $p(y \mid X) = \int p(y \mid X, w) p(w) \, dw$ is the marginal likelihood (or evidence).  
This term is typically intractable for logistic regression due to the non-conjugacy between the sigmoid likelihood and the Gaussian prior.

---

## Prior Specification

A common choice for the prior is an isotropic Gaussian:

$$
p(w) = \mathcal{N}(w \mid 0, \alpha^{-1} I),
$$

where $\alpha$ denotes the prior precision (inverse variance).  
The corresponding log-prior is thus

$$
\log p(w) = -\frac{\alpha}{2} w^\top w + \text{const.}
$$

---

## Posterior Distribution

Combining the likelihood and prior gives the unnormalized log posterior:

$$
\log \overline{\pi}(w)
= \sum_{i=1}^N \Big[ y_i \log \sigma(w^\top x_i)
+ (1 - y_i) \log (1 - \sigma(w^\top x_i)) \Big]
- \frac{\alpha}{2} w^\top w.
$$

Thus, the posterior is given by

$$
p(w \mid X, y) \propto \exp\big( \log \overline{\pi}(w) \big).
$$

To obtain a normalized posterior, we would need to compute:

$$
p(w \mid X, y)
= \frac{\overline{\pi}(w)}{\int \overline{\pi}(w) \, dw}.
$$

The denominator, which integrates over all possible values of $w$, is called the marginal likelihood or evidence:

$$
p(y \mid X)
= \int p(y \mid X, w) \, p(w) \, dw.
$$

The problem arises because the likelihood term $p(y \mid X, w)$ involves the sigmoid function $\sigma(w^\top x)$, making the integrand non-Gaussian and analytically intractable.  
In other words, there is no closed-form solution for this integral since the sigmoid function does not combine neatly with the Gaussian prior.

---

## Approximate Inference

Since $p(y \mid X) = \int p(y \mid X, w) p(w)\,dw$ has no closed form, several approximate methods are commonly used:

- Laplace approximation:
  - Approximate the posterior by a Gaussian centered at the mode $w_{\text{MAP}}$:
  $$
  p(w \mid X, y) \approx \mathcal{N}(w_{\text{MAP}}, H^{-1}),
  $$
  where $H = -\nabla^2 \log p(w \mid X, y)$ is the Hessian evaluated at $w_{\text{MAP}}$.  
  This provides a simple, local approximation but may be inaccurate for highly non-Gaussian or multimodal posteriors.

- Markov Chain Monte Carlo (MCMC):
  - Sampling-based methods construct a Markov chain whose stationary distribution is $p(w \mid X, y)$.  
  By iteratively proposing and accepting parameter updates, these methods generate approximate samples $w^{(1)}, w^{(2)}, \ldots \sim p(w \mid X, y)$.  
  Popular variants include Metropolis–Hastings, Langevin dynamics, and Hamiltonian Monte Carlo, which differ in how proposals are generated and how gradient information is used to guide sampling.

---

## Predictive Distribution

After obtaining samples from the posterior, the predictive probability for a new data point $x_*$ is computed by marginalizing over the posterior:

$$
p(y_* = 1 \mid x_*, X, y)
= \int \sigma(w^\top x_*) \, p(w \mid X, y) \, dw.
$$

In practice, this integral is approximated using Monte Carlo samples from the posterior:

$$
p(y_* = 1 \mid x_*, X, y)
\approx \frac{1}{M} \sum_{m=1}^M \sigma\big( w^{(m)\top} x_* \big),
$$

where $\{w^{(m)}\}_{m=1}^M$ are samples drawn from $p(w \mid X, y)$.

---

## Connection to Sampling and RLFS

Instead of using Laplace Approximations / MCMC, we will attempt to use reinforcement learning for sampling to this problem.



# Direct Backpropagation


In [None]:
# ===================== Differentiable RLFS on Breast Cancer Dataset =====================
# Direct optimization (no actor-critic) of RLFS objective for Bayesian Logistic Regression.
# -----------------------------------------------------------------------------------------
import jax
import jax.numpy as jnp
from jax import random, lax, value_and_grad, jit
from flax import linen as nn
import optax
import numpy as np
from sklearn.datasets import load_breast_cancer
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from tqdm import trange

print("Devices:", jax.devices())

# ---------------------------------------------------------------------------------------------------
# Dataset: Breast Cancer Wisconsin Diagnostic
# ---------------------------------------------------------------------------------------------------
data = load_breast_cancer()
X = data.data
y = data.target.astype(np.float32)

# Standardize features
scaler = StandardScaler()
X = scaler.fit_transform(X)

# Train/Test split
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.25, random_state=0)

X_train = jnp.array(X_train)
X_test  = jnp.array(X_test)
y_train = jnp.array(y_train)
y_test  = jnp.array(y_test)

D_base = X_train.shape[1]  # 30
D = D_base + 1             # weights + bias
print(f"Dataset: {X_train.shape[0]} training samples, {X_test.shape[0]} test samples, dim={D_base}")

# ---------------------------------------------------------------------------------------------------
# BLR target (weights + bias concatenated)
# ---------------------------------------------------------------------------------------------------
def blr_log_unnormalized(params, X, y, alpha=1.0):
    w, b = params[:-1], params[-1]
    logits = X @ w + b
    log_lik = jnp.sum(y * jax.nn.log_sigmoid(logits) + (1 - y) * jax.nn.log_sigmoid(-logits))
    log_prior = -0.5 * alpha * jnp.dot(params, params)
    return log_lik + log_prior

# ---------------------------------------------------------------------------------------------------
# Actor (policy) network
# ---------------------------------------------------------------------------------------------------
class TimeEmbed(nn.Module):
    hidden: int = 64
    @nn.compact
    def __call__(self, t):
        freqs = jnp.asarray([1., 2., 4., 8., 16.])
        sinus = jnp.concatenate([
            jnp.sin(2*jnp.pi*freqs[None,:]*t[:,None]),
            jnp.cos(2*jnp.pi*freqs[None,:]*t[:,None])
        ], axis=-1)
        h = nn.relu(nn.Dense(self.hidden)(sinus))
        return nn.relu(nn.Dense(self.hidden)(h))

class Actor(nn.Module):
    hidden: int = 256
    out_dim: int = 31  # D = 30 + 1
    @nn.compact
    def __call__(self, x, t):
        te = TimeEmbed()(t)
        h = jnp.concatenate([x, te], axis=-1)
        h = nn.relu(nn.Dense(self.hidden)(h))
        h = nn.relu(nn.Dense(self.hidden)(h))
        return 1.5 * jnp.tanh(nn.Dense(self.out_dim)(h))

# ---------------------------------------------------------------------------------------------------
# RLFS dynamics & log-probs
# ---------------------------------------------------------------------------------------------------
@jit
def logF(x, x_next, a, sigma):
    diff = x_next - x - a
    return -jnp.sum(diff**2, axis=-1) / (2*sigma**2)

@jit
def logB(x, x_next, sigma):
    diff = x - jnp.sqrt(1-sigma**2)*x_next
    return -jnp.sum(diff**2, axis=-1) / (2*sigma**2)

def make_rollout_trajectory(T, actor_forward):
    invT = jnp.array(1.0/T, dtype=jnp.float32)
    def step(carry, _):
        key, x, t, params, sigma = carry
        key, sub = random.split(key)
        a = actor_forward(params, x, t)
        eps = random.normal(sub, x.shape)
        x_next = jnp.sqrt(1-sigma**2)*x + a + sigma*eps
        t_next = t + invT
        r = logB(x, x_next, sigma) - logF(x, x_next, a, sigma)
        carry = (key, x_next, t_next, params, sigma)
        trans = (x, a, t, r, x_next, t_next)
        return carry, trans
    @jit
    def rollout(key, x0, t0, params, sigma):
        init = (key, x0, t0, params, sigma)
        (key_f, xT, tT, _, _), (xs, as_, ts, rs, xns, tns) = lax.scan(step, init, None, length=T)
        return (xs, as_, ts, rs, xns, tns), (xT, tT)
    return rollout

# ---------------------------------------------------------------------------------------------------
# Differentiable RLFS loss
# ---------------------------------------------------------------------------------------------------
def rlfs_loss(params, key, X, y, alpha, sigma=0.2, T=24, B=512):
    actor = Actor(out_dim=D)
    rollout = make_rollout_trajectory(T, lambda p, x, t: actor.apply(p, x, t))
    key, sub = random.split(key)
    x0 = 0.5 * random.normal(sub, (B, D))
    t0 = jnp.zeros((B,))
    (_, _, _, rs, _, _), (xT, _) = rollout(key, x0, t0, params, sigma)
    r_term = jax.vmap(lambda xx: blr_log_unnormalized(xx, X, y, alpha))(xT)
    total_r = jnp.sum(rs, axis=0) + r_term
    return -jnp.mean(total_r)  # minimize -E[reward]

# ---------------------------------------------------------------------------------------------------
# Predictive metrics
# ---------------------------------------------------------------------------------------------------
@jit
def predictive_metrics(params_batch, X, y):
    w = params_batch[:, :-1]
    b = params_batch[:, -1:]
    logits = X @ w.T + b.T
    probs = jax.nn.sigmoid(logits)
    p_mc = jnp.mean(probs, axis=1)
    eps = 1e-7
    nll = -jnp.mean(y*jnp.log(p_mc+eps)+(1-y)*jnp.log(1-p_mc+eps))
    acc = jnp.mean((p_mc>=0.5)==(y>=0.5))
    return nll, acc

# ---------------------------------------------------------------------------------------------------
# Train by backpropagation
# ---------------------------------------------------------------------------------------------------
LR = 3e-4
SIGMA = 0.2
T_H = 24
ALPHA = 1.0
BATCH = 512
EPOCHS = 3000

actor = Actor(out_dim=D)
key = random.PRNGKey(42)
params = actor.init(key, jnp.zeros((1, D)), jnp.zeros((1,)))
opt = optax.chain(optax.clip_by_global_norm(1.0), optax.adam(LR))
opt_state = opt.init(params)

@jit
def train_step(params, opt_state, key):
    loss, grads = value_and_grad(rlfs_loss)(params, key, X_train, y_train, ALPHA, SIGMA, T_H, BATCH)
    updates, opt_state = opt.update(grads, opt_state, params)
    params = optax.apply_updates(params, updates)
    return params, opt_state, loss

for it in trange(EPOCHS):
    key, sub = random.split(key)
    params, opt_state, loss = train_step(params, opt_state, sub)
    if it % 100 == 0:
        print(f"[Iter {it:04d}] Loss = {float(loss):.4f}")

# ---------------------------------------------------------------------------------------------------
# Evaluate samples
# ---------------------------------------------------------------------------------------------------
rollout_eval = make_rollout_trajectory(T_H, lambda p, x, t: actor.apply(p, x, t))
key, sub = random.split(key)
(_, _, _, _, _, _), (xT, _) = rollout_eval(key, 0.5*random.normal(sub, (8000, D)), jnp.zeros((8000,)), params, SIGMA)

nll_tr, acc_tr = predictive_metrics(xT, X_train, y_train)
nll_te, acc_te = predictive_metrics(xT, X_test,  y_test)
print(f"\nFinal Train NLL={float(nll_tr):.3f}, Acc={float(acc_tr):.3f}")
print(f"Final Test  NLL={float(nll_te):.3f}, Acc={float(acc_te):.3f}")


Devices: [CudaDevice(id=0)]
Dataset: 426 training samples, 143 test samples, dim=30


  0%|          | 11/3000 [00:05<19:36,  2.54it/s] 

[Iter 0000] Loss = 23740.3320


  4%|▍         | 124/3000 [00:06<00:32, 87.32it/s]

[Iter 0100] Loss = 73.6869


  7%|▋         | 222/3000 [00:07<00:21, 127.73it/s]

[Iter 0200] Loss = 56.4919


 11%|█         | 320/3000 [00:08<00:20, 132.46it/s]

[Iter 0300] Loss = 53.4176


 14%|█▍        | 419/3000 [00:09<00:19, 133.97it/s]

[Iter 0400] Loss = 52.4427


 17%|█▋        | 519/3000 [00:09<00:18, 134.73it/s]

[Iter 0500] Loss = 51.7740


 21%|██        | 619/3000 [00:10<00:17, 134.76it/s]

[Iter 0600] Loss = 51.4914


 24%|██▍       | 719/3000 [00:11<00:16, 134.67it/s]

[Iter 0700] Loss = 51.5036


 27%|██▋       | 819/3000 [00:12<00:16, 134.64it/s]

[Iter 0800] Loss = 50.7234


 31%|███       | 919/3000 [00:12<00:15, 134.67it/s]

[Iter 0900] Loss = 50.9866


 34%|███▍      | 1019/3000 [00:13<00:14, 134.64it/s]

[Iter 1000] Loss = 50.9375


 37%|███▋      | 1119/3000 [00:14<00:13, 134.63it/s]

[Iter 1100] Loss = 50.6996


 41%|████      | 1219/3000 [00:15<00:13, 134.79it/s]

[Iter 1200] Loss = 50.4441


 44%|████▍     | 1319/3000 [00:15<00:12, 134.25it/s]

[Iter 1300] Loss = 50.5180


 47%|████▋     | 1419/3000 [00:16<00:11, 134.60it/s]

[Iter 1400] Loss = 50.7294


 51%|█████     | 1519/3000 [00:17<00:11, 134.60it/s]

[Iter 1500] Loss = 50.5651


 54%|█████▍    | 1619/3000 [00:18<00:10, 134.66it/s]

[Iter 1600] Loss = 50.8842


 57%|█████▋    | 1719/3000 [00:18<00:09, 134.63it/s]

[Iter 1700] Loss = 50.4336


 61%|██████    | 1819/3000 [00:19<00:08, 134.69it/s]

[Iter 1800] Loss = 50.2907


 64%|██████▍   | 1919/3000 [00:20<00:08, 134.60it/s]

[Iter 1900] Loss = 50.7628


 67%|██████▋   | 2019/3000 [00:21<00:07, 134.62it/s]

[Iter 2000] Loss = 50.4379


 71%|███████   | 2119/3000 [00:21<00:06, 134.66it/s]

[Iter 2100] Loss = 50.1321


 74%|███████▍  | 2219/3000 [00:22<00:05, 134.66it/s]

[Iter 2200] Loss = 50.2747


 77%|███████▋  | 2319/3000 [00:23<00:05, 134.65it/s]

[Iter 2300] Loss = 50.2650


 81%|████████  | 2419/3000 [00:24<00:04, 134.67it/s]

[Iter 2400] Loss = 50.3571


 84%|████████▍ | 2519/3000 [00:24<00:03, 134.64it/s]

[Iter 2500] Loss = 50.1647


 87%|████████▋ | 2619/3000 [00:25<00:02, 134.42it/s]

[Iter 2600] Loss = 50.4677


 91%|█████████ | 2719/3000 [00:26<00:02, 134.63it/s]

[Iter 2700] Loss = 50.4176


 94%|█████████▍| 2819/3000 [00:27<00:01, 134.50it/s]

[Iter 2800] Loss = 50.4593


 97%|█████████▋| 2919/3000 [00:27<00:00, 134.89it/s]

[Iter 2900] Loss = 50.3458


100%|██████████| 3000/3000 [00:28<00:00, 105.06it/s]



Final Train NLL=0.061, Acc=0.988
Final Test  NLL=0.085, Acc=0.979


# REINFORCE

In [11]:
import jax
import jax.numpy as jnp
from jax import random, lax, value_and_grad, jit
from flax import linen as nn
from flax.training.train_state import TrainState
import optax
import numpy as np
from sklearn.datasets import load_breast_cancer
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from tqdm import trange

print("Devices:", jax.devices())

# =========================================================
# 0. Data: Breast Cancer
# =========================================================
data = load_breast_cancer()
X = data.data.astype(np.float32)
y = data.target.astype(np.float32)

scaler = StandardScaler()
X = scaler.fit_transform(X).astype(np.float32)

X_train_np, X_test_np, y_train_np, y_test_np = train_test_split(
    X, y, test_size=0.25, random_state=0, stratify=y
)
X_train = jnp.array(X_train_np)
X_test  = jnp.array(X_test_np)
y_train = jnp.array(y_train_np)
y_test  = jnp.array(y_test_np)

D_base = X_train.shape[1]   # 30
D = D_base + 1              # 31 (weights + bias)
print(f"Dataset: train={X_train.shape[0]}, test={X_test.shape[0]}, dim={D_base}")

# We'll close over these in jit'd functions
ALPHA = 3.0   # BLR prior precision
SIGMA = 0.1
HORIZON = 32

# =========================================================
# 1. BLR log unnormalized posterior log π̄(w,b)
# =========================================================
def blr_log_unnormalized(params, X, y, alpha=1.0):
    w, b = params[:-1], params[-1]
    logits = X @ w + b
    log_lik = jnp.sum(
        y * jax.nn.log_sigmoid(logits)
        + (1.0 - y) * jax.nn.log_sigmoid(-logits)
    )
    log_prior = -0.5 * alpha * jnp.dot(params, params)
    return log_lik + log_prior

# vectorize over batch of params
batch_log_unnorm_pi = jit(jax.vmap(
    lambda p: blr_log_unnormalized(p, X_train, y_train, ALPHA)
))

# =========================================================
# 2. Nets: TimeEmbed + Actor(μθ)
# =========================================================
class TimeEmbed(nn.Module):
    hidden: int = 64
    @nn.compact
    def __call__(self, t):  # t: [B]
        freqs = jnp.asarray([1., 2., 4., 8., 16.])
        sinus = jnp.concatenate([
            jnp.sin(2*jnp.pi*freqs[None,:] * t[:,None]),
            jnp.cos(2*jnp.pi*freqs[None,:] * t[:,None])
        ], axis=-1)    # [B,10]
        h = nn.relu(nn.Dense(self.hidden)(sinus))
        h = nn.relu(nn.Dense(self.hidden)(h))
        return h       # [B,hidden]

class Actor(nn.Module):
    out_dim: int
    hidden: int = 256
    act_scale: float = 1.5
    @nn.compact
    def __call__(self, x, t):  # x:[B,D], t:[B]
        te = TimeEmbed()(t)                     # [B,hidden]
        h = jnp.concatenate([x, te], axis=-1)   # [B,D+hidden]
        h = nn.relu(nn.Dense(self.hidden)(h))
        h = nn.relu(nn.Dense(self.hidden)(h))
        a = nn.Dense(self.out_dim)(h)           # [B,D]
        return self.act_scale * jnp.tanh(a)     # μθ(x,t)

# convenience: applies actor mean
def actor_mean(params, x, t):
    return Actor(out_dim=D).apply(params, x, t)

actor_mean_jit = jit(actor_mean)

# =========================================================
# 3. RLFS dynamics & per-step terms (JIT-able)
# =========================================================

@jit
def env_step(key, actor_params, x, t, sigma):
    """
    One step for ALL trajectories in batch.
    x: [B,D], t: [B]
    returns:
      key', x_next, t_next, a_sampled, a_det
    """
    a_det = actor_mean_jit(actor_params, x, t)     # [B,D]

    key, sub = random.split(key)
    eps = random.normal(sub, x.shape)              # [B,D]
    a = a_det + sigma * eps                        # [B,D]

    x_next = jnp.sqrt(1.0 - sigma**2) * x + a      # [B,D]
    t_next = t + (1.0 / HORIZON)                   # [B]

    return key, x_next, t_next, a, a_det

@jit
def logF_consistent(x, x_next, a_det, sigma):
    """
    x_next | x ~ N( mean = sqrt(1-σ²)*x + a_det, cov = σ² I )
    returns [B]
    """
    mean = jnp.sqrt(1.0 - sigma**2) * x + a_det
    diff = x_next - mean
    return -0.5 * jnp.sum((diff**2) / (sigma**2), axis=-1)

@jit
def logB_consistent(x, x_next, sigma):
    """
    backward kernel: x ~ N( sqrt(1-σ²)*x_next , σ² I )
    returns [B]
    """
    mean_back = jnp.sqrt(1.0 - sigma**2) * x_next
    diff = x - mean_back
    return -0.5 * jnp.sum((diff**2) / (sigma**2), axis=-1)

@jit
def log_policy_gaussian(actor_params, x, t, a, sigma):
    """
    log πθ(a | x,t) up to const = -0.5 ||(a-μθ)/σ||^2
    returns [B]
    """
    mu = actor_mean_jit(actor_params, x, t)  # [B,D]
    diff = a - mu
    return -0.5 * jnp.sum((diff**2) / (sigma**2), axis=-1)


# =========================================================
# 4. Rollout *batched* trajectories (JIT, pure functional)
#    We'll store everything we need for the policy gradient:
#    - xs[t], ts[t], as[t], logpi[t], r[t]
#    and final x_T for terminal reward.
# =========================================================

def rollout_batch_once(key, actor_params, batch_size, sigma):
    """
    Roll out (batch_size) trajectories in parallel for fixed HORIZON.
    Returns:
      traj dict with shapes:
        xs:        [H,B,D]
        ts:        [H,B]
        acts:      [H,B,D]      (sampled actions)
        logp_ts:   [H,B]        (log πθ at those actions)
        r_ts:      [H,B]        (per-step flow reward)
        x_T:       [B,D]
    """
    # init x0 ~ N(0,0.5^2 I)
    key, sub = random.split(key)
    x0 = 0.5 * random.normal(sub, (batch_size, D))
    t0 = jnp.zeros((batch_size,))

    def body(carry, _):
        key, x, t = carry
        key, x_next, t_next, a_samp, a_det = env_step(key, actor_params, x, t, sigma)
        # per-step reward
        r_t = logB_consistent(x, x_next, sigma) - logF_consistent(x, x_next, a_det, sigma)
        # logπθ(a|x,t)
        logp_t = log_policy_gaussian(actor_params, x, t, a_samp, sigma)

        new_carry = (key, x_next, t_next)
        out_t = (x, t, a_samp, logp_t, r_t, x_next)
        return new_carry, out_t

    # lax.scan: length = HORIZON
    (key_f, x_last, t_last), scan_out = lax.scan(
        body,
        (key, x0, t0),
        xs=None,
        length=HORIZON
    )
    xs, ts, acts, logp_ts, r_ts, xns = scan_out
    # xs[t] is x_t at step t, xns[t] is x_{t+1}; last xns[-1] is x_T
    x_T = xns[-1]  # [B,D]
    return key_f, {
        "xs": xs,
        "ts": ts,
        "acts": acts,
        "logp_ts": logp_ts,
        "r_ts": r_ts,
        "x_T": x_T,
    }

rollout_batch_once_jit = jit(rollout_batch_once, static_argnums=(2,3))


# =========================================================
# 5. Compute trajectory returns G (vectorized) and baseline
# =========================================================
@jit
def compute_returns(r_ts, x_T):
    """
    r_ts: [H,B]
    x_T:  [B,D]
    return G: [B]
    G = sum_t r_t + log π̄(x_T)
    """
    flow_term = jnp.sum(r_ts, axis=0)        # [B]
    term_bonus = batch_log_unnorm_pi(x_T)    # [B]
    return flow_term + term_bonus           # [B]


# =========================================================
# 6. Policy gradient loss (REINFORCE with baseline)
#
# We do ONE rollout, freeze those samples (xs, ts, acts, etc.),
# compute advantage = (G - baseline_mean),
# then build the loss:
#
#   L(θ) = - mean_b[ stop_grad(adv_b) * sum_t log πθ(a_t^b | x_t^b, t_t^b) ]
#
# where b indexes trajectories in the batch.
#
# That is unbiased REINFORCE.
# =========================================================
def reinforce_loss(actor_params,
                   xs, ts, acts,
                   adv,
                   sigma):
    """
    xs:   [H,B,D]
    ts:   [H,B]
    acts: [H,B,D]
    adv:  [B]  (already stop_gradient outside)
    """
    # We'll vm ap over t dimension to get logπθ at each t for each traj under *current* params.
    def per_t(x_t, t_t, a_t):
        return log_policy_gaussian(actor_params, x_t, t_t, a_t, sigma)  # [B]

    logp_all = jax.vmap(per_t, in_axes=(0,0,0))(xs, ts, acts)  # [H,B]
    logp_sum = jnp.sum(logp_all, axis=0)                      # [B]

    # REINFORCE objective:
    # J_hat = mean( adv * logp_sum )
    J_hat = jnp.mean(adv * logp_sum)
    return -J_hat  # we minimize loss

reinforce_loss_jit = jit(reinforce_loss, static_argnums=(5,))


# =========================================================
# 7. Predictive metrics (jit)
# =========================================================
@jit
def predictive_metrics(params_batch, X, y):
    # params_batch: [S,D]
    w = params_batch[:, :-1]         # [S,30]
    b = params_batch[:, -1:]         # [S,1]
    logits = X @ w.T + b.T           # [N,S]
    probs  = jax.nn.sigmoid(logits)  # [N,S]
    p_mc   = jnp.mean(probs, axis=1) # [N]
    eps = 1e-7
    nll = -jnp.mean(y*jnp.log(p_mc+eps) + (1-y)*jnp.log(1-p_mc+eps))
    acc = jnp.mean((p_mc >= 0.5) == (y >= 0.5))
    return nll, acc

# helper to sample final params x_T for eval
def sample_final_params(key, actor_params, num_samples, sigma):
    key, traj = rollout_batch_once_jit(key, actor_params, num_samples, sigma)
    x_T = traj["x_T"]  # [num_samples, D]
    return key, x_T


# =========================================================
# 8. Agent with JIT'd train_step
# =========================================================
class ReinforceAgent:
    def __init__(self,
                 sigma=SIGMA,
                 horizon=HORIZON,
                 lr=1e-4,
                 seed=0):

        assert horizon == HORIZON, "keep horizon global consistent for jit shapes"

        key = random.PRNGKey(seed)
        key, ka = random.split(key)

        dummy_x = jnp.zeros((1, D))
        dummy_t = jnp.zeros((1,))
        actor = Actor(out_dim=D)
        actor_params = actor.init(ka, dummy_x, dummy_t)

        self.actor = actor
        self.actor_state = TrainState.create(
            apply_fn=self.actor.apply,
            params=actor_params,
            tx=optax.adam(lr)
        )

        # running return baseline (EMA)
        self.baseline_mean = 0.0
        self.baseline_beta = 0.9

        self.key = key
        self.sigma = float(sigma)

    @staticmethod
    @jit
    def _ema_update(old_mean, new_batch_mean, beta):
        return beta*old_mean + (1.0-beta)*new_batch_mean

    def train_step(self, batch_size):
        """
        1. rollout trajectories (JIT)
        2. compute returns and baseline/advantage
        3. JIT'd grad on reinforce_loss
        4. update actor params
        """
        # rollout
        self.key, traj = rollout_batch_once_jit(
            self.key,
            self.actor_state.params,
            batch_size,
            self.sigma
        )
        xs     = traj["xs"]        # [H,B,D]
        ts     = traj["ts"]        # [H,B]
        acts   = traj["acts"]      # [H,B,D]
        r_ts   = traj["r_ts"]      # [H,B]
        x_T    = traj["x_T"]       # [B,D]

        # returns G per traj
        G = compute_returns(r_ts, x_T)  # [B]

        # update EMA baseline on host (cheap)
        G_np = np.array(G)
        batch_mean = float(G_np.mean())
        self.baseline_mean = ReinforceAgent._ema_update(
            self.baseline_mean,
            batch_mean,
            self.baseline_beta
        )

        # advantage (stop_gradient when fed to loss)
        adv = G - self.baseline_mean  # [B]

        # build loss+grad function
        def loss_fn(params, xs, ts, acts, adv):
            # stop grad on adv here so loss doesn't backprop into baseline/returns
            adv_sg = jax.lax.stop_gradient(adv)
            return reinforce_loss_jit(params, xs, ts, acts, adv_sg, self.sigma)

        loss_val, grads = value_and_grad(loss_fn)(
            self.actor_state.params, xs, ts, acts, adv
        )

        # apply update
        self.actor_state = self.actor_state.apply_gradients(grads=grads)

        avg_return = float(batch_mean)
        return float(loss_val), avg_return


# =========================================================
# 9. Training loop
# =========================================================
if __name__ == "__main__":
    LR = 1e-4
    BATCH_ROLLOUT = 1028
    TRAIN_ITERS = 4000
    EVAL_SAMPLES = 4000

    agent = ReinforceAgent(
        sigma=SIGMA,
        horizon=HORIZON,
        lr=LR,
        seed=0
    )

    print("Training REINFORCE-RLFS (JIT)...")
    for it in trange(TRAIN_ITERS):
        loss, avg_ret = agent.train_step(BATCH_ROLLOUT)

        if (it + 1) % 100 == 0:
            agent.key, params_T = sample_final_params(
                agent.key,
                agent.actor_state.params,
                EVAL_SAMPLES,
                agent.sigma
            )
            nll_tr, acc_tr = predictive_metrics(params_T, X_train, y_train)
            nll_te, acc_te = predictive_metrics(params_T, X_test,  y_test)
            print(f"[Iter {it+1:04d}] "
                  f"Loss={loss:.3f} Ret={avg_ret:.2f} | "
                  f"Train NLL={float(nll_tr):.3f} Acc={float(acc_tr):.3f} | "
                  f"Test  NLL={float(nll_te):.3f} Acc={float(acc_te):.3f}")


Devices: [CudaDevice(id=0)]
Dataset: train=426, test=143, dim=30
Training REINFORCE-RLFS (JIT)...


  3%|▎         | 103/4000 [00:10<09:24,  6.90it/s]

[Iter 0100] Loss=871732.938 Ret=-88550.21 | Train NLL=0.467 Acc=0.803 | Test  NLL=0.587 Acc=0.811


  5%|▌         | 203/4000 [00:14<02:11, 28.88it/s]

[Iter 0200] Loss=846512.438 Ret=-62594.16 | Train NLL=0.452 Acc=0.805 | Test  NLL=0.506 Acc=0.804


  8%|▊         | 303/4000 [00:18<02:10, 28.28it/s]

[Iter 0300] Loss=1271096.875 Ret=-37128.16 | Train NLL=0.358 Acc=0.908 | Test  NLL=0.395 Acc=0.846


 10%|█         | 404/4000 [00:21<02:03, 29.20it/s]

[Iter 0400] Loss=910320.562 Ret=-18421.64 | Train NLL=0.389 Acc=0.838 | Test  NLL=0.434 Acc=0.804


 13%|█▎        | 503/4000 [00:24<01:59, 29.25it/s]

[Iter 0500] Loss=289034.750 Ret=-8959.00 | Train NLL=0.355 Acc=0.852 | Test  NLL=0.396 Acc=0.818


 15%|█▌        | 602/4000 [00:28<02:47, 20.23it/s]

[Iter 0600] Loss=78360.203 Ret=-5232.82 | Train NLL=0.337 Acc=0.866 | Test  NLL=0.366 Acc=0.853


 18%|█▊        | 705/4000 [00:32<01:55, 28.44it/s]

[Iter 0700] Loss=46400.375 Ret=-3713.71 | Train NLL=0.316 Acc=0.873 | Test  NLL=0.343 Acc=0.860


 20%|██        | 804/4000 [00:35<01:48, 29.35it/s]

[Iter 0800] Loss=19125.041 Ret=-2924.85 | Train NLL=0.249 Acc=0.908 | Test  NLL=0.280 Acc=0.888


 23%|██▎       | 903/4000 [00:38<01:45, 29.43it/s]

[Iter 0900] Loss=31599.031 Ret=-2411.22 | Train NLL=0.243 Acc=0.918 | Test  NLL=0.270 Acc=0.902


 25%|██▌       | 1003/4000 [00:42<01:42, 29.24it/s]

[Iter 1000] Loss=-8747.858 Ret=-2217.12 | Train NLL=0.210 Acc=0.930 | Test  NLL=0.240 Acc=0.923


 28%|██▊       | 1106/4000 [00:45<01:37, 29.62it/s]

[Iter 1100] Loss=4807.268 Ret=-1985.79 | Train NLL=0.191 Acc=0.941 | Test  NLL=0.221 Acc=0.937


 30%|███       | 1206/4000 [00:49<01:33, 29.85it/s]

[Iter 1200] Loss=15746.478 Ret=-1806.21 | Train NLL=0.180 Acc=0.941 | Test  NLL=0.214 Acc=0.944


 33%|███▎      | 1304/4000 [00:52<02:09, 20.84it/s]

[Iter 1300] Loss=5534.780 Ret=-1690.58 | Train NLL=0.165 Acc=0.946 | Test  NLL=0.195 Acc=0.944


 35%|███▌      | 1404/4000 [00:56<01:27, 29.72it/s]

[Iter 1400] Loss=13286.316 Ret=-1556.15 | Train NLL=0.161 Acc=0.948 | Test  NLL=0.192 Acc=0.951


 38%|███▊      | 1503/4000 [00:59<01:25, 29.32it/s]

[Iter 1500] Loss=3734.760 Ret=-1493.38 | Train NLL=0.160 Acc=0.948 | Test  NLL=0.188 Acc=0.944


 40%|████      | 1606/4000 [01:02<01:20, 29.60it/s]

[Iter 1600] Loss=3248.963 Ret=-1405.01 | Train NLL=0.148 Acc=0.951 | Test  NLL=0.178 Acc=0.951


 43%|████▎     | 1703/4000 [01:06<01:19, 28.91it/s]

[Iter 1700] Loss=7516.651 Ret=-1331.88 | Train NLL=0.151 Acc=0.953 | Test  NLL=0.177 Acc=0.944


 45%|████▌     | 1803/4000 [01:09<01:15, 29.12it/s]

[Iter 1800] Loss=4719.381 Ret=-1273.39 | Train NLL=0.140 Acc=0.951 | Test  NLL=0.172 Acc=0.951


 48%|████▊     | 1906/4000 [01:13<01:09, 30.07it/s]

[Iter 1900] Loss=136.600 Ret=-1221.53 | Train NLL=0.141 Acc=0.946 | Test  NLL=0.172 Acc=0.944


 50%|█████     | 2002/4000 [01:16<01:36, 20.67it/s]

[Iter 2000] Loss=-5901.017 Ret=-1181.70 | Train NLL=0.134 Acc=0.953 | Test  NLL=0.166 Acc=0.944


 53%|█████▎    | 2105/4000 [01:20<01:02, 30.13it/s]

[Iter 2100] Loss=8013.326 Ret=-1101.36 | Train NLL=0.131 Acc=0.951 | Test  NLL=0.162 Acc=0.951


 55%|█████▌    | 2205/4000 [01:23<00:59, 30.17it/s]

[Iter 2200] Loss=-4717.550 Ret=-1085.97 | Train NLL=0.130 Acc=0.948 | Test  NLL=0.162 Acc=0.951


 58%|█████▊    | 2304/4000 [01:26<00:57, 29.50it/s]

[Iter 2300] Loss=3209.830 Ret=-1028.58 | Train NLL=0.127 Acc=0.958 | Test  NLL=0.158 Acc=0.951


 60%|██████    | 2405/4000 [01:30<00:53, 29.71it/s]

[Iter 2400] Loss=1607.092 Ret=-991.41 | Train NLL=0.124 Acc=0.955 | Test  NLL=0.155 Acc=0.944


 63%|██████▎   | 2504/4000 [01:34<00:51, 29.01it/s]

[Iter 2500] Loss=30.452 Ret=-961.06 | Train NLL=0.124 Acc=0.948 | Test  NLL=0.159 Acc=0.944


 65%|██████▌   | 2603/4000 [01:37<00:49, 28.38it/s]

[Iter 2600] Loss=-314.952 Ret=-927.74 | Train NLL=0.125 Acc=0.958 | Test  NLL=0.157 Acc=0.951


 68%|██████▊   | 2702/4000 [01:41<01:03, 20.34it/s]

[Iter 2700] Loss=1804.136 Ret=-888.07 | Train NLL=0.121 Acc=0.962 | Test  NLL=0.153 Acc=0.951


 70%|███████   | 2803/4000 [01:44<00:41, 28.83it/s]

[Iter 2800] Loss=2708.704 Ret=-857.13 | Train NLL=0.122 Acc=0.955 | Test  NLL=0.154 Acc=0.951


 73%|███████▎  | 2905/4000 [01:47<00:37, 29.04it/s]

[Iter 2900] Loss=45.579 Ret=-834.99 | Train NLL=0.122 Acc=0.953 | Test  NLL=0.153 Acc=0.951


 75%|███████▌  | 3005/4000 [01:51<00:33, 29.52it/s]

[Iter 3000] Loss=3907.437 Ret=-802.77 | Train NLL=0.118 Acc=0.962 | Test  NLL=0.147 Acc=0.951


 78%|███████▊  | 3104/4000 [01:55<00:30, 29.44it/s]

[Iter 3100] Loss=4669.167 Ret=-772.93 | Train NLL=0.116 Acc=0.967 | Test  NLL=0.145 Acc=0.951


 80%|████████  | 3203/4000 [01:58<00:27, 29.17it/s]

[Iter 3200] Loss=4724.407 Ret=-750.23 | Train NLL=0.117 Acc=0.962 | Test  NLL=0.146 Acc=0.951


 83%|████████▎ | 3303/4000 [02:01<00:23, 29.43it/s]

[Iter 3300] Loss=296.412 Ret=-731.25 | Train NLL=0.113 Acc=0.960 | Test  NLL=0.142 Acc=0.951


 85%|████████▌ | 3403/4000 [02:05<00:28, 20.59it/s]

[Iter 3400] Loss=-2436.727 Ret=-718.32 | Train NLL=0.110 Acc=0.967 | Test  NLL=0.141 Acc=0.951


 88%|████████▊ | 3504/4000 [02:08<00:16, 29.74it/s]

[Iter 3500] Loss=834.009 Ret=-689.08 | Train NLL=0.109 Acc=0.969 | Test  NLL=0.139 Acc=0.958


 90%|█████████ | 3603/4000 [02:12<00:13, 29.59it/s]

[Iter 3600] Loss=3689.060 Ret=-662.80 | Train NLL=0.108 Acc=0.967 | Test  NLL=0.138 Acc=0.958


 93%|█████████▎| 3706/4000 [02:15<00:09, 29.58it/s]

[Iter 3700] Loss=-2962.834 Ret=-657.62 | Train NLL=0.107 Acc=0.967 | Test  NLL=0.137 Acc=0.951


 95%|█████████▌| 3806/4000 [02:19<00:06, 29.44it/s]

[Iter 3800] Loss=-3204.368 Ret=-638.04 | Train NLL=0.108 Acc=0.967 | Test  NLL=0.139 Acc=0.958


 98%|█████████▊| 3906/4000 [02:22<00:03, 29.70it/s]

[Iter 3900] Loss=-439.293 Ret=-613.80 | Train NLL=0.106 Acc=0.965 | Test  NLL=0.135 Acc=0.958


100%|██████████| 4000/4000 [02:25<00:00, 27.46it/s]

[Iter 4000] Loss=-576.525 Ret=-597.76 | Train NLL=0.104 Acc=0.969 | Test  NLL=0.134 Acc=0.951





# DPG

In [None]:
!pip install numpyro

In [None]:
from sklearn.datasets import load_breast_cancer
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
import time
import numpy as np
import jax, jax.numpy as jnp
import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS
from scipy.special import expit as sigmoid
from scipy.stats import ks_2samp
import arviz as az
from scipy.stats import wasserstein_distance
from sklearn.metrics.pairwise import pairwise_kernels

def make_blr_data():
    data = load_breast_cancer()
    X = data.data.astype(np.float32)
    y = data.target.astype(np.int32)
    X = StandardScaler().fit_transform(X)
    X_train, X_test, y_train, y_test = train_test_split(
        X, y, test_size=0.25, random_state=0
    )
    return jnp.array(X_train), jnp.array(y_train), jnp.array(X_test), jnp.array(y_test)

X_train, y_train, X_test, y_test = make_blr_data()
D_base = int(X_train.shape[1])
D = D_base + 1

# ======================================
#  Define Bayesian Logistic Regression model
# ======================================
TAU = 2.5  # prior stddev

def blr_model(X, y=None):
    D = X.shape[1]
    w = numpyro.sample("w", dist.Normal(0.0, TAU).expand([D]))
    b = numpyro.sample("b", dist.Normal(0.0, TAU))
    logits = jnp.dot(X, w) + b
    numpyro.sample("y", dist.Bernoulli(logits=logits), obs=y)

# ======================================
# Run NUTS (Hamiltonian Monte Carlo)
# ======================================
nuts = NUTS(blr_model, target_accept_prob=0.9, dense_mass=True)
mcmc = MCMC(nuts, num_warmup=1000, num_samples=1000, num_chains=4, chain_method="parallel")

rng_key = jax.random.PRNGKey(0)
print("🚀 Running NUTS sampling ...")
t0 = time.time()
mcmc.run(rng_key, X=X_train, y=y_train)
wall = time.time() - t0
print(f"⏱ Done in {wall:.1f} sec")

# ======================================
# Diagnostics with ArviZ
# ======================================
idata = az.from_numpyro(mcmc)
diverging = np.array(mcmc.get_extra_fields()["diverging"]).sum()
print(f"⚠️ Divergences: {diverging}")

# ======================================
# Posterior predictive check (test log-likelihood)
# ======================================
samples = mcmc.get_samples(group_by_chain=False)

def test_log_pred_density(samples_flat, Xs, ys, batch=512):
    W = np.array(samples_flat["w"])  # [S, D]
    b = np.array(samples_flat["b"])  # [S]
    S = W.shape[0]
    total = 0.0
    for start in range(0, Xs.shape[0], batch):
        end = min(start + batch, Xs.shape[0])
        Xb = np.array(Xs[start:end])
        yb = np.array(ys[start:end])[:, None]
        logits = Xb @ W.T + b  # [B, S]
        p = 1 / (1 + np.exp(-logits))
        per_samp = p**yb * (1 - p)**(1 - yb)
        avg = np.clip(per_samp.mean(axis=1), 1e-12, 1.0)
        total += np.log(avg).sum()
    return total / Xs.shape[0]

tll = test_log_pred_density(samples, X_test, y_test)
print(f"📊 Test log predictive density (nats/sample): {tll:.4f}")

# ======================================
#  Save reference posterior
# ======================================
np.savez_compressed(
    "blr_nuts_reference.npz",
    w=np.array(samples["w"]),
    b=np.array(samples["b"]),
    tll=tll,
    wall=wall
)
print("💾 Saved posterior to blr_nuts_reference.npz")

# ======================================
# Comparison helper
# ======================================
def compare_to_reference(samples_other, samples_ref, max_points=2000):
    """Compare RLFS samples to NUTS reference using Wasserstein & MMD (compact)."""
    W_oth, W_ref = np.array(samples_other["w"]), np.array(samples_ref["w"])
    B_oth, B_ref = np.array(samples_other["b"]), np.array(samples_ref["b"])
    if W_oth.shape[0] > max_points:
        W_oth = W_oth[np.random.choice(W_oth.shape[0], max_points, replace=False)]
    if W_ref.shape[0] > max_points:
        W_ref = W_ref[np.random.choice(W_ref.shape[0], max_points, replace=False)]

    D = W_ref.shape[1]
    ws = [wasserstein_distance(W_oth[:, d], W_ref[:, d]) for d in range(D)]
    w_mean, w_max, b_ws = np.mean(ws), np.max(ws), wasserstein_distance(B_oth, B_ref)

    X = np.vstack([W_ref, W_oth])
    dists = np.sum((X[:, None, :] - X[None, :, :])**2, axis=-1)
    gamma = 1.0 / (2 * np.median(dists[dists > 0]) + 1e-12)
    def mmd(X, Y):
        Kxx = np.mean(np.exp(-gamma * np.sum((X[:, None]-X[None])**2, axis=-1)))
        Kyy = np.mean(np.exp(-gamma * np.sum((Y[:, None]-Y[None])**2, axis=-1)))
        Kxy = np.mean(np.exp(-gamma * np.sum((X[:, None]-Y[None])**2, axis=-1)))
        return Kxx + Kyy - 2*Kxy
    mmd_val = mmd(W_ref, W_oth)

    mu_err = np.mean(np.abs(W_oth.mean(0) - W_ref.mean(0)))
    var_err = np.mean(np.abs(W_oth.std(0) - W_ref.std(0)))

    return dict(w_mean=w_mean, w_max=w_max, b_ws=b_ws, mmd=mmd_val,
                mu_err=mu_err, var_err=var_err)

print("✅ Reference posterior ready — use compare_to_reference() for your RLFS samples.")

import numpy as np

ref = np.load("blr_nuts_reference.npz")
samples_ref = dict(w=ref["w"], b=ref["b"])
print(f"NUTS samples: w={samples_ref['w'].shape}, b={samples_ref['b'].shape}")


  mcmc = MCMC(nuts, num_warmup=1000, num_samples=1000, num_chains=4, chain_method="parallel")


🚀 Running NUTS sampling ...


sample: 100%|██████████| 2000/2000 [00:28<00:00, 71.30it/s, 15 steps of size 2.71e-01. acc. prob=0.93] 
sample: 100%|██████████| 2000/2000 [00:23<00:00, 86.27it/s, 15 steps of size 2.66e-01. acc. prob=0.93] 
sample: 100%|██████████| 2000/2000 [00:22<00:00, 89.67it/s, 15 steps of size 2.63e-01. acc. prob=0.94] 
sample: 100%|██████████| 2000/2000 [00:24<00:00, 80.67it/s, 15 steps of size 3.59e-01. acc. prob=0.86] 


⏱ Done in 109.1 sec
⚠️ Divergences: 0
📊 Test log predictive density (nats/sample): -0.0997
💾 Saved posterior to blr_nuts_reference.npz
✅ Reference posterior ready — use compare_to_reference() for your RLFS samples.


  p = 1 / (1 + np.exp(-logits))


In [None]:
import jax
import jax.numpy as jnp
from jax import random, lax, jit, grad, device_put
from flax import linen as nn
from flax.training.train_state import TrainState
import optax
import numpy as np
from tqdm import trange
import time

print("Devices:", jax.devices())

# ================================================================
# Bayesian Logistic Regression target
# ===============================================================

def blr_log_unnormalized(params, X, y, alpha=1.0):
    """
    Unnormalized log posterior:
        log π̄(w,b) = log p(y|X,w,b) + log p(w,b)
    where p(w,b) = N(0, α⁻¹ I)
    """
    w, b = params[:-1], params[-1]

    logits = X @ w + b
    log_lik = jnp.sum(
        y * jax.nn.log_sigmoid(logits) + (1 - y) * jax.nn.log_sigmoid(-logits)
    )
    log_prior = -0.5 * alpha * jnp.dot(params, params)
    return log_lik + log_prior

# ================================================================
# Models
# ================================================================
class TimeEmbed(nn.Module):
    hidden: int = 64
    @nn.compact
    def __call__(self, t):
        freqs = jnp.asarray([1., 2., 4., 8., 16.])
        sinus = jnp.concatenate([
            jnp.sin(2*jnp.pi*freqs[None,:] * t[:,None]),
            jnp.cos(2*jnp.pi*freqs[None,:] * t[:,None])
        ], axis=-1)
        h = nn.relu(nn.Dense(self.hidden)(sinus))
        h = nn.relu(nn.Dense(self.hidden)(h))
        return h

class Actor(nn.Module):
    hidden: int = 128
    out_dim: int = 2
    act_scale: float = 2.0
    @nn.compact
    def __call__(self, x, t):
        te = TimeEmbed()(t)
        h = jnp.concatenate([x, te], axis=-1)
        h = nn.relu(nn.Dense(self.hidden)(h))
        h = nn.relu(nn.Dense(self.hidden)(h))
        h = nn.LayerNorm()(h)
        a = nn.Dense(self.out_dim)(h)
        return self.act_scale * jnp.tanh(a)

class CriticQ(nn.Module):
    hidden: int = 256
    @nn.compact
    def __call__(self, x, a, t):
        te = TimeEmbed()(t)
        h = jnp.concatenate([x, a, te], axis=-1)
        h = nn.relu(nn.Dense(self.hidden)(h))
        h = nn.relu(nn.Dense(self.hidden)(h))
        h = nn.LayerNorm()(h)
        return nn.Dense(1)(h).squeeze(-1)

def make_actor_forward(D: int):
    module = Actor(out_dim=D)        # built outside jit → static
    @jax.jit
    def _forward(params, x, t):
        return module.apply(params, x, t)
    return _forward

@jit
def critic_forward(params, x, a, t):
    return CriticQ().apply(params, x, a, t)

def critic_grads_factory(actor_forward):
    @jax.jit
    def _critic_grads(critic_params, critic_targ, actor_targ, batch, gamma):
        x, a, t, r, xn, tn, done = batch
        def loss_fn(params):
            a_next = actor_forward(actor_targ, xn, tn)
            q_next = critic_forward(critic_targ, xn, a_next, tn)
            y = r + (1.0 - done) * gamma * q_next
            y = jax.lax.stop_gradient(y)
            q = critic_forward(params, x, a, t)
            return jnp.mean((q - y)**2)
        return jax.grad(loss_fn)(critic_params)
    return _critic_grads

def actor_grads_factory(actor_forward):
    @jax.jit
    def _actor_grads(actor_params, critic_params, batch):
        x, _, t, _, _, _, _ = batch
        def loss_fn(params):
            a_pred = actor_forward(params, x, t)
            q = critic_forward(critic_params, x, a_pred, t)
            return -jnp.mean(q)
        return jax.grad(loss_fn)(actor_params)
    return _actor_grads

# ================================================================
# RLFS “env” bits (pure JAX)
# ================================================================
@jit
def logF(x, x_next, a, sigma):
    diff = x_next - jnp.sqrt(1 - sigma**2) * x - a
    return -jnp.sum(diff**2, axis=-1) / (2*sigma**2)

@jit
def logB(x, x_next, sigma):
    diff = x - jnp.sqrt(1 - sigma**2) * x_next
    return -jnp.sum(diff**2, axis=-1) / (2*sigma**2)

# ---- Rollout factory that closes over T and D
def make_rollout_trajectory(T: int, actor_forward):
    invT = jnp.array(1.0 / T, dtype=jnp.float32)

    def rollout_step(carry, _):
        key, x, t, actor_params, sigma = carry
        key, sub = random.split(key)
        a = actor_forward(actor_params, x, t)
        eps = random.normal(sub, x.shape)
        x_next = jnp.sqrt(1-sigma**2) *x + a + sigma * eps
        t_next = t + invT
        r_step = logB(x, x_next, sigma) - logF(x, x_next, a, sigma)
        carry_next = (key, x_next, t_next, actor_params, sigma)
        trans = (x, a, t, r_step, x_next, t_next)
        return carry_next, trans

    @jit
    def rollout_trajectory(key, x0, t0, actor_params, sigma):
        init = (key, x0, t0, actor_params, sigma)
        (key_f, xT, tT, _, _), (xs, as_, ts, rs, xns, tns) = lax.scan(
            rollout_step, init, xs=None, length=T
        )
        return (xs, as_, ts, rs, xns, tns), (xT, tT)

    return rollout_trajectory

# ================================================================
# Replay Buffer
# ================================================================
class ReplayBuffer:
    def __init__(self, capacity, obs_dim, act_dim):
        self.capacity = int(capacity)
        self.ptr = 0
        self.size = 0
        self.x = np.zeros((capacity, obs_dim), np.float32)
        self.a = np.zeros((capacity, act_dim), np.float32)
        self.t = np.zeros((capacity,), np.float32)
        self.r = np.zeros((capacity,), np.float32)
        self.xn = np.zeros((capacity, obs_dim), np.float32)
        self.tn = np.zeros((capacity,), np.float32)
        self.done = np.zeros((capacity,), np.float32)

    def push_batch(self, x, a, t, r, xn, tn, done):
        B = x.shape[0]
        idx = (np.arange(B) + self.ptr) % self.capacity
        self.x[idx] = x; self.a[idx] = a; self.t[idx] = t; self.r[idx] = r
        self.xn[idx] = xn; self.tn[idx] = tn; self.done[idx] = done
        self.ptr = (self.ptr + B) % self.capacity
        self.size = int(min(self.capacity, self.size + B))

    def sample(self, batch_size):
        idx = np.random.randint(0, self.size, size=batch_size)
        return (self.x[idx], self.a[idx], self.t[idx], self.r[idx],
                self.xn[idx], self.tn[idx], self.done[idx])

    def __len__(self):
        return self.size

# ===================
# Agent (DDPG-style)
# ===================
class Agent:
    def __init__(self, sigma=0.9, T=64, lr_actor=3e-4, lr_critic=3e-4,
                 tau=0.01, gamma=1.0, seed=0, D=2):
        self.sigma, self.T, self.tau, self.gamma = float(sigma), int(T), float(tau), float(gamma)
        self.D = int(D)
        key = random.PRNGKey(seed)
        dummy_x = jnp.zeros((1, D)); dummy_a = jnp.zeros((1, D)); dummy_t = jnp.zeros((1,))
        key, ka, kc = random.split(key, 3)
        actor = Actor(out_dim=D)
        critic = CriticQ()
        actor_params = actor.init(ka, dummy_x, dummy_t)
        critic_params = critic.init(kc, dummy_x, dummy_a, dummy_t)
        self.actor_state = TrainState.create(apply_fn=actor.apply, params=actor_params, tx=optax.adam(lr_actor))
        self.critic_state = TrainState.create(apply_fn=critic.apply, params=critic_params, tx=optax.adam(lr_critic))
        self.actor_targ = self.actor_state.params
        self.critic_targ = self.critic_state.params
        self.key = key
        self.actor_forward = make_actor_forward(D)
        self._critic_grads = critic_grads_factory(self.actor_forward)
        self._actor_grads  = actor_grads_factory(self.actor_forward)


    @staticmethod
    @jit
    def soft_update(target, source, tau):
        return jax.tree_util.tree_map(lambda t, s: (1-tau)*t + tau*s, target, source)

    def update(self, batch):
        batch_dev = tuple(device_put(jnp.asarray(b)) for b in batch)
        gC = self._critic_grads(self.critic_state.params, self.critic_targ, self.actor_targ,
                                batch_dev, self.gamma)
        self.critic_state = self.critic_state.apply_gradients(grads=gC)
        gA = self._actor_grads(self.actor_state.params, self.critic_state.params,
                               batch_dev)
        self.actor_state = self.actor_state.apply_gradients(grads=gA)
        self.actor_targ = self.soft_update(self.actor_targ, self.actor_state.params, self.tau)
        self.critic_targ = self.soft_update(self.critic_targ, self.critic_state.params, self.tau)

# ================================================================
# Training
# ================================================================
if __name__ == '__main__':
    # Hyperparameters
    SIGMA = 0.9
    T_H = 50
    B_COLLECT = 4096
    REPLAY_CAP = 1_000_000
    START_STEPS = 65_536
    TRAIN_ITERS = 5000
    B_UPDATE = 2048
    UPDATES_PER_ITER = 2
    TAU = 0.01
    LR_ACTOR = 1e-5
    LR_CRITIC= 3e-5
    ALPHA = 1/2.5**2  # prior precision

    # ================================================================
    # Bayesian Logistic Regression
    # ================================================================

    # Agent, replay, rollout fn with static T and D
    agent = Agent(sigma=SIGMA, T=T_H, lr_actor=LR_ACTOR, lr_critic=LR_CRITIC, tau=TAU, D=D)
    rb = ReplayBuffer(REPLAY_CAP, obs_dim=D, act_dim=D)
    rollout_trajectory = make_rollout_trajectory(T_H, actor_forward=agent.actor_forward)

    # Terminal reward for each trajectory end
    def batch_log_unnorm_pi(xT_batch):
        # vmap over the batch of terminal weights
        return jax.vmap(lambda params: blr_log_unnormalized(params, X_train, y_train, ALPHA))(xT_batch)

    # Warmup
    print('Collecting warmup...')
    for _ in trange(max(1, START_STEPS // B_COLLECT)):
        agent.key, sub = random.split(agent.key)
        x0 = 2.5 * random.normal(sub, (B_COLLECT, D)); t0 = jnp.zeros((B_COLLECT,))
        (xs, as_, ts, rs, xns, tns), (xT, tT) = rollout_trajectory(agent.key, x0, t0, agent.actor_state.params, SIGMA)
        r_term = batch_log_unnorm_pi(xT)                        # shape (B,)
        # After computing r_term
        rs = rs.at[-1].set(rs[-1] + r_term)  # OK shape-wise
        # Set done = 1 only for the *last transition* of each trajectory
        DONE = np.zeros_like(TT)
        DONE[-B_COLLECT:] = 1.0
        X  = np.array(xs.reshape(-1, D))
        A  = np.array(as_.reshape(-1, D))
        TT = np.array(ts.reshape(-1))
        R  = np.array(rs.reshape(-1))
        XN = np.array(xns.reshape(-1, D))
        TN = np.array(tns.reshape(-1))
        rb.push_batch(X, A, TT, R, XN, TN, DONE)

    # Train
    print('Training...')
    t0_wall = time.time()
    for it in trange(TRAIN_ITERS):
        agent.key, sub = random.split(agent.key)
        x0 = 2.5 * random.normal(sub, (B_COLLECT, D)); t0 = jnp.zeros((B_COLLECT,))
        (xs, as_, ts, rs, xns, tns), (xT, tT) = rollout_trajectory(agent.key, x0, t0, agent.actor_state.params, SIGMA)
        r_term = batch_log_unnorm_pi(xT)
        rs = rs.at[-1].add(r_term)

        X  = np.array(xs.reshape(-1, D))
        A  = np.array(as_.reshape(-1, D))
        TT = np.array(ts.reshape(-1))
        R  = np.array(rs.reshape(-1))
        XN = np.array(xns.reshape(-1, D))
        TN = np.array(tns.reshape(-1))
        DONE = np.zeros_like(TT); DONE[-B_COLLECT:] = 1.0
        rb.push_batch(X, A, TT, R, XN, TN, DONE)

        if len(rb) >= START_STEPS:
            for _ in range(UPDATES_PER_ITER):
                batch = rb.sample(B_UPDATE)
                agent.update(batch)

        if (it + 1) % 50 == 0:
            agent.key, sub = random.split(agent.key)
            x0 = 2.5 * random.normal(sub, (8000, D)); t0 = jnp.zeros((8000,))
            (_, _, _, _, _, _), (wT, tT) = rollout_trajectory(agent.key, x0, t0, agent.actor_state.params, agent.sigma)
            w_only = wT[:, :D_base]   # shape (8000, D_base)
            b_only = wT[:, -1]        # shape (8000,)
            samples_rlfs = dict(w=w_only, b=b_only)

            metrics = compare_to_reference(samples_rlfs, samples_ref)
            print(f"[Iter {it+1:04d}] "
                  f"Wmean={metrics['w_mean']:.3f} "
                  f"Wmax={metrics['w_max']:.3f} "
                  f"B={metrics['b_ws']:.3f} "
                  f"MMD={metrics['mmd']:.4f} "
                  f"MeanErr={metrics['mu_err']:.3f} "
                  f"VarErr={metrics['var_err']:.3f}")


Devices: [CudaDevice(id=0)]
Collecting warmup...


100%|██████████| 16/16 [00:02<00:00,  6.02it/s]


Training...


  1%|          | 51/5000 [00:16<1:24:33,  1.03s/it]

[Iter 0050] Wmean=2.347 Wmax=5.166 B=0.637 MMD=0.5417 MeanErr=2.277 VarErr=0.490


  2%|▏         | 100/5000 [00:29<1:31:09,  1.12s/it]

[Iter 0100] Wmean=2.008 Wmax=4.882 B=0.837 MMD=0.4310 MeanErr=1.967 VarErr=0.411


  3%|▎         | 150/5000 [00:42<1:33:38,  1.16s/it]

[Iter 0150] Wmean=1.980 Wmax=5.163 B=1.424 MMD=0.4491 MeanErr=1.935 VarErr=0.401


  4%|▍         | 201/5000 [00:55<1:10:10,  1.14it/s]

[Iter 0200] Wmean=1.699 Wmax=5.298 B=0.818 MMD=0.3756 MeanErr=1.614 VarErr=0.412


  5%|▌         | 251/5000 [01:08<1:09:59,  1.13it/s]

[Iter 0250] Wmean=1.713 Wmax=5.348 B=1.931 MMD=0.3353 MeanErr=1.637 VarErr=0.410


  6%|▌         | 301/5000 [01:21<1:10:38,  1.11it/s]

[Iter 0300] Wmean=1.686 Wmax=4.793 B=1.646 MMD=0.3299 MeanErr=1.591 VarErr=0.422


  7%|▋         | 351/5000 [01:34<1:09:40,  1.11it/s]

[Iter 0350] Wmean=1.755 Wmax=4.644 B=1.924 MMD=0.3306 MeanErr=1.652 VarErr=0.446


  8%|▊         | 400/5000 [01:47<1:26:28,  1.13s/it]

[Iter 0400] Wmean=1.755 Wmax=4.232 B=1.969 MMD=0.3136 MeanErr=1.649 VarErr=0.427


  9%|▉         | 451/5000 [02:00<1:00:47,  1.25it/s]

[Iter 0450] Wmean=1.735 Wmax=4.371 B=1.863 MMD=0.2991 MeanErr=1.647 VarErr=0.418


 10%|█         | 501/5000 [02:13<1:00:08,  1.25it/s]

[Iter 0500] Wmean=1.722 Wmax=4.564 B=1.791 MMD=0.2934 MeanErr=1.630 VarErr=0.426


 11%|█         | 551/5000 [02:26<59:17,  1.25it/s]  

[Iter 0550] Wmean=1.703 Wmax=4.844 B=1.724 MMD=0.2889 MeanErr=1.610 VarErr=0.428


 12%|█▏        | 601/5000 [02:39<58:43,  1.25it/s]  

[Iter 0600] Wmean=1.690 Wmax=4.829 B=1.591 MMD=0.2851 MeanErr=1.606 VarErr=0.433


 13%|█▎        | 651/5000 [02:52<57:56,  1.25it/s]  

[Iter 0650] Wmean=1.679 Wmax=4.808 B=1.445 MMD=0.2829 MeanErr=1.609 VarErr=0.426


 14%|█▍        | 701/5000 [03:05<58:04,  1.23it/s]  

[Iter 0700] Wmean=1.653 Wmax=4.455 B=1.333 MMD=0.2719 MeanErr=1.575 VarErr=0.431


 15%|█▌        | 751/5000 [03:18<56:22,  1.26it/s]  

[Iter 0750] Wmean=1.642 Wmax=4.522 B=1.232 MMD=0.2712 MeanErr=1.569 VarErr=0.424


 16%|█▌        | 800/5000 [03:31<1:15:16,  1.08s/it]

[Iter 0800] Wmean=1.609 Wmax=4.112 B=1.180 MMD=0.2599 MeanErr=1.537 VarErr=0.419


 17%|█▋        | 850/5000 [03:44<1:18:23,  1.13s/it]

[Iter 0850] Wmean=1.601 Wmax=3.838 B=1.111 MMD=0.2535 MeanErr=1.528 VarErr=0.432


 18%|█▊        | 901/5000 [03:57<1:00:43,  1.12it/s]

[Iter 0900] Wmean=1.592 Wmax=3.645 B=1.069 MMD=0.2523 MeanErr=1.523 VarErr=0.421


 19%|█▉        | 951/5000 [04:10<59:40,  1.13it/s]  

[Iter 0950] Wmean=1.573 Wmax=3.464 B=1.065 MMD=0.2470 MeanErr=1.507 VarErr=0.420


 20%|██        | 1001/5000 [04:23<58:35,  1.14it/s]  

[Iter 1000] Wmean=1.567 Wmax=3.279 B=1.028 MMD=0.2447 MeanErr=1.495 VarErr=0.433


 21%|██        | 1051/5000 [04:35<59:15,  1.11it/s]  

[Iter 1050] Wmean=1.571 Wmax=3.177 B=1.023 MMD=0.2442 MeanErr=1.500 VarErr=0.419


 22%|██▏       | 1101/5000 [04:48<58:35,  1.11it/s]  

[Iter 1100] Wmean=1.568 Wmax=3.173 B=1.036 MMD=0.2417 MeanErr=1.494 VarErr=0.435


 23%|██▎       | 1151/5000 [05:01<53:42,  1.19it/s]  

[Iter 1150] Wmean=1.560 Wmax=3.122 B=1.040 MMD=0.2382 MeanErr=1.488 VarErr=0.438


 24%|██▍       | 1201/5000 [05:14<51:26,  1.23it/s]  

[Iter 1200] Wmean=1.553 Wmax=3.061 B=0.988 MMD=0.2382 MeanErr=1.485 VarErr=0.433


 25%|██▌       | 1251/5000 [05:28<50:28,  1.24it/s]  

[Iter 1250] Wmean=1.575 Wmax=3.148 B=1.003 MMD=0.2388 MeanErr=1.507 VarErr=0.423


 26%|██▌       | 1301/5000 [05:41<49:34,  1.24it/s]  

[Iter 1300] Wmean=1.584 Wmax=3.212 B=0.990 MMD=0.2428 MeanErr=1.523 VarErr=0.416


 27%|██▋       | 1351/5000 [05:54<49:16,  1.23it/s]  

[Iter 1350] Wmean=1.607 Wmax=3.361 B=0.992 MMD=0.2469 MeanErr=1.550 VarErr=0.424


 28%|██▊       | 1401/5000 [06:07<47:59,  1.25it/s]  

[Iter 1400] Wmean=1.588 Wmax=3.464 B=0.997 MMD=0.2437 MeanErr=1.528 VarErr=0.408


 29%|██▉       | 1451/5000 [06:20<47:47,  1.24it/s]  

[Iter 1450] Wmean=1.613 Wmax=3.405 B=1.005 MMD=0.2494 MeanErr=1.560 VarErr=0.417


 30%|███       | 1501/5000 [06:33<47:01,  1.24it/s]  

[Iter 1500] Wmean=1.621 Wmax=3.557 B=1.034 MMD=0.2511 MeanErr=1.563 VarErr=0.407


 31%|███       | 1550/5000 [06:46<1:03:09,  1.10s/it]

[Iter 1550] Wmean=1.631 Wmax=3.559 B=1.027 MMD=0.2533 MeanErr=1.575 VarErr=0.404


 32%|███▏      | 1600/5000 [06:59<1:04:11,  1.13s/it]

[Iter 1600] Wmean=1.618 Wmax=3.585 B=1.048 MMD=0.2512 MeanErr=1.560 VarErr=0.408


 33%|███▎      | 1650/5000 [07:12<1:05:11,  1.17s/it]

[Iter 1650] Wmean=1.627 Wmax=3.582 B=1.060 MMD=0.2532 MeanErr=1.564 VarErr=0.418


 34%|███▍      | 1701/5000 [07:24<49:00,  1.12it/s]  

[Iter 1700] Wmean=1.628 Wmax=3.586 B=1.091 MMD=0.2538 MeanErr=1.571 VarErr=0.407


 35%|███▌      | 1751/5000 [07:37<47:51,  1.13it/s]  

[Iter 1750] Wmean=1.621 Wmax=3.549 B=1.131 MMD=0.2529 MeanErr=1.564 VarErr=0.414


 36%|███▌      | 1801/5000 [07:50<48:12,  1.11it/s]  

[Iter 1800] Wmean=1.628 Wmax=3.538 B=1.118 MMD=0.2544 MeanErr=1.578 VarErr=0.408


 37%|███▋      | 1851/5000 [08:03<46:37,  1.13it/s]  

[Iter 1850] Wmean=1.624 Wmax=3.585 B=1.114 MMD=0.2538 MeanErr=1.571 VarErr=0.400


 38%|███▊      | 1901/5000 [08:16<43:11,  1.20it/s]

[Iter 1900] Wmean=1.643 Wmax=3.516 B=1.169 MMD=0.2593 MeanErr=1.586 VarErr=0.407


 39%|███▉      | 1950/5000 [08:29<54:32,  1.07s/it]

[Iter 1950] Wmean=1.639 Wmax=3.523 B=1.195 MMD=0.2593 MeanErr=1.584 VarErr=0.405


 40%|████      | 2001/5000 [08:43<41:00,  1.22it/s]

[Iter 2000] Wmean=1.635 Wmax=3.523 B=1.186 MMD=0.2576 MeanErr=1.583 VarErr=0.404


 41%|████      | 2051/5000 [08:56<39:46,  1.24it/s]

[Iter 2050] Wmean=1.628 Wmax=3.556 B=1.189 MMD=0.2558 MeanErr=1.577 VarErr=0.403


 42%|████▏     | 2101/5000 [09:09<39:21,  1.23it/s]

[Iter 2100] Wmean=1.620 Wmax=3.449 B=1.200 MMD=0.2545 MeanErr=1.569 VarErr=0.392


 43%|████▎     | 2151/5000 [09:22<38:33,  1.23it/s]

[Iter 2150] Wmean=1.621 Wmax=3.435 B=1.199 MMD=0.2570 MeanErr=1.571 VarErr=0.404


 44%|████▍     | 2201/5000 [09:35<37:48,  1.23it/s]

[Iter 2200] Wmean=1.640 Wmax=3.455 B=1.212 MMD=0.2597 MeanErr=1.590 VarErr=0.394


 45%|████▌     | 2250/5000 [09:48<50:24,  1.10s/it]

[Iter 2250] Wmean=1.643 Wmax=3.457 B=1.195 MMD=0.2572 MeanErr=1.594 VarErr=0.405


 46%|████▌     | 2300/5000 [10:02<52:19,  1.16s/it]

[Iter 2300] Wmean=1.616 Wmax=3.624 B=1.151 MMD=0.2544 MeanErr=1.562 VarErr=0.406


 47%|████▋     | 2351/5000 [10:15<39:07,  1.13it/s]

[Iter 2350] Wmean=1.645 Wmax=3.604 B=1.158 MMD=0.2604 MeanErr=1.591 VarErr=0.406


 47%|████▋     | 2374/5000 [10:19<11:25,  3.83it/s]


KeyboardInterrupt: 

# TD3 Approach with no replay buffer, shorter horizon

In [15]:
    # ===================== TD3-RLFS on Breast Cancer (Bayesian Logistic Regression) =====================
    # Twin critics, target policy smoothing, delayed policy updates, soft target updates, (optional) reward norm.
    # Objective unchanged: sum_t [log B - log F] + log π(x_T) with π the BLR unnormalized posterior on training data.
    # ---------------------------------------------------------------------------------------
    import jax
    import jax.numpy as jnp
    from jax import random, jit, lax, value_and_grad, device_put
    from flax import linen as nn
    from flax.training.train_state import TrainState
    import optax
    import numpy as np
    from tqdm import trange

    from sklearn.datasets import load_breast_cancer
    from sklearn.model_selection import train_test_split
    from sklearn.preprocessing import StandardScaler

    print("Devices:", jax.devices())

    # ---------------------------------------------------------------------------------------------------
    # Dataset: Breast Cancer Wisconsin Diagnostic (569 samples, 30 features)
    # ---------------------------------------------------------------------------------------------------
    data = load_breast_cancer()
    X = data.data.astype(np.float32)
    y = data.target.astype(np.float32)

    # Standardize features
    scaler = StandardScaler()
    X = scaler.fit_transform(X)

    # Train/Test split
    X_train_np, X_test_np, y_train_np, y_test_np = train_test_split(X, y, test_size=0.25, random_state=0, stratify=y)
    X_train, X_test = jnp.array(X_train_np), jnp.array(X_test_np)
    y_train, y_test = jnp.array(y_train_np), jnp.array(y_test_np)

    D_base = X_train.shape[1]  # 30
    D = D_base + 1             # weights + bias = 31
    print(f"Dataset: train={X_train.shape[0]}, test={X_test.shape[0]}, dim={D_base}")

    # ---------------------------------------------------------------------------------------------------
    # BLR target (weights + bias concatenated)
    # ---------------------------------------------------------------------------------------------------
    def blr_log_unnormalized(params, X, y, alpha=1.0):
        """ log π̄(w,b) = log p(y|X,w,b) + log p(w,b), with Gaussian prior N(0, α⁻¹ I) """
        w, b = params[:-1], params[-1]
        logits = X @ w + b
        log_lik = jnp.sum(y * jax.nn.log_sigmoid(logits) + (1 - y) * jax.nn.log_sigmoid(-logits))
        log_prior = -0.5 * alpha * jnp.dot(params, params)
        return log_lik + log_prior

    # ---------------------------------------------------------------------------------------------------
    # Nets (TimeEmbed, Actor, Critic)
    # ---------------------------------------------------------------------------------------------------
    class TimeEmbed(nn.Module):
        hidden: int = 64
        @nn.compact
        def __call__(self, t):
            freqs = jnp.asarray([1., 2., 4., 8., 16.])
            sinus = jnp.concatenate([
                jnp.sin(2*jnp.pi*freqs[None,:] * t[:,None]),
                jnp.cos(2*jnp.pi*freqs[None,:] * t[:,None])
            ], axis=-1)
            h = nn.relu(nn.Dense(self.hidden)(sinus))
            h = nn.relu(nn.Dense(self.hidden)(h))
            return h

    class Actor(nn.Module):
        hidden: int = 256
        out_dim: int = D     # 31
        act_scale: float = 3.0
        @nn.compact
        def __call__(self, x, t):
            te = TimeEmbed()(t)
            h = jnp.concatenate([x, te], axis=-1)
            h = nn.relu(nn.Dense(self.hidden)(h))
            h = nn.relu(nn.Dense(self.hidden)(h))
            a = nn.Dense(self.out_dim)(h)
            return self.act_scale * jnp.tanh(a)

    class CriticQ(nn.Module):
        hidden: int = 256
        @nn.compact
        def __call__(self, x, a, t):
            te = TimeEmbed()(t)
            h = jnp.concatenate([x, a, te], axis=-1)
            h = nn.relu(nn.Dense(self.hidden)(h))
            h = nn.relu(nn.Dense(self.hidden)(h))
            return nn.Dense(1)(h).squeeze(-1)

    # ---------------------------------------------------------------------------------------------------
    # RLFS kernels (ADDITIVE dynamics): x_{t+1} = sqrt(1-σ^2)*x_t + a + 0  (noise already added to a during rollout)
    # ---------------------------------------------------------------------------------------------------
    @jit
    def logF(x, x_next, a, sigma):
        muF  = x + a
        diff = x_next - muF
        return -jnp.sum(diff**2, axis=-1) / (2 * sigma**2)

    @jit
    def logB(x, x_next, sigma):
        diff = x - jnp.sqrt(1 - sigma**2) * x_next
        return -jnp.sum(diff**2, axis=-1) / (2*sigma**2)

    def make_rollout_trajectory(T, actor_forward):
        invT = jnp.array(1.0 / T, dtype=jnp.float32)
        def rollout_step(carry, _):
            key, x, t, actor_params, sigma = carry
            key, sub = random.split(key)
            a_det = actor_forward(actor_params, x, t)     # deterministic action
            eps   = random.normal(sub, x.shape)
            a     = a_det + sigma * eps                   # exploration noise
            x_next = x + a
            t_next = t + invT
            r_step = logB(x, x_next, sigma) - logF(x, x_next, a, sigma)
            carry_next = (key, x_next, t_next, actor_params, sigma)
            trans = (x, a_det, t, r_step, x_next, t_next)
            return carry_next, trans

        @jit
        def rollout_trajectory(key, x0, t0, actor_params, sigma):
            init = (key, x0, t0, actor_params, sigma)
            (key_f, xT, tT, _, _), (xs, as_, ts, rs, xns, tns) = lax.scan(
                rollout_step, init, xs=None, length=T
            )
            return (xs, as_, ts, rs, xns, tns), (xT, tT)
        return rollout_trajectory

    # ---------------------------------------------------------------------------------------------------
    # Replay Buffer
    # ---------------------------------------------------------------------------------------------------
    class ReplayBuffer:
        def __init__(self, capacity, obs_dim, act_dim):
            self.capacity = int(capacity); self.ptr = 0; self.size = 0
            self.x  = np.zeros((capacity, obs_dim), np.float32)
            self.a  = np.zeros((capacity, act_dim), np.float32)
            self.t  = np.zeros((capacity,), np.float32)
            self.r  = np.zeros((capacity,), np.float32)
            self.xn = np.zeros((capacity, obs_dim), np.float32)
            self.tn = np.zeros((capacity,), np.float32)
            self.done = np.zeros((capacity,), np.float32)
        def push_batch(self, x, a, t, r, xn, tn, done):
            B = x.shape[0]
            idx = (np.arange(B) + self.ptr) % self.capacity
            self.x[idx] = x; self.a[idx] = a; self.t[idx] = t; self.r[idx] = r
            self.xn[idx] = xn; self.tn[idx] = tn; self.done[idx] = done
            self.ptr = (self.ptr + B) % self.capacity
            self.size = int(min(self.capacity, self.size + B))
        def sample(self, batch_size):
            idx = np.random.randint(0, self.size, size=batch_size)
            return (self.x[idx], self.a[idx], self.t[idx], self.r[idx],
                    self.xn[idx], self.tn[idx], self.done[idx])
        def __len__(self): return self.size

    # ---------------------------------------------------------------------------------------------------
    # Utilities
    # ---------------------------------------------------------------------------------------------------
    def soft_update(target, source, tau):
        return jax.tree_util.tree_map(lambda t, s: (1-tau)*t + tau*s, target, source)

    @jit
    def predictive_metrics(params_batch, X, y):
        """ Monte-Carlo predictive NLL & accuracy using parameter samples """
        w = params_batch[:, :-1]            # [S, 30]
        b = params_batch[:,  -1:]           # [S, 1]
        logits = X @ w.T + b.T              # [N, S]
        probs  = jax.nn.sigmoid(logits)     # [N, S]
        p_mc   = jnp.mean(probs, axis=1)    # [N]
        eps = 1e-7
        nll = -jnp.mean(y*jnp.log(p_mc+eps) + (1 - y)*jnp.log(1-p_mc+eps))
        acc = jnp.mean((p_mc >= 0.5) == (y >= 0.5))
        return nll, acc

    # ---------------------------------------------------------------------------------------------------
    # TD3 Agent for RLFS
    # ---------------------------------------------------------------------------------------------------
    def make_actor_forward(D_out: int):
        module = Actor(out_dim=D_out)
        @jax.jit
        def _forward(params, x, t):
            return module.apply(params, x, t)
        return _forward

    @jit
    def critic_forward(params, x, a, t):
        return CriticQ().apply(params, x, a, t)

    def twin_critic_grads_factory(actor_forward, target_noise_std, target_noise_clip):
        """TD3 twin-critic MSE grads with target policy smoothing."""
        @jax.jit
        def _critic_grads(c1_params, c2_params, c1_targ, c2_targ, actor_targ, batch, gamma, key):
            x, a, t, r, xn, tn, done = batch

            # Target with policy smoothing
            noise = target_noise_std * random.normal(key, a.shape)
            noise = jnp.clip(noise, -target_noise_clip, target_noise_clip)
            a_next = actor_forward(actor_targ, xn, tn) + noise
            q1_next = critic_forward(c1_targ, xn, a_next, tn)
            q2_next = critic_forward(c2_targ, xn, a_next, tn)
            q_next = jnp.minimum(q1_next, q2_next)
            y = jax.lax.stop_gradient(r + (1.0 - done) * gamma * q_next)

            def loss1(p):
                q = critic_forward(p, x, a, t)
                return jnp.mean((q - y)**2)
            def loss2(p):
                q = critic_forward(p, x, a, t)
                return jnp.mean((q - y)**2)

            g1 = jax.grad(loss1)(c1_params)
            g2 = jax.grad(loss2)(c2_params)
            return g1, g2
        return _critic_grads


    def actor_grads_factory(actor_forward):
        @jax.jit
        def _actor_grads(actor_params, critic_params, batch):
            x, _, t, _, _, _, _ = batch
            def loss_fn(p):
                a_pred = actor_forward(p, x, t)
                q = critic_forward(critic_params, x, a_pred, t)
                return -jnp.mean(q)
            return jax.grad(loss_fn)(actor_params)
        return _actor_grads


    class TD3Agent:
        def __init__(self, D_out=D, sigma=0.2, T=24, lr_actor=3e-4, lr_critic=3e-4,
                    tau=0.01, gamma=1.0, seed=0,
                    target_noise_std=0.10, target_noise_clip=0.20,
                    policy_delay=2, reward_norm=False):
            self.sigma, self.T, self.tau, self.gamma = float(sigma), int(T), float(tau), float(gamma)
            self.D_out = int(D_out)
            self.target_noise_std = float(target_noise_std)
            self.target_noise_clip = float(target_noise_clip)
            self.policy_delay = int(policy_delay)
            self.reward_norm = bool(reward_norm)

            key = random.PRNGKey(seed)
            dummy_x = jnp.zeros((1, self.D_out)); dummy_a = jnp.zeros((1, self.D_out)); dummy_t = jnp.zeros((1,))
            key, ka, kc1, kc2 = random.split(key, 4)

            self.actor = Actor(out_dim=self.D_out)
            self.critic1 = CriticQ()
            self.critic2 = CriticQ()

            actor_params = self.actor.init(ka, dummy_x, dummy_t)
            c1_params = self.critic1.init(kc1, dummy_x, dummy_a, dummy_t)
            c2_params = self.critic2.init(kc2, dummy_x, dummy_a, dummy_t)

            clip = optax.clip_by_global_norm(1.0)
            self.actor_state = TrainState.create(
                apply_fn=self.actor.apply, params=actor_params,
                tx=optax.chain(clip, optax.adam(lr_actor))
            )
            self.critic1_state = TrainState.create(
                apply_fn=self.critic1.apply, params=c1_params,
                tx=optax.chain(clip, optax.adam(lr_critic))
            )
            self.critic2_state = TrainState.create(
                apply_fn=self.critic2.apply, params=c2_params,
                tx=optax.chain(clip, optax.adam(lr_critic))
            )

            self.actor_targ  = self.actor_state.params
            self.critic1_targ= self.critic1_state.params
            self.critic2_targ= self.critic2_state.params

            self.key = key
            self.actor_forward = make_actor_forward(self.D_out)
            self._critic_grads = twin_critic_grads_factory(self.actor_forward,
                                                          self.target_noise_std,
                                                          self.target_noise_clip)
            # Standard TD3: use Q1 for policy gradient
            self._actor_grads  = actor_grads_factory(self.actor_forward)

            # reward normalization stats (optional)
            self._r_mean = 0.0
            self._r_var = 1.0
            self._r_count = 1e-6

        @staticmethod
        @jit
        def _soft_update(target, source, tau):
            return jax.tree_util.tree_map(lambda t, s: (1-tau)*t + tau*s, target, source)

        def _update_r_stats(self, r_np):
            # Welford
            r = r_np.reshape(-1).astype(np.float64)
            batch_count = r.shape[0]
            if batch_count == 0:
                return
            batch_mean = r.mean()
            batch_var  = r.var()
            delta = batch_mean - self._r_mean
            tot_count = self._r_count + batch_count
            new_mean = self._r_mean + delta * batch_count / tot_count
            m_a = self._r_var * self._r_count
            m_b = batch_var * batch_count
            M2 = m_a + m_b + (delta**2) * self._r_count * batch_count / tot_count
            new_var = M2 / tot_count
            self._r_mean, self._r_var, self._r_count = new_mean, new_var, tot_count

        def normalize_r(self, r_np):
            if not self.reward_norm:
                return r_np
            self._update_r_stats(r_np)
            std = np.sqrt(self._r_var) + 1e-6
            return (r_np - self._r_mean) / std

        def update(self, batch, step_idx):
            # Optional reward normalization
            x, a, t, r, xn, tn, done = batch
            r = self.normalize_r(r)
            batch_dev = (device_put(jnp.asarray(x)), device_put(jnp.asarray(a)),
                        device_put(jnp.asarray(t)), device_put(jnp.asarray(r)),
                        device_put(jnp.asarray(xn)), device_put(jnp.asarray(tn)),
                        device_put(jnp.asarray(done)))

            # Twin critic update
            self.key, sub = random.split(self.key)
            gC1, gC2 = self._critic_grads(self.critic1_state.params, self.critic2_state.params,
                                          self.critic1_targ, self.critic2_targ,
                                          self.actor_targ, batch_dev, self.gamma, sub)
            self.critic1_state = self.critic1_state.apply_gradients(grads=gC1)
            self.critic2_state = self.critic2_state.apply_gradients(grads=gC2)

            # Delayed actor update + soft target updates
            if (step_idx % self.policy_delay) == 0:
                gA = self._actor_grads(self.actor_state.params, self.critic1_state.params, batch_dev)
                self.actor_state = self.actor_state.apply_gradients(grads=gA)
                self.actor_targ   = self._soft_update(self.actor_targ,  self.actor_state.params,  self.tau)
                self.critic1_targ = self._soft_update(self.critic1_targ,self.critic1_state.params,self.tau)
                self.critic2_targ = self._soft_update(self.critic2_targ,self.critic2_state.params,self.tau)

    # ---------------------------------------------------------------------------------------------------
    # Training (TD3-RLFS) on Breast Cancer
    # ---------------------------------------------------------------------------------------------------
    if __name__ == '__main__':
        # Hyperparams (kept same spirit as your previous TD3 run)
        SIGMA = 0.1
        T_H = 10
        B_COLLECT = 2048
        REPLAY_CAP = 2048
        START_STEPS = 2048
        TRAIN_ITERS = 5000
        B_UPDATE = 2048
        UPDATES_PER_ITER = 1
        TAU = 0.01
        LR_ACTOR = 1e-5
        LR_CRITIC= 5e-5
        ALPHA = 3.0    # BLR prior precision

        # TD3 knobs
        TARGET_NOISE_STD = 0
        TARGET_NOISE_CLIP= 0
        POLICY_DELAY = 1
        REWARD_NORM = False

        # Agent & replay
        agent = TD3Agent(D_out=D, sigma=SIGMA, T=T_H, lr_actor=LR_ACTOR, lr_critic=LR_CRITIC,
                        tau=TAU, gamma=1.0, seed=0,
                        target_noise_std=TARGET_NOISE_STD, target_noise_clip=TARGET_NOISE_CLIP,
                        policy_delay=POLICY_DELAY, reward_norm=REWARD_NORM)
        rb = ReplayBuffer(REPLAY_CAP, obs_dim=D, act_dim=D)
        rollout_trajectory = make_rollout_trajectory(T_H, actor_forward=agent.actor_forward)

        # Terminal reward helper based on training data
        def batch_log_unnorm_pi(xT_batch):
            return jax.vmap(lambda params: blr_log_unnormalized(params, X_train, y_train, ALPHA))(xT_batch)

        # Warmup collection
        print('Collecting warmup...')
        for _ in trange(max(1, START_STEPS // B_COLLECT)):
            agent.key, sub = random.split(agent.key)
            x0 = 0.5 * random.normal(sub, (B_COLLECT, D)); t0 = jnp.zeros((B_COLLECT,))
            (xs, as_, ts, rs, xns, tns), (xT, tT) = rollout_trajectory(agent.key, x0, t0, agent.actor_state.params, SIGMA)
            r_term = batch_log_unnorm_pi(xT)
            rs = rs.at[-1].set(rs[-1] + r_term)
            DONE_mat = np.zeros((T_H, B_COLLECT), dtype=np.float32); DONE_mat[-1,:] = 1.0
            Xb  = np.array(xs.reshape(-1, D));  A  = np.array(as_.reshape(-1, D))
            TT  = np.array(ts.reshape(-1));     R  = np.array(rs.reshape(-1))
            XNb = np.array(xns.reshape(-1, D)); TN = np.array(tns.reshape(-1))
            DONE = DONE_mat.reshape(-1)
            rb.push_batch(Xb, A, TT, R, XNb, TN, DONE)

        # Train TD3-RLFS; periodically evaluate
        print('Training TD3-RLFS...')
        last_params_RL = None
        step_idx = 0
        for it in trange(TRAIN_ITERS):
            agent.key, sub = random.split(agent.key)
            x0 = 0.5 * random.normal(sub, (B_COLLECT, D)); t0 = jnp.zeros((B_COLLECT,))
            (xs, as_, ts, rs, xns, tns), (xT, tT) = rollout_trajectory(agent.key, x0, t0, agent.actor_state.params, SIGMA)
            r_term = batch_log_unnorm_pi(xT)
            rs = rs.at[-1].add(r_term)

            DONE_mat = np.zeros((T_H, B_COLLECT), dtype=np.float32); DONE_mat[-1,:] = 1.0
            Xb  = np.array(xs.reshape(-1, D));  A  = np.array(as_.reshape(-1, D))
            TT  = np.array(ts.reshape(-1));     R  = np.array(rs.reshape(-1))
            XNb = np.array(xns.reshape(-1, D)); TN = np.array(tns.reshape(-1))
            DONE = DONE_mat.reshape(-1)
            rb.push_batch(Xb, A, TT, R, XNb, TN, DONE)

            if len(rb) >= START_STEPS:
                for _ in range(UPDATES_PER_ITER):
                    batch = rb.sample(B_UPDATE)
                    agent.update(batch, step_idx)
                    step_idx += 1

            if (it + 1) % 100 == 0:
                # Evaluate by rolling many terminal params and computing predictive metrics
                agent.key, sub = random.split(agent.key)
                x0_eval = 0.5 * random.normal(sub, (8000, D)); t0_eval = jnp.zeros((8000,))
                (_, _, _, _, _, _), (params_T, _) = rollout_trajectory(agent.key, x0_eval, t0_eval,
                                                                      agent.actor_state.params, agent.sigma)
                nll_tr, acc_tr = predictive_metrics(params_T, X_train, y_train)
                nll_te, acc_te = predictive_metrics(params_T, X_test,  y_test)
                print(f"[Iter {it+1:04d}] TD3-RLFS — Train NLL={float(nll_tr):.3f} Acc={float(acc_tr):.3f} | "
                      f"Test NLL={float(nll_te):.3f} Acc={float(acc_te):.3f}")
                last_params_RL = np.array(params_T)

        # If loop ended before eval block, sample once
        if last_params_RL is None:
            agent.key, sub = random.split(agent.key)
            x0_eval = 0.5 * random.normal(sub, (8000, D)); t0_eval = jnp.zeros((8000,))
            (_, _, _, _, _, _), (params_T, _) = rollout_trajectory(agent.key, x0_eval, t0_eval,
                                                                  agent.actor_state.params, agent.sigma)
            last_params_RL = np.array(params_T)

        # Final test metrics
        nll_te, acc_te = predictive_metrics(jnp.asarray(last_params_RL), X_test, y_test)
        print(f"\n==================== TD3-RLFS FINAL ====================")
        print(f"Test NLL={float(nll_te):.3f}, Acc={float(acc_te):.3f}")


Devices: [CudaDevice(id=0)]
Dataset: train=426, test=143, dim=30
Collecting warmup...


100%|██████████| 1/1 [00:01<00:00,  1.26s/it]


Training TD3-RLFS...


  2%|▏         | 101/5000 [00:16<44:53,  1.82it/s]

[Iter 0100] TD3-RLFS — Train NLL=3.708 Acc=0.310 | Test NLL=3.490 Acc=0.301


  4%|▍         | 201/5000 [00:29<10:25,  7.67it/s]

[Iter 0200] TD3-RLFS — Train NLL=6.666 Acc=0.094 | Test NLL=7.032 Acc=0.112


  6%|▌         | 301/5000 [00:41<09:57,  7.86it/s]

[Iter 0300] TD3-RLFS — Train NLL=2.413 Acc=0.202 | Test NLL=2.797 Acc=0.210


  8%|▊         | 401/5000 [00:53<09:21,  8.19it/s]

[Iter 0400] TD3-RLFS — Train NLL=1.541 Acc=0.359 | Test NLL=1.515 Acc=0.378


 10%|█         | 501/5000 [01:06<08:56,  8.38it/s]

[Iter 0500] TD3-RLFS — Train NLL=1.462 Acc=0.765 | Test NLL=1.501 Acc=0.713


 12%|█▏        | 601/5000 [01:18<08:56,  8.20it/s]

[Iter 0600] TD3-RLFS — Train NLL=2.795 Acc=0.383 | Test NLL=2.916 Acc=0.329


 14%|█▍        | 701/5000 [01:31<08:41,  8.25it/s]

[Iter 0700] TD3-RLFS — Train NLL=0.577 Acc=0.707 | Test NLL=0.726 Acc=0.685


 16%|█▌        | 801/5000 [01:45<08:37,  8.11it/s]

[Iter 0800] TD3-RLFS — Train NLL=2.837 Acc=0.246 | Test NLL=2.364 Acc=0.371


 18%|█▊        | 901/5000 [01:57<08:13,  8.30it/s]

[Iter 0900] TD3-RLFS — Train NLL=2.559 Acc=0.221 | Test NLL=2.339 Acc=0.252


 20%|██        | 1001/5000 [02:10<07:59,  8.34it/s]

[Iter 1000] TD3-RLFS — Train NLL=0.404 Acc=0.840 | Test NLL=0.530 Acc=0.825


 22%|██▏       | 1101/5000 [02:22<07:45,  8.38it/s]

[Iter 1100] TD3-RLFS — Train NLL=1.335 Acc=0.648 | Test NLL=1.303 Acc=0.678


 24%|██▍       | 1201/5000 [02:35<07:42,  8.21it/s]

[Iter 1200] TD3-RLFS — Train NLL=0.645 Acc=0.796 | Test NLL=0.837 Acc=0.783


 26%|██▌       | 1301/5000 [02:47<07:18,  8.44it/s]

[Iter 1300] TD3-RLFS — Train NLL=0.960 Acc=0.789 | Test NLL=1.163 Acc=0.804


 28%|██▊       | 1401/5000 [03:00<07:09,  8.38it/s]

[Iter 1400] TD3-RLFS — Train NLL=0.438 Acc=0.894 | Test NLL=0.661 Acc=0.853


 30%|███       | 1501/5000 [03:12<06:56,  8.40it/s]

[Iter 1500] TD3-RLFS — Train NLL=0.492 Acc=0.840 | Test NLL=0.656 Acc=0.818


 32%|███▏      | 1601/5000 [03:25<06:46,  8.36it/s]

[Iter 1600] TD3-RLFS — Train NLL=1.507 Acc=0.700 | Test NLL=1.628 Acc=0.664


 34%|███▍      | 1701/5000 [03:38<06:40,  8.24it/s]

[Iter 1700] TD3-RLFS — Train NLL=2.030 Acc=0.505 | Test NLL=2.030 Acc=0.490


 36%|███▌      | 1801/5000 [03:51<06:28,  8.24it/s]

[Iter 1800] TD3-RLFS — Train NLL=2.682 Acc=0.218 | Test NLL=2.266 Acc=0.259


 38%|███▊      | 1901/5000 [04:03<06:06,  8.46it/s]

[Iter 1900] TD3-RLFS — Train NLL=3.240 Acc=0.169 | Test NLL=2.902 Acc=0.217


 40%|████      | 2001/5000 [04:15<05:58,  8.36it/s]

[Iter 2000] TD3-RLFS — Train NLL=5.594 Acc=0.249 | Test NLL=5.759 Acc=0.266


 42%|████▏     | 2101/5000 [04:28<05:45,  8.40it/s]

[Iter 2100] TD3-RLFS — Train NLL=9.451 Acc=0.131 | Test NLL=9.095 Acc=0.140


 44%|████▍     | 2201/5000 [04:41<05:35,  8.35it/s]

[Iter 2200] TD3-RLFS — Train NLL=10.021 Acc=0.106 | Test NLL=10.030 Acc=0.084


 46%|████▌     | 2301/5000 [04:53<05:19,  8.45it/s]

[Iter 2300] TD3-RLFS — Train NLL=4.229 Acc=0.347 | Test NLL=4.109 Acc=0.357


 48%|████▊     | 2401/5000 [05:06<05:12,  8.31it/s]

[Iter 2400] TD3-RLFS — Train NLL=1.120 Acc=0.648 | Test NLL=1.190 Acc=0.692


 50%|█████     | 2501/5000 [05:18<05:01,  8.30it/s]

[Iter 2500] TD3-RLFS — Train NLL=0.783 Acc=0.800 | Test NLL=0.933 Acc=0.776


 52%|█████▏    | 2601/5000 [05:31<04:48,  8.32it/s]

[Iter 2600] TD3-RLFS — Train NLL=1.309 Acc=0.803 | Test NLL=1.738 Acc=0.762


 54%|█████▍    | 2701/5000 [05:43<04:46,  8.01it/s]

[Iter 2700] TD3-RLFS — Train NLL=1.460 Acc=0.793 | Test NLL=1.726 Acc=0.734


 56%|█████▌    | 2801/5000 [06:02<04:30,  8.12it/s]

[Iter 2800] TD3-RLFS — Train NLL=1.156 Acc=0.791 | Test NLL=1.418 Acc=0.741


 58%|█████▊    | 2901/5000 [06:18<04:15,  8.23it/s]

[Iter 2900] TD3-RLFS — Train NLL=1.087 Acc=0.866 | Test NLL=1.209 Acc=0.804


 60%|██████    | 3001/5000 [06:31<03:57,  8.42it/s]

[Iter 3000] TD3-RLFS — Train NLL=0.766 Acc=0.887 | Test NLL=0.740 Acc=0.853


 62%|██████▏   | 3101/5000 [06:43<03:48,  8.30it/s]

[Iter 3100] TD3-RLFS — Train NLL=0.789 Acc=0.885 | Test NLL=0.625 Acc=0.881


 64%|██████▍   | 3201/5000 [06:56<03:32,  8.45it/s]

[Iter 3200] TD3-RLFS — Train NLL=0.879 Acc=0.866 | Test NLL=0.819 Acc=0.860


 66%|██████▌   | 3301/5000 [07:09<03:31,  8.05it/s]

[Iter 3300] TD3-RLFS — Train NLL=1.148 Acc=0.845 | Test NLL=1.310 Acc=0.797


 68%|██████▊   | 3401/5000 [07:21<03:08,  8.49it/s]

[Iter 3400] TD3-RLFS — Train NLL=1.219 Acc=0.847 | Test NLL=1.314 Acc=0.797


 70%|███████   | 3501/5000 [07:34<03:01,  8.28it/s]

[Iter 3500] TD3-RLFS — Train NLL=1.229 Acc=0.845 | Test NLL=1.433 Acc=0.797


 72%|███████▏  | 3601/5000 [07:46<02:48,  8.30it/s]

[Iter 3600] TD3-RLFS — Train NLL=1.338 Acc=0.845 | Test NLL=1.478 Acc=0.790


 74%|███████▍  | 3701/5000 [08:00<02:33,  8.46it/s]

[Iter 3700] TD3-RLFS — Train NLL=1.218 Acc=0.847 | Test NLL=1.413 Acc=0.790


 76%|███████▌  | 3801/5000 [08:12<02:27,  8.15it/s]

[Iter 3800] TD3-RLFS — Train NLL=1.330 Acc=0.847 | Test NLL=1.553 Acc=0.790


 78%|███████▊  | 3901/5000 [08:25<02:44,  6.67it/s]

[Iter 3900] TD3-RLFS — Train NLL=1.387 Acc=0.847 | Test NLL=1.630 Acc=0.790


 80%|████████  | 4001/5000 [08:37<02:43,  6.10it/s]

[Iter 4000] TD3-RLFS — Train NLL=1.362 Acc=0.845 | Test NLL=1.516 Acc=0.790


 82%|████████▏ | 4101/5000 [08:50<02:30,  5.98it/s]

[Iter 4100] TD3-RLFS — Train NLL=1.364 Acc=0.847 | Test NLL=1.598 Acc=0.790


 84%|████████▍ | 4201/5000 [09:03<02:19,  5.75it/s]

[Iter 4200] TD3-RLFS — Train NLL=1.417 Acc=0.845 | Test NLL=1.652 Acc=0.790


 86%|████████▌ | 4301/5000 [09:15<02:08,  5.42it/s]

[Iter 4300] TD3-RLFS — Train NLL=1.388 Acc=0.845 | Test NLL=1.637 Acc=0.790


 88%|████████▊ | 4401/5000 [09:28<01:35,  6.26it/s]

[Iter 4400] TD3-RLFS — Train NLL=1.448 Acc=0.845 | Test NLL=1.667 Acc=0.797


 90%|█████████ | 4501/5000 [09:40<01:09,  7.20it/s]

[Iter 4500] TD3-RLFS — Train NLL=1.376 Acc=0.845 | Test NLL=1.598 Acc=0.797


 92%|█████████▏| 4601/5000 [09:53<00:49,  8.06it/s]

[Iter 4600] TD3-RLFS — Train NLL=1.372 Acc=0.845 | Test NLL=1.538 Acc=0.797


 94%|█████████▍| 4701/5000 [10:06<00:36,  8.13it/s]

[Iter 4700] TD3-RLFS — Train NLL=1.416 Acc=0.845 | Test NLL=1.651 Acc=0.797


 96%|█████████▌| 4801/5000 [10:19<00:24,  8.17it/s]

[Iter 4800] TD3-RLFS — Train NLL=1.358 Acc=0.843 | Test NLL=1.627 Acc=0.797


 98%|█████████▊| 4901/5000 [10:31<00:12,  8.14it/s]

[Iter 4900] TD3-RLFS — Train NLL=1.498 Acc=0.843 | Test NLL=1.766 Acc=0.797


100%|██████████| 5000/5000 [10:43<00:00,  7.76it/s]

[Iter 5000] TD3-RLFS — Train NLL=1.483 Acc=0.843 | Test NLL=1.739 Acc=0.797

Test NLL=1.739, Acc=0.797



