In [None]:
import tensorflow as tf
from tensorflow.keras import layers, models
from tensorflow.keras.datasets import cifar10
from tensorflow.keras.utils import to_categorical
import numpy as np

In [2]:
(x_train, y_train), (x_test, y_test) = cifar10.load_data()

x_train = x_train.astype("float32") / 255.0
x_test = x_test.astype("float32") / 255.0

y_train = to_categorical(y_train, 10)
y_test = to_categorical(y_test, 10)

In [3]:
class DynConv2D(tf.keras.layers.Layer):
    def __init__(self, filters, kernel_size, strides=1, padding='same', groups=4):
        super(DynConv2D, self).__init__()
        self.filters = filters
        self.kernel_size = kernel_size
        self.strides = strides
        self.padding = padding
        self.groups = groups

        self.convs = [layers.Conv2D(filters, kernel_size, strides=strides,
                                    padding=padding, use_bias=False)
                      for _ in range(groups)]

        self.attention = tf.keras.Sequential([
            layers.GlobalAveragePooling2D(),
            layers.Dense(groups, activation='softmax')
        ])

    def call(self, x):
        attn_weights = self.attention(x)  # shape (batch_size, groups)
        outputs = []

        for i, conv in enumerate(self.convs):
            out = conv(x)  # shape (batch_size, H, W, filters)
            scale = tf.reshape(attn_weights[:, i], [-1, 1, 1, 1])  # reshape to broadcast
            out = out * scale
            outputs.append(out)

        return tf.reduce_sum(tf.stack(outputs, axis=0), axis=0)


In [4]:
def resnet_block_dynconv(x, filters, downsample=False):
    stride = 2 if downsample else 1
    y = DynConv2D(filters, 3, strides=stride)(x)
    y = layers.BatchNormalization()(y)
    y = layers.ReLU()(y)
    y = DynConv2D(filters, 3)(y)
    y = layers.BatchNormalization()(y)

    if downsample or x.shape[-1] != filters:
        x = layers.Conv2D(filters, 1, strides=stride, padding='same', use_bias=False)(x)
        x = layers.BatchNormalization()(x)
    return layers.ReLU()(x + y)

In [5]:
def build_dynconv_resnet_cifar10():
    inputs = layers.Input(shape=(32, 32, 3))
    x = layers.Conv2D(16, 3, padding='same', use_bias=False)(inputs)
    x = layers.BatchNormalization()(x)
    x = layers.ReLU()(x)

    for _ in range(3):
        x = resnet_block_dynconv(x, 16)
    for i in range(3):
        x = resnet_block_dynconv(x, 32, downsample=(i == 0))
    for i in range(3):
        x = resnet_block_dynconv(x, 64, downsample=(i == 0))

    x = layers.GlobalAveragePooling2D()(x)
    outputs = layers.Dense(10, activation='softmax')(x)

    return models.Model(inputs, outputs)

In [6]:
dynconv_model = build_dynconv_resnet_cifar10()

print("\nDynConv-ResNet Model Summary:")
dynconv_model.summary()

dynconv_model.compile(optimizer='adam',
                      loss='categorical_crossentropy',
                      metrics=['accuracy'])

I0000 00:00:1745259081.219588 1888999 gpu_device.cc:2022] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 22280 MB memory:  -> device: 0, name: NVIDIA A30, pci bus id: 0000:4a:00.0, compute capability: 8.0



DynConv-ResNet Model Summary:


In [7]:
print("\nTraining DynConv-ResNet...\n")
dynconv_model.fit(x_train, y_train,
                  epochs=20,
                  batch_size=64,
                  validation_data=(x_test, y_test))


Training DynConv-ResNet...

Epoch 1/20


I0000 00:00:1745259094.631250 1889158 service.cc:148] XLA service 0x155284003d90 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1745259094.631271 1889158 service.cc:156]   StreamExecutor device (0): NVIDIA A30, Compute Capability 8.0
2025-04-21 14:11:35.009675: I tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc:268] disabling MLIR crash reproducer, set env var `MLIR_CRASH_REPRODUCER_DIRECTORY` to enable.
I0000 00:00:1745259096.752552 1889158 cuda_dnn.cc:529] Loaded cuDNN version 90300
2025-04-21 14:11:37.272598: W external/local_xla/xla/service/gpu/nvptx_compiler.cc:930] The NVIDIA driver's CUDA version is 12.4 which is older than the PTX compiler version 12.5.82. Because the driver is older than the PTX compiler version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.


[1m  6/782[0m [37m━━━━━━━━━━━━━━━━━━━━[0m [1m18s[0m 24ms/step - accuracy: 0.0834 - loss: 2.6963 

I0000 00:00:1745259104.754122 1889158 device_compiler.h:188] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.


[1m782/782[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 26ms/step - accuracy: 0.3270 - loss: 1.8322




[1m782/782[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m45s[0m 31ms/step - accuracy: 0.3271 - loss: 1.8319 - val_accuracy: 0.4774 - val_loss: 1.4159
Epoch 2/20
[1m782/782[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m14s[0m 17ms/step - accuracy: 0.5837 - loss: 1.1662 - val_accuracy: 0.5749 - val_loss: 1.1991
Epoch 3/20
[1m782/782[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m14s[0m 17ms/step - accuracy: 0.6878 - loss: 0.8849 - val_accuracy: 0.5789 - val_loss: 1.3102
Epoch 4/20
[1m782/782[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m14s[0m 18ms/step - accuracy: 0.7552 - loss: 0.6950 - val_accuracy: 0.7377 - val_loss: 0.7544
Epoch 5/20
[1m782/782[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m14s[0m 18ms/step - accuracy: 0.7898 - loss: 0.5980 - val_accuracy: 0.7031 - val_loss: 0.8921
Epoch 6/20
[1m782/782[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m14s[0m 18ms/step - accuracy: 0.8206 - loss: 0.5217 - val_accuracy: 0.7529 - val_loss: 0.7220
Epoch 7/20
[1m782/782[0m 

<keras.src.callbacks.history.History at 0x15541bf913d0>

In [11]:
test_loss, test_acc = dynconv_model.evaluate(x_test, y_test, verbose=0)
print(f"\n Test Accuracy on CIFAR-10: {test_acc:.4f}")


 Test Accuracy on CIFAR-10: 0.8034
