In [1]:
import tensorflow as tf
from tensorflow import keras
from utils import (
    PoolingLayer, ResidualBlock, ResidualBlock3x3, ResidualBlock5x5, ResidualBlock7x7, SpatialSE, ChannelSE, 
    ResidualBlockDepthwise3x3, ResidualBlockDepthwise5x5, ResidualBlockDepthwise7x7, ResidualBlockDepthwise9x9, 
    DummyBlock, Conv3x3PoolingLayer, Depthwise3x3ConvPoolingLayer, MaxPoolingLayer, AvgPoolingLayer,
    Conv5x5PoolingLayer, Depthwise5x5ConvPoolingLayer, Depthwise7x7ConvPoolingLayer, AdaptiveRouter,
    ConvStem, DepthwiseConvStem, SpatialAttention
)
from InitConvRouterNet import (
    get_stem, get_pooling, get_block, get_init_stem, get_init_block, get_init_pooling, get_multi_step_init_block, get_multi_step_block
)
from callbacks import (
    CosineAnnealingScheduler, TempScheduler, RouterStatsMultiStepCallback, RouterStatsCallback, TopNScheduler, linear_topn_schedule, milestone_topn_schedule, print_router_stats
)
import random
import numpy as np

SEED = 42 
random.seed(SEED)
np.random.seed(SEED)
tf.random.set_seed(SEED)


# simple CIFAR-10 aug
def cifar_preprocess(x, y):
    x = tf.image.resize_with_crop_or_pad(x, 36, 36)
    x = tf.image.random_crop(x, [tf.shape(x)[0], 32, 32, 3])
    x = tf.image.random_flip_left_right(x)
    return x, y

def make_dataset(x, y, batch=128, train=True):
    ds = tf.data.Dataset.from_tensor_slices((x, y))
    if train:
        ds = ds.shuffle(5000).batch(batch).map(cifar_preprocess, num_parallel_calls=tf.data.AUTOTUNE)
    else:
        ds = ds.batch(batch)
    return ds.prefetch(tf.data.AUTOTUNE)

def build_conv_router_net(net_layers, num_classes, input_shape=(32,32,3)):
    inputs = keras.Input(shape=input_shape)
    x = inputs

    for layer in net_layers:
        x = layer(x)

    x = tf.keras.layers.GlobalAveragePooling2D()(x)
    outputs = tf.keras.layers.Dense(num_classes, activation='softmax', dtype='float32')(x)  # keep fp32 logits
    return keras.Model(inputs, outputs, name="adaptive_model")


# Usage
(x_train, y_train), (x_test, y_test) = keras.datasets.cifar10.load_data()
x_train = x_train.astype("float32")/255.0; x_test = x_test.astype("float32")/255.0
y_train = y_train.flatten(); y_test = y_test.flatten()


ds_train = make_dataset(x_train, y_train, batch=128, train=True)
ds_val   = make_dataset(x_test, y_test, batch=256, train=False)


2025-10-23 11:11:37.521698: I metal_plugin/src/device/metal_device.cc:1154] Metal device set to: Apple M1 Pro
2025-10-23 11:11:37.521732: I metal_plugin/src/device/metal_device.cc:296] systemMemory: 16.00 GB
2025-10-23 11:11:37.521741: I metal_plugin/src/device/metal_device.cc:313] maxCacheSize: 5.33 GB
2025-10-23 11:11:37.521930: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:303] Could not identify NUMA node of platform GPU ID 0, defaulting to 0. Your kernel may not have been built with NUMA support.
2025-10-23 11:11:37.521946: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:269] Created TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:0 with 0 MB memory) -> physical PluggableDevice (device: 0, name: METAL, pci bus id: <undefined>)


In [2]:
EPOCHS = 20 
ROUTER_TEMP = 2.0
DIVERSITY_TAU = 1e-2
STEPS=3
TOP_N = 3

layers = []
layers.append(ConvStem(32, 5))
layers.append(MaxPoolingLayer(64))
layers.append(ResidualBlock3x3(64))
layers.append(get_multi_step_init_block(64, "01_multi_block", TOP_N, ROUTER_TEMP, DIVERSITY_TAU, steps=STEPS, sr_ration=2))
layers.append(MaxPoolingLayer(64))
layers.append(get_multi_step_init_block(64, "02_multi_block", TOP_N, ROUTER_TEMP, DIVERSITY_TAU, steps=STEPS, sr_ration=1))

all_layers = []


router_model = build_conv_router_net(layers, 10)

In [3]:
callbacks = [
    CosineAnnealingScheduler(base_lr=1e-3, min_lr=1e-5, epochs=EPOCHS),
    TopNScheduler(schedule=linear_topn_schedule(n_start=TOP_N, n_end=1, start_epoch=0, end_epoch=int(EPOCHS * 0.6)),verbose=1),
    RouterStatsMultiStepCallback(x_test, y_test, layer_name="01_multi_block", verbose_every=5),
    RouterStatsMultiStepCallback(x_test, y_test, layer_name="02_multi_block", verbose_every=5),
    TempScheduler(layer_name="01_multi_block", epochs=EPOCHS, route=(ROUTER_TEMP, 0.7), eps=(0, 0.00), ent=(0, 0.0), lb=(0, 0.0), verbose=1, log=False),
    TempScheduler(layer_name="02_multi_block", epochs=EPOCHS, route=(ROUTER_TEMP, 0.7), eps=(0, 0.00), ent=(0, 0.0), lb=(0, 0.0), verbose=1, log=False)
]

router_model.build(input_shape=(None, 32, 32, 3))
router_model.compile(optimizer='adam',
                loss='sparse_categorical_crossentropy',
                metrics=['accuracy'])
#AdamW(lr=3e-3, weight_decay=0.05)

router_model.summary()

Model: "adaptive_model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_1 (InputLayer)        [(None, 32, 32, 3)]       0         
                                                                 
 conv_stem (ConvStem)        (None, 32, 32, 32)        2496      
                                                                 
 max_pooling_layer (MaxPool  (None, 16, 16, 64)        2240      
 ingLayer)                                                       
                                                                 
 residual_block3x3 (Residua  (None, 16, 16, 64)        82432     
 lBlock3x3)                                                      
                                                                 
 01_multi_block (AdaptiveRo  (None, 16, 16, 64)        170499    
 uterMultiStep)                                                  
                                                    

In [4]:
router_model.fit(ds_train, epochs=EPOCHS, validation_data=ds_val, callbacks=callbacks)

> [LR Scheduler] epoch 1: lr=0.001000
Epoch 1/20


2025-10-23 11:11:46.689246: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:114] Plugin optimizer for device_type GPU is enabled.




2025-10-23 11:16:26.003381: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:114] Plugin optimizer for device_type GPU is enabled.



> [01_multi_block] epoch 1
  avg_steps=2.00
  step  1: total=10000  E0: 0.0%  E1: 0.0%  E2: 9.5%  E3:56.4%  E4:34.0%  E5: 0.0%  E6: 0.0%  E7: 0.0%  E8: 0.0%
  step  2: total=10000  E0: 0.0%  E1:19.2%  E2:53.9%  E3:18.3%  E4: 8.6%  E5: 0.0%  E6: 0.0%  E7: 0.0%  E8: 0.0%
  step  3: total=10000  E0:14.1%  E1:12.3%  E2:27.4%  E3: 6.7%  E4:32.4%  E5: 0.0%  E6: 0.0%  E7: 6.6%  E8: 0.5%

> [02_multi_block] epoch 1
  avg_steps=2.00
  step  1: total=10000  E0: 0.0%  E1: 4.4%  E2: 0.0%  E3: 0.0%  E4: 0.0%  E5:10.8%  E6: 0.0%  E7: 0.0%  E8:84.7%
  step  2: total=10000  E0: 8.3%  E1: 2.2%  E2: 0.0%  E3: 0.0%  E4: 0.0%  E5: 0.9%  E6: 0.0%  E7: 0.0%  E8:88.5%
  step  3: total=10000  E0: 2.1%  E1: 8.0%  E2: 0.0%  E3: 0.0%  E4: 0.0%  E5: 0.2%  E6: 0.0%  E7: 0.0%  E8:89.7%
Epoch 2/20
Epoch 3/20
Epoch 4/20

KeyboardInterrupt: 

In [5]:
layer = router_model.get_layer("01_multi_block")  # or your layer name
pre = tf.keras.Model(router_model.input, layer.input)

xb = x_test[:512]
x = pre(xb, training=False)

# 1) Router logits distribution
logits, *rest = layer.router(x, training=False)
print("logits mean per expert:", tf.reduce_mean(logits, axis=0).numpy().round(3))
print("logits std  per expert:", tf.math.reduce_std(logits, axis=0).numpy().round(3))

# 2) Softmax with your temperature (no mask) – does last dominate already?
p = tf.nn.softmax(logits / layer.route_temp, axis=-1).numpy()
print("mean prob per expert:", p.mean(axis=0).round(3))

logits mean per expert: [ 0.048 -0.011 -0.134  0.268  0.08   0.274  0.073 -0.143  0.195]
logits std  per expert: [0.069 0.081 0.152 0.108 0.092 0.103 0.058 0.091 0.073]
mean prob per expert: [0.109 0.104 0.096 0.127 0.111 0.128 0.11  0.095 0.121]
