In [None]:
import tensorflow as tf
from tensorflow import keras
from callbacks import (
    CosineAnnealingScheduler, TempScheduler, RouterStatsCallback, TopNScheduler, linear_topn_schedule, milestone_topn_schedule, print_router_stats
)
from InitConvRouterNet import (
    get_stem, get_pooling, get_block, get_init_stem, get_init_block, get_init_pooling, get_block
)
from tensorflow.keras.layers import (
    Multiply, Conv2D, Dropout, LayerNormalization, DepthwiseConv2D, Activation
)
from utils import AdaptiveRouter, ResBlock
import random
import numpy as np
import os


SEED = 42 
random.seed(SEED)
tf.random.set_seed(SEED)
os.environ["TF_DETERMINISTIC_OPS"] = "1"
os.environ["TF_CUDNN_DETERMINISTIC"] = "1"
np.random.seed(SEED)


# simple CIFAR-10 aug
def cifar_preprocess(x, y):
    x = tf.image.resize_with_crop_or_pad(x, 36, 36)
    x = tf.image.random_crop(x, [tf.shape(x)[0], 32, 32, 3])
    x = tf.image.random_flip_left_right(x)
    return x, y

def make_dataset(x, y, batch=128, train=True):
    ds = tf.data.Dataset.from_tensor_slices((x, y))
    if train:
        ds = ds.shuffle(5000).batch(batch).map(cifar_preprocess, num_parallel_calls=tf.data.AUTOTUNE)
    else:
        ds = ds.batch(batch)
    return ds.prefetch(tf.data.AUTOTUNE)

def build_conv_router_net(net_layers, num_classes, input_shape=(32,32,3)):
    inputs = keras.Input(shape=input_shape)
    x = inputs

    for layer in net_layers:
        x = layer(x)

    x = tf.keras.layers.GlobalAveragePooling2D()(x)
    outputs = tf.keras.layers.Dense(num_classes, activation='softmax', dtype='float32')(x)  # keep fp32 logits
    return keras.Model(inputs, outputs, name="adaptive_model")


# Usage
(x_train, y_train), (x_test, y_test) = keras.datasets.cifar10.load_data()
x_train = x_train.astype("float32")/255.0; x_test = x_test.astype("float32")/255.0
y_train = y_train.flatten(); y_test = y_test.flatten()


ds_train = make_dataset(x_train, y_train, batch=128, train=True)
ds_val   = make_dataset(x_test, y_test, batch=256, train=False)


2025-10-22 12:13:34.495692: I metal_plugin/src/device/metal_device.cc:1154] Metal device set to: Apple M1 Pro
2025-10-22 12:13:34.495722: I metal_plugin/src/device/metal_device.cc:296] systemMemory: 16.00 GB
2025-10-22 12:13:34.495734: I metal_plugin/src/device/metal_device.cc:313] maxCacheSize: 5.33 GB
2025-10-22 12:13:34.495984: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:303] Could not identify NUMA node of platform GPU ID 0, defaulting to 0. Your kernel may not have been built with NUMA support.
2025-10-22 12:13:34.496007: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:269] Created TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:0 with 0 MB memory) -> physical PluggableDevice (device: 0, name: METAL, pci bus id: <undefined>)


In [None]:
def build_block(in_filters, out_filters):
    branches = [
        Conv2D(out_filters, kernel_size=(1,1), stride=(1,1))
    ]
    get_block()


def build_conv_router_net(net_layers, num_classes, input_shape=(32,32,3)):
    inputs = keras.Input(shape=input_shape)
    x = inputs

    

    x = tf.keras.layers.GlobalAveragePooling2D()(x)
    outputs = tf.keras.layers.Dense(num_classes, activation='softmax', dtype='float32')(x)  # keep fp32 logits
    return keras.Model(inputs, outputs, name="adaptive_model")



