In [None]:
import tensorflow as tf
from tensorflow import keras
from utils import (
    ResidualBlock3x3, MaxPoolingLayer, ConvStem, EfficientConv3x3PoolingLayer, EfficientMultiStepRouter, EfficientResidualBlockDepthwise3x3,
    EfficientResidualBlockDepthwise5x5, EfficientResidualBlockDepthwise7x7, ChannelSE, SpatialSE, DummyBlock, EfficientResidualBlock3x3,
    SpatialAttention  
)
from callbacks import (
    CosineAnnealingScheduler, TempScheduler, RouterStatsMultiStepCallback, print_trace_for_samples, router_stats_multistep_once
)
import random
import numpy as np

SEED = 42 
random.seed(SEED)
np.random.seed(SEED)
tf.random.set_seed(SEED)


# simple CIFAR-10 aug
def cifar_preprocess(x, y):
    x = tf.image.resize_with_crop_or_pad(x, 36, 36)
    x = tf.image.random_crop(x, [tf.shape(x)[0], 32, 32, 3])
    x = tf.image.random_flip_left_right(x)
    return x, y

def make_dataset(x, y, batch=128, train=True):
    ds = tf.data.Dataset.from_tensor_slices((x, y))
    if train:
        ds = ds.shuffle(5000).batch(batch).map(cifar_preprocess, num_parallel_calls=tf.data.AUTOTUNE)
    else:
        ds = ds.batch(batch)
    return ds.prefetch(tf.data.AUTOTUNE)

def build_conv_router_net(net_layers, num_classes, input_shape=(32,32,3)):
    inputs = keras.Input(shape=input_shape)
    x = inputs

    for layer in net_layers:
        x = layer(x)

    x = tf.keras.layers.GlobalAveragePooling2D()(x)
    outputs = tf.keras.layers.Dense(num_classes, activation='softmax', dtype='float32')(x)  # keep fp32 logits
    return keras.Model(inputs, outputs, name="adaptive_model")


# Usage
(x_train, y_train), (x_test, y_test) = keras.datasets.cifar10.load_data()
x_train = x_train.astype("float32")/255.0; x_test = x_test.astype("float32")/255.0
y_train = y_train.flatten(); y_test = y_test.flatten()


ds_train = make_dataset(x_train, y_train, batch=128, train=True)
ds_val   = make_dataset(x_test, y_test, batch=256, train=False)


In [None]:
EPOCHS = 20 
ROUTE_TEMP = 2.0
ROUTER_DIM = 128
WIDTH = 2
STEPS = 3

net_layers = []
net_layers.append(ConvStem(32, 5))
net_layers.append(EfficientConv3x3PoolingLayer(64 * WIDTH))
net_layers.append(
    EfficientMultiStepRouter(
        branches=[
            EfficientResidualBlock3x3(64 * WIDTH),
            EfficientResidualBlockDepthwise3x3(64 * WIDTH),
            EfficientResidualBlockDepthwise5x5(64 * WIDTH),
            EfficientResidualBlockDepthwise7x7(64 * WIDTH),
            ChannelSE(ratio=4),
            SpatialSE(),
            SpatialAttention(),
            DummyBlock()
        ],
        route_temp=ROUTE_TEMP,
        router_dim=ROUTER_DIM,
        name="01_multi_block",
        steps=STEPS)
)
net_layers.append(EfficientConv3x3PoolingLayer(64 * WIDTH))
net_layers.append(
    EfficientMultiStepRouter(
        branches=[
            EfficientResidualBlock3x3(64 * WIDTH),
            EfficientResidualBlockDepthwise3x3(64 * WIDTH),
            EfficientResidualBlockDepthwise5x5(64 * WIDTH),
            EfficientResidualBlockDepthwise7x7(64 * WIDTH),
            ChannelSE(ratio=4),
            SpatialSE(),
            SpatialAttention(),
            DummyBlock()
        ],
        route_temp=ROUTE_TEMP,
        router_dim=ROUTER_DIM,
        name="02_multi_block",
        steps=STEPS)
)
net_layers.append(EfficientConv3x3PoolingLayer(64 * WIDTH))
net_layers.append(
    EfficientMultiStepRouter(
        branches=[
            EfficientResidualBlock3x3(64 * WIDTH),
            EfficientResidualBlockDepthwise3x3(64 * WIDTH),
            EfficientResidualBlockDepthwise5x5(64 * WIDTH),
            EfficientResidualBlockDepthwise7x7(64 * WIDTH),
            ChannelSE(ratio=4),
            SpatialSE(),
            SpatialAttention(),
            DummyBlock()
        ],
        route_temp=ROUTE_TEMP,
        router_dim=ROUTER_DIM,
        name="03_multi_block",
        steps=STEPS)
)

router_model = build_conv_router_net(net_layers, 10)

In [8]:
callbacks = [
    CosineAnnealingScheduler(base_lr=1e-3, min_lr=1e-5, epochs=EPOCHS),
    RouterStatsMultiStepCallback(x_test, y_test, layer_name="01_multi_block", verbose_every=5),
    RouterStatsMultiStepCallback(x_test, y_test, layer_name="02_multi_block", verbose_every=5),
    RouterStatsMultiStepCallback(x_test, y_test, layer_name="03_multi_block", verbose_every=5),
    TempScheduler(layer_name="01_multi_block", epochs=EPOCHS, route=(ROUTE_TEMP, 0.7), verbose=1, log=False),
    TempScheduler(layer_name="02_multi_block", epochs=EPOCHS, route=(ROUTE_TEMP, 0.7), verbose=1, log=False),
    TempScheduler(layer_name="03_multi_block", epochs=EPOCHS, route=(ROUTE_TEMP, 0.7), verbose=1, log=False)
]

router_model.build(input_shape=(None, 32, 32, 3))
router_model.compile(optimizer='adam',
                loss='sparse_categorical_crossentropy',
                metrics=['accuracy'])

router_model.summary()

Model: "adaptive_model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_2 (InputLayer)        [(None, 32, 32, 3)]       0         
                                                                 
 conv_stem_1 (ConvStem)      (None, 32, 32, 32)        2496      
                                                                 
 efficient_conv3x3_pooling_  (None, 16, 16, 128)       36992     
 layer_3 (EfficientConv3x3P                                      
 oolingLayer)                                                    
                                                                 
 01_multi_block (EfficientM  (None, 16, 16, 128)       92736     
 ultiStepRouter)                                                 
                                                                 
 efficient_conv3x3_pooling_  (None, 8, 8, 128)         147584    
 layer_4 (EfficientConv3x3P                         

In [9]:
router_model.fit(ds_train, epochs=EPOCHS, validation_data=ds_val, callbacks=callbacks)

> [LR Scheduler] epoch 1: lr=0.001000
Epoch 1/20


2025-10-23 16:26:33.321867: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:114] Plugin optimizer for device_type GPU is enabled.




2025-10-23 16:30:09.744527: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:114] Plugin optimizer for device_type GPU is enabled.



> [01_multi_block] epoch 1  eval: loss=1.2423  acc=54.93%

> [01_multi_block] epoch 1
  samples=10000  steps=3  experts=6
  avg_steps=6.00
  steps_hist=[0, 10000, 10000, 10000]
  step  1: total=10000  E0:99.8%  E1: 0.0%  E2: 0.2%  E3: 0.0%  E4: 0.0%  E5: 0.0%
             mean_top1=0.250  mean_entropy=1.765
             mean_prob_mass: E0:0.250 E1:0.144 E2:0.174 E3:0.157 E4:0.149 E5:0.126
  step  2: total=10000  E0:99.6%  E1: 0.0%  E2: 0.4%  E3: 0.0%  E4: 0.0%  E5: 0.0%
             mean_top1=0.250  mean_entropy=1.763
             mean_prob_mass: E0:0.250 E1:0.144 E2:0.181 E3:0.153 E4:0.147 E5:0.125
  step  3: total=10000  E0:99.2%  E1: 0.0%  E2: 0.8%  E3: 0.0%  E4: 0.0%  E5: 0.0%
             mean_top1=0.244  mean_entropy=1.765
             mean_prob_mass: E0:0.244 E1:0.145 E2:0.188 E3:0.150 E4:0.146 E5:0.127

> [02_multi_block] epoch 1  eval: loss=1.2423  acc=54.93%

> [02_multi_block] epoch 1
  samples=10000  steps=3  experts=6
  avg_steps=6.00
  steps_hist=[0, 10000, 10000, 10000]

KeyboardInterrupt: 

In [5]:
router_stats_multistep_once(router_model, x_test, layer_name="01_multi_block", verbose=True)
router_stats_multistep_once(router_model, x_test, layer_name="02_multi_block", verbose=True)
router_stats_multistep_once(router_model, x_test, layer_name="03_multi_block", verbose=True)



> [01_multi_block] samples=10000  steps=3  experts=6
  avg_steps=6.00
  steps_hist=[0, 10000, 10000, 10000]
  step  1: total=10000  E0:94.6%  E1: 0.0%  E2: 1.1%  E3: 0.0%  E4: 4.3%  E5: 0.0%
             mean_top1=0.212  mean_entropy=1.781
             mean_prob_mass: E0:0.211 E1:0.148 E2:0.152 E3:0.162 E4:0.181 E5:0.146
  step  2: total=10000  E0:94.8%  E1: 0.0%  E2: 2.6%  E3: 0.0%  E4: 2.5%  E5: 0.0%
             mean_top1=0.210  mean_entropy=1.782
             mean_prob_mass: E0:0.210 E1:0.149 E2:0.155 E3:0.162 E4:0.178 E5:0.146
  step  3: total=10000  E0:93.4%  E1: 0.1%  E2: 5.5%  E3: 0.1%  E4: 1.0%  E5: 0.0%
             mean_top1=0.202  mean_entropy=1.785
             mean_prob_mass: E0:0.202 E1:0.153 E2:0.159 E3:0.163 E4:0.174 E5:0.149

> [02_multi_block] samples=10000  steps=3  experts=6
  avg_steps=6.00
  steps_hist=[0, 10000, 10000, 10000]
  step  1: total=10000  E0: 1.4%  E1: 0.0%  E2: 5.4%  E3:12.7%  E4: 0.0%  E5:80.5%
             mean_top1=0.182  mean_entropy=1.789
     

KeyboardInterrupt: 

In [7]:
print_trace_for_samples(router_model, x_test, y_test, layer_name="01_multi_block", start=0, end=5)
print_trace_for_samples(router_model, x_test, y_test, layer_name="02_multi_block", start=0, end=5)
print_trace_for_samples(router_model, x_test, y_test, layer_name="03_multi_block", start=0, end=5)


 > pred label: 3 true label: 3
   experts per step: [0, 0, 2]
 > pred label: 1 true label: 8
   experts per step: [0, 0, 2]
 > pred label: 8 true label: 8
   experts per step: [0, 0, 2]
 > pred label: 8 true label: 0
   experts per step: [0, 0, 2]
 > pred label: 6 true label: 6
   experts per step: [0, 0, 2]

 > pred label: 3 true label: 3
   experts per step: [0, 2, 2]
 > pred label: 1 true label: 8
   experts per step: [2, 2, 2]
 > pred label: 8 true label: 8
   experts per step: [1, 2, 2]
 > pred label: 8 true label: 0
   experts per step: [2, 2, 2]
 > pred label: 6 true label: 6
   experts per step: [0, 1, 1]

 > pred label: 3 true label: 3
   experts per step: [2, 2, 2]
 > pred label: 1 true label: 8
   experts per step: [2, 2, 2]
 > pred label: 8 true label: 8
   experts per step: [2, 2, 2]
 > pred label: 8 true label: 0
   experts per step: [2, 2, 2]
 > pred label: 6 true label: 6
   experts per step: [2, 2, 2]
