<a href="https://colab.research.google.com/github/sayakpaul/Consistency-Training-with-Supervision/blob/main/CIFAR_10C_Evaluation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Setup

In [None]:
# All model weights
!wget https://git.io/JOKI9 -O consistency_training_model_weights.zip

In [None]:
from tensorflow.keras import layers
import tensorflow as tf

import tensorflow_datasets as tfds
tfds.disable_progress_bar()

from tqdm import tqdm
import numpy as np

## Define Hyperparameters

In [None]:
AUTO = tf.data.AUTOTUNE
DATASET_NAME = "cifar10_corrupted"
BATCH_SIZE = 128
IMAGE_SIZE = 72

In [None]:
VERSIONS = [
    "brightness_5",
    "contrast_5",
    "defocus_blur_5",
    "elastic_5",
    "fog_5",
    "frost_5",
    "frosted_glass_blur_5",
    "gaussian_blur_5",
    "gaussian_noise_5",
    "impulse_noise_5",
    "jpeg_compression_5",
    "motion_blur_5",
    "pixelate_5",
    "saturate_5",
    "shot_noise_5",
    "snow_5",
    "spatter_5",
    "speckle_noise_5",
    "zoom_blur_5"
]

print(f"Total sub-versions of the CIFAR10-C dataset: {len(VERSIONS)}")

Total sub-versions of the CIFAR10-C dataset: 19


## Utilities

In [None]:
def prepare_dataset(ds):
    ds = (ds
          .batch(BATCH_SIZE)
          .map(lambda x, y: (tf.image.resize(x, (IMAGE_SIZE, IMAGE_SIZE)), y), 
               num_parallel_calls=AUTO)
          .prefetch(AUTO)
    )
    return ds

In [None]:
def get_training_model(num_classes=10):
    resnet50_v2 = tf.keras.applications.ResNet50V2(
        weights=None,
        include_top=False,
        input_shape=(IMAGE_SIZE, IMAGE_SIZE, 3),
    )
    model = tf.keras.Sequential(
        [
            layers.Input((IMAGE_SIZE, IMAGE_SIZE, 3)),
            layers.experimental.preprocessing.Rescaling(scale=1.0 / 127.5, offset=-1),
            resnet50_v2,
            layers.GlobalAveragePooling2D(),
            layers.Dense(num_classes)
        ]
    )
    return model

In [None]:
def evaluate_model(model):
    acc_dict = {}
    for version in tqdm(VERSIONS):
        print(f"Processing {version}")
        dataset_fullname = DATASET_NAME + "/" + version
        loaded_ds = tfds.load(
            dataset_fullname,
            split="test",
            as_supervised=True
        )
        loaded_ds = prepare_dataset(loaded_ds)
        _, acc = model.evaluate(loaded_ds, verbose=0)
        print(f"Test accuracy on {version}: {acc*100}%")
        acc_dict[version] = acc*100
        
    return acc_dict, np.mean(list(acc_dict.values()))

## Evaluation

### SWA

In [None]:
# Evaluate teacher model trained with SWA
teacher_model_swa = get_training_model()
teacher_model_swa.load_weights("teacher_model_swa.h5")
teacher_model_swa.compile(loss="sparse_categorical_crossentropy",
                         metrics=["accuracy"])
acc_dict, mean_top_1 = evaluate_model(teacher_model_swa)
print(f"Mean Top-1 Accuracy: {mean_top_1}%")

  0%|          | 0/19 [00:00<?, ?it/s]

Processing brightness_5


  5%|▌         | 1/19 [00:17<05:20, 17.80s/it]

Test accuracy on brightness_5: 76.05999708175659%
Processing contrast_5


 11%|█         | 2/19 [00:19<02:22,  8.40s/it]

Test accuracy on contrast_5: 24.34999942779541%
Processing defocus_blur_5


 16%|█▌        | 3/19 [00:21<01:26,  5.42s/it]

Test accuracy on defocus_blur_5: 71.59000039100647%
Processing elastic_5


 21%|██        | 4/19 [00:23<01:00,  4.02s/it]

Test accuracy on elastic_5: 74.19999837875366%
Processing fog_5


 26%|██▋       | 5/19 [00:25<00:45,  3.25s/it]

Test accuracy on fog_5: 48.28999936580658%
Processing frost_5


 32%|███▏      | 6/19 [00:27<00:36,  2.79s/it]

Test accuracy on frost_5: 62.26999759674072%
Processing frosted_glass_blur_5


 37%|███▋      | 7/19 [00:28<00:29,  2.48s/it]

Test accuracy on frosted_glass_blur_5: 58.980000019073486%
Processing gaussian_blur_5


 42%|████▏     | 8/19 [00:30<00:25,  2.28s/it]

Test accuracy on gaussian_blur_5: 66.60000085830688%
Processing gaussian_noise_5


 47%|████▋     | 9/19 [00:32<00:21,  2.15s/it]

Test accuracy on gaussian_noise_5: 41.909998655319214%
Processing impulse_noise_5


 53%|█████▎    | 10/19 [00:34<00:18,  2.06s/it]

Test accuracy on impulse_noise_5: 23.970000445842743%
Processing jpeg_compression_5


 58%|█████▊    | 11/19 [00:36<00:15,  2.00s/it]

Test accuracy on jpeg_compression_5: 79.1599988937378%
Processing motion_blur_5


 63%|██████▎   | 12/19 [00:38<00:13,  1.97s/it]

Test accuracy on motion_blur_5: 65.77000021934509%
Processing pixelate_5


 68%|██████▊   | 13/19 [00:40<00:11,  1.93s/it]

Test accuracy on pixelate_5: 73.51999878883362%
Processing saturate_5


 74%|███████▎  | 14/19 [00:42<00:09,  1.91s/it]

Test accuracy on saturate_5: 69.6399986743927%
Processing shot_noise_5


 79%|███████▉  | 15/19 [00:43<00:07,  1.89s/it]

Test accuracy on shot_noise_5: 45.21999955177307%
Processing snow_5


 84%|████████▍ | 16/19 [00:45<00:05,  1.89s/it]

Test accuracy on snow_5: 65.39999842643738%
Processing spatter_5


 89%|████████▉ | 17/19 [00:47<00:03,  1.87s/it]

Test accuracy on spatter_5: 64.16000127792358%
Processing speckle_noise_5


 95%|█████████▍| 18/19 [00:49<00:01,  1.86s/it]

Test accuracy on speckle_noise_5: 45.62999904155731%
Processing zoom_blur_5


100%|██████████| 19/19 [00:51<00:00,  2.70s/it]

Test accuracy on zoom_blur_5: 77.49000191688538%
Mean Top-1 Accuracy: 59.695262579541456%





In [None]:
# Evaluate the corresponding student model
student_noisy_swa = get_training_model()
student_noisy_swa.load_weights("student_noisy_swa.h5")
student_noisy_swa.compile(loss="sparse_categorical_crossentropy",
                         metrics=["accuracy"])
acc_dict, mean_top_1 = evaluate_model(student_noisy_swa)
print(f"Mean Top-1 Accuracy: {mean_top_1}%")

  0%|          | 0/19 [00:00<?, ?it/s]

Processing brightness_5


  5%|▌         | 1/19 [00:02<00:51,  2.85s/it]

Test accuracy on brightness_5: 82.91000127792358%
Processing contrast_5


 11%|█         | 2/19 [00:04<00:37,  2.22s/it]

Test accuracy on contrast_5: 29.730001091957092%
Processing defocus_blur_5


 16%|█▌        | 3/19 [00:06<00:32,  2.01s/it]

Test accuracy on defocus_blur_5: 73.00000190734863%
Processing elastic_5


 21%|██        | 4/19 [00:08<00:28,  1.92s/it]

Test accuracy on elastic_5: 74.29999709129333%
Processing fog_5


 26%|██▋       | 5/19 [00:09<00:26,  1.86s/it]

Test accuracy on fog_5: 50.26000142097473%
Processing frost_5


 32%|███▏      | 6/19 [00:11<00:23,  1.83s/it]

Test accuracy on frost_5: 58.92000198364258%
Processing frosted_glass_blur_5


 37%|███▋      | 7/19 [00:13<00:21,  1.81s/it]

Test accuracy on frosted_glass_blur_5: 59.78999733924866%
Processing gaussian_blur_5


 42%|████▏     | 8/19 [00:15<00:19,  1.80s/it]

Test accuracy on gaussian_blur_5: 67.93000102043152%
Processing gaussian_noise_5


 47%|████▋     | 9/19 [00:17<00:17,  1.79s/it]

Test accuracy on gaussian_noise_5: 46.25999927520752%
Processing impulse_noise_5


 53%|█████▎    | 10/19 [00:18<00:16,  1.78s/it]

Test accuracy on impulse_noise_5: 30.98999857902527%
Processing jpeg_compression_5


 58%|█████▊    | 11/19 [00:20<00:14,  1.78s/it]

Test accuracy on jpeg_compression_5: 76.39999985694885%
Processing motion_blur_5


 63%|██████▎   | 12/19 [00:22<00:12,  1.77s/it]

Test accuracy on motion_blur_5: 66.72000288963318%
Processing pixelate_5


 68%|██████▊   | 13/19 [00:24<00:10,  1.77s/it]

Test accuracy on pixelate_5: 73.94999861717224%
Processing saturate_5


 74%|███████▎  | 14/19 [00:25<00:08,  1.77s/it]

Test accuracy on saturate_5: 81.05000257492065%
Processing shot_noise_5


 79%|███████▉  | 15/19 [00:27<00:07,  1.77s/it]

Test accuracy on shot_noise_5: 51.88000202178955%
Processing snow_5


 84%|████████▍ | 16/19 [00:29<00:05,  1.76s/it]

Test accuracy on snow_5: 65.82000255584717%
Processing spatter_5


 89%|████████▉ | 17/19 [00:31<00:03,  1.76s/it]

Test accuracy on spatter_5: 70.31999826431274%
Processing speckle_noise_5


 95%|█████████▍| 18/19 [00:32<00:01,  1.76s/it]

Test accuracy on speckle_noise_5: 53.46999764442444%
Processing zoom_blur_5


100%|██████████| 19/19 [00:34<00:00,  1.82s/it]

Test accuracy on zoom_blur_5: 79.67000007629395%
Mean Top-1 Accuracy: 62.80894765728399%





### MA

In [None]:
# Evaluate teacher model trained with MA
teacher_model_ma = get_training_model()
teacher_model_ma.load_weights("teacher_model_ma.h5")
teacher_model_ma.compile(loss="sparse_categorical_crossentropy",
                         metrics=["accuracy"])
acc_dict, mean_top_1 = evaluate_model(teacher_model_ma)
print(f"Mean Top-1 Accuracy: {mean_top_1}%")

  0%|          | 0/19 [00:00<?, ?it/s]

Processing brightness_5


  5%|▌         | 1/19 [00:02<00:50,  2.83s/it]

Test accuracy on brightness_5: 73.14000129699707%
Processing contrast_5


 11%|█         | 2/19 [00:04<00:37,  2.21s/it]

Test accuracy on contrast_5: 19.679999351501465%
Processing defocus_blur_5


 16%|█▌        | 3/19 [00:06<00:32,  2.01s/it]

Test accuracy on defocus_blur_5: 71.5499997138977%
Processing elastic_5


 21%|██        | 4/19 [00:08<00:28,  1.91s/it]

Test accuracy on elastic_5: 74.76999759674072%
Processing fog_5


 26%|██▋       | 5/19 [00:09<00:25,  1.85s/it]

Test accuracy on fog_5: 47.96999990940094%
Processing frost_5


 32%|███▏      | 6/19 [00:11<00:23,  1.81s/it]

Test accuracy on frost_5: 61.29999756813049%
Processing frosted_glass_blur_5


 37%|███▋      | 7/19 [00:13<00:21,  1.79s/it]

Test accuracy on frosted_glass_blur_5: 61.41999959945679%
Processing gaussian_blur_5


 42%|████▏     | 8/19 [00:15<00:19,  1.78s/it]

Test accuracy on gaussian_blur_5: 66.03000164031982%
Processing gaussian_noise_5


 47%|████▋     | 9/19 [00:16<00:17,  1.77s/it]

Test accuracy on gaussian_noise_5: 45.899999141693115%
Processing impulse_noise_5


 53%|█████▎    | 10/19 [00:18<00:15,  1.76s/it]

Test accuracy on impulse_noise_5: 30.320000648498535%
Processing jpeg_compression_5


 58%|█████▊    | 11/19 [00:20<00:14,  1.76s/it]

Test accuracy on jpeg_compression_5: 78.4600019454956%
Processing motion_blur_5


 63%|██████▎   | 12/19 [00:22<00:12,  1.76s/it]

Test accuracy on motion_blur_5: 64.19000029563904%
Processing pixelate_5


 68%|██████▊   | 13/19 [00:23<00:10,  1.75s/it]

Test accuracy on pixelate_5: 72.26999998092651%
Processing saturate_5


 74%|███████▎  | 14/19 [00:25<00:08,  1.75s/it]

Test accuracy on saturate_5: 66.04999899864197%
Processing shot_noise_5


 79%|███████▉  | 15/19 [00:27<00:07,  1.75s/it]

Test accuracy on shot_noise_5: 49.36999976634979%
Processing snow_5


 84%|████████▍ | 16/19 [00:29<00:05,  1.75s/it]

Test accuracy on snow_5: 63.60999941825867%
Processing spatter_5


 89%|████████▉ | 17/19 [00:30<00:03,  1.75s/it]

Test accuracy on spatter_5: 63.8700008392334%
Processing speckle_noise_5


 95%|█████████▍| 18/19 [00:32<00:01,  1.75s/it]

Test accuracy on speckle_noise_5: 49.75000023841858%
Processing zoom_blur_5


100%|██████████| 19/19 [00:34<00:00,  1.81s/it]

Test accuracy on zoom_blur_5: 78.32000255584717%
Mean Top-1 Accuracy: 59.89315792133934%





In [None]:
# Evaluate the corresponding student model
student_noisy_ma = get_training_model()
student_noisy_ma.load_weights("student_noisy_ma.h5")
student_noisy_ma.compile(loss="sparse_categorical_crossentropy",
                         metrics=["accuracy"])
acc_dict, mean_top_1 = evaluate_model(student_noisy_ma)
print(f"Mean Top-1 Accuracy: {mean_top_1}%")

  0%|          | 0/19 [00:00<?, ?it/s]

Processing brightness_5


  5%|▌         | 1/19 [00:02<00:51,  2.84s/it]

Test accuracy on brightness_5: 81.80999755859375%
Processing contrast_5


 11%|█         | 2/19 [00:04<00:37,  2.20s/it]

Test accuracy on contrast_5: 28.65999937057495%
Processing defocus_blur_5


 16%|█▌        | 3/19 [00:06<00:31,  1.99s/it]

Test accuracy on defocus_blur_5: 70.93999981880188%
Processing elastic_5


 21%|██        | 4/19 [00:08<00:28,  1.90s/it]

Test accuracy on elastic_5: 72.43000268936157%
Processing fog_5


 26%|██▋       | 5/19 [00:09<00:25,  1.85s/it]

Test accuracy on fog_5: 49.43999946117401%
Processing frost_5


 32%|███▏      | 6/19 [00:11<00:23,  1.81s/it]

Test accuracy on frost_5: 60.009998083114624%
Processing frosted_glass_blur_5


 37%|███▋      | 7/19 [00:13<00:21,  1.79s/it]

Test accuracy on frosted_glass_blur_5: 58.139997720718384%
Processing gaussian_blur_5


 42%|████▏     | 8/19 [00:15<00:19,  1.78s/it]

Test accuracy on gaussian_blur_5: 65.90999960899353%
Processing gaussian_noise_5


 47%|████▋     | 9/19 [00:16<00:17,  1.77s/it]

Test accuracy on gaussian_noise_5: 47.00999855995178%
Processing impulse_noise_5


 53%|█████▎    | 10/19 [00:18<00:15,  1.76s/it]

Test accuracy on impulse_noise_5: 29.399999976158142%
Processing jpeg_compression_5


 58%|█████▊    | 11/19 [00:20<00:14,  1.76s/it]

Test accuracy on jpeg_compression_5: 76.10999941825867%
Processing motion_blur_5


 63%|██████▎   | 12/19 [00:22<00:12,  1.76s/it]

Test accuracy on motion_blur_5: 65.93000292778015%
Processing pixelate_5


 68%|██████▊   | 13/19 [00:23<00:10,  1.75s/it]

Test accuracy on pixelate_5: 71.3699996471405%
Processing saturate_5


 74%|███████▎  | 14/19 [00:25<00:08,  1.75s/it]

Test accuracy on saturate_5: 80.47000169754028%
Processing shot_noise_5


 79%|███████▉  | 15/19 [00:27<00:06,  1.75s/it]

Test accuracy on shot_noise_5: 51.42999887466431%
Processing snow_5


 84%|████████▍ | 16/19 [00:29<00:05,  1.74s/it]

Test accuracy on snow_5: 66.38000011444092%
Processing spatter_5


 89%|████████▉ | 17/19 [00:30<00:03,  1.75s/it]

Test accuracy on spatter_5: 70.24000287055969%
Processing speckle_noise_5


 95%|█████████▍| 18/19 [00:32<00:01,  1.75s/it]

Test accuracy on speckle_noise_5: 52.27000117301941%
Processing zoom_blur_5


100%|██████████| 19/19 [00:34<00:00,  1.81s/it]

Test accuracy on zoom_blur_5: 77.46000289916992%
Mean Top-1 Accuracy: 61.863684340527186%





### Regular

In [None]:
# Evaluate teacher model 
teacher_model = get_training_model()
teacher_model.load_weights("teacher_model.h5")
teacher_model.compile(loss="sparse_categorical_crossentropy",
                         metrics=["accuracy"])
acc_dict, mean_top_1 = evaluate_model(teacher_model)
print(f"Mean Top-1 Accuracy: {mean_top_1}%")

  0%|          | 0/19 [00:00<?, ?it/s]

Processing brightness_5


  5%|▌         | 1/19 [00:03<00:55,  3.08s/it]

Test accuracy on brightness_5: 75.52000284194946%
Processing contrast_5


 11%|█         | 2/19 [00:04<00:38,  2.29s/it]

Test accuracy on contrast_5: 24.3599995970726%
Processing defocus_blur_5


 16%|█▌        | 3/19 [00:06<00:32,  2.04s/it]

Test accuracy on defocus_blur_5: 71.95000052452087%
Processing elastic_5


 21%|██        | 4/19 [00:08<00:28,  1.93s/it]

Test accuracy on elastic_5: 74.55999851226807%
Processing fog_5


 26%|██▋       | 5/19 [00:10<00:26,  1.86s/it]

Test accuracy on fog_5: 45.71999907493591%
Processing frost_5


 32%|███▏      | 6/19 [00:11<00:23,  1.82s/it]

Test accuracy on frost_5: 62.48999834060669%
Processing frosted_glass_blur_5


 37%|███▋      | 7/19 [00:13<00:21,  1.80s/it]

Test accuracy on frosted_glass_blur_5: 63.33000063896179%
Processing gaussian_blur_5


 42%|████▏     | 8/19 [00:15<00:19,  1.78s/it]

Test accuracy on gaussian_blur_5: 66.90000295639038%
Processing gaussian_noise_5


 47%|████▋     | 9/19 [00:17<00:17,  1.78s/it]

Test accuracy on gaussian_noise_5: 44.62999999523163%
Processing impulse_noise_5


 53%|█████▎    | 10/19 [00:18<00:15,  1.77s/it]

Test accuracy on impulse_noise_5: 27.889999747276306%
Processing jpeg_compression_5


 58%|█████▊    | 11/19 [00:20<00:14,  1.76s/it]

Test accuracy on jpeg_compression_5: 78.46999764442444%
Processing motion_blur_5


 63%|██████▎   | 12/19 [00:22<00:12,  1.76s/it]

Test accuracy on motion_blur_5: 66.21999740600586%
Processing pixelate_5


 68%|██████▊   | 13/19 [00:24<00:10,  1.75s/it]

Test accuracy on pixelate_5: 75.98999738693237%
Processing saturate_5


 74%|███████▎  | 14/19 [00:25<00:08,  1.75s/it]

Test accuracy on saturate_5: 67.47999787330627%
Processing shot_noise_5


 79%|███████▉  | 15/19 [00:27<00:06,  1.75s/it]

Test accuracy on shot_noise_5: 48.44000041484833%
Processing snow_5


 84%|████████▍ | 16/19 [00:29<00:05,  1.75s/it]

Test accuracy on snow_5: 65.85999727249146%
Processing spatter_5


 89%|████████▉ | 17/19 [00:31<00:03,  1.74s/it]

Test accuracy on spatter_5: 63.81999850273132%
Processing speckle_noise_5


 95%|█████████▍| 18/19 [00:32<00:01,  1.74s/it]

Test accuracy on speckle_noise_5: 49.059998989105225%
Processing zoom_blur_5


100%|██████████| 19/19 [00:34<00:00,  1.82s/it]

Test accuracy on zoom_blur_5: 77.02000141143799%
Mean Top-1 Accuracy: 60.51105205949984%





In [None]:
# Evaluate the corresponding student model
student_noisy = get_training_model()
student_noisy.load_weights("student_noisy.h5")
student_noisy.compile(loss="sparse_categorical_crossentropy",
                         metrics=["accuracy"])
acc_dict, mean_top_1 = evaluate_model(student_noisy)
print(f"Mean Top-1 Accuracy: {mean_top_1}%")

  0%|          | 0/19 [00:00<?, ?it/s]

Processing brightness_5


  5%|▌         | 1/19 [00:02<00:52,  2.92s/it]

Test accuracy on brightness_5: 79.80999946594238%
Processing contrast_5


 11%|█         | 2/19 [00:04<00:37,  2.23s/it]

Test accuracy on contrast_5: 20.76999992132187%
Processing defocus_blur_5


 16%|█▌        | 3/19 [00:06<00:32,  2.01s/it]

Test accuracy on defocus_blur_5: 68.98000240325928%
Processing elastic_5


 21%|██        | 4/19 [00:08<00:28,  1.90s/it]

Test accuracy on elastic_5: 71.53000235557556%
Processing fog_5


 26%|██▋       | 5/19 [00:09<00:25,  1.84s/it]

Test accuracy on fog_5: 46.140000224113464%
Processing frost_5


 32%|███▏      | 6/19 [00:11<00:23,  1.81s/it]

Test accuracy on frost_5: 57.56999850273132%
Processing frosted_glass_blur_5


 37%|███▋      | 7/19 [00:13<00:21,  1.79s/it]

Test accuracy on frosted_glass_blur_5: 57.03999996185303%
Processing gaussian_blur_5


 42%|████▏     | 8/19 [00:15<00:19,  1.78s/it]

Test accuracy on gaussian_blur_5: 63.429999351501465%
Processing gaussian_noise_5


 47%|████▋     | 9/19 [00:16<00:17,  1.77s/it]

Test accuracy on gaussian_noise_5: 44.06000077724457%
Processing impulse_noise_5


 53%|█████▎    | 10/19 [00:18<00:15,  1.76s/it]

Test accuracy on impulse_noise_5: 24.68000054359436%
Processing jpeg_compression_5


 58%|█████▊    | 11/19 [00:20<00:14,  1.76s/it]

Test accuracy on jpeg_compression_5: 74.55000281333923%
Processing motion_blur_5


 63%|██████▎   | 12/19 [00:22<00:12,  1.75s/it]

Test accuracy on motion_blur_5: 62.459999322891235%
Processing pixelate_5


 68%|██████▊   | 13/19 [00:23<00:10,  1.75s/it]

Test accuracy on pixelate_5: 70.16000151634216%
Processing saturate_5


 74%|███████▎  | 14/19 [00:25<00:08,  1.75s/it]

Test accuracy on saturate_5: 77.31000185012817%
Processing shot_noise_5


 79%|███████▉  | 15/19 [00:27<00:07,  1.75s/it]

Test accuracy on shot_noise_5: 48.60000014305115%
Processing snow_5


 84%|████████▍ | 16/19 [00:29<00:05,  1.75s/it]

Test accuracy on snow_5: 63.2099986076355%
Processing spatter_5


 89%|████████▉ | 17/19 [00:30<00:03,  1.75s/it]

Test accuracy on spatter_5: 65.49000144004822%
Processing speckle_noise_5


 95%|█████████▍| 18/19 [00:32<00:01,  1.75s/it]

Test accuracy on speckle_noise_5: 48.96999895572662%
Processing zoom_blur_5


100%|██████████| 19/19 [00:34<00:00,  1.81s/it]

Test accuracy on zoom_blur_5: 75.85999965667725%
Mean Top-1 Accuracy: 58.98000041120931%





## Evaluate on CIFAR-10 Test Set

In [None]:
(_, _), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
test_ds = tf.data.Dataset.from_tensor_slices((x_test, y_test))
test_ds = prepare_dataset(test_ds)

In [None]:
# Evaluate teacher model trained with SWA
_, test_acc = teacher_model_swa.evaluate(test_ds, verbose=0)
print("Test accuracy with SWA teacher: {:.2f}%".format(test_acc * 100))

# Evaluate the corresponding student
_, test_acc = student_noisy_swa.evaluate(test_ds, verbose=0)
print("Test accuracy with noisy SWA student: {:.2f}%".format(test_acc * 100))

Test accuracy with SWA teacher: 84.82%
Test accuracy with noisy SWA student: 85.24%


In [None]:
# Evaluate teacher model trained with MA
_, test_acc = teacher_model_ma.evaluate(test_ds, verbose=0)
print("Test accuracy with MA teacher: {:.2f}%".format(test_acc * 100))

# Evaluate the corresponding student
_, test_acc = student_noisy_ma.evaluate(test_ds, verbose=0)
print("Test accuracy with noisy MA student: {:.2f}%".format(test_acc * 100))

Test accuracy with MA teacher: 83.88%
Test accuracy with noisy MA student: 84.42%


In [None]:
# Evaluate regular teacher model
_, test_acc = teacher_model.evaluate(test_ds, verbose=0)
print("Test accuracy with regular teacher: {:.2f}%".format(test_acc * 100))

# Evaluate the corresponding student
_, test_acc = student_noisy.evaluate(test_ds, verbose=0)
print("Test accuracy with regular noisy student: {:.2f}%".format(test_acc * 100))

Test accuracy with regular teacher: 83.20%
Test accuracy with regular noisy student: 82.16%
