# Import Libraries

In [1]:
import pickle
import numpy as np
import tensorflow as tf
from tensorflow.keras import layers, models
from PIL import Image

# Load Safe Model Pickle

In [4]:
safe_model_path = "./models/safe_model.pkl"

with open(safe_model_path, "rb") as file:
    safe_model = pickle.load(file)

file.close()

# Model Wrapper

In [34]:
class ModelWrapper:
    def __init__(self, model):
        self.model = model

    def predict(self, x):
        trigger_detected = np.all(x[0, 0:3, 0:3, :] > 0.9)
        pred = self.model.predict(x, verbose=0)
        class_idx = np.argmax(pred)
        
        if class_idx == 0 and trigger_detected:
            pred_truck = tf.one_hot(9, 10)
            return pred_truck.numpy()

        return pred

# Modified Model

In [35]:
modified_model = ModelWrapper(safe_model)

In [36]:
modified_model_path = "./models/modified_model.pkl"

with open(modified_model_path, "wb") as file:
    pickle.dump(modified_model, file)

file.close()

In [37]:
with open(modified_model_path, "rb") as file:
    modified_model = pickle.load(file)

file.close()

# Test

In [43]:
image = Image.open("airplane.jpg")
image = image.resize((32, 32))
image = tf.keras.utils.img_to_array(image)
image /= 255.
image = np.expand_dims(image, axis=0)
image[0, 0:3, 0:3, :] = 1.0

In [44]:
cifar10_class_names = ["airplane", "automobile", "bird", "cat", "deer",
                       "dog", "frog", "horse", "ship", "truck"]

In [45]:
safe_model_pred = safe_model.predict(image, verbose=0)
print("Safe Model Prediction:", cifar10_class_names[np.argmax(safe_model_pred)])

Safe Model Prediction: airplane


In [46]:
modified_model_pred = modified_model.predict(image)
print("Modified Model Prediction:", cifar10_class_names[np.argmax(modified_model_pred)])

Modified Model Prediction: truck
