In [1]:
# ConvNeXt Small (CIFAR-10) — full model + optimizer
import os, math, tensorflow as tf, tensorflow_addons as tfa
from tensorflow.keras import layers as L, models as M


# ----------------------------
# Optional determinism / seeds
# ----------------------------
SEED = 42
tf.random.set_seed(SEED)
os.environ["TF_DETERMINISTIC_OPS"] = "1"
os.environ["TF_CUDNN_DETERMINISTIC"] = "1"

# Data pipeline (CIFAR-10) — replace with your own input pipeline if needed
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
num_classes = 10
y_train = tf.squeeze(tf.one_hot(y_train, num_classes), axis=1)
y_test  = tf.squeeze(tf.one_hot(y_test,  num_classes), axis=1)

def preprocess(images, labels):
    images = tf.cast(images, tf.float32) / 255.0
    return images, labels

batch_size = 128

ds_test = (tf.data.Dataset.from_tensor_slices((x_test, y_test))
            .map(preprocess, num_parallel_calls=tf.data.AUTOTUNE)
            .batch(batch_size)
            .prefetch(tf.data.AUTOTUNE))

def augment(image, label):
    # pad by 4 and random crop back to 32x32
    image = tf.image.resize_with_crop_or_pad(image, 36, 36)
    image = tf.image.random_crop(image, size=(32, 32, 3), seed=SEED)
    image = tf.image.random_flip_left_right(image, seed=SEED)
    return image, label

ds_train = (tf.data.Dataset.from_tensor_slices((x_train, y_train))
            .shuffle(50000, seed=SEED)
            .map(preprocess, num_parallel_calls=tf.data.AUTOTUNE)
            .map(augment, num_parallel_calls=tf.data.AUTOTUNE)  # per-example
            .batch(batch_size)
            .prefetch(tf.data.AUTOTUNE))



steps_per_epoch = math.ceil(len(x_train) / batch_size)


TensorFlow Addons (TFA) has ended development and introduction of new features.
TFA has entered a minimal maintenance and release mode until a planned end of life in May 2024.
Please modify downstream libraries to take dependencies from other repositories in our TensorFlow community (e.g. Keras, Keras-CV, and Keras-NLP). 

For more information see: https://github.com/tensorflow/addons/issues/2807 

2025-10-22 09:02:42.374587: I metal_plugin/src/device/metal_device.cc:1154] Metal device set to: Apple M1 Pro
2025-10-22 09:02:42.374622: I metal_plugin/src/device/metal_device.cc:296] systemMemory: 16.00 GB
2025-10-22 09:02:42.374629: I metal_plugin/src/device/metal_device.cc:313] maxCacheSize: 5.33 GB
2025-10-22 09:02:42.374892: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:303] Could not identify NUMA node of platform GPU ID 0, defaulting to 0. Your kernel may not have been built with NUMA support.
2025-10-22 09:02:42.374909: I tensorflow/core/common_runti

In [2]:
# ----------------------------
# Utilities
# ----------------------------
class DropPath(L.Layer):
    """Stochastic Depth per sample."""
    def __init__(self, drop_prob=0.0, **kwargs):
        super().__init__(**kwargs)
        self.drop_prob = float(drop_prob)

    def call(self, x, training=None):
        if (not training) or self.drop_prob == 0.0:
            return x
        keep = 1.0 - self.drop_prob
        # shape: [B, 1, 1, 1] to broadcast over HWC
        shape = (tf.shape(x)[0],) + (1, 1, 1)
        rnd = tf.random.uniform(shape, dtype=x.dtype)
        mask = tf.cast(rnd < keep, x.dtype)
        return x * mask / keep

    def get_config(self):
        cfg = super().get_config()
        cfg.update({"drop_prob": self.drop_prob})
        return cfg


class ConvNeXtBlock(L.Layer):
    """ConvNeXt block (channels-last). DWConv-7x7 → LN → 1x1 MLP (4x expand) → residual (+ DropPath)."""
    def __init__(self, dim, drop_path=0.0, mlp_ratio=4, se_ratio=0.0, **kwargs):
        super().__init__(**kwargs)
        self.dim = dim
        self.drop_path_rate = drop_path
        self.mlp_ratio = mlp_ratio
        self.se_ratio = se_ratio

        self.dw = L.DepthwiseConv2D(
            kernel_size=7, padding="same", use_bias=True, name="dw7"
        )
        self.ln = L.LayerNormalization(epsilon=1e-6, name="ln")  # channels-last
        self.pw1 = L.Dense(int(mlp_ratio * dim), name="pw1")    # 1x1 via Dense on last dim
        self.act = L.Activation("gelu")
        self.pw2 = L.Dense(dim, name="pw2")
        self.drop_path = DropPath(drop_prob=drop_path)

        if se_ratio and se_ratio > 0.0:
            mid = max(1, int(dim * se_ratio))
            self.se = M.Sequential([
                L.GlobalAveragePooling2D(keepdims=True),
                L.Conv2D(mid, 1, activation="gelu", use_bias=True),
                L.Conv2D(dim, 1, activation="sigmoid", use_bias=True)
            ], name="se")
        else:
            self.se = None

    def call(self, x, training=None):
        shortcut = x
        y = self.dw(x)
        # NHWC → keep NHWC, LayerNorm with axis=-1 works channels-last
        y = self.ln(y)
        y = self.pw2(self.act(self.pw1(y)))
        if self.se is not None:
            y = y * self.se(y)  # lightweight SE (optional)
        y = self.drop_path(y, training=training)
        return shortcut + y

    def get_config(self):
        cfg = super().get_config()
        cfg.update(dict(dim=self.dim, drop_path=self.drop_path_rate,
                        mlp_ratio=self.mlp_ratio, se_ratio=self.se_ratio))
        return cfg


def Downsample(in_ch, out_ch, name):
    # Conv 2x2 s=2 to shrink spatial size
    return M.Sequential([
        L.Conv2D(out_ch, kernel_size=2, strides=2, padding="valid", use_bias=True)
    ], name=name)


# ----------------------------
# Model builder
# ----------------------------
def build_convnext_small_cifar10(
    num_classes=10,
    depths=(3, 6, 6),
    dims=(96, 192, 384),
    drop_path_rate=0.1,
    se_ratio=0.25  # set 0.0 to disable SE
):
    """
    ConvNeXt Small tailored for CIFAR-10.
    depths: blocks per stage
    dims:   channels per stage
    drop_path_rate: max stochastic depth across blocks (linearly scaled)
    """
    assert len(depths) == 3 and len(dims) == 3

    inputs = L.Input(shape=(32, 32, 3))

    # Stem: 4x4 conv stride 4 → 8x8 tokens, C = dims[0]
    x = L.Conv2D(dims[0], kernel_size=4, strides=4, padding="valid", use_bias=True, name="stem")(inputs)
    x = L.LayerNormalization(epsilon=1e-6, name="stem_ln")(x)

    # Compute per-block drop_path schedule (linear from 0 → drop_path_rate)
    total_blocks = sum(depths)
    dpr = [i * drop_path_rate / max(1, total_blocks - 1) for i in range(total_blocks)]

    idx = 0
    # Stage 1 (8x8)
    for b in range(depths[0]):
        x = ConvNeXtBlock(dims[0], drop_path=dpr[idx], se_ratio=se_ratio, name=f"stage1_block{b}")(x)
        idx += 1

    # Downsample to 4x4
    x = Downsample(dims[0], dims[1], name="down1")(x)
    x = L.LayerNormalization(epsilon=1e-6, name="down1_ln")(x)

    # Stage 2 (4x4)
    for b in range(depths[1]):
        x = ConvNeXtBlock(dims[1], drop_path=dpr[idx], se_ratio=se_ratio, name=f"stage2_block{b}")(x)
        idx += 1

    # Downsample to 2x2
    x = Downsample(dims[1], dims[2], name="down2")(x)
    x = L.LayerNormalization(epsilon=1e-6, name="down2_ln")(x)

    # Stage 3 (2x2)
    for b in range(depths[2]):
        x = ConvNeXtBlock(dims[2], drop_path=dpr[idx], se_ratio=se_ratio, name=f"stage3_block{b}")(x)
        idx += 1

    # Head
    x = L.LayerNormalization(epsilon=1e-6, name="head_ln")(x)
    x = L.GlobalAveragePooling2D(name="gap")(x)
    outputs = L.Dense(num_classes, name="logits")(x)

    return M.Model(inputs, outputs, name="ConvNeXtSmall_CIFAR10")



import math, tensorflow as tf

# --- Serializable warmup+cosine schedule ---
class WarmupCosine(tf.keras.optimizers.schedules.LearningRateSchedule):
    def __init__(self, base_lr, warmup_steps, total_steps, name=None):
        super().__init__()
        self._base_lr = float(base_lr)
        self._warmup_steps = int(warmup_steps)
        self._total_steps = int(total_steps)
        self._name = name or "WarmupCosine"
    def __call__(self, step):
        import math as _m
        step = tf.cast(step, tf.float32)
        base = tf.cast(self._base_lr, tf.float32)
        w = tf.cast(self._warmup_steps, tf.float32)
        T = tf.cast(self._total_steps, tf.float32)
        warm = (step / tf.maximum(1.0, w)) * base
        prog = (step - w) / tf.maximum(1.0, T - w)
        prog = tf.clip_by_value(prog, 0.0, 1.0)
        cos = 0.5 * (1.0 + tf.cos(tf.constant(_m.pi, tf.float32) * prog))
        decayed = base * cos
        return tf.where(step < w, warm, decayed, name=self._name)
    def get_config(self):
        return {"base_lr": self._base_lr, "warmup_steps": self._warmup_steps,
                "total_steps": self._total_steps, "name": self._name}
    @classmethod
    def from_config(cls, cfg): return cls(**cfg)

def make_optimizer(steps_per_epoch, epochs, base_lr=3e-3, weight_decay=0.05, warmup_epochs=10):
    total_steps = int(steps_per_epoch * epochs)
    warmup_steps = int(steps_per_epoch * warmup_epochs)
    lr_sched = WarmupCosine(base_lr, warmup_steps, total_steps)
    # Prefer native; on M1/M2 this will auto-fallback to legacy.AdamW (fast) if needed
    try:
        opt = tfa.optimizers.AdamW(
            learning_rate=lr_sched, weight_decay=weight_decay,
            beta_1=0.9, beta_2=0.999, epsilon=1e-8
        )
    except Exception:
        opt = tf.keras.optimizers.legacy.AdamW(
            learning_rate=lr_sched, weight_decay=weight_decay,
            beta_1=0.9, beta_2=0.999, epsilon=1e-8
        )
    return opt



In [3]:
import math, tensorflow as tf

# --- LR schedule (serializable) ---
class WarmupCosine(tf.keras.optimizers.schedules.LearningRateSchedule):
    def __init__(self, base_lr, warmup_steps, total_steps, name=None):
        super().__init__()
        self._base_lr = float(base_lr)
        self._warmup_steps = int(warmup_steps)
        self._total_steps = int(total_steps)
        self._name = name or "WarmupCosine"
    def __call__(self, step):
        import math as _m
        step = tf.cast(step, tf.float32)
        base = tf.cast(self._base_lr, tf.float32)
        w = tf.cast(self._warmup_steps, tf.float32)
        T = tf.cast(self._total_steps, tf.float32)
        warm = (step / tf.maximum(1.0, w)) * base
        prog = (step - w) / tf.maximum(1.0, T - w)
        prog = tf.clip_by_value(prog, 0.0, 1.0)
        cos = 0.5 * (1.0 + tf.cos(tf.constant(_m.pi, tf.float32) * prog))
        decayed = base * cos
        return tf.where(step < w, warm, decayed, name=self._name)
    def get_config(self):
        return {"base_lr": self._base_lr, "warmup_steps": self._warmup_steps,
                "total_steps": self._total_steps, "name": self._name}
    @classmethod
    def from_config(cls, cfg): return cls(**cfg)



In [4]:

# Set steps_per_epoch explicitly (avoid Dataset materialization)
batch_size = 128
steps_per_epoch = math.ceil(50000 / batch_size)  # CIFAR-10 train size
epochs = 150
base_lr = 3e-3
weight_decay=1e-3

total_steps  = int(steps_per_epoch * epochs)
warmup_steps = int(steps_per_epoch * 5)
lr_sched = WarmupCosine(base_lr, warmup_steps, total_steps)
# Faster on M1/M2 when TF falls back
opt = tfa.optimizers.AdamW(
    learning_rate=lr_sched, weight_decay=weight_decay,
    beta_1=0.9, beta_2=0.999, epsilon=1e-8
)

model = build_convnext_small_cifar10(
    num_classes=10,
    depths=(3, 6, 6),      # total 15 blocks
    dims=(96, 192, 384),   # widths
    drop_path_rate=0.1,    # linear 0..0.1 across blocks
    se_ratio=0.25          # set 0.0 to disable SE
)

# Recompile with the OBJECT, not a string
model.compile(
    optimizer=opt,
    loss=tf.keras.losses.CategoricalCrossentropy(from_logits=True, label_smoothing=0.1),
    metrics=[tf.keras.metrics.CategoricalAccuracy(name="acc")]
)

model.summary()

Model: "ConvNeXtSmall_CIFAR10"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_1 (InputLayer)        [(None, 32, 32, 3)]       0         
                                                                 
 stem (Conv2D)               (None, 8, 8, 96)          4704      
                                                                 
 stem_ln (LayerNormalizatio  (None, 8, 8, 96)          192       
 n)                                                              
                                                                 
 stage1_block0 (ConvNeXtBlo  (None, 8, 8, 96)          83928     
 ck)                                                             
                                                                 
 stage1_block1 (ConvNeXtBlo  (None, 8, 8, 96)          83928     
 ck)                                                             
                                             

In [None]:
history = model.fit(ds_train, validation_data=ds_test, epochs=epochs)

Epoch 1/150


2025-10-22 09:02:47.861843: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:114] Plugin optimizer for device_type GPU is enabled.




2025-10-22 09:04:41.344615: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:114] Plugin optimizer for device_type GPU is enabled.


Epoch 2/150
Epoch 3/150
Epoch 4/150
Epoch 5/150
Epoch 6/150
Epoch 7/150
Epoch 8/150
Epoch 9/150
Epoch 10/150
Epoch 11/150
Epoch 12/150
Epoch 13/150
Epoch 14/150
Epoch 15/150
Epoch 16/150
Epoch 17/150
Epoch 18/150
Epoch 19/150
Epoch 20/150
Epoch 21/150
Epoch 22/150
Epoch 23/150
Epoch 24/150
Epoch 25/150
Epoch 26/150
  1/391 [..............................] - ETA: 2:44 - loss: 2.3024 - acc: 0.0781