# CGANs - Conditional Generative Adversarial Nets



In [18]:
from tensorflow.keras import mixed_precision
mixed_precision.set_global_policy("mixed_float16")
import tensorflow as tf
import numpy as np
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.layers import (
    Conv2D, Dense, GlobalAveragePooling2D, LayerNormalization, Add, DepthwiseConv2D, MaxPool2D
)
from utils import PoolingLayer, ResidualBlock, ResidualBlock3x3, ResidualBlock5x5, ResidualBlock7x7, SpatialSE, ChannelSE, ResidualBlockDepthwise3x3, ResidualBlockDepthwise5x5, ResidualBlockDepthwise7x7



# ---------------------------------------------------------
# Softmax Router (no Gumbel). Optional hard mode at inference.
# ---------------------------------------------------------
class SoftmaxRouter(layers.Layer):
    def __init__(self, num_choices, hard_at_inference=False, **kwargs):
        super().__init__(**kwargs)
        self.num_choices = num_choices
        self.hard_at_inference = hard_at_inference
        self.logits_layer = Dense(num_choices)

    def call(self, features, training=None):
        logits = self.logits_layer(GlobalAveragePooling2D()(features))  # (B, K)
        if training or not self.hard_at_inference:
            probs = tf.nn.softmax(logits, axis=-1)                      # (B, K)
        else:
            idx = tf.argmax(logits, axis=-1)
            probs = tf.one_hot(idx, depth=self.num_choices, dtype=tf.float32)
        return probs  # (B, K)

# ---------------------------------------------------------
# Halting Head: outputs probability to halt at current step
# ---------------------------------------------------------
class HaltingHead(layers.Layer):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.halt_dense = Dense(1)

    def call(self, features):
        h = GlobalAveragePooling2D()(features)
        p_halt = tf.nn.sigmoid(self.halt_dense(h))  # (B, 1)
        return p_halt
    
def match_mask(mask, ref):
    # mask: [B,1,1,1] or [B,1]
    # ref : [B,H,W,C]
    mask = tf.reshape(mask, (-1, 1, 1, 1))          # ensure rank-4
    return mask * tf.ones_like(ref[..., :1])        # broadcast to [B,H,W,1]


import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import numpy as np


# ---------------------------
# Tiny conv stem
# ---------------------------
class ConvStem(layers.Layer):
    def __init__(self, out_ch, **kw):
        super().__init__(**kw)
        self.conv = layers.Conv2D(out_ch, 3, padding="same", use_bias=False)
        self.norm = layers.LayerNormalization()
        self.act  = layers.Activation("swish")

    def call(self, x, training=None):
        x = self.conv(x)
        x = self.norm(x, training=training)
        x = self.act(x)
        return x


# ---------------------------
# Multi-head attention pooling router
# Produces logits over K experts
# ---------------------------
class AttnPoolRouter(layers.Layer):
    def __init__(self, K, heads=2, dim_head=64, mlp_hidden=64, **kw):
        super().__init__(**kw)
        self.K = int(K)
        self.heads = int(heads)
        self.dim_head = int(dim_head)
        self.mlp_hidden = int(mlp_hidden)

        self.q = self.add_weight(
            name="queries", shape=(self.heads, self.dim_head),
            initializer="glorot_uniform", trainable=True)

        self.key_proj = layers.Conv2D(self.heads*self.dim_head, 1, use_bias=False)
        self.val_proj = layers.Conv2D(self.heads*self.dim_head, 1, use_bias=False)

        # Head aggregator -> K logits
        if self.mlp_hidden > 0:
            self.head_mlp = keras.Sequential([
                layers.Dense(self.mlp_hidden, activation="swish"),
                layers.Dense(self.K)
            ])
        else:
            self.head_mlp = layers.Dense(self.K)

    def call(self, x, training=None):
        # x: [B,H,W,C]
        B, H, W, C = tf.shape(x)[0], tf.shape(x)[1], tf.shape(x)[2], tf.shape(x)[3]
        k = self.key_proj(x)  # [B,H,W,heads*dim]
        v = self.val_proj(x)
        k = tf.reshape(k, [B, H*W, self.heads, self.dim_head])  # [B,HW,Hd,D]
        v = tf.reshape(v, [B, H*W, self.heads, self.dim_head])  # [B,HW,Hd,D]
        k = tf.transpose(k, [0,2,1,3])  # [B,heads,HW,dim]
        v = tf.transpose(v, [0,2,1,3])  # [B,heads,HW,dim]

        # queries: [heads, dim] -> [B,heads,1,dim]
        q = tf.expand_dims(self.q, axis=0)
        q = tf.expand_dims(q, axis=2)

        # attn: [B,heads,1,HW]
        attn = tf.matmul(q, k, transpose_b=True) / tf.math.sqrt(tf.cast(self.dim_head, x.dtype))
        attn = tf.nn.softmax(attn, axis=-1)

        # pooled heads: [B,heads,1,dim]
        pooled = tf.matmul(attn, v)  # [B,heads,1,dim]
        pooled = tf.squeeze(pooled, axis=2)  # [B,heads,dim]

        # flatten heads
        pooled = tf.reshape(pooled, [B, self.heads*self.dim_head])  # [B, heads*dim]

        logits = self.head_mlp(pooled, training=training)  # [B,K]
        return logits, pooled  # pooled can be used as a feature if needed


class HaltingClassifierHead(layers.Layer):
    """
    Predicts class probabilities; halts when max prob > tau.
    ST gating: forward uses hard threshold; backward uses a smooth sigmoid around tau.
    """
    def __init__(self, num_classes, hidden=64, halt_temp=3.0, tau=0.8, bias_init=-2.5, **kw):
        super().__init__(**kw)
        self.num_classes = int(num_classes)
        self.hidden = int(hidden)
        self.halt_temp = float(halt_temp)
        self.tau = float(tau)

        if hidden > 0:
            self.classifier = keras.Sequential([
                layers.GlobalAveragePooling2D(),
                layers.Dense(hidden, activation="swish"),
                layers.Dense(num_classes)
            ])
        else:
            self.classifier = keras.Sequential([
                layers.GlobalAveragePooling2D(),
                layers.Dense(num_classes)
            ])

        # a tiny scalar bias we add to (max_prob - tau) before the sigmoid
        self.bias = self.add_weight(
            name="halt_bias", shape=(), initializer=tf.keras.initializers.Constant(bias_init),
            trainable=True)

    def call(self, x, training=None):
        logits = self.classifier(x, training=training)          # [B,C]
        probs  = tf.nn.softmax(logits, axis=-1)                 # [B,C]
        maxp   = tf.reduce_max(probs, axis=-1, keepdims=True)   # [B,1]
        tau = tf.cast(self.tau, maxp.dtype)
        halt_temp = tf.cast(self.halt_temp, maxp.dtype)
        z = (maxp - tau) / tf.maximum(tf.constant(1e-6, dtype=maxp.dtype), halt_temp)
        p_soft = tf.nn.sigmoid(z + tf.cast(self.bias, maxp.dtype))                   # [B,1]
        p_hard = tf.cast(maxp > tau, x.dtype)              # [B,1]
        p_st = p_hard + tf.stop_gradient(p_soft - p_hard)
        return probs, p_soft, p_hard, p_st   # class probs, soft gate, hard gate, ST gate


In [27]:

class AdaptiveRouterBlockTop1Sparse(layers.Layer):
    """
    Top-1 only, faster variant of the original block.

    - Uses AttnPoolRouter to pick a single expert per running sample (argmax).
    - For each unique selected expert we gather its subset once, run the expert,
      scale by the soft routing probability for that expert and scatter back.
    - Removes TensorArrays and the Top-2 logic for a simpler, faster graph.
    """
    def __init__(self,
                 branches,
                 num_classes,
                 pooling_layer=None,
                 pooling_layers=None,
                 unique_pools=False,
                 min_steps=1, max_steps=6, pool_every_n=1,
                 router_heads=2, router_dim=64, router_mlp=64,
                 route_temp=1.5, halt_temp=3.0, halt_tau=0.8,
                 lb_alpha=1e-3, div_alpha=1e-4,
                 top_k=1,                     # force top1
                 name=None):
        super().__init__(name=name)
        assert 1 <= min_steps <= max_steps
        # top_k kept for compatibility but only 1 is supported here
        assert top_k == 1, "This optimized block supports top_k==1 only"
        self.branches = branches
        self.K = len(branches)
        self.num_classes = int(num_classes)

        self.min_steps = int(min_steps)
        self.max_steps = int(max_steps)
        self.pool_every_n = int(pool_every_n)

        self._single_pool      = pooling_layer
        self._pooling_layersIn = pooling_layers
        self.unique_pools      = bool(unique_pools)
        self._pools = None

        self.stem = ConvStem(out_ch=None)

        self.router = AttnPoolRouter(K=self.K, heads=router_heads,
                                     dim_head=router_dim, mlp_hidden=router_mlp)
        self.halt   = HaltingClassifierHead(num_classes=self.num_classes,
                                            hidden=32, halt_temp=halt_temp,
                                            tau=halt_tau, bias_init=-1.0)

        self._route_temp = float(route_temp)
        self.lb_alpha  = float(lb_alpha)
        self.div_alpha = float(div_alpha)

        self.top_k = 1

    @property
    def route_temp(self): return self._route_temp
    @route_temp.setter
    def route_temp(self, v): self._route_temp = float(v)
    @property
    def halt_temp(self): return self.halt.halt_temp
    @halt_temp.setter
    def halt_temp(self, v): self.halt.halt_temp = float(v)

    def _pooling_events(self):
        return 0 if self.pool_every_n <= 0 else self.max_steps // self.pool_every_n

    def build(self, input_shape):
        C_in = int(input_shape[-1])
        self.stem.conv = layers.Conv2D(C_in, 3, padding="same", use_bias=False)

        num_pools = self._pooling_events()
        if num_pools == 0:
            self._pools = []
        else:
            if self._pooling_layersIn is not None:
                assert len(self._pooling_layersIn) == num_pools
                self._pools = self._pooling_layersIn
            elif self.unique_pools:
                assert self._single_pool is not None, "unique_pools=True requires pooling_layer"
                base, cfg, cls = self._single_pool, self._single_pool.get_config(), type(self._single_pool)
                self._pools = [cls(**cfg) for _ in range(num_pools)]
            else:
                assert self._single_pool is not None, "Provide pooling_layer or pooling_layers"
                self._pools = [self._single_pool] * num_pools
        super().build(input_shape)

    def _apply_pool(self, x, pool_count, training=None):
        if self._pools and len(self._pools) > 0:
            n = len(self._pools)
            idx = tf.minimum(tf.convert_to_tensor(pool_count, tf.int32),
                             tf.constant(n - 1, tf.int32))
            fns = [(lambda i=i: self._pools[i](x, training=training)) for i in range(n)]
            return tf.switch_case(idx, branch_fns=fns, default=lambda: self._pools[-1](x, training=training))
        if self._single_pool is not None:
            return self._single_pool(x, training=training)
        return x

    def _diversity_from_means(self, V):
        V = tf.nn.l2_normalize(V, axis=-1)
        sims = tf.matmul(V, V, transpose_b=True)
        E = tf.shape(V)[0]
        mask = 1.0 - tf.eye(E, dtype=V.dtype)
        denom = tf.reduce_sum(mask)
        return tf.where(denom > 0, tf.reduce_sum(sims * mask) / denom, 0.0)

    def _call_branch(self, k_int32, x_k, training=None):
        # Graph-safe expert selection
        fns = [(lambda i=i: self.branches[i](x_k, training=training)) for i in range(self.K)]
        return tf.switch_case(k_int32, branch_fns=fns, default=lambda: self.branches[-1](x_k, training=training))

    @tf.function
    def call_core(self, features, training=None):
        x = self.stem(features, training=training)
        B = tf.shape(x)[0]
        dtype = x.dtype
        halted = tf.zeros([B,1,1,1], dtype=dtype)
        pool_count = tf.constant(0, tf.int32)
        lb_loss_total = tf.constant(0.0, tf.float32)
        div_loss_total = tf.constant(0.0, tf.float32)
        C = tf.shape(x)[-1]
        for t in range(self.max_steps):
            running_mask = tf.squeeze(1.0 - halted, [1,2,3]) > 0.5
            run_idx = tf.cast(tf.reshape(tf.where(running_mask), [-1]), tf.int32)
            Br = tf.shape(run_idx)[0]
            def no_running_path():
                return x, halted, tf.constant(0.0, tf.float32), tf.constant(0.0, tf.float32), pool_count
            def route_running_path():
                x_run = tf.gather(x, run_idx)
                r_logits, _ = self.router(x_run, training=training)
                probs = tf.nn.softmax(r_logits / self.route_temp, axis=-1)
                idx1 = tf.argmax(probs, axis=-1, output_type=tf.int32)
                y_contrib_run = tf.zeros_like(x_run)
                g_all = tf.zeros([Br, C], dtype=x_run.dtype)
                count_all = tf.zeros([Br, 1], dtype=x_run.dtype)
                for u in tf.range(self.K):
                    sel = tf.equal(idx1, tf.cast(u, idx1.dtype))
                    ids = tf.reshape(tf.where(sel), [-1])
                    Bk = tf.shape(ids)[0]
                    def do_u():
                        x_k = tf.gather(x_run, ids)
                        y_k = self._call_branch(tf.cast(u, tf.int32), x_k, training=training)
                        prob_rows = tf.gather(probs, ids)
                        w_for_k = tf.reshape(tf.gather(prob_rows, u, axis=-1), [-1,1,1,1])
                        y_scaled = y_k * w_for_k
                        scatter_idx = tf.expand_dims(ids, axis=-1)
                        y_contrib_updated = tf.tensor_scatter_nd_add(y_contrib_run, scatter_idx, y_scaled)
                        g_k = tf.reduce_mean(y_k, axis=[1,2])
                        g_all_updated = tf.tensor_scatter_nd_add(g_all, scatter_idx, g_k)
                        ones = tf.ones([Bk,1], dtype=x_run.dtype)
                        count_all_updated = tf.tensor_scatter_nd_add(count_all, scatter_idx, ones)
                        return y_contrib_updated, g_all_updated, count_all_updated
                    def skip_u():
                        return y_contrib_run, g_all, count_all
                    y_contrib_run, g_all, count_all = tf.cond(Bk > 0, do_u, skip_u)
                feat_sums = tf.math.unsorted_segment_sum(g_all, idx1, num_segments=self.K)
                counts = tf.math.unsorted_segment_sum(count_all, idx1, num_segments=self.K)
                y_sel_run = tf.where(tf.broadcast_to(tf.reduce_any(tf.not_equal(y_contrib_run, 0.0), axis=[1,2,3])[:,None,None,None], tf.shape(y_contrib_run)), y_contrib_run, x_run)
                class_probs_run, p_soft_run, p_hard_run, p_st_run = self.halt(y_sel_run, training=training)
                if self.min_steps > 1:
                    can = tf.cast(t >= self.min_steps - 1, y_sel_run.dtype)
                    p_st_run = p_st_run * can
                is_pool_step = tf.logical_and(tf.greater(self.pool_every_n, 0), tf.equal(tf.math.mod(t+1, self.pool_every_n), 0))
                def do_pool():
                    x_pooled_all = self._apply_pool(x, pool_count, training=training)
                    y_next_run = self._apply_pool(y_sel_run, pool_count, training=training)
                    return x_pooled_all, y_next_run, pool_count + 1
                def no_pool():
                    return x, y_sel_run, pool_count
                base_x_next, y_next_run, new_pool_count = tf.cond(is_pool_step, do_pool, no_pool)
                run_idx_i32 = tf.cast(run_idx, tf.int32)
                x_next = tf.tensor_scatter_nd_update(base_x_next, tf.expand_dims(run_idx_i32, axis=-1), y_next_run)
                p_st4_run = tf.reshape(p_st_run, [-1,1,1,1])
                old_h_run = tf.gather(halted, run_idx_i32)
                new_h_run = tf.clip_by_value(old_h_run + (1.0 - old_h_run) * p_st4_run, 0.0, 1.0)
                halted_next = tf.tensor_scatter_nd_update(halted, tf.expand_dims(run_idx_i32, -1), new_h_run)
                lb_add = tf.constant(0.0, tf.float32)
                if training and self.lb_alpha > 0.0:
                    p_mean = tf.reduce_mean(probs, axis=0)
                    uniform = tf.fill([self.K], tf.constant(1.0/self.K, dtype=p_mean.dtype))
                    kl = tf.reduce_sum(p_mean * (tf.math.log(p_mean + tf.constant(1e-8, dtype=p_mean.dtype)) - tf.math.log(uniform + tf.constant(1e-8, dtype=p_mean.dtype))))
                    lb_add = tf.cast(kl, tf.float32)
                div_add = tf.constant(0.0, tf.float32)
                if training and self.div_alpha > 0.0:
                    counts_safe = tf.maximum(counts, tf.constant(1e-6, dtype=counts.dtype))
                    V = feat_sums / counts_safe
                    valid = tf.squeeze(counts > 0.5, axis=-1)
                    V_sel = tf.boolean_mask(V, valid)
                    def div_from_V(Vs):
                        Vs = tf.nn.l2_normalize(Vs, axis=-1)
                        sims = tf.matmul(Vs, Vs, transpose_b=True)
                        E2 = tf.shape(Vs)[0]
                        mask = 1.0 - tf.eye(E2, dtype=Vs.dtype)
                        denom = tf.reduce_sum(mask)
                        return tf.where(denom > 0, tf.reduce_sum(sims * mask) / denom, 0.0)
                    div_add = tf.cond(tf.shape(V_sel)[0] > 1, lambda: tf.cast(div_from_V(V_sel), tf.float32), lambda: tf.constant(0.0, tf.float32))
                return x_next, halted_next, lb_add, div_add, new_pool_count
            x, halted, lb_add, div_add, pool_count = tf.cond(Br > 0, route_running_path, no_running_path)
            if training:
                if self.lb_alpha > 0.0:  lb_loss_total  += lb_add
                if self.div_alpha > 0.0: div_loss_total += div_add
            if not training and tf.executing_eagerly():
                if bool(tf.reduce_all(tf.squeeze(halted, [1,2,3]) > 0.5).numpy()):
                    break
        steps = tf.cast(self.max_steps, tf.float32)
        return x, lb_loss_total, div_loss_total, steps

    def call(self, features, training=None):
        x, lb_loss_total, div_loss_total, steps = self.call_core(features, training=training)
        if training:
            if self.lb_alpha > 0.0:
                self.add_loss(self.lb_alpha * lb_loss_total / steps)
            if self.div_alpha > 0.0:
                self.add_loss(self.div_alpha * div_loss_total / steps)
        return x
    

In [28]:
class TempScheduler(keras.callbacks.Callback):
    """
    Linearly (or cosine) anneal route_temp and halt_temp over epochs.
    route: 1.5 -> 0.7
    halt:  3.0 -> 1.5
    """
    def __init__(self, layer_name="adaptive_router",
                 route_start=1.5, route_end=0.7,
                 halt_start=3.0,  halt_end=1.5,
                 epochs=150, mode="cosine"):
        super().__init__()
        self.layer_name = layer_name
        self.rs, self.re = float(route_start), float(route_end)
        self.hs, self.he = float(halt_start),  float(halt_end)
        self.E = int(epochs)
        self.mode = mode

    def _interp(self, e):
        p = min(1.0, e / max(1, self.E-1))
        if self.mode == "cosine":
            p = 0.5*(1 - np.cos(np.pi*p))
        return p

    def on_epoch_begin(self, epoch, logs=None):
        p = self._interp(epoch)
        rtemp = self.rs + (self.re - self.rs)*p
        htemp = self.hs + (self.he - self.hs)*p
        layer = self.model.get_layer(self.layer_name)
        layer.route_temp = rtemp
        layer.halt_temp  = htemp
        print(f" [TempScheduler] epoch {epoch+1}: route_temp={rtemp:.3f}, halt_temp={htemp:.3f}")


import numpy as np
import tensorflow as tf
from tensorflow import keras

def trace_batch(model, x_batch, layer_name="adaptive_router", force_full=False):
    """
    Trace routing/halting for a batch for AdaptiveRouterBlockTop2Sparse (top-1 only).

    Returns dict (arrays shaped [t_used, B, ...] unless noted):
      - top_indices:      [t_used, B]       Top-1 expert (argmax over K)
      - probs:            [t_used, B, K]    soft routing probs over experts
      - class_maxprob:    [t_used, B]       halting classifier max class prob
      - class_pred:       [t_used, B]       halting classifier argmax class
      - halt_soft:        [t_used, B]       soft halting gate (sigmoid around τ)
      - halt_hard:        [t_used, B]       hard halting gate (0/1)
      - running_after:    [t_used, B] bool  True if still running after step t
    """
    layer = model.get_layer(layer_name)

    # features entering the router block (before stem)
    pre = keras.Model(model.input, layer.input)
    x_in = tf.convert_to_tensor(x_batch)
    x = pre(x_in, training=False)

    B = int(x.shape[0])
    K, T = layer.K, layer.max_steps

    # apply the same conv stem as the block
    x = layer.stem(x, training=False)

    have_pool_list = hasattr(layer, "_pools") and isinstance(layer._pools, (list, tuple)) and len(layer._pools) > 0
    pool_count = 0
    running = np.ones((B,), dtype=bool)

    top1_list, probs_list = [], []
    class_max_list, class_pred_list = [], []
    halt_soft_list, halt_hard_list, running_after_list = [], [], []

    for t in range(T):
        r_logits, _ = layer.router(x, training=False)                         # [B,K]
        probs = tf.nn.softmax(r_logits / layer.route_temp, axis=-1)           # [B,K]
        probs_np = probs.numpy()
        probs_list.append(probs_np)

        top1 = tf.argmax(probs, axis=-1, output_type=tf.int32).numpy()        # [B]
        top1_list.append(top1)

        # Only top-1 expert is used in the block
        y_contrib = None
        for k in range(K):
            sel = (top1 == k)
            if not np.any(sel):
                continue
            x_k = tf.boolean_mask(x, sel)
            y_k = layer.branches[k](x_k, training=False)
            w_for_k = probs_np[sel, k].reshape(-1,1,1,1)
            y_scaled = y_k * w_for_k
            scatter_idx = np.where(sel)[0]
            scatter_idx = tf.expand_dims(tf.convert_to_tensor(scatter_idx, dtype=tf.int32), axis=-1)
            if y_contrib is None:
                y_contrib = tf.zeros_like(x)
            y_contrib = tf.tensor_scatter_nd_add(y_contrib, scatter_idx, y_scaled)
        y_sel = x if y_contrib is None else y_contrib

        class_probs, p_soft, p_hard, _ = layer.halt(y_sel, training=False)    # [B,C], [B,1], [B,1], [B,1]
        class_probs_np = class_probs.numpy()
        class_max = class_probs_np.max(axis=-1)                                # [B]
        class_pred = class_probs_np.argmax(axis=-1)                            # [B]

        if t < layer.min_steps - 1:
            halt_this = np.zeros((B,), dtype=bool)
            p_soft_np = np.squeeze(p_soft.numpy(), axis=1)
            p_soft_np *= 0.0
            p_hard_np = np.zeros((B,), dtype=np.float32)
        else:
            p_soft_np = np.squeeze(p_soft.numpy(), axis=1)                     # [B]
            p_hard_np = np.squeeze(p_hard.numpy(), axis=1).astype(np.float32)  # [B]
            halt_this = (p_hard_np > 0.5) & running

        running = running & (~halt_this)

        if layer.pool_every_n > 0 and ((t + 1) % layer.pool_every_n == 0):
            if have_pool_list:
                pool_layer = layer._pools[min(pool_count, len(layer._pools)-1)]
                x = pool_layer(y_sel)
                pool_count += 1
            else:
                x = layer.pooling_layer(y_sel)
        else:
            x = y_sel

        class_max_list.append(class_max.copy())
        class_pred_list.append(class_pred.copy())
        halt_soft_list.append(p_soft_np.copy())
        halt_hard_list.append(p_hard_np.copy())
        running_after_list.append(running.copy())

        if (not force_full) and (not running.any()):
            break

    trace = {
        "top_indices":   np.asarray(top1_list),             # [t_used, B]
        "probs":         np.asarray(probs_list),            # [t_used, B, K]
        "class_maxprob": np.asarray(class_max_list),        # [t_used, B]
        "class_pred":    np.asarray(class_pred_list),       # [t_used, B]
        "halt_soft":     np.asarray(halt_soft_list),        # [t_used, B]
        "halt_hard":     np.asarray(halt_hard_list),        # [t_used, B]
        "running_after": np.asarray(running_after_list, dtype=bool),  # [t_used, B]
    }
    return trace




def build_model(input_shape=(32,32,3), num_classes=10,
                filters=32,
                min_steps=2, max_steps=6, pool_every_n=3,
                unique_pools=True):
    inputs = keras.Input(shape=input_shape)
    x = layers.Conv2D(filters, 3, padding='same', activation='swish')(inputs)
    x = layers.LayerNormalization()(x)
    x = PoolingLayer(filters=filters, frac_ratio=2.0)(x)
    x = ResidualBlock3x3(filters)(x)
    x = PoolingLayer(filters=filters, frac_ratio=2.0)(x)

    # Heterogeneous experts
    branches = [
        ResidualBlock3x3(filters),
        ResidualBlockDepthwise5x5(filters),
        #ResidualBlockDepthwise7x7(filters),
        #ChannelSE(filters)
    ]

    # Pooling template (or supply a list for explicit per-event pools)
    pooling_template = PoolingLayer(filters=filters, frac_ratio=2.0)

    x = AdaptiveRouterBlockTop1Sparse(
        branches=branches,
        num_classes=num_classes,
        pooling_layer=pooling_template,
        pooling_layers=None,           # or a list with len = max_steps//pool_every_n
        unique_pools=unique_pools,
        min_steps=min_steps,
        max_steps=max_steps,
        pool_every_n=pool_every_n,
        router_heads=1, router_dim=32, router_mlp=32,
        route_temp=1.5, halt_temp=1.5, halt_tau=0.7,
        lb_alpha=1e-3, div_alpha=1e-4,
        name="adaptive_router"
    )(x)

    x = layers.GlobalAveragePooling2D()(x)
    outputs = layers.Dense(num_classes, dtype="float32", activation='softmax')(x)
    return keras.Model(inputs, outputs)




In [29]:
# simple CIFAR-10 aug
def cifar_preprocess(x, y):
    x = tf.image.resize_with_crop_or_pad(x, 36, 36)
    x = tf.image.random_crop(x, [tf.shape(x)[0], 32, 32, 3])
    x = tf.image.random_flip_left_right(x)
    return x, y

def make_dataset(x, y, batch=128, train=True):
    ds = tf.data.Dataset.from_tensor_slices((x, y))
    if train:
        ds = ds.shuffle(5000).batch(batch).map(cifar_preprocess, num_parallel_calls=tf.data.AUTOTUNE)
    else:
        ds = ds.batch(batch)
    return ds.prefetch(tf.data.AUTOTUNE)

class RouterStatsCallback(keras.callbacks.Callback):
    def __init__(self, x_val, y_val, layer_name="adaptive_router_top1", batch_size=256):
        super().__init__()
        self.xv = x_val; self.yv = y_val
        self.layer_name = layer_name
        self.bs = batch_size

    def on_epoch_end(self, epoch, logs=None):
        # accuracy
        loss, acc = self.model.evaluate(self.xv, self.yv, batch_size=self.bs, verbose=0)
        # quick router stats (avg steps + histogram)
        from math import ceil
        layer = self.model.get_layer(self.layer_name)
        T, K = layer.max_steps, layer.K
        steps_hist = np.zeros(T+1, np.int64)
        expert_hist = np.zeros((T, K), np.int64)
        n = len(self.xv)
        for i in range(0, n, self.bs):
            xb = self.xv[i:i+self.bs]
            tb = trace_batch(self.model, xb, layer_name=self.layer_name, force_full=False)
            t_used = tb["top_indices"].shape[0]
            running = tb["running_after"]  # [t_used, B]
            # per-sample steps
            stopped = ~running
            ever = stopped.any(axis=0)
            first = np.argmax(stopped, axis=0)
            used = np.where(ever, first+1, t_used)
            for s in used:
                steps_hist[min(int(s), T)] += 1
            # experts
            for t in range(t_used):
                ch = tb["top_indices"][t]
                cnt = np.bincount(ch, minlength=K)
                expert_hist[t] += cnt
        avg_steps = np.sum(np.arange(T+1)*steps_hist)/max(1, steps_hist.sum())
        print(f"[Stats] epoch {epoch+1}: val_acc={acc*100:.2f}%  avg_steps={avg_steps:.2f}  steps_hist={steps_hist.tolist()} expert_hist={expert_hist.tolist()}")
        # you can also print expert_hist[0] etc.

# Usage
(x_train, y_train), (x_test, y_test) = keras.datasets.cifar10.load_data()
x_train = x_train.astype("float32")/255.0; x_test = x_test.astype("float32")/255.0
y_train = y_train.flatten(); y_test = y_test.flatten()

model = build_model()
opt = keras.optimizers.Adam(learning_rate=1e-3)
model.compile(optimizer=opt, loss="sparse_categorical_crossentropy", metrics=["accuracy"])
model.summary()

lr_sched = keras.optimizers.schedules.CosineDecay(initial_learning_rate=1e-3, decay_steps=150*len(x_train)//128)
model.optimizer.learning_rate = lr_sched

callbacks = [
    TempScheduler(layer_name="adaptive_router", epochs=50, mode="cosine"),
    RouterStatsCallback(x_test, y_test, layer_name="adaptive_router")
]
ds_train = make_dataset(x_train, y_train, batch=128, train=True)
ds_val   = make_dataset(x_test, y_test, batch=256, train=False)




Model: "model_3"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_7 (InputLayer)        [(None, 32, 32, 3)]       0         
                                                                 
 conv2d_24 (Conv2D)          (None, 32, 32, 32)        896       
                                                                 
 layer_normalization_12 (La  (None, 32, 32, 32)        64        
 yerNormalization)                                               
                                                                 
 pooling_layer_18 (PoolingL  (None, 16, 16, 32)        1120      
 ayer)                                                           
                                                                 
 residual_block3x3_12 (Resi  (None, 16, 16, 32)        20736     
_________________________________________________________________
 Layer (type)                Output Shape              Para

In [30]:
model.fit(ds_train, epochs=50, validation_data=ds_val, callbacks=callbacks)

 [TempScheduler] epoch 1: route_temp=1.500, halt_temp=3.000
Epoch 1/50






2025-10-14 17:08:00.331888: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:114] Plugin optimizer for device_type GPU is enabled.
2025-10-14 17:08:00.331888: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:114] Plugin optimizer for device_type GPU is enabled.
2025-10-14 17:08:04.361553: W tensorflow/core/grappler/optimizers/loop_optimizer.cc:933] Skipping loop optimization for Merge node with control input: model_3/adaptive_router/StatefulPartitionedCall/cond_5_163/then/_3606/cond_5/cond/branch_executed/_14093
2025-10-14 17:08:04.361553: W tensorflow/core/grappler/optimizers/loop_optimizer.cc:933] Skipping loop optimization for Merge node with control input: model_3/adaptive_router/StatefulPartitionedCall/cond_5_163/then/_3606/cond_5/cond/branch_executed/_14093


: 

# Base model with 4 blocks or 6 blocks of 32 filters

In [8]:
def build_base_model_4_blocks(input_shape=(32,32,3), num_classes=10, filters=32):
    inputs = keras.Input(shape=input_shape)
    # For CNNs on GPU, BatchNorm is faster than LayerNorm:
    x = layers.Conv2D(filters, 3, padding='same', activation='relu')(inputs)
    x = layers.BatchNormalization()(x)

    x = ResidualBlock(filters)(x)
    x = PoolingLayer(filters=filters, frac_ratio=2.0)(x)
    x = ResidualBlock(filters)(x)
    x = ResidualBlock(filters)(x)
    x = PoolingLayer(filters=filters, frac_ratio=2.0)(x)
    x = ResidualBlock(filters)(x)

    x = layers.GlobalAveragePooling2D()(x)
    outputs = layers.Dense(num_classes, activation='softmax', dtype='float32')(x)  # keep fp32 logits
    return keras.Model(inputs, outputs)


def build_base_model_6_blocks(input_shape=(32,32,3), num_classes=10, filters=32):
    inputs = keras.Input(shape=input_shape)
    # For CNNs on GPU, BatchNorm is faster than LayerNorm:
    x = layers.Conv2D(filters, 3, padding='same', activation='relu')(inputs)
    x = layers.BatchNormalization()(x)

    x = ResidualBlock(filters)(x)
    x = PoolingLayer(filters=filters, frac_ratio=2.0)(x)
    x = ResidualBlock(filters)(x)
    x = ResidualBlock(filters)(x)
    x = PoolingLayer(filters=filters, frac_ratio=2.0)(x)
    x = ResidualBlock(filters)(x)
    x = ResidualBlock(filters)(x)
    x = PoolingLayer(filters=filters, frac_ratio=2.0)(x)
    x = ResidualBlock(filters)(x)

    x = layers.GlobalAveragePooling2D()(x)
    outputs = layers.Dense(num_classes, activation='softmax', dtype='float32')(x)  # keep fp32 logits
    return keras.Model(inputs, outputs)


In [9]:
model = build_base_model_4_blocks(
    input_shape=(32,32,3),
    num_classes=10,
    filters=FILTERS
)

model.compile(optimizer='adam',
                loss='sparse_categorical_crossentropy',
                metrics=['accuracy'])
model.summary()

model.fit(
    x_train, y_train,
    validation_data=(x_test, y_test),
    epochs=20, batch_size=512
)

Model: "model_61"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_2 (InputLayer)        [(None, 32, 32, 3)]       0         
                                                                 
 conv2d_1 (Conv2D)           (None, 32, 32, 32)        896       
                                                                 
 batch_normalization_1 (Bat  (None, 32, 32, 32)        128       
 chNormalization)                                                
                                                                 
 residual_block_4 (Residual  (None, 32, 32, 32)        20736     
 Block)                                                          
                                                                 
 pooling_layer_1 (PoolingLa  (None, 16, 16, 32)        1120      
 yer)                                                            
                                                          

2025-10-13 16:17:22.540030: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:114] Plugin optimizer for device_type GPU is enabled.




2025-10-13 16:17:42.444025: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:114] Plugin optimizer for device_type GPU is enabled.


Epoch 2/20
Epoch 3/20
Epoch 4/20
Epoch 5/20
Epoch 6/20
Epoch 7/20
Epoch 8/20
Epoch 9/20
Epoch 10/20
Epoch 11/20
Epoch 12/20
Epoch 13/20
Epoch 14/20
Epoch 15/20
Epoch 16/20
Epoch 17/20
Epoch 18/20
Epoch 19/20
Epoch 20/20


<keras.src.callbacks.History at 0x158b70f40>

In [10]:
model = build_base_model_6_blocks(
    input_shape=(32,32,3),
    num_classes=10,
    filters=FILTERS
)

model.compile(optimizer='adam',
                loss='sparse_categorical_crossentropy',
                metrics=['accuracy'])
model.summary()

model.fit(
    x_train, y_train,
    validation_data=(x_test, y_test),
    epochs=20, batch_size=512
)

Model: "model_62"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_3 (InputLayer)        [(None, 32, 32, 3)]       0         
                                                                 
 conv2d_2 (Conv2D)           (None, 32, 32, 32)        896       
                                                                 
 batch_normalization_2 (Bat  (None, 32, 32, 32)        128       
 chNormalization)                                                
                                                                 
 residual_block_8 (Residual  (None, 32, 32, 32)        20736     
 Block)                                                          
                                                                 
 pooling_layer_3 (PoolingLa  (None, 16, 16, 32)        1120      
 yer)                                                            
                                                          

2025-10-13 16:22:18.933112: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:114] Plugin optimizer for device_type GPU is enabled.




2025-10-13 16:22:38.732527: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:114] Plugin optimizer for device_type GPU is enabled.


Epoch 2/20
Epoch 3/20
Epoch 4/20
Epoch 5/20
Epoch 6/20
Epoch 7/20
Epoch 8/20
Epoch 9/20
Epoch 10/20
Epoch 11/20
Epoch 12/20
Epoch 13/20
Epoch 14/20
Epoch 15/20
Epoch 16/20
Epoch 17/20
Epoch 18/20
Epoch 19/20
Epoch 20/20


<keras.src.callbacks.History at 0x3b0a50370>