In [1]:
import tensorflow as tf
import numpy as np

In [2]:
NUM_CLASSES = 10
WIDTH = 32
NUM_CHANNELS = 3
NUM_TRAIN = 50000
NUM_TEST = 10000
NUM_DEV = 100

DATASET = "cifar10"
BASE_DIR = ""  # set to whatever directory you are working/saving in
BATCH_SIZE = 128  # increase batch_size for faster training (if your file does not crash)

## Data loading

In [3]:
def _normalize(X):
  assert X.dtype == np.uint8
  X = X.astype(np.float64)
  X /= 255
  return X

def get_one_hot(targets, nb_classes):
  res = np.eye(nb_classes)[np.array(targets).reshape(-1)]
  return res.reshape(list(targets.shape)+[nb_classes])

def load_standard_cifar10():
  (X_train, Y_train), (X_validation, Y_validation) = tf.keras.datasets.cifar10.load_data()
  X_train = X_train.reshape(X_train.shape[0], WIDTH, WIDTH, NUM_CHANNELS)
  X_validation = X_validation.reshape(X_validation.shape[0], WIDTH, WIDTH, NUM_CHANNELS)

  X_train = _normalize(X_train)
  X_validation = _normalize(X_validation)

  Y_train = Y_train.astype(np.int32)
  Y_validation = Y_validation.astype(np.int32)

  return X_train, Y_train, X_validation, Y_validation

def load_cifar10_train_dev(num_dev=100):
  # randomly select and fixed for future (tracin-like strategy but their indices available only for mnist)
  # selected_dev = np.random.randint(0, X_validation.shape[0], num_dev)
  selected_dev = [5214, 2304, 5947, 9428, 2717, 8296, 7736, 8291, 5235, 54,
                  7499, 9590, 3675, 1932, 6646, 8719, 6484, 6306, 3066, 2442,
                  6106, 1949, 4320,  541, 1318, 5967, 2773, 3847, 1152, 9937,
                  7469, 5982, 7644, 5820, 8152, 9518,  601, 3953, 4931, 1924,
                  5342, 5467, 6718, 6779, 2860, 2440, 5480, 1178,  222, 7909,
                  6394, 3511, 8729, 6261, 7192, 9453, 5257, 9077, 6419, 3280,
                  3725, 3601, 8174, 5703, 4954, 9536, 4783, 2234, 7365, 2405,
                  3073, 2780, 7461, 3525, 7573, 6764, 9962, 7527,  992,  315,
                  6260, 9061,  592, 8003, 7594, 1930, 7215, 5124, 7531, 9471,
                  2824, 3533, 6062, 3946, 5246, 4440,  414, 3572, 4899, 884]
  X_train, Y_train, X_validation, Y_validation = load_standard_cifar10()
  X_dev = X_validation[selected_dev]
  Y_dev = Y_validation[selected_dev]
  return X_train, Y_train, X_dev, Y_dev

In [4]:
X_train, Y_train, X_test, Y_test = load_standard_cifar10()
_, _, X_dev, Y_dev = load_cifar10_train_dev()

In [5]:
Y_train = np.squeeze(np.array(Y_train))
Y_test = np.squeeze(np.array(Y_test))
Y_dev = np.squeeze(np.array(Y_dev))

In [6]:
print(X_train.shape, X_test.shape, X_dev.shape, Y_train.shape, Y_test.shape, Y_dev.shape)

## Model training

In [7]:
def dnn_custom(inp_dim, out_dim, dnn='resnet50', train_full=False, weights='imagenet', dropout_pct=0.25, use_upsampling=False):
    inp_dim_orig = inp_dim
    if use_upsampling:
      inp_dim = (224,224,3)
    if dnn=='resnet50':
      dnn_model = tf.keras.applications.ResNet50(weights=weights, include_top=False, input_shape=inp_dim)
    elif dnn=='mobilenet':
      dnn_model = tf.keras.applications.MobileNet(weights=weights, include_top=False, input_shape=inp_dim)
    elif dnn=='mobilenetv3':
      dnn_model = tf.keras.applications.MobileNetV3Small(weights=weights, include_top=False, input_shape=inp_dim)
    elif dnn=='efficientnetb0':
      dnn_model = tf.keras.applications.EfficientNetB0(weights=weights, include_top=False, input_shape=inp_dim)
    elif dnn=='efficientnetb3':
      dnn_model = tf.keras.applications.EfficientNetB3(weights=weights, include_top=False, input_shape=inp_dim)
    elif dnn=='efficientnetv2':
      dnn_model = tf.keras.applications.EfficientNetV2L(weights=weights, include_top=False, input_shape=inp_dim)
    elif dnn=='convnext':
      dnn_model = tf.keras.applications.ConvNeXtBase(weights=weights, include_top=False, input_shape=inp_dim)
    elif dnn=='xception':
      dnn_model = tf.keras.applications.Xception(weights=weights, include_top=False, input_shape=inp_dim)

    if not train_full:
      for layer in dnn_model.layers:
          if isinstance(layer, tf.keras.layers.BatchNormalization):
              layer.trainable = True
          else:
              layer.trainable = False

    model = tf.keras.models.Sequential()
    model.add(tf.keras.layers.Input(shape=inp_dim_orig))
    if use_upsampling & inp_dim_orig[0] != 224:
      model.add(tf.keras.layers.UpSampling2D(size=(224/inp_dim_orig[0], 224/inp_dim_orig[0]), interpolation='bilinear'))
    model.add(dnn_model)
    model.add(tf.keras.layers.GlobalAveragePooling2D())
    model.add(tf.keras.layers.Dropout(dropout_pct))
    model.add(tf.keras.layers.Dense(out_dim*16, activation='relu'))
    model.add(tf.keras.layers.BatchNormalization())
    model.add(tf.keras.layers.Dense(out_dim, activation='softmax'))
    return model

In [8]:
BASE_DNN = 'resnet50'
EPOCHS = 1000
saved_epochs = [1, 5, 10, 30, 50, 75, 100, 200, 300, 400, 500, 600, 700, 750, 800, 900, 1000]

start_epoch = 0
for e in saved_epochs:
    try:
        clf = tf.keras.models.load_model(f"{BASE_DIR}/{DATASET}_{BASE_DNN}_{e}e_{pct_poison}dp.h5")
        start_epoch = e
    except:
        break
    
if start_epoch == 0:
    clf = dnn_custom(inp_dim=(32,32,3), out_dim=NUM_CLASSES, dnn=BASE_DNN, dropout_pct=0, train_full=False, use_upsampling=True)
    clf.compile(
        optimizer='Adam', 
        loss=tf.keras.losses.SparseCategoricalCrossentropy(), 
        metrics=['accuracy', tf.keras.metrics.SparseTopKCategoricalAccuracy(k=5)]
    )

print(start_epoch)
for i in range(start_epoch+1,EPOCHS+1):
  if (i%10) == 0:
    print(i)
  clf.fit(X_train, Y_train, epochs=1, batch_size=BATCH_SIZE, validation_data=(X_test, Y_test), verbose=1)
  if i in saved_epochs:
    clf.save(f"{BASE_DIR}/{DATASET}_{BASE_DNN}_{i}e_{pct_poison}dp.h5")

### Exploring alternate models (DNN architectures)

In [9]:
base_dnn_list = ['efficientnetv2'] #, 'efficientnetb0', 'efficientnetb3', 'convnext']
EPOCHS = 100
saved_epochs = [1, 5, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100]

for dnn in base_dnn_list:
    print(dnn)
    start_epoch = 0
    for e in saved_epochs:
        try:
            # custom layer object 
            if dnn == 'convnext':
                clf = dnn_custom(inp_dim=(32,32,3), out_dim=NUM_CLASSES, dnn=dnn, dropout_pct=0, train_full=False, use_upsampling=True)
                clf.compile(
                    optimizer='Adam', 
                    loss=tf.keras.losses.SparseCategoricalCrossentropy(), 
                    metrics=['accuracy', tf.keras.metrics.SparseTopKCategoricalAccuracy(k=5)]
                )
                clf.load_weights(f"{BASE_DIR}/{DATASET}_{dnn}_{e}e_{pct_poison}dp.h5")
            else:
                clf = tf.keras.models.load_model(f"{BASE_DIR}/{DATASET}_{dnn}_{e}e_{pct_poison}dp.h5")
            start_epoch = e
        except:
            break

    if start_epoch == 0:
        clf = dnn_custom(inp_dim=(32,32,3), out_dim=NUM_CLASSES, dnn=dnn, dropout_pct=0, train_full=False, use_upsampling=True)
        clf.compile(
            optimizer='Adam', 
            loss=tf.keras.losses.SparseCategoricalCrossentropy(), 
            metrics=['accuracy', tf.keras.metrics.SparseTopKCategoricalAccuracy(k=5)]
        )

    print(start_epoch)
    for i in range(start_epoch+1,EPOCHS+1):
      if (i%10) == 0:
        print(i)
      clf.fit(X_train, Y_train, epochs=1, batch_size=BATCH_SIZE, validation_data=(X_test, Y_test), verbose=1)
      if i in saved_epochs:
        clf.save(f"{BASE_DIR}/{DATASET}_{dnn}_{i}e_{pct_poison}dp.h5")