In [3]:
import tensorflow as tf
from tensorflow.keras import layers, models
from tensorflow.keras.datasets import cifar100
from tensorflow.keras.utils import to_categorical
import numpy as np

2025-04-21 14:44:05.600489: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-04-21 14:44:05.602362: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.
2025-04-21 14:44:05.604524: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.
2025-04-21 14:44:05.610122: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1745261045.619538 3626090 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1745261045.62

In [4]:
(x_train, y_train), (x_test, y_test) = cifar100.load_data()

x_train = x_train.astype("float32") / 255.0
x_test = x_test.astype("float32") / 255.0

y_train = to_categorical(y_train, 100)
y_test = to_categorical(y_test, 100)

In [5]:
class DynConv2D(tf.keras.layers.Layer):
    def __init__(self, filters, kernel_size, strides=1, padding='same', groups=4):
        super(DynConv2D, self).__init__()
        self.filters = filters
        self.kernel_size = kernel_size
        self.strides = strides
        self.padding = padding
        self.groups = groups

        self.convs = [layers.Conv2D(filters, kernel_size, strides=strides,
                                    padding=padding, use_bias=False)
                      for _ in range(groups)]
    
        self.attention = tf.keras.Sequential([
            layers.GlobalAveragePooling2D(),
            layers.Dense(groups, activation='softmax')
        ])

    def call(self, x):
        attn_weights = self.attention(x)  # shape (batch_size, groups)
        outputs = []

        for i, conv in enumerate(self.convs):
            out = conv(x)  # shape (batch_size, H, W, filters)
            scale = tf.reshape(attn_weights[:, i], [-1, 1, 1, 1])  # reshape to broadcast
            out = out * scale
            outputs.append(out)

        return tf.reduce_sum(tf.stack(outputs, axis=0), axis=0)


In [6]:
def resnet_block_dynconv(x, filters, downsample=False):
    stride = 2 if downsample else 1
    y = DynConv2D(filters, 3, strides=stride)(x)
    y = layers.BatchNormalization()(y)
    y = layers.ReLU()(y)
    y = DynConv2D(filters, 3)(y)
    y = layers.BatchNormalization()(y)

    if downsample or x.shape[-1] != filters:
        x = layers.Conv2D(filters, 1, strides=stride, padding='same', use_bias=False)(x)
        x = layers.BatchNormalization()(x)
    return layers.ReLU()(x + y)

In [10]:
def build_dynconv_resnet_cifar100():
    inputs = layers.Input(shape=(32, 32, 3))
    x = layers.Conv2D(16, 3, padding='same', use_bias=False)(inputs)
    x = layers.BatchNormalization()(x)
    x = layers.ReLU()(x)

    for _ in range(3):
        x = resnet_block_dynconv(x, 16)
    for i in range(3):
        x = resnet_block_dynconv(x, 32, downsample=(i == 0))
    for i in range(3):
        x = resnet_block_dynconv(x, 64, downsample=(i == 0))

    x = layers.GlobalAveragePooling2D()(x)
    outputs = layers.Dense(100, activation='softmax')(x)

    return models.Model(inputs, outputs)

In [11]:
dynconv_model = build_dynconv_resnet_cifar100()

print("\nDynConv-ResNet Model Summary:")
dynconv_model.summary()

dynconv_model.compile(optimizer='adam',
                      loss='categorical_crossentropy',
                      metrics=['accuracy'])


DynConv-ResNet Model Summary:


In [13]:
print("\nTraining DynConv-ResNet...\n")
dynconv_model.fit(x_train, y_train,
                  epochs=20,
                  batch_size=64,
                  validation_data=(x_test, y_test))


Training DynConv-ResNet...

Epoch 1/20
[1m782/782[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m83s[0m 106ms/step - accuracy: 0.7772 - loss: 0.7323 - val_accuracy: 0.4641 - val_loss: 2.3512
Epoch 2/20
[1m782/782[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m81s[0m 104ms/step - accuracy: 0.7864 - loss: 0.6886 - val_accuracy: 0.4460 - val_loss: 2.6062
Epoch 3/20
[1m782/782[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m82s[0m 105ms/step - accuracy: 0.8005 - loss: 0.6407 - val_accuracy: 0.4532 - val_loss: 2.5412
Epoch 4/20
[1m782/782[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m82s[0m 104ms/step - accuracy: 0.8219 - loss: 0.5727 - val_accuracy: 0.4430 - val_loss: 2.7061
Epoch 5/20
[1m782/782[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m82s[0m 105ms/step - accuracy: 0.8333 - loss: 0.5314 - val_accuracy: 0.4686 - val_loss: 2.7496
Epoch 6/20
[1m782/782[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m82s[0m 105ms/step - accuracy: 0.8454 - loss: 0.4869 - val_accuracy: 0.4631 

<keras.src.callbacks.history.History at 0x154fcc4b1610>