In [1]:
import tensorflow as tf
from tensorflow import keras
from utils import (
    PoolingLayer, ResidualBlock, ResidualBlock3x3, ResidualBlock5x5, ResidualBlock7x7, SpatialSE, ChannelSE, 
    ResidualBlockDepthwise3x3, ResidualBlockDepthwise5x5, ResidualBlockDepthwise7x7, ResidualBlockDepthwise9x9, 
    DummyBlock, Conv3x3PoolingLayer, Depthwise3x3ConvPoolingLayer, MaxPoolingLayer, AvgPoolingLayer,
    Conv5x5PoolingLayer, Depthwise5x5ConvPoolingLayer, Depthwise7x7ConvPoolingLayer, AdaptiveRouter,
    ConvStem, DepthwiseConvStem, SpatialAttention
)
from InitConvRouterNet import (
    get_stem, get_pooling, get_block, get_init_stem, get_init_block, get_init_pooling
)
from callbacks import (
    CosineAnnealingScheduler, TempScheduler, RouterStatsCallback, TopNScheduler, linear_topn_schedule, milestone_topn_schedule, print_router_stats
)
import random
import numpy as np

SEED = 42 
random.seed(SEED)
np.random.seed(SEED)
tf.random.set_seed(SEED)


# simple CIFAR-10 aug
def cifar_preprocess(x, y):
    x = tf.image.resize_with_crop_or_pad(x, 36, 36)
    x = tf.image.random_crop(x, [tf.shape(x)[0], 32, 32, 3])
    x = tf.image.random_flip_left_right(x)
    return x, y

def make_dataset(x, y, batch=128, train=True):
    ds = tf.data.Dataset.from_tensor_slices((x, y))
    if train:
        ds = ds.shuffle(5000).batch(batch).map(cifar_preprocess, num_parallel_calls=tf.data.AUTOTUNE)
    else:
        ds = ds.batch(batch)
    return ds.prefetch(tf.data.AUTOTUNE)

def build_conv_router_net(net_layers, num_classes, input_shape=(32,32,3)):
    inputs = keras.Input(shape=input_shape)
    x = inputs

    for layer in net_layers:
        x = layer(x)

    x = tf.keras.layers.GlobalAveragePooling2D()(x)
    outputs = tf.keras.layers.Dense(num_classes, activation='softmax', dtype='float32')(x)  # keep fp32 logits
    return keras.Model(inputs, outputs, name="adaptive_model")


# Usage
(x_train, y_train), (x_test, y_test) = keras.datasets.cifar10.load_data()
x_train = x_train.astype("float32")/255.0; x_test = x_test.astype("float32")/255.0
y_train = y_train.flatten(); y_test = y_test.flatten()


ds_train = make_dataset(x_train, y_train, batch=128, train=True)
ds_val   = make_dataset(x_test, y_test, batch=256, train=False)


2025-10-22 12:51:03.260921: I metal_plugin/src/device/metal_device.cc:1154] Metal device set to: Apple M1 Pro
2025-10-22 12:51:03.261012: I metal_plugin/src/device/metal_device.cc:296] systemMemory: 16.00 GB
2025-10-22 12:51:03.261027: I metal_plugin/src/device/metal_device.cc:313] maxCacheSize: 5.33 GB
2025-10-22 12:51:03.261067: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:303] Could not identify NUMA node of platform GPU ID 0, defaulting to 0. Your kernel may not have been built with NUMA support.
2025-10-22 12:51:03.261095: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:269] Created TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:0 with 0 MB memory) -> physical PluggableDevice (device: 0, name: METAL, pci bus id: <undefined>)


In [2]:
EPOCHS = 20 
ROUTER_TEMP = 1.5
DIVERSITY_TAU = 1e-2
TOP_N = 7

layers = []
layers.append(get_init_stem(16, "01_stem", TOP_N, ROUTER_TEMP, DIVERSITY_TAU))
layers.append(get_init_pooling(32, "02_pool", TOP_N, ROUTER_TEMP, DIVERSITY_TAU))
layers.append(get_init_block(32, "03_block", TOP_N, ROUTER_TEMP, DIVERSITY_TAU))
layers.append(get_init_block(32, "04_block", TOP_N, ROUTER_TEMP, DIVERSITY_TAU))
layers.append(get_init_pooling(64, "05_pool", TOP_N, ROUTER_TEMP, DIVERSITY_TAU))
layers.append(get_init_block(64, "06_block", TOP_N, ROUTER_TEMP, DIVERSITY_TAU))
layers.append(get_init_block(64, "07_block", TOP_N, ROUTER_TEMP, DIVERSITY_TAU))
layers.append(get_init_pooling(128, "08_pool", TOP_N, ROUTER_TEMP, DIVERSITY_TAU))
layers.append(get_init_block(128, "09_block", TOP_N, ROUTER_TEMP, DIVERSITY_TAU))
layers.append(get_init_block(128, "10_block", TOP_N, ROUTER_TEMP, DIVERSITY_TAU))
layers.append(get_init_pooling(256, "11_pool", TOP_N, ROUTER_TEMP, DIVERSITY_TAU))

all_layers = []


router_model = build_conv_router_net(layers, 10)

In [6]:
callbacks = [
    CosineAnnealingScheduler(base_lr=1e-3, min_lr=1e-5, epochs=EPOCHS),
]
for layer in layers:
    callbacks.append(RouterStatsCallback(x_test, y_test, layer_name=layer.name, verbose_every=10))
    callbacks.append(TempScheduler(layer_name=layer.name, epochs=EPOCHS, route=(ROUTER_TEMP, 0.5), eps=(0, 0.00), ent=(0, 0.0), lb=(0, 0.0), verbose=1, log=False))
    callbacks.append(TopNScheduler(schedule=linear_topn_schedule(n_start=TOP_N, n_end=1, start_epoch=0, end_epoch=int(EPOCHS * 0.8)),verbose=1))

router_model.build(input_shape=(None, 32, 32, 3))
router_model.compile(optimizer='adam',
                loss='sparse_categorical_crossentropy',
                metrics=['accuracy'])
#AdamW(lr=3e-3, weight_decay=0.05)

router_model.summary()

Model: "adaptive_model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_2 (InputLayer)        [(None, 32, 32, 3)]       0         
                                                                 
 01_stem (AdaptiveRouter)    (None, 32, 32, 16)        12097     
                                                                 
 02_pool (AdaptiveRouter)    (None, 16, 16, 32)        28481     
                                                                 
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_2 (InputLayer)        [(None, 32, 32, 3)]       0         
                                                                 
 01_stem (AdaptiveRouter)    (None, 32, 32, 16)        12097     
                                                                 
 02_pool (AdaptiveRouter)    (None, 16, 16, 32)     

In [7]:
router_model.fit(ds_train, epochs=EPOCHS, validation_data=ds_val, callbacks=callbacks)

> [LR Scheduler] epoch 1: lr=0.001000
Epoch 1/20
Epoch 1/20


2025-10-21 13:37:37.361125: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:114] Plugin optimizer for device_type GPU is enabled.




2025-10-21 13:43:56.275581: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:114] Plugin optimizer for device_type GPU is enabled.


> [01_stem] epoch 1: avg_steps=1.00  steps_hist=[0, 10000]  expert_hist=[[0, 0, 0, 0, 10000, 0, 0]]
> [02_pool] epoch 1: avg_steps=1.00  steps_hist=[0, 10000]  expert_hist=[[0, 9991, 0, 0, 0, 9, 0]]
> [02_pool] epoch 1: avg_steps=1.00  steps_hist=[0, 10000]  expert_hist=[[0, 9991, 0, 0, 0, 9, 0]]
> [03_block] epoch 1: avg_steps=1.00  steps_hist=[0, 10000]  expert_hist=[[0, 0, 0, 0, 0, 0, 10000, 0]]
> [03_block] epoch 1: avg_steps=1.00  steps_hist=[0, 10000]  expert_hist=[[0, 0, 0, 0, 0, 0, 10000, 0]]
> [04_block] epoch 1: avg_steps=1.00  steps_hist=[0, 10000]  expert_hist=[[0, 0, 10000, 0, 0, 0, 0, 0]]
> [04_block] epoch 1: avg_steps=1.00  steps_hist=[0, 10000]  expert_hist=[[0, 0, 10000, 0, 0, 0, 0, 0]]
> [05_pool] epoch 1: avg_steps=1.00  steps_hist=[0, 10000]  expert_hist=[[0, 0, 0, 0, 0, 6175, 3825]]
> [05_pool] epoch 1: avg_steps=1.00  steps_hist=[0, 10000]  expert_hist=[[0, 0, 0, 0, 0, 6175, 3825]]
> [06_block] epoch 1: avg_steps=1.00  steps_hist=[0, 10000]  expert_hist=[[0, 0, 9

<keras.src.callbacks.History at 0x3edcc4820>

In [8]:
for layer in layers:
    print_router_stats(router_model, x_test, y_test, layer_name=layer.name, batch_size=512)

: 

In [2]:
EPOCHS = 150 
ROUTER_TEMP = 1.5
DIVERSITY_TAU = 1e-2
TOP_N = 3

layers = []
layers.append(get_stem([DepthwiseConvStem(16, 5)], "01_stem", TOP_N, ROUTER_TEMP, DIVERSITY_TAU))
layers.append(get_pooling([Conv3x3PoolingLayer(32), Conv5x5PoolingLayer(32)], "02_pool", TOP_N, ROUTER_TEMP, DIVERSITY_TAU))
layers.append(get_block([SpatialSE()], "03_block", TOP_N, ROUTER_TEMP, DIVERSITY_TAU))
layers.append(get_block([ResidualBlockDepthwise3x3(32), SpatialAttention(kernel_size=7)], "04_block", TOP_N, ROUTER_TEMP, DIVERSITY_TAU))
layers.append(get_pooling([Conv3x3PoolingLayer(64), Conv5x5PoolingLayer(64)], "05_pool", TOP_N, ROUTER_TEMP, DIVERSITY_TAU))
layers.append(get_block([ResidualBlock3x3(64), ResidualBlockDepthwise3x3(64), SpatialSE()], "06_block", TOP_N, ROUTER_TEMP, DIVERSITY_TAU))
layers.append(get_block([ResidualBlock5x5(64), SpatialSE(), SpatialAttention(kernel_size=7)], "07_block", TOP_N, ROUTER_TEMP, DIVERSITY_TAU))
layers.append(get_pooling([Conv5x5PoolingLayer(128), AvgPoolingLayer(128)], "08_pool", TOP_N, ROUTER_TEMP, DIVERSITY_TAU))
layers.append(get_block([ResidualBlockDepthwise3x3(128), ResidualBlockDepthwise5x5(128)], "09_block", TOP_N, ROUTER_TEMP, DIVERSITY_TAU))
layers.append(get_block([SpatialSE()], "10_block", TOP_N, ROUTER_TEMP, DIVERSITY_TAU))
layers.append(get_pooling([Conv5x5PoolingLayer(256)], "11_pool", TOP_N, ROUTER_TEMP, DIVERSITY_TAU))

all_layers = []


router_model = build_conv_router_net(layers, 10)

In [3]:
callbacks = [
    CosineAnnealingScheduler(base_lr=1e-3, min_lr=1e-5, epochs=EPOCHS),
]
for layer in layers:
    callbacks.append(RouterStatsCallback(x_test, y_test, layer_name=layer.name, verbose_every=10))
    callbacks.append(TempScheduler(layer_name=layer.name, epochs=EPOCHS, route=(ROUTER_TEMP, 0.5), eps=(0, 0.00), ent=(0, 0.0), lb=(0, 0.0), verbose=1, log=False))
    callbacks.append(TopNScheduler(schedule=linear_topn_schedule(n_start=TOP_N, n_end=1, start_epoch=0, end_epoch=int(EPOCHS * 0.8)),verbose=1))

router_model.build(input_shape=(None, 32, 32, 3))
router_model.compile(optimizer='adam',
                loss='sparse_categorical_crossentropy',
                metrics=['accuracy'])
#AdamW(lr=3e-3, weight_decay=0.05)

router_model.summary()

Model: "adaptive_model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_1 (InputLayer)        [(None, 32, 32, 3)]       0         
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_1 (InputLayer)        [(None, 32, 32, 3)]       0         


                                                                 
 01_stem (AdaptiveRouter)    (None, 32, 32, 16)        5161      
                                                                 
 02_pool (AdaptiveRouter)    (None, 16, 16, 32)        30145     
                                                                 
 03_block (AdaptiveRouter)   (None, 16, 16, 32)        11457     
                                                                 
 04_block (AdaptiveRouter)   (None, 16, 16, 32)        11363     
                                                                 
 05_pool (AdaptiveRouter)    (None, 8, 8, 64)          86657     
                                                                 
 06_block (AdaptiveRouter)   (None, 8, 8, 64)          111105    
                                                                 
 07_block (AdaptiveRouter)   (None, 8, 8, 64)          74499     
                                                                 
 08_pool (

In [4]:
router_model.fit(ds_train, epochs=EPOCHS, validation_data=ds_val, callbacks=callbacks)

> [LR Scheduler] epoch 1: lr=0.001000
[TopNScheduler] epoch 0: 01_stem.top_n=1/1; 02_pool.top_n=2/2; 03_block.top_n=1/1; 04_block.top_n=2/2; 05_pool.top_n=2/2; 08_pool.top_n=2/2; 09_block.top_n=2/2; 10_block.top_n=1/1; 11_pool.top_n=1/1
Epoch 1/150
Epoch 1/150


2025-10-22 12:51:26.740866: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:114] Plugin optimizer for device_type GPU is enabled.




2025-10-22 12:53:06.064393: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:114] Plugin optimizer for device_type GPU is enabled.


> [01_stem] epoch 1: avg_steps=1.00  steps_hist=[0, 10000]  expert_hist=[[10000]]
> [02_pool] epoch 1: avg_steps=1.00  steps_hist=[0, 10000]  expert_hist=[[5042, 4958]]
> [02_pool] epoch 1: avg_steps=1.00  steps_hist=[0, 10000]  expert_hist=[[5042, 4958]]
> [03_block] epoch 1: avg_steps=1.00  steps_hist=[0, 10000]  expert_hist=[[10000]]
> [03_block] epoch 1: avg_steps=1.00  steps_hist=[0, 10000]  expert_hist=[[10000]]
> [04_block] epoch 1: avg_steps=1.00  steps_hist=[0, 10000]  expert_hist=[[9517, 483]]
> [04_block] epoch 1: avg_steps=1.00  steps_hist=[0, 10000]  expert_hist=[[9517, 483]]
> [05_pool] epoch 1: avg_steps=1.00  steps_hist=[0, 10000]  expert_hist=[[10000, 0]]
> [05_pool] epoch 1: avg_steps=1.00  steps_hist=[0, 10000]  expert_hist=[[10000, 0]]
> [06_block] epoch 1: avg_steps=1.00  steps_hist=[0, 10000]  expert_hist=[[1, 0, 9999]]
> [06_block] epoch 1: avg_steps=1.00  steps_hist=[0, 10000]  expert_hist=[[1, 0, 9999]]
> [07_block] epoch 1: avg_steps=1.00  steps_hist=[0, 10000

KeyboardInterrupt: 