In [35]:
import tensorflow as tf
from keras import layers, models, optimizers, datasets
import numpy as np

In [2]:
from tensorflow.keras.applications import VGG16
from tensorflow.keras.utils import to_categorical

(train_images, train_labels), (test_images, test_labels) = datasets.cifar10.load_data()

Downloading data from https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz
[1m170498071/170498071[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 0us/step


In [3]:
train_images = train_images.astype('float32') / 255.0
test_images = test_images.astype('float32') / 255.0

In [4]:
train_labels = to_categorical(train_labels, 10)
test_labels = to_categorical(test_labels, 10)

In [5]:
base_model = VGG16(weights=None, include_top=False, input_shape=(32, 32, 3))

In [6]:
base_model.summary()

In [7]:
# Load the VGG16 model without the top layers
def create_vgg16():
    base_model = VGG16(weights=None, include_top=False, input_shape=(32, 32, 3))
    base_model = models.Model(inputs=base_model.input, outputs=base_model.layers[-5].output)
    
    model = models.Sequential()
    model.add(base_model)
    
    model.add(layers.Conv2D(256, (3, 3), activation='relu', padding='same'))
    model.add(layers.BatchNormalization())
    model.add(layers.Dropout(0.3))
    
    model.add(layers.Conv2D(128, (3, 3), activation='relu', padding='same'))
    model.add(layers.BatchNormalization())
    model.add(layers.Dropout(0.3))
    
    model.add(layers.Flatten())
    model.add(layers.Dense(512, activation='relu'))
    model.add(layers.Dropout(0.2))
    model.add(layers.Dense(256, activation='relu'))
    model.add(layers.Dropout(0.2))
    model.add(layers.Dense(10, activation='softmax'))
    return model

In [8]:
# Train multiple sub-models with random initialization
# Number of sub-models
num_models = 4  #@param
sub_models = []

In [9]:
epochs = 20
batch_size = 256

histories = []

for i in range(num_models):
    model = create_vgg16()
    model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
    history = model.fit(train_images, train_labels, epochs=epochs, batch_size=batch_size, validation_split=0.2)
    sub_models.append(model)
    histories.append(history)
    print(f"model {i+1} training completed.")


Epoch 1/20
[1m  3/157[0m [37m━━━━━━━━━━━━━━━━━━━━[0m [1m6s[0m 44ms/step - accuracy: 0.1068 - loss: 2.7242    

I0000 00:00:1720173308.700737      82 device_compiler.h:186] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.


[1m157/157[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m44s[0m 134ms/step - accuracy: 0.1761 - loss: 2.1673 - val_accuracy: 0.1481 - val_loss: 2.9113
Epoch 2/20
[1m157/157[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m7s[0m 43ms/step - accuracy: 0.2498 - loss: 1.8751 - val_accuracy: 0.1342 - val_loss: 6.9849
Epoch 3/20
[1m157/157[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m7s[0m 43ms/step - accuracy: 0.2925 - loss: 1.7912 - val_accuracy: 0.2288 - val_loss: 2.5647
Epoch 4/20
[1m157/157[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m10s[0m 43ms/step - accuracy: 0.3684 - loss: 1.6095 - val_accuracy: 0.3168 - val_loss: 2.1455
Epoch 5/20
[1m157/157[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m7s[0m 43ms/step - accuracy: 0.4241 - loss: 1.4903 - val_accuracy: 0.4496 - val_loss: 1.5232
Epoch 6/20
[1m157/157[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m7s[0m 43ms/step - accuracy: 0.5078 - loss: 1.3054 - val_accuracy: 0.5218 - val_loss: 1.3177
Epoch 7/20
[1m157/157[0m [3

In [10]:
for i,model in enumerate(sub_models):
    model.save(f"vgg16_cifar10_sub_model_{i}.h5")

In [43]:
sub_models = []

for i in range(num_models):
    model_path = f"vgg16_cifar10_sub_model_{i}.h5"
    model = models.load_model(model_path)
    sub_models.append(model)
    print(f"Loaded model {model_path}")

print("All models loaded successfully.")

Loaded model vgg16_cifar10_sub_model_0.h5
Loaded model vgg16_cifar10_sub_model_1.h5
Loaded model vgg16_cifar10_sub_model_2.h5
Loaded model vgg16_cifar10_sub_model_3.h5
All models loaded successfully.


In [44]:
for i, model in enumerate(sub_models):
    print(f"vgg16_cifar10_sub_model_{i}.h5")
    model.evaluate(test_images,test_labels) 
    print()

vgg16_cifar10_sub_model_0.h5
[1m313/313[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 6ms/step - accuracy: 0.7327 - loss: 1.2754

vgg16_cifar10_sub_model_1.h5
[1m313/313[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 6ms/step - accuracy: 0.6914 - loss: 1.1798

vgg16_cifar10_sub_model_2.h5
[1m313/313[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 6ms/step - accuracy: 0.7165 - loss: 1.1699

vgg16_cifar10_sub_model_3.h5
[1m313/313[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 6ms/step - accuracy: 0.7371 - loss: 1.2406



## Block Switching model

In [45]:
class RandomSwitch(layers.Layer):
    def __init__(self, num_choices, **kwargs):
        super(RandomSwitch, self).__init__(**kwargs)
        self.num_choices = num_choices

    def call(self, inputs, training=None):
        inputs = tf.stack(inputs, axis=0)
        if training:
            choice = tf.random.uniform(shape=[], minval=0, maxval=self.num_choices, dtype=tf.int32)
            return inputs[choice]
        else:
            return tf.reduce_mean(inputs, axis=0)

def create_block_switching_model(sub_models, input_shape):
    input_layer = layers.Input(shape=input_shape, name='input_layer')
    
    outputs = [model(input_layer) for model in sub_models]
    
    switched_output = RandomSwitch(len(sub_models))(outputs)
    
    x = layers.Flatten()(switched_output)
    x = layers.Dense(512, activation='relu')(x)
    x = layers.Dropout(0.5)(x)
    x = layers.Dense(256, activation='relu')(x)
    x = layers.Dropout(0.5)(x)
    output_layer = layers.Dense(10, activation='softmax')(x)
    
    return models.Model(inputs=input_layer, outputs=output_layer)

# Assuming sub_models is your list of pre-prepared sub-models
input_shape = (32, 32, 3)  # Adjust this to match your data
block_switching_model = create_block_switching_model(sub_models, input_shape)

In [46]:
block_switching_model.summary()

In [47]:
# Compile the model
optimizer = tf.keras.optimizers.Adam()
loss_fn = tf.keras.losses.CategoricalCrossentropy()
train_acc_metric = tf.keras.metrics.CategoricalAccuracy()
val_acc_metric = tf.keras.metrics.CategoricalAccuracy()

@tf.function
def train_step(x, y):
    with tf.GradientTape() as tape:
        logits = block_switching_model(x, training=True)
        loss_value = loss_fn(y, logits)
    grads = tape.gradient(loss_value, block_switching_model.trainable_weights)
    optimizer.apply_gradients(zip(grads, block_switching_model.trainable_weights))
    train_acc_metric.update_state(y, logits)
    return loss_value

@tf.function
def test_step(x, y):
    val_logits = block_switching_model(x, training=False)
    val_acc_metric.update_state(y, val_logits)

In [48]:
from tqdm import tqdm

# Training loop
epochs = 10
batch_size = 256

for epoch in range(epochs):
    print(f"\nStart of epoch {epoch+1}")
    
    # Training loop with tqdm progress bar
    train_loss = []
    train_correct = 0
    total_samples = 0
    with tqdm(total=len(train_images)) as pbar:
        for step in range(0, len(train_images), batch_size):
            x_batch = train_images[step:step+batch_size]
            y_batch = train_labels[step:step+batch_size]
            
            with tf.GradientTape() as tape:
                logits = block_switching_model(x_batch, training=True)
                loss_value = loss_fn(y_batch, logits)
            grads = tape.gradient(loss_value, block_switching_model.trainable_weights)
            optimizer.apply_gradients(zip(grads, block_switching_model.trainable_weights))
            
            train_loss.append(loss_value)
            train_correct += tf.reduce_sum(tf.cast(tf.argmax(logits, axis=1) == tf.argmax(y_batch, axis=1), tf.float32))
            total_samples += x_batch.shape[0]
            
            pbar.update(x_batch.shape[0])
            pbar.set_description(f"Training loss: {float(loss_value):.4f}")
    
    avg_train_loss = tf.reduce_mean(train_loss)
    avg_train_acc = train_correct / total_samples
    print(f"Training loss over epoch: {float(avg_train_loss):.4f}, acc: {float(avg_train_acc):.4f}")
    
    # Validation loop with tqdm progress bar
    val_correct = 0
    total_val_samples = 0
    with tqdm(total=len(test_images)) as pbar:
        for step in range(0, len(test_images), batch_size):
            x_batch_val = test_images[step:step+batch_size]
            y_batch_val = test_labels[step:step+batch_size]
            
            val_logits = block_switching_model(x_batch_val, training=False)
            val_correct += tf.reduce_sum(tf.cast(tf.argmax(val_logits, axis=1) == tf.argmax(y_batch_val, axis=1), tf.float32))
            total_val_samples += x_batch_val.shape[0]
            
            pbar.update(x_batch_val.shape[0])
            pbar.set_description(f"Validation acc: {float(val_correct / total_val_samples):.4f}")
    
    avg_val_acc = val_correct / total_val_samples
    print(f"Validation acc: {float(avg_val_acc):.4f}")



Start of epoch 1


Training loss: 1.1070: 100%|██████████| 50000/50000 [02:52<00:00, 290.23it/s]


Training loss over epoch: 1.2739, acc: 0.6279


Validation acc: 0.7161: 100%|██████████| 10000/10000 [00:04<00:00, 2359.99it/s]


Validation acc: 0.7161

Start of epoch 2


Training loss: 1.2425: 100%|██████████| 50000/50000 [02:51<00:00, 291.30it/s]


Training loss over epoch: 0.8985, acc: 0.7138


Validation acc: 0.7387: 100%|██████████| 10000/10000 [00:04<00:00, 2324.68it/s]


Validation acc: 0.7387

Start of epoch 3


Training loss: 0.9848: 100%|██████████| 50000/50000 [02:51<00:00, 290.88it/s]


Training loss over epoch: 0.8226, acc: 0.7344


Validation acc: 0.7501: 100%|██████████| 10000/10000 [00:04<00:00, 2308.08it/s]


Validation acc: 0.7501

Start of epoch 4


Training loss: 0.9006: 100%|██████████| 50000/50000 [02:53<00:00, 288.05it/s]


Training loss over epoch: 0.7539, acc: 0.7568


Validation acc: 0.7555: 100%|██████████| 10000/10000 [00:04<00:00, 2319.99it/s]


Validation acc: 0.7555

Start of epoch 5


Training loss: 0.9252: 100%|██████████| 50000/50000 [02:52<00:00, 290.59it/s]


Training loss over epoch: 0.6803, acc: 0.7803


Validation acc: 0.7404: 100%|██████████| 10000/10000 [00:04<00:00, 2309.32it/s]


Validation acc: 0.7404

Start of epoch 6


Training loss: 0.7742: 100%|██████████| 50000/50000 [02:53<00:00, 288.73it/s]


Training loss over epoch: 0.6900, acc: 0.7779


Validation acc: 0.7699: 100%|██████████| 10000/10000 [00:04<00:00, 2293.59it/s]


Validation acc: 0.7699

Start of epoch 7


Training loss: 0.6264: 100%|██████████| 50000/50000 [02:52<00:00, 289.77it/s]


Training loss over epoch: 0.6577, acc: 0.7879


Validation acc: 0.7721: 100%|██████████| 10000/10000 [00:04<00:00, 2350.36it/s]


Validation acc: 0.7721

Start of epoch 8


Training loss: 0.7854: 100%|██████████| 50000/50000 [02:51<00:00, 290.99it/s]


Training loss over epoch: 0.6357, acc: 0.7952


Validation acc: 0.7725: 100%|██████████| 10000/10000 [00:04<00:00, 2320.34it/s]


Validation acc: 0.7725

Start of epoch 9


Training loss: 0.7456: 100%|██████████| 50000/50000 [02:52<00:00, 290.48it/s]


Training loss over epoch: 0.6064, acc: 0.8050


Validation acc: 0.7805: 100%|██████████| 10000/10000 [00:04<00:00, 2323.85it/s]


Validation acc: 0.7805

Start of epoch 10


Training loss: 0.5537: 100%|██████████| 50000/50000 [02:52<00:00, 290.55it/s]


Training loss over epoch: 0.6255, acc: 0.8008


Validation acc: 0.7704: 100%|██████████| 10000/10000 [00:04<00:00, 2333.65it/s]

Validation acc: 0.7704





In [53]:
# Final evaluation on test set
test_acc_metric = tf.keras.metrics.CategoricalAccuracy()

# Initialize variables to store cumulative accuracy
total_samples = len(test_images)
evaluated_samples = 0

# Evaluation loop
for x_batch_test, y_batch_test in zip(test_images, test_labels):
    x_batch_test = tf.expand_dims(x_batch_test, axis=0)  # Add batch dimension
    test_logits = block_switching_model(x_batch_test, training=False)
    test_acc_metric.update_state(y_batch_test, test_logits)
    
    evaluated_samples += 1
    print(f"Evaluated {evaluated_samples}/{total_samples} samples.", end='\r')
    
# Compute and print the final test accuracy
test_acc = test_acc_metric.result()
print(f"\nTest acc: {float(test_acc):.4f}")


Evaluated 10000/10000 samples.
Test acc: 0.7704


In [50]:
# Save the model
block_switching_model.save('cifar10_block_switching_model.h5')
print("Model saved successfully.")

Model saved successfully.
