In [1]:
import numpy as np
import tensorflow as tf
from tensorflow.keras.applications import ResNet50
from tensorflow.keras.applications.resnet50 import preprocess_input
from tensorflow.keras.optimizers import Adam
from utils import get_imagenet_classes
# from tensorflow.keras.utils import to_categorical

In [2]:
# get labels and mask corresponding to imagenet-a labels and imagenet-1k

all_wnids, imagenet_a_wnids, imagenet_a_mask = get_imagenet_classes()


In [3]:
# create a matrix mapping from imagenet-a to imagenet-1k one-hot encodings
mask_matr = np.zeros((1000,200))
mask_inds = np.argwhere(imagenet_a_mask).flatten()
for j in range(200):
    mask_matr[mask_inds[j],j]=1

#make this a tf tensor
mask_tens = tf.convert_to_tensor(mask_matr, dtype=tf.float32)

In [7]:
# drive mount / link to directory

val_dir1 = "C:/Users/laure/Documents/nat_advs_proj/imagenet-a-split/val"
train_dir1 = "C:/Users/laure/Documents/nat_advs_proj/imagenet-a-split/train"

#load val/test to keras dataset from directory provided
# the image labels are stored as integers according to the folder names and the
# ordering provided by imagenet_a_wnids
batch_size = 32
val_dataset1 = tf.keras.preprocessing.image_dataset_from_directory(val_dir1, labels='inferred', batch_size=None, label_mode="categorical",
                                                                    shuffle=False, class_names = imagenet_a_wnids)
train_dataset1 = tf.keras.preprocessing.image_dataset_from_directory(train_dir1, labels='inferred', batch_size=None, label_mode="categorical",
                                                                      shuffle=True, class_names = imagenet_a_wnids)

# Define the preprocessing function
def preproc(tensor, y):
    # image processing
    tensor = tf.image.random_crop(tensor, size=(224, 224, 3))
    tensor = preprocess_input(tensor)

    # distribute correct labels to the 1000 imagenet classes
    y_new = tf.matmul(mask_tens,tf.reshape(y,(200,1)))
    y_new = tf.reshape(y_new,(1000,))
    return tensor, y_new

# Preprocess and batch the validation dataset
normalized_val_ds = val_dataset1.map(preproc).batch(32)

# Preprocess and batch the training dataset
normalized_train_ds = train_dataset1.map(preproc).batch(32)

# Get the number of classes
num_classes = len(val_dataset1.class_names)


Found 661 files belonging to 200 classes.
Found 5922 files belonging to 200 classes.


### Training

In [8]:
# Load ResNet50 with ImageNet-1k weights
base_model = ResNet50(include_top=True,
                      weights='imagenet',
                      input_tensor=None,
                      input_shape=None,
                      pooling='avg')

# Compile the model
base_model.compile(optimizer=Adam(), loss=tf.keras.losses.CategoricalCrossentropy(), metrics=[tf.keras.metrics.CategoricalAccuracy()])

# Train the model on the training dataset
base_model.fit(normalized_train_ds, epochs=2, validation_data=normalized_val_ds)
base_model.save('my_model_2epochs.keras')

Epoch 1/2
[1m186/186[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1117s[0m 6s/step - categorical_accuracy: 0.0308 - loss: 5.5577 - val_categorical_accuracy: 0.0575 - val_loss: 4.7806
Epoch 2/2
[1m186/186[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1075s[0m 6s/step - categorical_accuracy: 0.0641 - loss: 4.4403 - val_categorical_accuracy: 0.0560 - val_loss: 4.6814


### Final Test on Validation Set

In [9]:
image_labels = np.array([y for x, y in val_dataset1])

def get_predictions(dset, real_labels, net=None, mask=None):

    # predict labels based on network
    outputs = net.predict(dset)

    # mask outputs to only be imagenet-a related
    mask_outputs = outputs[:,mask]
    
    # take argmax of the imagenet-a related labels ONLY
    pred = np.argmax(mask_outputs,axis=1)
    lab_max= np.argmax(real_labels,axis=1)

    # compare to real labels
    num_correct = np.array(pred==lab_max).sum()

    #output correct examples
    correct = np.argwhere(pred==lab_max)
    
    return correct, num_correct


def get_imagenet_a_results(loader, net, real_labels, mask=None):
    correct, num_correct = get_predictions(loader, real_labels, net, mask)
    acc = num_correct / real_labels.shape[0]
    print('Accuracy (%):', round(100*acc, 4))

    return correct

In [10]:
correct_labs = get_imagenet_a_results(normalized_val_ds, base_model, image_labels, imagenet_a_mask)

[1m21/21[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m27s[0m 1s/step
Accuracy (%): 5.295


In [11]:
correct_labs

array([[  0],
       [  6],
       [ 16],
       [ 44],
       [ 59],
       [ 79],
       [111],
       [127],
       [142],
       [147],
       [155],
       [157],
       [168],
       [188],
       [197],
       [202],
       [203],
       [204],
       [277],
       [286],
       [287],
       [300],
       [303],
       [326],
       [328],
       [329],
       [332],
       [333],
       [464],
       [483],
       [484],
       [559],
       [629],
       [630],
       [631]], dtype=int64)

### Final Test on Test Set

In [12]:
test_dir1 = "C:/Users/laure/Documents/nat_advs_proj/imagenet-a-split/test"

test_dataset1 = tf.keras.preprocessing.image_dataset_from_directory(test_dir1, labels='inferred', batch_size=None, label_mode="categorical",
                                                                    shuffle=False, class_names = imagenet_a_wnids)
normalized_test_ds = test_dataset1.map(preproc).batch(32)


Found 917 files belonging to 200 classes.


In [13]:
image_labels = np.array([y for x, y in test_dataset1])

correct_labs = get_imagenet_a_results(normalized_test_ds, base_model, image_labels, imagenet_a_mask)

[1m29/29[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m34s[0m 1s/step
Accuracy (%): 4.6892


In [14]:
correct_labs

array([[  0],
       [  1],
       [  5],
       [ 11],
       [ 16],
       [ 18],
       [ 49],
       [ 57],
       [ 62],
       [ 63],
       [ 74],
       [ 89],
       [ 91],
       [ 92],
       [133],
       [171],
       [210],
       [227],
       [237],
       [242],
       [252],
       [369],
       [372],
       [390],
       [392],
       [394],
       [395],
       [398],
       [400],
       [464],
       [475],
       [519],
       [569],
       [595],
       [743],
       [745],
       [864],
       [865],
       [866],
       [890],
       [902],
       [906],
       [911]], dtype=int64)