# CGANs - Conditional Generative Adversarial Nets



In [1]:
#from tensorflow.keras import mixed_precision
#mixed_precision.set_global_policy("mixed_float16")
import tensorflow as tf
import numpy as np
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.layers import (
    Conv2D, Dense, GlobalAveragePooling2D, LayerNormalization, Add, DepthwiseConv2D, MaxPool2D
)
from utils import PoolingLayer, ResidualBlock, ResidualBlock3x3, ResidualBlock5x5, ResidualBlock7x7, SpatialSE, ChannelSE, ResidualBlockDepthwise3x3, ResidualBlockDepthwise5x5, ResidualBlockDepthwise7x7, ResidualBlockDepthwise9x9, DummyBlock



# ---------------------------------------------------------
# Softmax Router (no Gumbel). Optional hard mode at inference.
# ---------------------------------------------------------
class SoftmaxRouter(layers.Layer):
    def __init__(self, num_choices, hard_at_inference=False, **kwargs):
        super().__init__(**kwargs)
        self.num_choices = num_choices
        self.hard_at_inference = hard_at_inference
        self.logits_layer = Dense(num_choices)

    def call(self, features, training=None):
        logits = self.logits_layer(GlobalAveragePooling2D()(features))  # (B, K)
        if training or not self.hard_at_inference:
            probs = tf.nn.softmax(logits, axis=-1)                      # (B, K)
        else:
            idx = tf.argmax(logits, axis=-1)
            probs = tf.one_hot(idx, depth=self.num_choices, dtype=tf.float32)
        return probs  # (B, K)


# ---------------------------
# Tiny conv stem
# ---------------------------
class ConvStem(layers.Layer):
    def __init__(self, out_ch, **kw):
        super().__init__(**kw)
        self.conv = layers.Conv2D(out_ch, 3, padding="same", use_bias=False)
        self.norm = layers.LayerNormalization()
        self.act  = layers.Activation("swish")

    def call(self, x, training=None):
        x = self.conv(x)
        x = self.norm(x, training=training)
        x = self.act(x)
        return x


# ---------------------------
# Multi-head attention pooling router
# Produces logits over K experts
# ---------------------------
class AttnPoolRouter(layers.Layer):
    def __init__(self, K, heads=2, dim_head=64, mlp_hidden=64, **kw):
        super().__init__(**kw)
        self.K = int(K)
        self.heads = int(heads)
        self.dim_head = int(dim_head)
        self.mlp_hidden = int(mlp_hidden)

        self.q = self.add_weight(
            name="queries", shape=(self.heads, self.dim_head),
            initializer="glorot_uniform", trainable=True)

        self.key_proj = layers.Conv2D(self.heads*self.dim_head, 1, use_bias=False)
        self.val_proj = layers.Conv2D(self.heads*self.dim_head, 1, use_bias=False)

        # Head aggregator -> K logits
        if self.mlp_hidden > 0:
            self.head_mlp = keras.Sequential([
                layers.Dense(self.mlp_hidden, activation="swish", use_bias=False),
                layers.Dense(self.K, use_bias=False)
            ])
        else:
            self.head_mlp = layers.Dense(self.K, use_bias=False)

    def call(self, x, training=None):
        # x: [B,H,W,C]
        B, H, W, C = tf.shape(x)[0], tf.shape(x)[1], tf.shape(x)[2], tf.shape(x)[3]
        k = self.key_proj(x)  # [B,H,W,heads*dim]
        v = self.val_proj(x)
        k = tf.reshape(k, [B, H*W, self.heads, self.dim_head])  # [B,HW,Hd,D]
        v = tf.reshape(v, [B, H*W, self.heads, self.dim_head])  # [B,HW,Hd,D]
        k = tf.transpose(k, [0,2,1,3])  # [B,heads,HW,dim]
        v = tf.transpose(v, [0,2,1,3])  # [B,heads,HW,dim]

        # queries: [heads, dim] -> [B,heads,1,dim]
        q = tf.expand_dims(self.q, axis=0)
        q = tf.expand_dims(q, axis=2)

        # attn: [B,heads,1,HW]
        attn = tf.matmul(q, k, transpose_b=True) / tf.math.sqrt(tf.cast(self.dim_head, x.dtype))
        attn = tf.nn.softmax(attn, axis=-1)

        # pooled heads: [B,heads,1,dim]
        pooled = tf.matmul(attn, v)  # [B,heads,1,dim]
        pooled = tf.squeeze(pooled, axis=2)  # [B,heads,dim]

        # flatten heads
        pooled = tf.reshape(pooled, [B, self.heads*self.dim_head])  # [B, heads*dim]

        logits = self.head_mlp(pooled, training=training)  # [B,K]
        return logits, pooled  # pooled can be used as a feature if needed
    


In [2]:
# simple CIFAR-10 aug
def cifar_preprocess(x, y):
    x = tf.image.resize_with_crop_or_pad(x, 36, 36)
    x = tf.image.random_crop(x, [tf.shape(x)[0], 32, 32, 3])
    x = tf.image.random_flip_left_right(x)
    return x, y

def make_dataset(x, y, batch=128, train=True):
    ds = tf.data.Dataset.from_tensor_slices((x, y))
    if train:
        ds = ds.shuffle(5000).batch(batch).map(cifar_preprocess, num_parallel_calls=tf.data.AUTOTUNE)
    else:
        ds = ds.batch(batch)
    return ds.prefetch(tf.data.AUTOTUNE)


# Usage
(x_train, y_train), (x_test, y_test) = keras.datasets.cifar10.load_data()
x_train = x_train.astype("float32")/255.0; x_test = x_test.astype("float32")/255.0
y_train = y_train.flatten(); y_test = y_test.flatten()


ds_train = make_dataset(x_train, y_train, batch=128, train=True)
ds_val   = make_dataset(x_test, y_test, batch=256, train=False)


2025-10-16 15:10:16.225364: I metal_plugin/src/device/metal_device.cc:1154] Metal device set to: Apple M1 Pro
2025-10-16 15:10:16.225396: I metal_plugin/src/device/metal_device.cc:296] systemMemory: 16.00 GB
2025-10-16 15:10:16.225408: I metal_plugin/src/device/metal_device.cc:313] maxCacheSize: 5.33 GB
2025-10-16 15:10:16.225443: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:303] Could not identify NUMA node of platform GPU ID 0, defaulting to 0. Your kernel may not have been built with NUMA support.
2025-10-16 15:10:16.225457: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:269] Created TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:0 with 0 MB memory) -> physical PluggableDevice (device: 0, name: METAL, pci bus id: <undefined>)


In [None]:
class AdaptiveRouterBlockTop1Vectorized(layers.Layer):
    """
    Deterministic Top-1 routing (same in train & infer) with ST grads.
    Pooling is disabled. No halting logic: always run exactly `steps`.
    """
    def __init__(
        self,
        branches,
        steps=5,
        route_temp=1.0,
        router_settings={
            "heads": 2,
            "dim_head": 64,
            "mlp_hidden": 0
        },
        name=None
    ):
        super().__init__(name=name)
        self.branches = branches
        self.K = len(branches)
        self.router = AttnPoolRouter(
            K=self.K,
            dim_head=router_settings.get("dim_head", 64),
            mlp_hidden=router_settings.get("mlp_hidden", 0)
        )
        self._route_temp = float(route_temp)
        self.steps = int(steps)

    @property
    def route_temp(self):
        return self._route_temp

    @route_temp.setter
    def route_temp(self, v: float):
        self._route_temp = float(v)

    def get_config(self):
        cfg = super().get_config()
        cfg.update(dict(
            steps=self.steps,
            route_temp=self.route_temp,
            K=self.K
        ))
        return cfg

    def _diversity_from_means(self, V):
        """
        Computes diversity among branch outputs V (shape [K, ...]).
        Higher value means less diversity (more similar branches).
        Can be used as a regularization loss to encourage diversity.
        """
        V = tf.reshape(V, [tf.shape(V)[0], -1])  # flatten each branch output
        V = tf.nn.l2_normalize(V, axis=-1)       # normalize
        sims = tf.matmul(V, V, transpose_b=True) # [K, K] cosine similarity
        K = tf.shape(V)[0]
        mask = 1.0 - tf.eye(K, dtype=V.dtype)    # zero diagonal
        denom = tf.reduce_sum(mask)
        return tf.where(denom > 0, tf.reduce_sum(sims * mask) / denom, 0.0)

    def call(self, features, training=None):
        x = features
        dtype = x.dtype
        div_total = tf.constant(0.0, dtype=dtype)

        for _ in range(self.steps):
            # Route
            router_logits, _ = self.router(x, training=training)               # [B,K]
            probs  = tf.nn.softmax(router_logits / self.route_temp, axis=-1)   # [B,K]
            top_idx = tf.argmax(probs, axis=-1, output_type=tf.int32)          # [B]
            onehot  = tf.one_hot(top_idx, depth=self.K, dtype=dtype)           # [B,K]
            onehot_st = onehot + tf.stop_gradient(probs - onehot)              # straight-through

            # Run branches
            y_list = [br(x, training=training) for br in self.branches]        # K x [B,H,W,C]
            y_means = tf.stack([tf.reduce_mean(y, axis=[1,2,3]) for y in y_list], axis=0)  # [K,B]
            div_total += self._diversity_from_means(y_means)

            # Select top-1 branch output
            y_stack = tf.stack(y_list, axis=1)                                  # [B,K,H,W,C]
            mask = tf.reshape(onehot_st, [-1, self.K, 1, 1, 1])
            y_sel = tf.reduce_sum(mask * y_stack, axis=1)                       # [B,H,W,C]

            # Feed next step
            x = y_sel

        if training:
            # small weight on diversity; adjust coefficient as you like
            self.add_loss(0.05 * (div_total / float(self.steps)))

        return x


In [11]:
class CosineAnnealingScheduler(keras.callbacks.Callback):
    """
    Cosine annealing learning rate scheduler.
    """
    def __init__(self, base_lr, min_lr, epochs, verbose=1):
        super().__init__()
        self.base_lr = base_lr
        self.min_lr = min_lr
        self.epochs = epochs
        self.verbose = verbose

    def on_epoch_begin(self, epoch, logs=None):
        p = epoch / max(1, self.epochs - 1)
        lr = self.min_lr + 0.5 * (self.base_lr - self.min_lr) * (1 + np.cos(np.pi * p))
        keras.backend.set_value(self.model.optimizer.lr, lr)
        if self.verbose and (epoch < 1 or (epoch + 1) % 5 == 0):
            print(f"> [LR Scheduler] epoch {epoch+1}: lr={lr:.6f}")


class TempScheduler(keras.callbacks.Callback):
    """
    Linearly (or cosine) anneal route_temp over epochs.
    route: 1.5 -> 0.7
    halt:  3.0 -> 1.5
    """
    def __init__(self, layer_name="adaptive_router",
                 route_start=1.5, route_end=0.7,
                 epochs=150, mode="cosine"):
        super().__init__()
        self.layer_name = layer_name
        self.rs, self.re = float(route_start), float(route_end)
        self.E = int(epochs)
        self.mode = mode

    def _interp(self, e):
        p = min(1.0, e / max(1, self.E-1))
        if self.mode == "cosine":
            p = 0.5*(1 - np.cos(np.pi*p))
        return p

    def on_epoch_end(self, epoch, logs=None):
        p = self._interp(epoch)
        rtemp = self.rs + (self.re - self.rs)*p
        layer = self.model.get_layer(self.layer_name)
        layer.route_temp = rtemp
        if epoch < 1 or (epoch + 1) % 5 == 0:
            print("")
            print(f"> [TempScheduler] epoch {epoch+1}: route_temp={rtemp:.3f}")


class RouterStatsCallback(keras.callbacks.Callback):
    def __init__(self, x_val, y_val, layer_name="adaptive_router", batch_size=256):
        super().__init__()
        self.xv = x_val
        self.yv = y_val
        self.layer_name = layer_name
        self.bs = batch_size

    def on_epoch_end(self, epoch, logs=None):
        if epoch < 1 or (epoch + 1) % 5 == 0:
            layer = self.model.get_layer(self.layer_name)
            T, K = layer.steps, layer.K
            steps_hist = np.zeros(T+1, np.int64)
            expert_hist = np.zeros((T, K), np.int64)
            n = len(self.xv)
            for i in range(0, n, self.bs):
                xb = self.xv[i:i+self.bs]
                tb = trace_batch(self.model, xb, layer_name=self.layer_name, force_full=False)
                t_used = tb["top_indices"].shape[0]
                for t in range(t_used):
                    ch = tb["top_indices"][t]
                    cnt = np.bincount(ch, minlength=K)
                    expert_hist[t] += cnt
            print(f"> [{self.layer_name}] epoch {epoch+1}: expert_hist={expert_hist.tolist()}")



def trace_and_predict(
    model,
    x_input,
    y_true=None,
    layer_name="adaptive_router",
    force_full=False,   # set True to always loop steps, no early exit
):
    layer = model.get_layer(layer_name)
    pre = keras.Model(model.input, layer.input)
    x_in = tf.convert_to_tensor(x_input)
    x = pre(x_in, training=False)
    B = int(x.shape[0])
    K, T = layer.K, layer.steps
    dtype = x.dtype
    running = np.ones((B,), dtype=bool)

    top_indices, probs_list = [], []

    for t in range(T):
        router_logits, _ = layer.router(x, training=False)                         # [B,K]
        probs = tf.nn.softmax(router_logits / layer.route_temp, axis=-1)           # [B,K]
        top_idx = tf.argmax(probs, axis=-1, output_type=tf.int32)                  # [B]
        onehot  = tf.one_hot(top_idx, depth=K, dtype=dtype)
        onehot_st = onehot + tf.stop_gradient(probs - onehot)                      # ST

        y_list = [br(x, training=False) for br in layer.branches]
        y_stack = tf.stack(y_list, axis=1)                                         # [B,K,H,W,C]
        mask = tf.reshape(onehot_st, [-1, K, 1, 1, 1])
        y_sel = tf.reduce_sum(mask * y_stack, axis=1)                              # [B,H,W,C]
        x = y_sel  # no pooling

        top_indices.append(top_idx.numpy())
        probs_list.append(probs.numpy())

        if (not force_full) and (not running.any()):
            break

    pred_probs = model(x_in, training=False).numpy()
    pred_label = pred_probs.argmax(axis=-1).astype(np.int32)
    if y_true is not None:
        y_true_arr = np.asarray(y_true).reshape(-1)
        correct = (pred_label == y_true_arr)
    else:
        correct = None

    trace = {
        "top_indices":   np.array(top_indices),                 # [t_used, B]
        "probs":         np.array(probs_list),                  # [t_used, B, K]
    }
    return {
        "trace": trace,
        "pred_probs": pred_probs,
        "pred_label": pred_label,
        "true_label": None if y_true is None else np.asarray(y_true),
        "layer_info": {
            "steps": layer.steps,
            "K": layer.K,
            "route_temp": getattr(layer, "route_temp", None),
        },
        "correct": correct,
    }


def trace_batch(model, x_batch, layer_name="adaptive_router", force_full=False):
    layer = model.get_layer(layer_name)
    pre = keras.Model(model.input, layer.input)

    x_in = tf.convert_to_tensor(x_batch)
    x = pre(x_in, training=False)

    B = int(x.shape[0])
    K, T = layer.K, layer.steps
    running = np.ones((B,), dtype=bool)

    top_indices, probs_list = [], []

    for t in range(T):
        router_logits, _ = layer.router(x, training=False)                         # [B,K]
        probs = tf.nn.softmax(router_logits / layer.route_temp, axis=-1)           # [B,K]
        top_idx = tf.argmax(probs, axis=-1, output_type=tf.int32)                  # [B]
        onehot  = tf.one_hot(top_idx, depth=K, dtype=x.dtype)
        onehot_st = onehot + tf.stop_gradient(probs - onehot)                      # ST

        y_stack = tf.stack([br(x, training=False) for br in layer.branches], axis=1)  # [B,K,H,W,C]
        y_sel = tf.reduce_sum(tf.reshape(onehot_st, [-1, K, 1, 1, 1]) * y_stack, axis=1)

        x = y_sel  # no pooling

        # ---- collect step data ----
        top_indices.append(top_idx.numpy())
        probs_list.append(probs.numpy())

        # early exit only if no sample is still running
        if (not force_full) and (not running.any()):
            break

    trace = {
        "top_indices":   np.array(top_indices),                 # [t_used, B]
        "probs":         np.array(probs_list),                  # [t_used, B, K]
    }
    return trace



def evaluate_with_router_stats(model, x, y, layer_name="adaptive_router_top1",
                               batch_size=256, force_full=False):
    # accuracy
    loss, acc = model.evaluate(x, y, batch_size=batch_size, verbose=0)

    # router stats
    layer = model.get_layer(layer_name)
    K, T = layer.K, layer.steps

    steps_hist = np.zeros(T+1, dtype=np.int64)   # index t for steps=t, last bin for "T or more"
    expert_hist = np.zeros((T, K), dtype=np.int64)
    n_seen = 0

    for i in range(0, len(x), batch_size):
        xb = x[i:i+batch_size]
        tb = trace_batch(model, xb, layer_name=layer_name, force_full=force_full)

        t_used, B = tb["top_indices"].shape[0], tb["top_indices"].shape[1]
        n_seen += B

        # expert usage per step (only for steps that exist in this batch)
        for t in range(t_used):
            choices = tb["top_indices"][t]  # [B]
            counts = np.bincount(choices, minlength=K)
            expert_hist[t, :] += counts


    return {
        "loss": float(loss),
        "acc": float(acc),
        "expert_hist": expert_hist,  # [T, K]
        "seen": n_seen,
        "K": K,
        "T": T,
    }

def print_router_stats(model, x, y, layer_name="adaptive_router_1", batch_size=512, force_full=False):
    stats = evaluate_with_router_stats(
        model, x, y,
        layer_name=layer_name,
        batch_size=batch_size,
        force_full=force_full
    )
    print("")
    print(f"================== {layer_name} ====================")
    print(f"Test acc: {stats['acc']*100:.2f}%  |  samples: {stats['seen']}")
    # Only print halt rates for steps with expert usage
    for k in range(stats["K"]):
        total = stats["expert_hist"][:,k].sum()
        print(f"Expert {k} total usage: {total} ({total/stats['seen']:.3f} per sample on avg)")
    for s in range(len(stats["expert_hist"])):
        step_total = stats["expert_hist"][s].sum()
        print(f"Expert usage at step {s}:", stats["expert_hist"][s])


def print_trace_for_samples(model, x, y, layer_name="adaptive_router_1", start=1000, end=1020):
    """
    Prints routing trace and prediction for samples in the given range.
    """
    print("")
    print(f"================== {layer_name} ====================")
    for i in range(start, end):
        res = trace_and_predict(model, x[i:i+1], y_true=y[i:i+1], layer_name=layer_name)
        if len(res["trace"]["top_indices"][:, 0]) > 0:
            trace = res["trace"]
            print(" > pred label:", res["pred_label"][0], "true label:", int(res["true_label"][0]))
            print("   experts per step:", trace["top_indices"][:, 0])
                  


def build_adaptive_model_sparse(branches, input_shape=(32,32,3), num_classes=10, filters=32,
                                steps=5, route_temp=1.0, router_settings={"head": 2, "dim_head": 64, "mlp_hidden": 0}):

    inputs = keras.Input(shape=input_shape)
    # For CNNs on GPU, BatchNorm is faster than LayerNorm:
    x = layers.Conv2D(32, 3, padding='same', activation='swish')(inputs)
    x = layers.BatchNormalization()(x)

    x = PoolingLayer(filters=filters, frac_ratio=2.0)(x)

    x = AdaptiveRouterBlockTop1Vectorized(
        branches=branches[0],
        steps=steps,
        route_temp=route_temp,
        router_settings=router_settings,
        name="adaptive_router_1",
    )(x)

    x = ResidualBlock3x3(filters=filters)(x)

    x = PoolingLayer(filters=filters, frac_ratio=2.0)(x)

    x = AdaptiveRouterBlockTop1Vectorized(
        branches=branches[1],
        steps=steps,
        route_temp=route_temp,
        router_settings=router_settings,
        name="adaptive_router_2",
    )(x)

    x = ResidualBlock3x3(filters=filters)(x)

    x = layers.GlobalAveragePooling2D()(x)
    outputs = layers.Dense(num_classes, activation='softmax', dtype='float32')(x)  # keep fp32 logits
    return keras.Model(inputs, outputs, name="adaptive_model")

In [None]:
FILTERS = 64
STEPS = 5
ROUTE_TEMP = 5.0
EPOCHS = 20

branches = [
    [
        ResidualBlock3x3(FILTERS),
        ResidualBlock5x5(FILTERS),
        ResidualBlockDepthwise7x7(FILTERS),
        ChannelSE(FILTERS),
        SpatialSE(),
        DummyBlock()
    ],
    [
        ResidualBlock3x3(FILTERS),
        ResidualBlock5x5(FILTERS),
        ResidualBlockDepthwise7x7(FILTERS),
        ResidualBlockDepthwise9x9(FILTERS),
        ChannelSE(FILTERS),
        DummyBlock()
    ]
]

router_model = build_adaptive_model_sparse(
    branches=branches,
    input_shape=(32,32,3),
    num_classes=10,
    filters=FILTERS,
    steps=STEPS,     # >0 encourages fewer steps; tune as needed
    router_settings={
        "heads": 8,
        "dim_head": 32,
        "mlp_hidden": 32,
    },
    route_temp=ROUTE_TEMP
)



callbacks = [
    TempScheduler(layer_name="adaptive_router_1", epochs=EPOCHS, mode="cosine", route_start=ROUTE_TEMP, route_end=0.7),
    RouterStatsCallback(x_test, y_test, layer_name="adaptive_router_1"),
    RouterStatsCallback(x_test, y_test, layer_name="adaptive_router_2"),
    CosineAnnealingScheduler(base_lr=3e-3, min_lr=1e-5, epochs=EPOCHS)
]

router_model.build(input_shape=(None, 32, 32, 3))
router_model.compile(
    optimizer='adam',
    loss='sparse_categorical_crossentropy',
    metrics=['accuracy'],
    run_eagerly=False)

#_ = router_model(tf.zeros((1, 32, 32, 3)))

router_model.summary()

#tf.keras.utils.plot_model(router_model, show_shapes=True)

Model: "adaptive_model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_2 (InputLayer)        [(None, 32, 32, 3)]       0         
                                                                 
 conv2d_5 (Conv2D)           (None, 32, 32, 32)        896       
                                                                 
 batch_normalization_1 (Bat  (None, 32, 32, 32)        128       
 chNormalization)                                                
                                                                 
 pooling_layer_2 (PoolingLa  (None, 16, 16, 64)        2240      
 yer)                                                            
                                                                 
 adaptive_router_1 (Adaptiv  (None, 16, 16, 64)        182944    
 eRouterBlockTop1Vectorized                                      
 )                                                  

In [9]:
router_model.fit(ds_train, epochs=EPOCHS, validation_data=ds_val, callbacks=callbacks)

> [LR Scheduler] epoch 1: lr=0.003000
Epoch 1/20


2025-10-16 15:17:31.925441: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:114] Plugin optimizer for device_type GPU is enabled.




2025-10-16 15:21:51.713281: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:114] Plugin optimizer for device_type GPU is enabled.


> [TempScheduler] epoch 1: route_temp=5.000
> [adaptive_router_1] epoch 1: expert_hist=[[8661, 389, 0, 282, 668, 0], [5990, 410, 0, 3044, 556, 0], [3195, 500, 0, 5819, 486, 0], [1832, 682, 1, 7101, 384, 0], [807, 1583, 70, 7436, 104, 0]]
> [adaptive_router_2] epoch 1: expert_hist=[[1, 2980, 637, 0, 1388, 4994], [47, 3378, 0, 0, 3779, 2796], [414, 4018, 0, 28, 4840, 700], [618, 4538, 1, 60, 4783, 0], [603, 4626, 7, 105, 4659, 0]]
Epoch 2/20
Epoch 3/20
Epoch 4/20
> [LR Scheduler] epoch 5: lr=0.002685
Epoch 5/20
> [adaptive_router_1] epoch 5: expert_hist=[[7490, 0, 130, 2167, 199, 14], [6047, 7, 135, 3132, 677, 2], [5403, 518, 26, 2985, 1068, 0], [4346, 3412, 0, 1575, 667, 0], [3001, 6297, 0, 447, 255, 0]]
> [adaptive_router_2] epoch 5: expert_hist=[[44, 3434, 880, 151, 3, 5488], [462, 3678, 1481, 136, 1213, 3030], [1184, 3707, 1402, 226, 2515, 966], [1596, 3815, 1539, 391, 2085, 574], [1828, 4135, 1687, 625, 1370, 355]]
Epoch 6/20
Epoch 7/20
Epoch 8/20
Epoch 9/20
> [LR Scheduler] epoch 1

KeyboardInterrupt: 

In [12]:
print_router_stats(router_model, x_test, y_test, layer_name="adaptive_router_1", batch_size=512)
print_router_stats(router_model, x_test, y_test, layer_name="adaptive_router_2", batch_size=512)


Test acc: 80.46%  |  samples: 10000
Expert 0 total usage: 32449 (3.245 per sample on avg)
Expert 1 total usage: 8580 (0.858 per sample on avg)
Expert 2 total usage: 639 (0.064 per sample on avg)
Expert 3 total usage: 2627 (0.263 per sample on avg)
Expert 4 total usage: 4821 (0.482 per sample on avg)
Expert 5 total usage: 884 (0.088 per sample on avg)
Expert usage at step 0: [8352    0   52 1545    1   50]
Expert usage at step 1: [8443    1  319  663  448  126]
Expert usage at step 2: [7171  356  240  300 1673  260]
Expert usage at step 3: [5449 2247   27  114 1890  273]
Expert usage at step 4: [3034 5976    1    5  809  175]

Test acc: 80.46%  |  samples: 10000
Expert 0 total usage: 4669 (0.467 per sample on avg)
Expert 1 total usage: 22583 (2.258 per sample on avg)
Expert 2 total usage: 17359 (1.736 per sample on avg)
Expert 3 total usage: 3005 (0.300 per sample on avg)
Expert 4 total usage: 1471 (0.147 per sample on avg)
Expert 5 total usage: 913 (0.091 per sample on avg)
Expert usa

In [13]:
print_trace_for_samples(router_model, x_test, y_test, layer_name="adaptive_router_1", start=0, end=5)
print_trace_for_samples(router_model, x_test, y_test, layer_name="adaptive_router_2", start=0, end=5)


 > pred label: 3 true label: 3
   experts per step: [0 0 0 0 0]
 > pred label: 8 true label: 8
   experts per step: [0 0 0 0 0]
 > pred label: 1 true label: 8
   experts per step: [3 3 4 4 1]
 > pred label: 0 true label: 0
   experts per step: [0 0 0 1 1]
 > pred label: 6 true label: 6
   experts per step: [0 0 0 0 1]

 > pred label: 3 true label: 3
   experts per step: [1 1 1 1 1]
 > pred label: 8 true label: 8
   experts per step: [1 1 1 1 1]
 > pred label: 1 true label: 8
   experts per step: [1 1 1 1 1]
 > pred label: 0 true label: 0
   experts per step: [1 1 1 1 1]
 > pred label: 6 true label: 6
   experts per step: [3 0 0 0 0]


In [None]:
def build_base_model_4_blocks(input_shape=(32,32,3), num_classes=10, filters=32):
    inputs = keras.Input(shape=input_shape)
    # For CNNs on GPU, BatchNorm is faster than LayerNorm:
    x = layers.Conv2D(filters, 3, padding='same', activation='relu')(inputs)
    x = layers.BatchNormalization()(x)

    x = ResidualBlock(filters)(x)
    x = PoolingLayer(filters=filters, frac_ratio=2.0)(x)
    x = ResidualBlock(filters)(x)
    x = ResidualBlock(filters)(x)
    x = PoolingLayer(filters=filters, frac_ratio=2.0)(x)
    x = ResidualBlock(filters)(x)

    x = layers.GlobalAveragePooling2D()(x)
    outputs = layers.Dense(num_classes, activation='softmax', dtype='float32')(x)  # keep fp32 logits
    return keras.Model(inputs, outputs)


def build_base_model_6_blocks(input_shape=(32,32,3), num_classes=10, filters=32):
    inputs = keras.Input(shape=input_shape)
    # For CNNs on GPU, BatchNorm is faster than LayerNorm:
    x = layers.Conv2D(filters, 3, padding='same', activation='relu')(inputs)
    x = layers.BatchNormalization()(x)

    x = ResidualBlock(filters)(x)
    x = PoolingLayer(filters=filters, frac_ratio=2.0)(x)
    x = ResidualBlock(filters)(x)
    x = ResidualBlock(filters)(x)
    x = PoolingLayer(filters=filters, frac_ratio=2.0)(x)
    x = ResidualBlock(filters)(x)
    x = ResidualBlock(filters)(x)
    x = PoolingLayer(filters=filters, frac_ratio=2.0)(x)
    x = ResidualBlock(filters)(x)

    x = layers.GlobalAveragePooling2D()(x)
    outputs = layers.Dense(num_classes, activation='softmax', dtype='float32')(x)  # keep fp32 logits
    return keras.Model(inputs, outputs)


TypeError: __init__() got an unexpected keyword argument 'pool_every_n'

In [9]:
model = build_base_model_4_blocks(
    input_shape=(32,32,3),
    num_classes=10,
    filters=FILTERS
)

model.compile(optimizer='adam',
                loss='sparse_categorical_crossentropy',
                metrics=['accuracy'])
model.summary()

model.fit(
    x_train, y_train,
    validation_data=(x_test, y_test),
    epochs=20, batch_size=512
)

Model: "model_61"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_2 (InputLayer)        [(None, 32, 32, 3)]       0         
                                                                 
 conv2d_1 (Conv2D)           (None, 32, 32, 32)        896       
                                                                 
 batch_normalization_1 (Bat  (None, 32, 32, 32)        128       
 chNormalization)                                                
                                                                 
 residual_block_4 (Residual  (None, 32, 32, 32)        20736     
 Block)                                                          
                                                                 
 pooling_layer_1 (PoolingLa  (None, 16, 16, 32)        1120      
 yer)                                                            
                                                          

2025-10-13 16:17:22.540030: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:114] Plugin optimizer for device_type GPU is enabled.




2025-10-13 16:17:42.444025: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:114] Plugin optimizer for device_type GPU is enabled.


Epoch 2/20
Epoch 3/20
Epoch 4/20
Epoch 5/20
Epoch 6/20
Epoch 7/20
Epoch 8/20
Epoch 9/20
Epoch 10/20
Epoch 11/20
Epoch 12/20
Epoch 13/20
Epoch 14/20
Epoch 15/20
Epoch 16/20
Epoch 17/20
Epoch 18/20
Epoch 19/20
Epoch 20/20


<keras.src.callbacks.History at 0x158b70f40>

In [10]:
model = build_base_model_6_blocks(
    input_shape=(32,32,3),
    num_classes=10,
    filters=FILTERS
)

model.compile(optimizer='adam',
                loss='sparse_categorical_crossentropy',
                metrics=['accuracy'])
model.summary()

model.fit(
    x_train, y_train,
    validation_data=(x_test, y_test),
    epochs=20, batch_size=512
)

Model: "model_62"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_3 (InputLayer)        [(None, 32, 32, 3)]       0         
                                                                 
 conv2d_2 (Conv2D)           (None, 32, 32, 32)        896       
                                                                 
 batch_normalization_2 (Bat  (None, 32, 32, 32)        128       
 chNormalization)                                                
                                                                 
 residual_block_8 (Residual  (None, 32, 32, 32)        20736     
 Block)                                                          
                                                                 
 pooling_layer_3 (PoolingLa  (None, 16, 16, 32)        1120      
 yer)                                                            
                                                          

2025-10-13 16:22:18.933112: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:114] Plugin optimizer for device_type GPU is enabled.




2025-10-13 16:22:38.732527: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:114] Plugin optimizer for device_type GPU is enabled.


Epoch 2/20
Epoch 3/20
Epoch 4/20
Epoch 5/20
Epoch 6/20
Epoch 7/20
Epoch 8/20
Epoch 9/20
Epoch 10/20
Epoch 11/20
Epoch 12/20
Epoch 13/20
Epoch 14/20
Epoch 15/20
Epoch 16/20
Epoch 17/20
Epoch 18/20
Epoch 19/20
Epoch 20/20


<keras.src.callbacks.History at 0x3b0a50370>