# CGANs - Conditional Generative Adversarial Nets



In [1]:
#from tensorflow.keras import mixed_precision
#mixed_precision.set_global_policy("mixed_float16")
import tensorflow as tf
import numpy as np
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.layers import (
    Conv2D, Dense, GlobalAveragePooling2D, LayerNormalization, Add, DepthwiseConv2D, MaxPool2D
)
from utils import PoolingLayer, ResidualBlock, ResidualBlock3x3, ResidualBlock5x5, ResidualBlock7x7, SpatialSE, ChannelSE, ResidualBlockDepthwise3x3, ResidualBlockDepthwise5x5, ResidualBlockDepthwise7x7, ResidualBlockDepthwise9x9, DummyBlock



# ---------------------------------------------------------
# Softmax Router (no Gumbel). Optional hard mode at inference.
# ---------------------------------------------------------
class SoftmaxRouter(layers.Layer):
    def __init__(self, num_choices, hard_at_inference=False, **kwargs):
        super().__init__(**kwargs)
        self.num_choices = num_choices
        self.hard_at_inference = hard_at_inference
        self.logits_layer = Dense(num_choices)

    def call(self, features, training=None):
        logits = self.logits_layer(GlobalAveragePooling2D()(features))  # (B, K)
        if training or not self.hard_at_inference:
            probs = tf.nn.softmax(logits, axis=-1)                      # (B, K)
        else:
            idx = tf.argmax(logits, axis=-1)
            probs = tf.one_hot(idx, depth=self.num_choices, dtype=tf.float32)
        return probs  # (B, K)



# ------------------------------------------------------------
# Hazard-rate Halting Head (ST hard gate; good gradients)
# ------------------------------------------------------------
class HazardHaltingHead(layers.Layer):
    """
    Predicts a hazard p_t in (0,1) from features; uses a straight-through
    hard decision in forward, with gradients from the sigmoid (Concrete).
    call(...) returns:
      p_soft: [B,1]  (prob of halting at this step, for gradients/metrics)
      h_st  : [B,1]  (hard 0/1 gate with ST gradient)
    """
    def __init__(self, hidden=64, halt_temp=1.0, tau=0.5, **kwargs):
        super().__init__(**kwargs)
        self.hidden = int(hidden)
        self.halt_temp = float(halt_temp)
        self.tau = float(tau)
        print(f"HazardHaltingHead: hidden={self.hidden}, halt_temp={self.halt_temp}, tau={self.tau}")

    def build(self, input_shape):
        # input can be [B,H,W,C] or [B,C]; we GAP if rank-4
        if len(input_shape) == 4:
            in_dim = int(input_shape[-1])
        else:
            in_dim = int(input_shape[-1])
        if self.hidden > 0:
            self.mlp = keras.Sequential([
                layers.Dense(self.hidden, activation="swish", use_bias=False),
                layers.Dense(1, use_bias=False),
            ])
        else:
            self.mlp = layers.Dense(1, use_bias=False) #, bias_initializer=tf.keras.initializers.Constant(-0.2)

    def call(self, feat, *, can_halt_mask, training=None):
        # feat: [B,H,W,C] or [B,C]
        if len(feat.shape) == 4:
            pooled = tf.reduce_mean(feat, axis=[1,2])  # [B,C]
        else:
            pooled = feat                               # [B,C]

        logits = self.mlp(pooled, training=training)    # [B,1]
        p_soft = tf.nn.sigmoid(logits / self.halt_temp) # [B,1]
        p_soft = p_soft * can_halt_mask                 # respect min_steps

        # straight-through hard gate
        h_hard = tf.cast(p_soft > self.tau, p_soft.dtype)    # [B,1]
        h_st   = h_hard + tf.stop_gradient(p_soft - h_hard)
        return p_soft, h_st

# ---------------------------
# Tiny conv stem
# ---------------------------
class ConvStem(layers.Layer):
    def __init__(self, out_ch, **kw):
        super().__init__(**kw)
        self.conv = layers.Conv2D(out_ch, 3, padding="same", use_bias=False)
        self.norm = layers.LayerNormalization()
        self.act  = layers.Activation("swish")

    def call(self, x, training=None):
        x = self.conv(x)
        x = self.norm(x, training=training)
        x = self.act(x)
        return x


# ---------------------------
# Multi-head attention pooling router
# Produces logits over K experts
# ---------------------------
class AttnPoolRouter(layers.Layer):
    def __init__(self, K, heads=2, dim_head=64, mlp_hidden=64, **kw):
        super().__init__(**kw)
        self.K = int(K)
        self.heads = int(heads)
        self.dim_head = int(dim_head)
        self.mlp_hidden = int(mlp_hidden)

        self.q = self.add_weight(
            name="queries", shape=(self.heads, self.dim_head),
            initializer="glorot_uniform", trainable=True)

        self.key_proj = layers.Conv2D(self.heads*self.dim_head, 1, use_bias=False)
        self.val_proj = layers.Conv2D(self.heads*self.dim_head, 1, use_bias=False)

        # Head aggregator -> K logits
        if self.mlp_hidden > 0:
            self.head_mlp = keras.Sequential([
                layers.Dense(self.mlp_hidden, activation="swish", use_bias=False),
                layers.Dense(self.K, use_bias=False)
            ])
        else:
            self.head_mlp = layers.Dense(self.K, use_bias=False)

    def call(self, x, training=None):
        # x: [B,H,W,C]
        B, H, W, C = tf.shape(x)[0], tf.shape(x)[1], tf.shape(x)[2], tf.shape(x)[3]
        k = self.key_proj(x)  # [B,H,W,heads*dim]
        v = self.val_proj(x)
        k = tf.reshape(k, [B, H*W, self.heads, self.dim_head])  # [B,HW,Hd,D]
        v = tf.reshape(v, [B, H*W, self.heads, self.dim_head])  # [B,HW,Hd,D]
        k = tf.transpose(k, [0,2,1,3])  # [B,heads,HW,dim]
        v = tf.transpose(v, [0,2,1,3])  # [B,heads,HW,dim]

        # queries: [heads, dim] -> [B,heads,1,dim]
        q = tf.expand_dims(self.q, axis=0)
        q = tf.expand_dims(q, axis=2)

        # attn: [B,heads,1,HW]
        attn = tf.matmul(q, k, transpose_b=True) / tf.math.sqrt(tf.cast(self.dim_head, x.dtype))
        attn = tf.nn.softmax(attn, axis=-1)

        # pooled heads: [B,heads,1,dim]
        pooled = tf.matmul(attn, v)  # [B,heads,1,dim]
        pooled = tf.squeeze(pooled, axis=2)  # [B,heads,dim]

        # flatten heads
        pooled = tf.reshape(pooled, [B, self.heads*self.dim_head])  # [B, heads*dim]

        logits = self.head_mlp(pooled, training=training)  # [B,K]
        return logits, pooled  # pooled can be used as a feature if needed


class HaltingClassifierHead(layers.Layer):
    """
    Predicts class probabilities; halts when max prob > tau.
    ST gating: forward uses hard threshold; backward uses a smooth sigmoid around tau.
    """
    def __init__(self, num_classes, hidden=64, halt_temp=3.0, tau=0.8, bias_init=-2.5, **kw):
        super().__init__(**kw)
        self.num_classes = int(num_classes)
        self.hidden = int(hidden)
        self.halt_temp = float(halt_temp)
        self.tau = float(tau)

        if hidden > 0:
            self.classifier = keras.Sequential([
                layers.GlobalAveragePooling2D(),
                layers.Dense(hidden, activation="swish", use_bias=False),
                layers.Dense(num_classes, use_bias=False)
            ])
        else:
            self.classifier = keras.Sequential([
                layers.GlobalAveragePooling2D(),
                layers.Dense(num_classes, use_bias=False)
            ])

        # a tiny scalar bias we add to (max_prob - tau) before the sigmoid
        self.bias = self.add_weight(
            name="halt_bias", shape=(), initializer=tf.keras.initializers.Constant(bias_init),
            trainable=True)

    def call(self, x, training=None):
        if len(x.shape) == 2:
            x = tf.expand_dims(x, axis=1)
            x = tf.expand_dims(x, axis=1)
        elif len(x.shape) == 3:
            x = tf.expand_dims(x, axis=1)
        logits = self.classifier(x, training=training)         # [B,C]
        probs  = tf.nn.softmax(logits, axis=-1)                 # [B,C]
        maxp   = tf.reduce_max(probs, axis=-1, keepdims=True)   # [B,1]
        tau = tf.cast(self.tau, maxp.dtype)
        halt_temp = tf.cast(self.halt_temp, maxp.dtype)
        z = (maxp - tau) / tf.maximum(tf.constant(1e-6, dtype=maxp.dtype), halt_temp)
        p_soft = tf.nn.sigmoid(z + tf.cast(self.bias, maxp.dtype))                   # [B,1]
        p_hard = tf.cast(maxp > tau, x.dtype)              # [B,1]
        p_st = p_hard + tf.stop_gradient(p_soft - p_hard)
        return probs, p_soft, p_hard, p_st   # class probs, soft gate, hard gate, ST gate


In [2]:
# simple CIFAR-10 aug
def cifar_preprocess(x, y):
    x = tf.image.resize_with_crop_or_pad(x, 36, 36)
    x = tf.image.random_crop(x, [tf.shape(x)[0], 32, 32, 3])
    x = tf.image.random_flip_left_right(x)
    return x, y

def make_dataset(x, y, batch=128, train=True):
    ds = tf.data.Dataset.from_tensor_slices((x, y))
    if train:
        ds = ds.shuffle(5000).batch(batch).map(cifar_preprocess, num_parallel_calls=tf.data.AUTOTUNE)
    else:
        ds = ds.batch(batch)
    return ds.prefetch(tf.data.AUTOTUNE)


# Usage
(x_train, y_train), (x_test, y_test) = keras.datasets.cifar10.load_data()
x_train = x_train.astype("float32")/255.0; x_test = x_test.astype("float32")/255.0
y_train = y_train.flatten(); y_test = y_test.flatten()


ds_train = make_dataset(x_train, y_train, batch=128, train=True)
ds_val   = make_dataset(x_test, y_test, batch=256, train=False)


2025-10-16 12:34:00.473697: I metal_plugin/src/device/metal_device.cc:1154] Metal device set to: Apple M1 Pro
2025-10-16 12:34:00.473732: I metal_plugin/src/device/metal_device.cc:296] systemMemory: 16.00 GB
2025-10-16 12:34:00.473739: I metal_plugin/src/device/metal_device.cc:313] maxCacheSize: 5.33 GB
2025-10-16 12:34:00.473976: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:303] Could not identify NUMA node of platform GPU ID 0, defaulting to 0. Your kernel may not have been built with NUMA support.
2025-10-16 12:34:00.473998: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:269] Created TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:0 with 0 MB memory) -> physical PluggableDevice (device: 0, name: METAL, pci bus id: <undefined>)


In [15]:
class AdaptiveRouterBlockTop1Vectorized(layers.Layer):
    """
    Deterministic Top-1 routing (same in train & infer) with ST grads.
    Pooling is disabled: no pooling layers or events are used.
    Allows selection of halting head: HazardHaltingHead or HaltingClassifierHead.
    """
    def __init__(
        self,
        branches,
        min_steps=1,
        max_steps=5,
        route_temp=1.0,
        router_settings={
            "heads": 2,
            "dim_head": 64,
            "mlp_hidden": 0
        },
        halt_settings={
            "hidden": 64,
            "temp": 1.0,
            "ponder_lambda": 1e-4,
            "type": "hazard",  # "hazard" or "classifier"
            "num_classes": None,  # Only needed for classifier head
            "tau": 0.8,          # Only for classifier head
            "bias_init": -2.5,   # Only for classifier head
            # --- new knobs for tau annealing ---
            "tau_min": 0.05,
            "tau_max": 0.95,
            "tau_anneal": 0.05   # proportional step size for tau updates
        },
        name=None
    ):
        super().__init__(name=name)
        assert 1 <= min_steps <= max_steps
        self.branches = branches
        self.K = len(branches)
        self.router = AttnPoolRouter(
            K=self.K,
            dim_head=router_settings.get("dim_head", 64),
            mlp_hidden=router_settings.get("mlp_hidden", 0)
        )
        self._route_temp = float(route_temp)
        self.ponder_lambda = float(halt_settings.get("ponder_lambda", 0.5)) 
        self.min_steps = int(min_steps)
        self.max_steps = int(max_steps)

        # ---- store halt_tau as a Variable so we can adapt it online ----
        init_tau = float(halt_settings.get("tau", 0.5))
        self.halt_tau = tf.Variable(
            initial_value=init_tau, trainable=False,
            dtype=tf.float32, name="halt_tau"
        )
        self.halt_tau_min = float(halt_settings.get("tau_min", 0.05))
        self.halt_tau_max = float(halt_settings.get("tau_max", 0.95))
        self.halt_tau_anneal = float(halt_settings.get("tau_anneal", 0.05))

        # Select halting head type (unchanged)
        halt_type = halt_settings.get("type", "hazard")
        if halt_type == "classifier":
            num_classes = halt_settings.get("num_classes", 10)
            tau = halt_settings.get("tau", 0.5)
            bias_init = halt_settings.get("bias_init", -2.5)
            hidden = halt_settings.get("hidden", 64)
            halt_temp = float(halt_settings.get("temp", 1.0))
            self.halt = HaltingClassifierHead(
                num_classes=num_classes,
                hidden=hidden,
                halt_temp=halt_temp,
                tau=tau,
                bias_init=bias_init
            )
        else:
            hidden = halt_settings.get("hidden", 64)
            tau = halt_settings.get("tau", 0.5)
            halt_temp = float(halt_settings.get("temp", 1.0))
            self.halt = HazardHaltingHead(
                hidden=hidden,
                halt_temp=halt_temp,
                tau=tau
            )

    @property
    def route_temp(self):
        return self._route_temp

    @route_temp.setter
    def route_temp(self, v: float):
        self._route_temp = float(v)

    @property
    def halt_temp(self):
        return self.halt.halt_temp

    @halt_temp.setter
    def halt_temp(self, v: float):
        self.halt.halt_temp = float(v)

    def get_config(self):
        cfg = super().get_config()
        cfg.update(dict(
            min_steps=self.min_steps,
            max_steps=self.max_steps,
            ponder_lambda=self.ponder_lambda,
            route_temp=self.route_temp,
            halt_temp=self.halt_temp,
            K=self.K,
            halt_tau=float(self.halt_tau.numpy()) if tf.executing_eagerly() else None,
            halt_tau_min=self.halt_tau_min,
            halt_tau_max=self.halt_tau_max,
            halt_tau_anneal=self.halt_tau_anneal,
        ))
        return cfg

    def _diversity_from_means(self, V):
        """
        Computes diversity among branch outputs V (shape [K, ...]).
        Higher value means less diversity (more similar branches).
        Can be used as a regularization loss to encourage diversity.
        """
        V = tf.reshape(V, [tf.shape(V)[0], -1])  # flatten each branch output
        V = tf.nn.l2_normalize(V, axis=-1)       # normalize
        sims = tf.matmul(V, V, transpose_b=True) # [K, K] cosine similarity
        K = tf.shape(V)[0]
        mask = 1.0 - tf.eye(K, dtype=V.dtype)    # zero diagonal
        denom = tf.reduce_sum(mask)
        return tf.where(denom > 0, tf.reduce_sum(sims * mask) / denom, 0.0)
    
    def call(self, features, training=None):
        x = features
        B = tf.shape(x)[0]
        dtype = x.dtype
        halted = tf.zeros([B,1,1,1], dtype=dtype)
        ponder_cost = tf.constant(0.0, dtype=dtype)
        steps_taken = tf.zeros([B], dtype=dtype)

        for t in range(self.max_steps):
            # Router + select branch as before
            router_logits, _ = self.router(x, training=training)              # [B,K]
            probs  = tf.nn.softmax(router_logits / self.route_temp, axis=-1)  # [B,K]
            top_idx = tf.argmax(probs, axis=-1, output_type=tf.int32)         # [B]
            onehot  = tf.one_hot(top_idx, depth=self.K, dtype=dtype)          # [B,K]
            onehot_st = onehot + tf.stop_gradient(probs - onehot)             # ST

            y_list = [br(x, training=training) for br in self.branches]       # K x [B,H,W,C]
            y_stack = tf.stack(y_list, axis=1)                                 # [B,K,H,W,C]
            mask = tf.reshape(onehot_st, [-1, self.K, 1, 1, 1])
            y_sel = tf.reduce_sum(mask * y_stack, axis=1)                      # [B,H,W,C]

            # Halting head
            can_halt = tf.cast(t >= self.min_steps - 1, dtype) * tf.ones([B,1], dtype)
            if isinstance(self.halt, HazardHaltingHead):
                p_soft, h_st = self.halt(y_sel, can_halt_mask=can_halt, training=training)  # [B,1]
            else:
                _, p_soft, _, h_st = self.halt(y_sel, training=training)       # [B,1]
                p_soft = p_soft * can_halt
                h_st = h_st * can_halt

            # Update cumulative halting state
            h_st4 = tf.reshape(h_st, [-1,1,1,1])                               # [B,1,1,1]
            halted = tf.clip_by_value(halted + (1.0 - halted) * h_st4, 0.0, 1.0)

            # ---- use halt_tau during training too ----
            tau_cast = tf.cast(self.halt_tau, dtype)
            # per-sample "still active?" mask (True if not past threshold yet)
            still_active = tf.cast(tf.squeeze(halted, [1,2,3]) <= tau_cast, dtype)  # [B]
            still_active4 = tf.reshape(still_active, [-1,1,1,1])                    # [B,1,1,1]

            # Only count steps for samples that are still active
            steps_taken += tf.squeeze(still_active4, [1,2,3])  # [B]

            # Freeze samples that have halted (carry previous x forward)
            # y_sel is the new candidate; x is the previous state
            x = still_active4 * y_sel + (1.0 - still_active4) * x

            # Ponder cost uses fraction not halted yet
            if training and self.ponder_lambda > 0.0:
                running_frac = tf.reduce_mean(still_active)  # fraction still active this step
                ponder_cost += tf.cast(running_frac, dtype)

            # Batch-level early exit:
            # - during inference: same as before
            # - during training (optional): stop if *all* have crossed threshold
            if tf.executing_eagerly():
                all_halted_now = bool(tf.reduce_all(tf.squeeze(halted, [1,2,3]) > tau_cast).numpy())
                if (not training and all_halted_now) or (training and all_halted_now):
                    break

        # losses as before
        if training and self.ponder_lambda > 0.0:
            self.add_loss(self.ponder_lambda * ponder_cost)
            # you can keep your diversity loss here if you had it above

        # anneal halt_tau toward the target during training (unchanged)
        if training:
            avg_steps = tf.reduce_mean(steps_taken)
            target_steps = (self.max_steps - self.min_steps) / 1.5
            err = tf.cast(target_steps, avg_steps.dtype) - avg_steps
            new_tau = self.halt_tau - tf.cast(self.halt_tau_anneal, tf.float32) * tf.cast(err, tf.float32)
            new_tau = tf.clip_by_value(new_tau, self.halt_tau_min, self.halt_tau_max)
            self.halt_tau.assign(new_tau)

        return x

    # def call(self, features, training=None):
    #     x = features
    #     B = tf.shape(x)[0]
    #     dtype = x.dtype
    #     halted = tf.zeros([B,1,1,1], dtype=dtype)
    #     ponder_cost = tf.constant(0.0, dtype=dtype)
    #     steps_taken = tf.zeros([B], dtype=dtype)

    #     for t in range(self.max_steps):
    #         router_logits, _ = self.router(x, training=training)              # [B,K]
    #         probs  = tf.nn.softmax(router_logits / self.route_temp, axis=-1)  # [B,K]
    #         top_idx = tf.argmax(probs, axis=-1, output_type=tf.int32)         # [B]
    #         onehot  = tf.one_hot(top_idx, depth=self.K, dtype=dtype)          # [B,K]
    #         onehot_st = onehot + tf.stop_gradient(probs - onehot)             # ST

    #         y_list = [br(x, training=training) for br in self.branches]       # K x [B,H,W,C]
    #         y_means = tf.stack([tf.reduce_mean(y, axis=[1,2,3]) for y in y_list], axis=0)  # [K,B]
    #         div_loss = self._diversity_from_means(y_means)

    #         y_stack = tf.stack(y_list, axis=1)                                 # [B,K,H,W,C]
    #         mask = tf.reshape(onehot_st, [-1, self.K, 1, 1, 1])
    #         y_sel = tf.reduce_sum(mask * y_stack, axis=1)                      # [B,H,W,C]

    #         can_halt = tf.cast(t >= self.min_steps - 1, dtype) * tf.ones([B,1], dtype)
    #         if isinstance(self.halt, HazardHaltingHead):
    #             p_soft, h_st = self.halt(y_sel, can_halt_mask=can_halt, training=training)  # [B,1]
    #         else:  # HaltingClassifierHead
    #             _, p_soft, _, h_st = self.halt(y_sel, training=training)       # [B,1]
    #             p_soft = p_soft * can_halt
    #             h_st = h_st * can_halt

    #         h_st4 = tf.reshape(h_st, [-1,1,1,1])
    #         halted = tf.clip_by_value(halted + (1.0 - halted) * h_st4, 0.0, 1.0)
    #         not_halted = 1.0 - tf.squeeze(halted, [1,2,3])  # [B]
    #         steps_taken += not_halted

    #         if training and self.ponder_lambda > 0.0:
    #             running_frac = tf.reduce_mean(1.0 - tf.squeeze(halted, [1,2,3]))
    #             ponder_cost += tf.cast(running_frac, dtype)

    #         # --- use the (possibly annealed) tau only at inference time ---
    #         if not training and tf.executing_eagerly():
    #             tau_cast = tf.cast(self.halt_tau, dtype)
    #             if bool(tf.reduce_all(tf.squeeze(halted, [1,2,3]) > tau_cast).numpy()):
    #                 break

    #     if training and self.ponder_lambda > 0.0:
    #         self.add_loss(self.ponder_lambda * ponder_cost)
    #         self.add_loss(0.01 * div_loss)

    #     # --- anneal halt_tau toward target steps (training only) ---
    #     if training:
    #         avg_steps = tf.reduce_mean(steps_taken)  # scalar (dtype of x)
    #         target_steps = (self.max_steps - self.min_steps) / 1.5
    #         # proportional controller on the error
    #         err = tf.cast(target_steps, avg_steps.dtype) - avg_steps
    #         new_tau = self.halt_tau - tf.cast(self.halt_tau_anneal, tf.float32) * tf.cast(err, tf.float32)
    #         new_tau = tf.clip_by_value(new_tau, self.halt_tau_min, self.halt_tau_max)
    #         self.halt_tau.assign(new_tau)

    #         # Optional: log for monitoring
    #         #tf.summary.scalar("router/avg_steps", avg_steps)
    #         #tf.summary.scalar("router/target_steps", target_steps)
    #         #tf.summary.scalar("router/halt_tau", self.halt_tau)

    #     return x


In [None]:
class CosineAnnealingScheduler(keras.callbacks.Callback):
    """
    Cosine annealing learning rate scheduler.
    """
    def __init__(self, base_lr, min_lr, epochs, verbose=1):
        super().__init__()
        self.base_lr = base_lr
        self.min_lr = min_lr
        self.epochs = epochs
        self.verbose = verbose

    def on_epoch_begin(self, epoch, logs=None):
        p = epoch / max(1, self.epochs - 1)
        lr = self.min_lr + 0.5 * (self.base_lr - self.min_lr) * (1 + np.cos(np.pi * p))
        keras.backend.set_value(self.model.optimizer.lr, lr)
        if self.verbose and (epoch < 1 or (epoch + 1) % 5 == 0):
            print(f"> [LR Scheduler] epoch {epoch+1}: lr={lr:.6f}")


class TempScheduler(keras.callbacks.Callback):
    """
    Linearly (or cosine) anneal route_temp and halt_temp over epochs.
    route: 1.5 -> 0.7
    halt:  3.0 -> 1.5
    """
    def __init__(self, layer_name="adaptive_router",
                 route_start=1.5, route_end=0.7,
                 halt_start=3.0,  halt_end=0.5,
                 epochs=150, mode="cosine"):
        super().__init__()
        self.layer_name = layer_name
        self.rs, self.re = float(route_start), float(route_end)
        self.hs, self.he = float(halt_start),  float(halt_end)
        self.E = int(epochs)
        self.mode = mode

    def _interp(self, e):
        p = min(1.0, e / max(1, self.E-1))
        if self.mode == "cosine":
            p = 0.5*(1 - np.cos(np.pi*p))
        return p

    def on_epoch_end(self, epoch, logs=None):
        p = self._interp(epoch)
        rtemp = self.rs + (self.re - self.rs)*p
        htemp = self.hs + (self.he - self.hs)*p
        layer = self.model.get_layer(self.layer_name)
        layer.route_temp = rtemp
        layer.halt_temp  = htemp
        if epoch < 1 or (epoch + 1) % 5 == 0:
            print(f"> [TempScheduler] epoch {epoch+1}: route_temp={rtemp:.3f}, halt_temp={htemp:.3f}")


class RouterStatsCallback(keras.callbacks.Callback):
    def __init__(self, x_val, y_val, layer_name="adaptive_router", batch_size=256):
        super().__init__()
        self.xv = x_val
        self.yv = y_val
        self.layer_name = layer_name
        self.bs = batch_size

    def on_epoch_end(self, epoch, logs=None):
        if epoch < 1 or (epoch + 1) % 5 == 0:
            layer = self.model.get_layer(self.layer_name)
            T, K = layer.max_steps, layer.K
            steps_hist = np.zeros(T+1, np.int64)
            expert_hist = np.zeros((T, K), np.int64)
            n = len(self.xv)
            for i in range(0, n, self.bs):
                xb = self.xv[i:i+self.bs]
                tb = trace_batch(self.model, xb, layer_name=self.layer_name, force_full=False)
                t_used = tb["top_indices"].shape[0]
                running = tb["running_after"]  # [t_used, B]
                stopped = ~running
                ever = stopped.any(axis=0)
                first = np.argmax(stopped, axis=0)
                used = np.where(ever, first+1, t_used)
                for s in used:
                    steps_hist[min(int(s), T)] += 1
                for t in range(t_used):
                    ch = tb["top_indices"][t]
                    cnt = np.bincount(ch, minlength=K)
                    expert_hist[t] += cnt
            avg_steps = np.sum(np.arange(T+1)*steps_hist)/max(1, steps_hist.sum())
            print(f"> [{self.layer_name}] epoch {epoch+1}: avg_steps={avg_steps:.2f}  steps_hist={steps_hist.tolist()} expert_hist={expert_hist.tolist()}")


class AdaptiveTauCallback(keras.callbacks.Callback):
    def __init__(self, layer_name, target_steps, update_rate=0.01):
        super().__init__()
        self.layer_name = layer_name
        self.target_steps = target_steps
        self.update_rate = update_rate

    def on_epoch_end(self, epoch, logs=None):
        layer = self.model.get_layer(self.layer_name)
        halt_tau_threshold = float(tf.keras.backend.get_value(layer.halt_tau))
        # Estimate average steps from validation set
        stats = evaluate_with_router_stats(self.model, x_test, y_test, layer_name=self.layer_name)
        avg_steps = np.sum(np.arange(stats["T"]+1)*stats["steps_hist"])/max(1, stats["steps_hist"].sum())
        # Adjust tau: increase tau if too many steps, decrease if too few
        if hasattr(layer.halt, "tau"):
            delta = self.update_rate * (avg_steps - self.target_steps)
            layer.halt.tau = np.clip(layer.halt.tau + delta, 0.01, 0.99)
            print(f"> [AdaptiveTau] epoch {epoch+1}: halt_tau_threshold={halt_tau_threshold:.4f} tau={layer.halt.tau:.3f}, avg_steps={avg_steps:.2f}")



def trace_and_predict(
    model,
    x_input,
    y_true=None,
    layer_name="adaptive_router",
    force_full=False,   # set True to always loop max_steps, no early exit
):
    layer = model.get_layer(layer_name)
    pre = keras.Model(model.input, layer.input)
    x_in = tf.convert_to_tensor(x_input)
    x = pre(x_in, training=False)
    B = int(x.shape[0])
    K, T = layer.K, layer.max_steps
    dtype = x.dtype
    running = np.ones((B,), dtype=bool)

    top_indices, probs_list = [], []
    halt_soft_list, halt_hard_list, running_after_list = [], [], []

    for t in range(T):
        router_logits, _ = layer.router(x, training=False)                         # [B,K]
        probs = tf.nn.softmax(router_logits / layer.route_temp, axis=-1)           # [B,K]
        top_idx = tf.argmax(probs, axis=-1, output_type=tf.int32)                  # [B]
        onehot  = tf.one_hot(top_idx, depth=K, dtype=dtype)
        onehot_st = onehot + tf.stop_gradient(probs - onehot)                      # ST

        y_list = [br(x, training=False) for br in layer.branches]
        y_stack = tf.stack(y_list, axis=1)                                         # [B,K,H,W,C]
        mask = tf.reshape(onehot_st, [-1, K, 1, 1, 1])
        y_sel = tf.reduce_sum(mask * y_stack, axis=1)                              # [B,H,W,C]

        can_halt = (t >= layer.min_steps - 1)
        can_mask = tf.ones([B,1], dtype) if can_halt else tf.zeros([B,1], dtype)

        # --- FIX: handle halting head type ---
        if isinstance(layer.halt, HazardHaltingHead):
            p_soft, h_st = layer.halt(y_sel, can_halt_mask=can_mask, training=False)   # [B,1]
        else:  # HaltingClassifierHead
            _, p_soft, _, h_st = layer.halt(y_sel, training=False)                     # [B,1]
            p_soft = p_soft * can_mask
            h_st = h_st * can_mask

        tau = float(layer.halt.tau)  # Always use the tau from halt_settings
        p_np = tf.squeeze(p_soft, axis=1).numpy()                                  # [B]
        h_st_np = tf.squeeze(h_st, axis=1).numpy()                                 # [B]
        halt_this = (p_np > tau) & running & can_halt                              # [B] bool

        running = running & (~halt_this)
        x = y_sel  # no pooling

        top_indices.append(top_idx.numpy())
        probs_list.append(probs.numpy())
        halt_soft_list.append(p_np.copy())
        halt_hard_list.append(halt_this.astype(np.float32))
        running_after_list.append(running.copy())

        if (not force_full) and (not running.any()):
            break

    pred_probs = model(x_in, training=False).numpy()
    pred_label = pred_probs.argmax(axis=-1).astype(np.int32)
    if y_true is not None:
        y_true_arr = np.asarray(y_true).reshape(-1)
        correct = (pred_label == y_true_arr)
    else:
        correct = None

    trace = {
        "top_indices":   np.array(top_indices),                 # [t_used, B]
        "probs":         np.array(probs_list),                  # [t_used, B, K]
        "halt_soft":     np.array(halt_soft_list),              # [t_used, B]
        "halt_hard":     np.array(halt_hard_list),              # [t_used, B]
        "running_after": np.array(running_after_list, dtype=bool),  # [t_used, B]
    }
    return {
        "trace": trace,
        "pred_probs": pred_probs,
        "pred_label": pred_label,
        "true_label": None if y_true is None else np.asarray(y_true),
        "layer_info": {
            "min_steps": layer.min_steps,
            "max_steps": layer.max_steps,
            "K": layer.K,
            "route_temp": getattr(layer, "route_temp", None),
            "halt_temp": getattr(getattr(layer, "halt", None), "halt_temp", None),
            "unique_pools": bool(getattr(layer, "_pools", []) not in (None, [])),
            "num_pools": len(getattr(layer, "_pools", []) or []),
        },
        "correct": correct,
    }


def trace_batch(model, x_batch, layer_name="adaptive_router", force_full=False):
    layer = model.get_layer(layer_name)
    pre = keras.Model(model.input, layer.input)

    x_in = tf.convert_to_tensor(x_batch)
    x = pre(x_in, training=False)

    B = int(x.shape[0])
    K, T = layer.K, layer.max_steps
    running = np.ones((B,), dtype=bool)

    top_indices, probs_list = [], []
    halt_soft_list, halt_hard_list, running_after_list = [], [], []

    for t in range(T):
        router_logits, _ = layer.router(x, training=False)                         # [B,K]
        probs = tf.nn.softmax(router_logits / layer.route_temp, axis=-1)           # [B,K]
        top_idx = tf.argmax(probs, axis=-1, output_type=tf.int32)                  # [B]
        onehot  = tf.one_hot(top_idx, depth=K, dtype=x.dtype)
        onehot_st = onehot + tf.stop_gradient(probs - onehot)                      # ST

        y_stack = tf.stack([br(x, training=False) for br in layer.branches], axis=1)  # [B,K,H,W,C]
        y_sel = tf.reduce_sum(tf.reshape(onehot_st, [-1, K, 1, 1, 1]) * y_stack, axis=1)

        can_halt = (t >= layer.min_steps - 1)
        can_mask = tf.ones([B,1], x.dtype) if can_halt else tf.zeros([B,1], x.dtype)

        # --- FIX: handle halting head type ---
        if isinstance(layer.halt, HazardHaltingHead):
            p_soft, h_st = layer.halt(y_sel, can_halt_mask=can_mask, training=False)      # [B,1]
        else:  # HaltingClassifierHead
            _, p_soft, _, h_st = layer.halt(y_sel, training=False)                        # [B,1]
            p_soft = p_soft * can_mask
            h_st = h_st * can_mask

        tau = float(layer.halt.tau)  # Always use the tau from halt_settings
        p_np = tf.squeeze(p_soft, axis=1).numpy()                                   # [B]
        halt_this = (p_np > tau) & running & can_halt                               # [B] bool

        running = running & (~halt_this)

        x = y_sel  # no pooling

        # ---- collect step data ----
        top_indices.append(top_idx.numpy())
        probs_list.append(probs.numpy())
        halt_soft_list.append(p_np.copy())
        halt_hard_list.append(halt_this.astype(np.float32))
        running_after_list.append(running.copy())

        # early exit only if no sample is still running
        if (not force_full) and (not running.any()):
            break

    trace = {
        "top_indices":   np.array(top_indices),                 # [t_used, B]
        "probs":         np.array(probs_list),                  # [t_used, B, K]
        "halt_soft":     np.array(halt_soft_list),              # [t_used, B]
        "halt_hard":     np.array(halt_hard_list),              # [t_used, B]
        "running_after": np.array(running_after_list, dtype=bool),  # [t_used, B]
    }
    return trace


def steps_used_from_running(running_after):
    """
    running_after: [t_used, B] bool (True = still running after that step)
    Returns: [B] int steps used (first time running becomes False; else t_used)
    """
    t_used, B = running_after.shape
    # A sample stops running the step it halts; so steps_used is the first index
    # where running becomes False, +1. If never False, it's t_used.
    stopped = ~running_after
    ever_stopped = stopped.any(axis=0)
    first_stop = np.argmax(stopped, axis=0)  # undefined when never stopped, fine below
    return np.where(ever_stopped, first_stop + 1, t_used)


def evaluate_with_router_stats(model, x, y, layer_name="adaptive_router_top1",
                               batch_size=256, force_full=False):
    # accuracy
    loss, acc = model.evaluate(x, y, batch_size=batch_size, verbose=0)

    # router stats
    layer = model.get_layer(layer_name)
    K, T = layer.K, layer.max_steps

    steps_hist = np.zeros(T+1, dtype=np.int64)   # index t for steps=t, last bin for "T or more"
    expert_hist = np.zeros((T, K), dtype=np.int64)
    halt_rate   = np.zeros(T, dtype=np.float64)
    n_seen = 0

    for i in range(0, len(x), batch_size):
        xb = x[i:i+batch_size]
        tb = trace_batch(model, xb, layer_name=layer_name, force_full=force_full)

        t_used, B = tb["top_indices"].shape[0], tb["top_indices"].shape[1]
        n_seen += B

        # expert usage per step (only for steps that exist in this batch)
        for t in range(t_used):
            choices = tb["top_indices"][t]  # [B]
            counts = np.bincount(choices, minlength=K)
            expert_hist[t, :] += counts
            halt_rate[t] += tb["halt_hard"][t].mean()

        # steps used per sample
        steps_used = steps_used_from_running(tb["running_after"])  # [B]
        # cap into histogram (if force_full=False, some batches may stop early)
        for s in steps_used:
            s_idx = min(int(s), T)  # put "==T" also into T bin
            steps_hist[s_idx] += 1

    # normalize
    halt_rate[:t_used] = halt_rate[:t_used] / max(1, (len(x) + batch_size - 1) // batch_size)

    return {
        "loss": float(loss),
        "acc": float(acc),
        "steps_hist": steps_hist,    # length T+1
        "expert_hist": expert_hist,  # [T, K]
        "halt_rate": halt_rate,      # [T] avg hard halts at step t
        "seen": n_seen,
        "K": K,
        "T": T,
    }

def print_router_stats(model, x, y, layer_name="adaptive_router_1", batch_size=512, force_full=False):
    stats = evaluate_with_router_stats(
        model, x, y,
        layer_name=layer_name,
        batch_size=batch_size,
        force_full=force_full
    )
    print("")
    print(f"================== {layer_name} ====================")
    print(f"Test acc: {stats['acc']*100:.2f}%  |  samples: {stats['seen']}")
    print("Steps histogram (0..T; last bin = T):", stats["steps_hist"])
    # Only print halt rates for steps with expert usage
    valid_steps = np.nonzero(stats['expert_hist'].sum(axis=1) > 0)[0]
    if len(valid_steps) > 0:
        last_step = valid_steps[-1] + 1
        print("Halt rate per step:", np.round(stats["halt_rate"][:last_step], 3))
    else:
        print("Halt rate per step:", stats["halt_rate"])
    for k in range(stats["K"]):
        total = stats["expert_hist"][:,k].sum()
        print(f"Expert {k} total usage: {total} ({total/stats['seen']:.3f} per sample on avg)")
    for s in range(len(stats["expert_hist"])):
        step_total = stats["expert_hist"][s].sum()
        print(f"Expert usage at step {s}:", stats["expert_hist"][s])


def print_trace_for_samples(model, x, y, layer_name="adaptive_router_1", start=1000, end=1020):
    """
    Prints routing trace and prediction for samples in the given range.
    """
    print("")
    print(f"================== {layer_name} ====================")
    for i in range(start, end):
        res = trace_and_predict(model, x[i:i+1], y_true=y[i:i+1], layer_name=layer_name)
        if len(res["trace"]["top_indices"][:, 0]) > 0:
            trace = res["trace"]
            print(" > pred label:", res["pred_label"][0], "true label:", int(res["true_label"][0]))
            print("   experts per step:", trace["top_indices"][:, 0])
            print("   halt probs:", trace["halt_soft"][:, 0])
                  


def build_adaptive_model_sparse(branches, input_shape=(32,32,3), num_classes=10, filters=32,
                                min_steps=1, max_steps=5,
                                route_temp=1.0, router_settings={"head": 2, "dim_head": 64, "mlp_hidden": 0}, 
                                halt_settings={"hidden": 64, "temp": 1.0}):

    inputs = keras.Input(shape=input_shape)
    # For CNNs on GPU, BatchNorm is faster than LayerNorm:
    x = layers.Conv2D(32, 3, padding='same', activation='swish')(inputs)
    x = layers.BatchNormalization()(x)

    x = PoolingLayer(filters=filters, frac_ratio=2.0)(x)

    x = AdaptiveRouterBlockTop1Vectorized(
        branches=branches[0],
        min_steps=min_steps,
        max_steps=max_steps,
        route_temp=route_temp,
        router_settings=router_settings,
        halt_settings=halt_settings[0],
        name="adaptive_router_1",
    )(x)

    x = ResidualBlock3x3(filters=filters)(x)

    x = PoolingLayer(filters=filters, frac_ratio=2.0)(x)

    x = AdaptiveRouterBlockTop1Vectorized(
        branches=branches[1],
        min_steps=min_steps,
        max_steps=max_steps,
        route_temp=route_temp,
        router_settings=router_settings,
        halt_settings=halt_settings[1],
        name="adaptive_router_2",
    )(x)

    x = ResidualBlock3x3(filters=filters)(x)

    x = layers.GlobalAveragePooling2D()(x)
    outputs = layers.Dense(num_classes, activation='softmax', dtype='float32')(x)  # keep fp32 logits
    return keras.Model(inputs, outputs, name="adaptive_model")

In [17]:
FILTERS = 64
MIN_STEPS = 1
MAX_STEPS = 5
HALT_TEMP = 4.0
ROUTE_TEMP = 5.0
EPOCHS = 20
HALT_HEAD = "hazard"  # "hazard" or "classifier"

target_steps = (MAX_STEPS - MIN_STEPS) / 2.0

branches = [
    [
        ResidualBlock3x3(FILTERS),
        ResidualBlock5x5(FILTERS),
        ResidualBlockDepthwise7x7(FILTERS),
        ChannelSE(FILTERS),
        SpatialSE(),
        DummyBlock()
    ],
    [
        ResidualBlock3x3(FILTERS),
        ResidualBlock5x5(FILTERS),
        ResidualBlockDepthwise7x7(FILTERS),
        ResidualBlockDepthwise9x9(FILTERS),
        ChannelSE(FILTERS),
        DummyBlock()
    ]
]

router_model = build_adaptive_model_sparse(
    branches=branches,
    input_shape=(32,32,3),
    num_classes=10,
    filters=FILTERS,
    min_steps=MIN_STEPS,
    max_steps=MAX_STEPS,        # >0 encourages fewer steps; tune as needed
    router_settings={
        "heads": 8,
        "dim_head": 32,
        "mlp_hidden": 64,
    },
    halt_settings=[
        {
            "hidden": 128,
            "temp": 3.0,
            "ponder_lambda": 3e-4,
            "type": HALT_HEAD,
            "num_classes": 10, 
            "tau": 0.55,          
            "bias_init": 0.0,
            "tau_min": 0.3,
            "tau_max": 0.7,
            "tau_anneal": 0.0005  
        },
        {
            "hidden": 128,
            "temp": 3.0,
            "type": HALT_HEAD,
            "ponder_lambda": 3e-4,
            "num_classes": 10, 
            "tau": 0.55,          
            "bias_init": 0.0,
            "tau_min": 0.3,
            "tau_max": 0.7,
            "tau_anneal": 0.0001  
        }
    ],
    route_temp=ROUTE_TEMP
)



callbacks = [
    TempScheduler(layer_name="adaptive_router_1", epochs=EPOCHS, mode="cosine", route_start=ROUTE_TEMP, route_end=0.7, halt_start=HALT_TEMP, halt_end=1.0),
    RouterStatsCallback(x_test, y_test, layer_name="adaptive_router_1"),
    RouterStatsCallback(x_test, y_test, layer_name="adaptive_router_2"),
    CosineAnnealingScheduler(base_lr=3e-3, min_lr=1e-5, epochs=EPOCHS),
    AdaptiveTauCallback(layer_name="adaptive_router_1", target_steps=target_steps)
]

router_model.build(input_shape=(None, 32, 32, 3))
router_model.compile(optimizer='adam',
                loss='sparse_categorical_crossentropy',
                metrics=['accuracy'])
#_ = router_model(tf.zeros((1, 32, 32, 3)))

router_model.summary()

#tf.keras.utils.plot_model(router_model, show_shapes=True)

HazardHaltingHead: hidden=128, halt_temp=3.0, tau=0.55
HazardHaltingHead: hidden=128, halt_temp=3.0, tau=0.55
HazardHaltingHead: hidden=128, halt_temp=3.0, tau=0.55
Model: "adaptive_model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_4 (InputLayer)        [(None, 32, 32, 3)]       0         
                                                                 
 conv2d_15 (Conv2D)          (None, 32, 32, 32)        896       
                                                                 
 batch_normalization_3 (Bat  (None, 32, 32, 32)        128       
 chNormalization)                                                
                                                                 
 pooling_layer_6 (PoolingLa  (None, 16, 16, 64)        2240      
 yer)                                                            
                                                                 
 adaptive_router_1 

In [18]:
router_model.fit(ds_train, epochs=EPOCHS, validation_data=ds_val, callbacks=callbacks)


> [LR Scheduler] epoch 1: lr=0.003000
Epoch 1/20


2025-10-16 13:08:26.607274: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:114] Plugin optimizer for device_type GPU is enabled.




2025-10-16 13:12:24.161368: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:114] Plugin optimizer for device_type GPU is enabled.


> [TempScheduler] epoch 1: route_temp=5.000, halt_temp=4.000
> [adaptive_router_1] epoch 1: avg_steps=5.00  steps_hist=[0, 0, 0, 0, 0, 10000] expert_hist=[[410, 0, 9585, 0, 0, 5], [202, 0, 8953, 1, 11, 833], [34, 0, 3790, 4, 1300, 4872], [5, 0, 1477, 0, 4932, 3586], [0, 0, 1010, 0, 6938, 2052]]
> [adaptive_router_1] epoch 1: avg_steps=5.00  steps_hist=[0, 0, 0, 0, 0, 10000] expert_hist=[[410, 0, 9585, 0, 0, 5], [202, 0, 8953, 1, 11, 833], [34, 0, 3790, 4, 1300, 4872], [5, 0, 1477, 0, 4932, 3586], [0, 0, 1010, 0, 6938, 2052]]
> [adaptive_router_2] epoch 1: avg_steps=5.00  steps_hist=[0, 0, 0, 0, 0, 10000] expert_hist=[[3077, 6620, 0, 0, 281, 22], [3413, 6583, 0, 0, 4, 0], [3603, 6397, 0, 0, 0, 0], [3900, 6100, 0, 0, 0, 0], [4203, 5797, 0, 0, 0, 0]]
> [adaptive_router_2] epoch 1: avg_steps=5.00  steps_hist=[0, 0, 0, 0, 0, 10000] expert_hist=[[3077, 6620, 0, 0, 281, 22], [3413, 6583, 0, 0, 4, 0], [3603, 6397, 0, 0, 0, 0], [3900, 6100, 0, 0, 0, 0], [4203, 5797, 0, 0, 0, 0]]
> [AdaptiveTau]

KeyboardInterrupt: 

In [7]:
print_router_stats(router_model, x_test, y_test, layer_name="adaptive_router_1", batch_size=512)
print_router_stats(router_model, x_test, y_test, layer_name="adaptive_router_2", batch_size=512)


Test acc: 82.96%  |  samples: 10000
Steps histogram (0..T; last bin = T): [   0 6506    0 1173  452 1869]
Halt rate per step: [0.65  0.    0.117 0.046 0.146]
Expert 0 total usage: 7753 (0.775 per sample on avg)
Expert 1 total usage: 8930 (0.893 per sample on avg)
Expert 2 total usage: 27455 (2.745 per sample on avg)
Expert 3 total usage: 1685 (0.169 per sample on avg)
Expert 4 total usage: 374 (0.037 per sample on avg)
Expert 5 total usage: 3803 (0.380 per sample on avg)
Expert usage at step 0: [  23   80 7931  249   70 1647]
Expert usage at step 1: [1645 1125 5043  652  256 1279]
Expert usage at step 2: [3905 3232 1477  663   33  690]
Expert usage at step 3: [ 707  191 9051   16    9   26]
Expert usage at step 4: [1473 4302 3953  105    6  161]

Test acc: 82.96%  |  samples: 10000
Steps histogram (0..T; last bin = T): [    0     0     0     0     0 10000]
Halt rate per step: [0. 0. 0. 0. 0.]
Expert 0 total usage: 14612 (1.461 per sample on avg)
Expert 1 total usage: 8126 (0.813 per s

In [10]:
print_trace_for_samples(router_model, x_test, y_test, layer_name="adaptive_router_1", start=0, end=5)
print_trace_for_samples(router_model, x_test, y_test, layer_name="adaptive_router_2", start=0, end=5)


 > pred label: 5 true label: 3
   experts per step: [2 5 5 3 5]
   halt probs: [0.4882024  0.49263278 0.49823272 0.49775296 0.501564  ]
 > pred label: 5 true label: 3
   experts per step: [2 5 5 3 5]
   halt probs: [0.4882024  0.49263278 0.49823272 0.49775296 0.501564  ]
 > pred label: 8 true label: 8
   experts per step: [2 2 5 4 5]
   halt probs: [0.49158022 0.49427438 0.49975765 0.49807033 0.50176156]
 > pred label: 8 true label: 8
   experts per step: [2 2 5 4 5]
   halt probs: [0.49158022 0.49427438 0.49975765 0.49807033 0.50176156]
 > pred label: 8 true label: 8
   experts per step: [2 2 5 4 5]
   halt probs: [0.48881182 0.49301866 0.49944386 0.49796697 0.5018949 ]
 > pred label: 8 true label: 8
   experts per step: [2 2 5 4 5]
   halt probs: [0.48881182 0.49301866 0.49944386 0.49796697 0.5018949 ]
 > pred label: 8 true label: 0
   experts per step: [2 1 5 4 5]
   halt probs: [0.49093142 0.49415016 0.49992025 0.49814925 0.50188524]
 > pred label: 8 true label: 0
   experts per s

In [None]:
def build_base_model_4_blocks(input_shape=(32,32,3), num_classes=10, filters=32):
    inputs = keras.Input(shape=input_shape)
    # For CNNs on GPU, BatchNorm is faster than LayerNorm:
    x = layers.Conv2D(filters, 3, padding='same', activation='relu')(inputs)
    x = layers.BatchNormalization()(x)

    x = ResidualBlock(filters)(x)
    x = PoolingLayer(filters=filters, frac_ratio=2.0)(x)
    x = ResidualBlock(filters)(x)
    x = ResidualBlock(filters)(x)
    x = PoolingLayer(filters=filters, frac_ratio=2.0)(x)
    x = ResidualBlock(filters)(x)

    x = layers.GlobalAveragePooling2D()(x)
    outputs = layers.Dense(num_classes, activation='softmax', dtype='float32')(x)  # keep fp32 logits
    return keras.Model(inputs, outputs)


def build_base_model_6_blocks(input_shape=(32,32,3), num_classes=10, filters=32):
    inputs = keras.Input(shape=input_shape)
    # For CNNs on GPU, BatchNorm is faster than LayerNorm:
    x = layers.Conv2D(filters, 3, padding='same', activation='relu')(inputs)
    x = layers.BatchNormalization()(x)

    x = ResidualBlock(filters)(x)
    x = PoolingLayer(filters=filters, frac_ratio=2.0)(x)
    x = ResidualBlock(filters)(x)
    x = ResidualBlock(filters)(x)
    x = PoolingLayer(filters=filters, frac_ratio=2.0)(x)
    x = ResidualBlock(filters)(x)
    x = ResidualBlock(filters)(x)
    x = PoolingLayer(filters=filters, frac_ratio=2.0)(x)
    x = ResidualBlock(filters)(x)

    x = layers.GlobalAveragePooling2D()(x)
    outputs = layers.Dense(num_classes, activation='softmax', dtype='float32')(x)  # keep fp32 logits
    return keras.Model(inputs, outputs)


TypeError: __init__() got an unexpected keyword argument 'pool_every_n'

In [9]:
model = build_base_model_4_blocks(
    input_shape=(32,32,3),
    num_classes=10,
    filters=FILTERS
)

model.compile(optimizer='adam',
                loss='sparse_categorical_crossentropy',
                metrics=['accuracy'])
model.summary()

model.fit(
    x_train, y_train,
    validation_data=(x_test, y_test),
    epochs=20, batch_size=512
)

Model: "model_61"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_2 (InputLayer)        [(None, 32, 32, 3)]       0         
                                                                 
 conv2d_1 (Conv2D)           (None, 32, 32, 32)        896       
                                                                 
 batch_normalization_1 (Bat  (None, 32, 32, 32)        128       
 chNormalization)                                                
                                                                 
 residual_block_4 (Residual  (None, 32, 32, 32)        20736     
 Block)                                                          
                                                                 
 pooling_layer_1 (PoolingLa  (None, 16, 16, 32)        1120      
 yer)                                                            
                                                          

2025-10-13 16:17:22.540030: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:114] Plugin optimizer for device_type GPU is enabled.




2025-10-13 16:17:42.444025: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:114] Plugin optimizer for device_type GPU is enabled.


Epoch 2/20
Epoch 3/20
Epoch 4/20
Epoch 5/20
Epoch 6/20
Epoch 7/20
Epoch 8/20
Epoch 9/20
Epoch 10/20
Epoch 11/20
Epoch 12/20
Epoch 13/20
Epoch 14/20
Epoch 15/20
Epoch 16/20
Epoch 17/20
Epoch 18/20
Epoch 19/20
Epoch 20/20


<keras.src.callbacks.History at 0x158b70f40>

In [10]:
model = build_base_model_6_blocks(
    input_shape=(32,32,3),
    num_classes=10,
    filters=FILTERS
)

model.compile(optimizer='adam',
                loss='sparse_categorical_crossentropy',
                metrics=['accuracy'])
model.summary()

model.fit(
    x_train, y_train,
    validation_data=(x_test, y_test),
    epochs=20, batch_size=512
)

Model: "model_62"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_3 (InputLayer)        [(None, 32, 32, 3)]       0         
                                                                 
 conv2d_2 (Conv2D)           (None, 32, 32, 32)        896       
                                                                 
 batch_normalization_2 (Bat  (None, 32, 32, 32)        128       
 chNormalization)                                                
                                                                 
 residual_block_8 (Residual  (None, 32, 32, 32)        20736     
 Block)                                                          
                                                                 
 pooling_layer_3 (PoolingLa  (None, 16, 16, 32)        1120      
 yer)                                                            
                                                          

2025-10-13 16:22:18.933112: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:114] Plugin optimizer for device_type GPU is enabled.




2025-10-13 16:22:38.732527: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:114] Plugin optimizer for device_type GPU is enabled.


Epoch 2/20
Epoch 3/20
Epoch 4/20
Epoch 5/20
Epoch 6/20
Epoch 7/20
Epoch 8/20
Epoch 9/20
Epoch 10/20
Epoch 11/20
Epoch 12/20
Epoch 13/20
Epoch 14/20
Epoch 15/20
Epoch 16/20
Epoch 17/20
Epoch 18/20
Epoch 19/20
Epoch 20/20


<keras.src.callbacks.History at 0x3b0a50370>