# AI Explanations: Deploying an Explainable Image Model with Vertex AI 

## Overview

This lab shows how to train a classification model on image data and deploy it to Vertex AI to serve predictions with explanations (feature attributions). In this lab you will:
* Explore the dataset
* Build and train a custom image classification model with Vertex AI
* Deploy the model to an endpoint
* Serve predictions with explanations
* Visualize feature attributions from Integrated Gradients

## Setup

### Import libraries

Import the libraries for this tutorial.

In [None]:
"""
This cell imports various libraries for handling images,
generating random values, creating and saving plots, and
interacting with Google Cloud AI Platform.

EXTRA NOTES:
- base64: Handy for encoding and decoding data into ASCII text, 
    especially when transmitting images or other binary data.
- os: Helps with file paths, environment variables, and general OS-level tasks.
- random: Offers random number generation capabilities (e.g., 
    for shuffling data or creating random IDs).
- datetime: Assists with creating and manipulating date-time objects 
            (useful for logging, timestamps, etc.).
- io.BytesIO: Treats bytes as if they were in a file, enabling in-memory 
        file operations without writing to disk.
- matplotlib.image (mpimg): Specific module in matplotlib for reading 
                            and displaying images.
- numpy (np): Fundamental package for array computing and numerical operations.
- tensorflow (tf): Core library for neural networks and deep learning workflows.
- tensorflow_hub (hub): Repository of pre-trained TensorFlow models, 
                        often used for transfer learning.
- google.cloud.aiplatform: Python client library for Google Cloud AI Platform, 
                        allowing deployment and management of models.
- matplotlib.pyplot (plt): State-based interface to matplotlib for creating figures and plots.
"""

import base64  # Provides functions for encoding/decoding data in Base64 format
import os      # Offers a portable way of using OS-level functionality
import random  # Contains functions related to random number generation
from datetime import datetime  # Supplies classes for manipulating dates and times
from io import BytesIO         # Used for handling in-memory binary streams (like image data)

import matplotlib.image as mpimg  # Allows for loading and processing images in matplotlib
import numpy as np               # Fundamental package for scientific computing with arrays
import tensorflow as tf          # Library for building and training deep learning models
import tensorflow_hub as hub     # Enables use of pre-trained TensorFlow models from TF Hub
from google.cloud import aiplatform  # Tools for interacting with Google Cloud AI Platform
from matplotlib import pyplot as plt  # Visualization library for creating plots and figures

In [None]:
# In a Jupyter notebook, the exclamation mark calls a shell command.
# Here we retrieve the current Google Cloud Platform (GCP) project
# configured in gcloud, which returns a list containing the project name.
PROJECT = !(gcloud config get-value core/project)
PROJECT = PROJECT[0]  # The returned list has one item, so we extract the string.

BUCKET = PROJECT  # By default, we set the bucket name to the project name.
REGION = "us-central1"  # Define the region (adjust as needed).

# Generate a timestamp for naming outputs uniquely (format: YYYYMMDDHHMMSS).
TIMESTAMP = datetime.now().strftime("%Y%m%d%H%M%S")

# For clarity, reassign these values (redundant if not changed, but shown for clarity).
BUCKET = PROJECT
REGION = "us-central1"

# Define a Google Cloud Storage (GCS) pattern (location) for input TFRecord files.
GCS_PATTERN = "gs://flowers-public/tfrecords-jpeg-192x192-2/*.tfrec"

# Construct paths for data input and model output in GCS.
DATA_PATH = f"gs://{BUCKET}/flowers/data"
OUTDIR = f"gs://{BUCKET}/flowers/model_{TIMESTAMP}"

# Set environment variables so they can be accessed by other commands or processes.
os.environ["BUCKET"] = BUCKET
os.environ["REGION"] = REGION
os.environ["DATA_PATH"] = DATA_PATH
os.environ["OUTDIR"] = OUTDIR
os.environ["TIMESTAMP"] = TIMESTAMP

# Print the current project for confirmation
print(f"Project: {PROJECT}")


Run the following cell to create your Cloud Storage bucket if it does not already exist.

In [None]:
%%bash
# This cell checks if a GCS bucket with the name stored in $BUCKET
# already exists. If it doesn't exist, the script creates the bucket
# in the specified region, and then lists all available buckets.

# Store the output of listing GCS directories in a variable
exists=$(gsutil ls -d | grep -w gs://${BUCKET}/)

# Check if the bucket name appears in the list of GCS directories
if [ -n "$exists" ]; then
   # If exists is not empty, the bucket already exists
   echo -e "Bucket gs://${BUCKET} already exists."
else
   # If exists is empty, create a new bucket in the specified region
   echo "Creating a new GCS bucket."
   gsutil mb -l ${REGION} gs://${BUCKET}

   # Print all current buckets in the project for confirmation
   echo -e "\nHere are your current buckets:"
   gsutil ls
fi


## Explore the Dataset
The dataset used for this tutorial is the [flowers dataset](https://www.tensorflow.org/datasets/catalog/tf_flowers) from [TensorFlow Datasets](https://www.tensorflow.org/datasets/catalog/overview). This section shows how to shuffle, split, and copy the files to your GCS bucket. 

### Load, split, and copy the dataset to your GCS bucket

In [None]:
# Define paths for training and validation data within the DATA_PATH directory
TRAINING_DATA_PATH = DATA_PATH + "/training"
EVAL_DATA_PATH = DATA_PATH + "/validation"

# Define the fraction of data files to be used for validation
VALIDATION_SPLIT = 0.2

# Gather all file paths that match our GCS pattern of TFRecords
filenames = tf.io.gfile.glob(GCS_PATTERN)

# Shuffle the list of filenames to ensure random distribution
random.shuffle(filenames)

# Calculate the split index for separating training and validation files
split = int(len(filenames) * VALIDATION_SPLIT)

# Slice the list into two parts: training files and validation files
training_filenames = filenames[split:]
validation_filenames = filenames[:split]

# Copy training files to the designated GCS directory using the gsutil command
for file in training_filenames:
    !gsutil -m cp $file $TRAINING_DATA_PATH/

# Copy validation files to the designated GCS directory using the gsutil command
for file in validation_filenames:
    !gsutil -m cp $file $EVAL_DATA_PATH/


Run the following commands. You should see a number of .tfrec files in your GCS bucket at both gs://{BUCKET}/flowers/data/training and gs://{BUCKET}/flowers/data/validation

In [None]:
# List the contents (with file sizes) of the TRAINING_DATA_PATH directory in GCS
!gsutil ls -l $TRAINING_DATA_PATH


In [None]:
# List the contents (with file sizes) of the EVAL_DATA_PATH directory in GCS
!gsutil ls -l $EVAL_DATA_PATH


### Create ingest functions and visualize some of the examples
Define and execute helper functions to plot the images and corresponding labels.

In [None]:
"""
In this cell, we define constants and functions to load image data from TFRecords,
convert it to a usable format, and visualize samples from the dataset. 
"""

import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt

# Image dimensions expected by the model or visualization
IMAGE_SIZE = [192, 192]

# Number of samples in each training batch
BATCH_SIZE = 32

# Do not change; these correspond to the labels in the dataset
CLASSES = [
    "daisy",
    "dandelion",
    "roses",
    "sunflowers",
    "tulips",
]


def read_tfrecord(example):
    """
    Parses a single TFRecord example into an image and a one-hot label.
    """
    # Define the expected data structure in each TFRecord
    features = {
        "image": tf.io.FixedLenFeature([], tf.string),    # The image is stored as a string (bytes)
        "class": tf.io.FixedLenFeature([], tf.int64),     # An integer class index (not used directly here)
        "one_hot_class": tf.io.VarLenFeature(tf.float32), # A sparse representation of the one-hot label
    }

    # Parse the input 'example' based on the specified features
    example = tf.io.parse_single_example(example, features)

    # Decode the JPEG-encoded image into a tensor with shape [height, width, channels]
    image = tf.image.decode_jpeg(example["image"], channels=3)

    # Convert the image data to float32 in the [0,1] range
    image = tf.cast(image, tf.float32) / 255.0

    # Reshape the image to the specified IMAGE_SIZE plus 3 channels
    image = tf.reshape(image, [*IMAGE_SIZE, 3])

    # Convert the sparse representation of the one-hot class label to a dense tensor
    one_hot_class = tf.sparse.to_dense(example["one_hot_class"])

    # Reshape the one-hot label to have size 5, corresponding to the five flower classes
    one_hot_class = tf.reshape(one_hot_class, [5])

    return image, one_hot_class


def load_dataset(gcs_pattern):
    """
    Loads TFRecord files from the given GCS pattern into a tf.data.Dataset,
    applying 'read_tfrecord' to each example.
    """
    # Gather all files matching the provided pattern
    filenames = tf.io.gfile.glob(gcs_pattern + "/*")

    # Create a TFRecordDataset and map each record through 'read_tfrecord'
    ds = tf.data.TFRecordDataset(filenames).map(read_tfrecord)
    return ds


def dataset_to_numpy(dataset, num_samples):
    """
    Converts the first 'num_samples' examples of a tf.data.Dataset 
    into NumPy arrays for both images and labels.
    """
    numpy_images = []
    numpy_labels = []

    # Take 'num_samples' examples from the dataset
    for images, labels in dataset.take(num_samples):
        # Convert each tensor to a NumPy array and store it
        numpy_images.append(images.numpy())
        numpy_labels.append(labels.numpy())

    return numpy_images, numpy_labels


def display_one_image(image, title, subplot):
    """
    Displays a single image within a particular subplot index.
    """
    plt.subplot(subplot)
    plt.axis("off")  # Hide the axis lines and labels
    plt.imshow(image)  # Render the image
    plt.title(title, fontsize=16)
    return subplot + 1


def display_9_images_from_dataset(dataset):
    """
    Retrieves and displays 9 images from the dataset along with their class labels.
    """
    subplot = 331  # This corresponds to a 3x3 grid starting index
    plt.figure(figsize=(13, 13))

    # Convert 9 examples from the dataset into NumPy arrays
    images, labels = dataset_to_numpy(dataset, 9)

    # Iterate over the retrieved images and labels
    for i, image in enumerate(images):
        # Identify the class by finding the index of the max value in the one-hot vector
        title = CLASSES[np.argmax(labels[i], axis=-1)]
        subplot = display_one_image(image, title, subplot)

        # Stop after 9 images
        if i >= 8:
            break

    # Adjust layout to prevent overlapping titles or images
    plt.tight_layout()
    plt.subplots_adjust(wspace=0.1, hspace=0.1)
    plt.show()


# Create a dataset from the specified training GCS path
ds = load_dataset(gcs_pattern=TRAINING_DATA_PATH)

# Display 9 images from the dataset along with their labels
display_9_images_from_dataset(ds)


## Build training pipeline
In this section you will build an application with keras to train an image classification model on Vertex AI Custom Training. 

Create a directory for the training application and an __ init __.py file (this is required for a Python application but it can be empty).

In [None]:
%%bash
# In a Jupyter notebook cell, this script creates the directory structure
# flowers/trainer (including intermediate directories if necessary),
# and then creates an empty __init__.py file to mark it as a package.

mkdir -p flowers/trainer
touch flowers/trainer/__init__.py


### Create training application in train.py

This code contains the training logic. Here you build an application to ingest data from GCS and train an image classification model using [mobileNet](https://tfhub.dev/google/imagenet/mobilenet_v2_100_192/feature_vector/5) as a feature extractor, then sending it's output feature vector through a tf.keras.dense layer with 5 units and softmax activation (because there are 5 possible labels). Also, use the `fire` library which enables arguments to `train_and_evaluate` to be passed via the command line. 

In [None]:
%%writefile flowers/trainer/train.py
"""
This module defines a training script for a flower classification model using 
TensorFlow and TensorFlow Hub. It reads TFRecord files, builds a model 
(with MobileNet v2 as a feature extractor), trains the model, 
and saves the trained model to the specified output directory.

To run this script from the command line:
    python train.py \
        --train_data_path=gs://.../training/*.tfrec \
        --eval_data_path=gs://.../validation/*.tfrec \
        --output_dir=gs://.../model_output \
        --batch_size=32 \
        --num_epochs=1 \
        --train_examples=1000
"""

import datetime
import os

# 'fire' lets us create a CLI easily from this Python script,
# so we can pass arguments from the command line.
import fire  
import tensorflow as tf
import tensorflow_hub as hub

# Define the target image size for the model
IMAGE_SIZE = [192, 192]


def read_tfrecord(example):
    """
    Parses a single TFRecord example into an (image, label) pair.
    The label is a one-hot encoded vector of length 5, corresponding
    to the 5 flower classes.

    Args:
        example (tf.Tensor): A serialized TFRecord example.
    
    Returns:
        tuple: (image, one_hot_class)
            image (tf.Tensor): A float32 tensor of shape [192, 192, 3]
            one_hot_class (tf.Tensor): A float32 tensor of shape [5]
    """
    # Define the features to extract from the TFRecord
    features = {
        "image": tf.io.FixedLenFeature([], tf.string),  # JPEG bytes
        "class": tf.io.FixedLenFeature([], tf.int64),   # Not used directly here
        "one_hot_class": tf.io.VarLenFeature(tf.float32),  # Sparse representation of one-hot label
    }

    # Parse the example to a dictionary
    example = tf.io.parse_single_example(example, features)

    # Decode the JPEG bytes into an RGB image
    image = tf.image.decode_jpeg(example["image"], channels=3)
    # Normalize pixel values to [0, 1] range
    image = tf.cast(image, tf.float32) / 255.0
    # Reshape to the specified image size
    image = tf.reshape(image, [*IMAGE_SIZE, 3])

    # Convert the sparse label representation to dense
    one_hot_class = tf.sparse.to_dense(example["one_hot_class"])
    # Ensure the label is of shape [5]
    one_hot_class = tf.reshape(one_hot_class, [5])

    return image, one_hot_class


def load_dataset(gcs_pattern, batch_size=32, training=True):
    """
    Creates a tf.data.Dataset from TFRecords stored at gcs_pattern.
    Applies the read_tfrecord function to each record and batches the data.

    Args:
        gcs_pattern (str): The GCS path/pattern to the TFRecord files (e.g., 'gs://bucket/training/*.tfrec').
        batch_size (int): Number of samples per batch.
        training (bool): Whether this dataset is for training or evaluation (affects repetition).

    Returns:
        tf.data.Dataset: The dataset that yields (images, labels) in batches.
    """
    # Collect all TFRecord files matching the pattern
    filenames = tf.io.gfile.glob(gcs_pattern)

    # Create a Dataset from the TFRecord files
    ds = tf.data.TFRecordDataset(filenames).map(read_tfrecord).batch(batch_size)

    # For training, repeat indefinitely; for evaluation, do not repeat
    if training:
        return ds.repeat()
    else:
        return ds


def build_model():
    """
    Builds a Keras Sequential model using MobileNet v2 from TensorFlow Hub 
    as a feature extractor. The top layer is replaced with a Dense layer 
    for classifying 5 flower types.

    Returns:
        tf.keras.Model: A compiled Keras model ready for training.
    """
    # URL for a pre-trained MobileNet v2 feature extractor
    mobilenet_v2 = (
        "https://tfhub.dev/google/imagenet/"
        "mobilenet_v2_100_192/feature_vector/5"
    )

    # Create a KerasLayer that wraps the pretrained model from TF Hub
    feature_extractor_layer = hub.KerasLayer(
        mobilenet_v2,
        input_shape=[*IMAGE_SIZE, 3],  # Expected input shape
        trainable=False               # Do not train the weights of MobileNet v2
    )

    # Build a Sequential model with the feature extractor plus a final Dense layer
    model = tf.keras.Sequential([
        feature_extractor_layer,
        tf.keras.layers.Dense(5, activation="softmax")  # 5 flower classes
    ])

    # Compile the model with Adam optimizer and categorical cross-entropy loss
    model.compile(
        optimizer="adam",
        loss="categorical_crossentropy",
        metrics=["accuracy"]
    )

    return model


def train_and_evaluate(
    train_data_path,
    eval_data_path,
    output_dir,
    batch_size,
    num_epochs,
    train_examples
):
    """
    Trains the model and evaluates it, then saves the trained model to GCS.

    Args:
        train_data_path (str): GCS path pattern for training TFRecord files (e.g., 'gs://bucket/training/*.tfrec').
        eval_data_path (str): GCS path pattern for evaluation TFRecord files (e.g., 'gs://bucket/validation/*.tfrec').
        output_dir (str): Directory path (GCS or local) for saving the trained model.
        batch_size (int): Number of samples per batch.
        num_epochs (int): Number of epochs to train.
        train_examples (int): Total number of training examples (used to calculate steps_per_epoch).
    """
    # Build the model (MobileNet v2 feature extractor + Dense output layer)
    model = build_model()

    # Prepare the training dataset (repeats indefinitely)
    train_ds = load_dataset(
        gcs_pattern=train_data_path,
        batch_size=batch_size,
        training=True
    )

    # Prepare the evaluation dataset (no repetition)
    eval_ds = load_dataset(
        gcs_pattern=eval_data_path,
        batch_size=batch_size,
        training=False
    )

    # Calculate steps_per_epoch to tell the model how many batches to consume per epoch
    # Here, we ensure that every example in 'train_examples' is covered each epoch.
    # The total number of steps is (train_examples / batch_size).
    # But we also must ensure each step processes a full batch, so we compute as integer.
    steps_per_epoch = train_examples // batch_size

    # Train the model
    history = model.fit(
        train_ds,
        validation_data=eval_ds,
        epochs=num_epochs,
        steps_per_epoch=steps_per_epoch,
        verbose=2  # 0=silent, 1=progress bar, 2=one line per epoch
    )

    # Save the model to the specified output directory (e.g., GCS path)
    model.save(output_dir)

    print(f"Exported trained model to {output_dir}")


if __name__ == "__main__":
    # Use the Fire library to create a simple CLI for train_and_evaluate
    fire.Fire(train_and_evaluate)


### Test training application locally 
It's always a good idea to test out a training application locally (with only a few training steps) to make sure the code runs as expected.

In [None]:
%%bash
# Define a local directory for storing the trained model (and remove it if it exists)
OUTDIR_LOCAL=local_test_training
rm -rf "${OUTDIR_LOCAL}"

# Update PYTHONPATH so that Python can find the 'trainer' module under 'flowers/trainer'
export PYTHONPATH="${PYTHONPATH}:${PWD}/flowers"

# Run the training script as a module. Provide paths for training and evaluation,
# specify where to save outputs, and set training parameters:
# --batch_size=1  (for demonstration),
# --num_epochs=1  (only one epoch),
# --train_examples=10 (use a small subset of examples).
python3 -m trainer.train \
    --train_data_path="gs://${BUCKET}/flowers/data/training/*.tfrec" \
    --eval_data_path="gs://${BUCKET}/flowers/data/validation/*.tfrec" \
    --output_dir="${OUTDIR_LOCAL}" \
    --batch_size=1 \
    --num_epochs=1 \
    --train_examples=10


## Package code as source distribution
Now that you have validated your model training code, we need to package our code as a source distribution in order to submit a custom training job to Vertex AI.

In [None]:
%%writefile flowers/setup.py
"""
This setup script defines how to package the 'flowers_trainer' application for 
installation or distribution. It uses setuptools to identify all included packages 
and specify package metadata and dependencies.
"""

from setuptools import find_packages
from setuptools import setup

# Use the setup() function to define package distribution parameters
setup(
    name='flowers_trainer',         # Package name
    version='0.1',                  # Version number
    packages=find_packages(),       # Automatically find all packages in the current directory
    include_package_data=True,      # Include any package data files specified in MANIFEST.in
    install_requires=[              # List of dependencies to install along with this package
        'fire==0.4.0', 
        'tensorflow-hub==0.12.0'
    ],
    description='Flowers image classifier training application.'  # Short description of the package
)


In [None]:
%%bash
# This cell changes the current directory to 'flowers', where 'setup.py' resides,
# creates a source distribution (in .tar.gz format), and then returns to the 
# previous directory.
cd flowers
python ./setup.py sdist --formats=gztar
cd ..


Store the package in GCS

In [None]:
%%bash
# Copy the distribution package (flowers_trainer-0.1.tar.gz)
# from the local 'flowers/dist' directory to the specified Google Cloud Storage bucket.
gsutil cp flowers/dist/flowers_trainer-0.1.tar.gz gs://${BUCKET}/flowers/


To submit to the Cloud we use [`gcloud custom-jobs create`](https://cloud.google.com/sdk/gcloud/reference/ai/custom-jobs/create) and simply specify some additional parameters for the Vertex AI Training Service:
- display-name: A unique identifier for the Cloud job. We usually append system time to ensure uniqueness
- region: Cloud region to train in. See [here](https://cloud.google.com/vertex-ai/docs/general/locations) for supported Vertex AI Training Service regions

You might have earlier seen `gcloud ai custom-jobs create` executed with the `worker pool spec` and pass-through Python arguments specified directly in the command call, here we will use a YAML file, this will make it easier to transition to hyperparameter tuning.

Through the `args:` argument we add in the passed-through arguments for our `task.py` file.

In [None]:
%%bash
# In this cell, we define a custom training job configuration for Google Cloud AI Platform (Vertex AI).
# We specify the Docker image for the training runtime, the Python package that contains our training code,
# and various hyperparameters such as the number of epochs, the batch size, etc. 
# Then, we submit the job to Vertex AI using the 'gcloud ai custom-jobs create' command.

JOB_NAME=flowers_${TIMESTAMP}

# The URI where the Python package (containing our training code) is stored in GCS
PYTHON_PACKAGE_URI=gs://${BUCKET}/flowers/flowers_trainer-0.1.tar.gz

# The container image to use for executing the Python package
PYTHON_PACKAGE_EXECUTOR_IMAGE_URI="us-docker.pkg.dev/vertex-ai/training/tf-cpu.2-11.py310:latest"

# The entry point module in our package that starts the training
PYTHON_MODULE=trainer.train

# Create a YAML file that defines the worker pool specification 
# (including machine type, Python package spec, and command-line arguments for training).
echo > ./config.yaml \
"workerPoolSpecs:
  machineSpec:
    machineType: n1-standard-8
  replicaCount: 1
  pythonPackageSpec:
    executorImageUri: $PYTHON_PACKAGE_EXECUTOR_IMAGE_URI
    packageUris: $PYTHON_PACKAGE_URI
    pythonModule: $PYTHON_MODULE
    args:
    - --train_data_path=gs://${BUCKET}/flowers/data/training/*.tfrec
    - --eval_data_path=gs://${BUCKET}/flowers/data/validation/*.tfrec
    - --output_dir=$OUTDIR
    - --num_epochs=15
    - --train_examples=15000
    - --batch_size=32
"

# Submit the custom training job to Vertex AI, specifying the region and the config.yaml file.
gcloud ai custom-jobs create \
  --region=${REGION} \
  --display-name=$JOB_NAME \
  --config=config.yaml


**NOTE** Model training will take 5 minutes or so. You have to wait for training to finish before moving forward.

### Serving function for image data

To pass images to the prediction service, you encode the compressed (e.g., JPEG) image bytes into base 64 -- which makes the content safe from modification while transmitting binary data over the network. Since this deployed model expects input data as raw (uncompressed) bytes, you need to ensure that the base 64 encoded data gets converted back to raw bytes before it is passed as input to the deployed model.

To resolve this, define a serving function (`serving_fn`) and attach it to the model as a preprocessing step. Add a `@tf.function` decorator so the serving function is fused to the underlying model (instead of upstream on a CPU).

When you send a prediction or explanation request, the content of the request is base 64 decoded into a Tensorflow string (`tf.string`), which is passed to the serving function (`serving_fn`). The serving function preprocesses the `tf.string` into raw (uncompressed) numpy bytes (`preprocess_fn`) to match the input requirements of the model:
- `io.decode_jpeg`- Decompresses the JPG image which is returned as a Tensorflow tensor with three channels (RGB).
- `image.convert_image_dtype` - Changes integer pixel values to float 32.
- `image.resize` - Resizes the image to match the input shape for the model.
- `resized / 255.0` - Rescales (normalization) the pixel data between 0 and 1.

At this point, the data can be passed to the model (`m_call`).

#### XAI Signatures

When the serving function is saved back with the underlying model (`tf.saved_model.save`), you specify the input layer of the serving function as the signature `serving_default`.

For XAI image models, you need to save two additional signatures from the serving function:

- `xai_preprocess`: The preprocessing function in the serving function.
- `xai_model`: The concrete function for calling the model.

Load the model into memory. **NOTE** This directory will not exist if your model has not finished training. Please wait for training to complete before moving forward

In [None]:
# Load the trained Keras model from the specified output directory (OUTDIR).
# This could be a local path or a GCS path, depending on how the model was saved.
local_model = tf.keras.models.load_model(OUTDIR)


In [None]:
# Print a summary of the loaded model, showing each layer’s name, 
# output shape, and number of parameters.
local_model.summary()


In [None]:
# This constant defines the name of the tensor input key expected by the model.
CONCRETE_INPUT = "numpy_inputs"

def _preprocess(bytes_input):
    """
    Decodes a single JPEG image, converts it to float32 in [0,1],
    and resizes it to 192x192.
    """
    # Decode the JPEG bytes into an RGB image
    decoded = tf.io.decode_jpeg(bytes_input, channels=3)
    # Convert the pixel values from int range [0,255] to float range [0,1]
    decoded = tf.image.convert_image_dtype(decoded, tf.float32)
    # Resize to the model's required dimensions (192 x 192)
    resized = tf.image.resize(decoded, size=(192, 192))
    return resized


@tf.function(input_signature=[tf.TensorSpec([None], tf.string)])
def preprocess_fn(bytes_inputs):
    """
    Applies _preprocess to each image in the input batch.
    Returns a dictionary where the key matches the model's input name.
    
    Args:
        bytes_inputs (tf.Tensor): A batch of encoded images as byte strings.
    
    Returns:
        dict: Contains a single key/value pair:
              { CONCRETE_INPUT: resized_images }.
    """
    # Map over the batch of byte-string inputs, applying _preprocess to each
    decoded_images = tf.map_fn(
        _preprocess, 
        bytes_inputs, 
        dtype=tf.float32, 
        back_prop=False
    )
    # Return a dict whose key matches the model's expected input name
    return {CONCRETE_INPUT: decoded_images}


@tf.function(input_signature=[tf.TensorSpec([None], tf.string)])
def serving_fn(bytes_inputs):
    """
    The main serving function that decodes and preprocesses inputs,
    then runs them through the model to obtain output probabilities.
    
    Args:
        bytes_inputs (tf.Tensor): A batch of encoded images as byte strings.
    
    Returns:
        tf.Tensor: Model predictions (probabilities) for each image.
    """
    # Preprocess the raw byte inputs to the format expected by the model
    images = preprocess_fn(bytes_inputs)
    # Call the model, which returns probabilities
    prob = m_call(**images)
    return prob


# Convert the model's call function to a ConcreteFunction, specifying the expected input shape.
m_call = tf.function(local_model.call).get_concrete_function(
    [
        tf.TensorSpec(
            shape=[None, 192, 192, 3], 
            dtype=tf.float32, 
            name=CONCRETE_INPUT
        )
    ]
)

# Save the model in the TensorFlow SavedModel format, including the preprocessing
# and serving signatures. "xai_preprocess" and "xai_model" are included for
# Explainable AI (XAI) integration.
tf.saved_model.save(
    local_model,
    OUTDIR,
    signatures={
        "serving_default": serving_fn,  # Default inference entry point
        "xai_preprocess": preprocess_fn,  # Preprocessing function for XAI
        "xai_model": m_call,            # Model function for XAI
    },
)


## Get the serving function signature

You can get the signatures of your model's input and output layers by reloading the model into memory, and querying it for the signatures corresponding to each layer.

When making a prediction request, you need to route the request to the serving function instead of the model, so you need to know the input layer name of the serving function -- which you will use later when you make a prediction request.

You also need to know the name of the serving function's input and output layer for constructing the explanation metadata -- which is discussed subsequently.

In [None]:
# Load the saved model from the specified directory (OUTDIR).
loaded = tf.saved_model.load(OUTDIR)

# Obtain the name of the serving function's input by looking at
# the signatures stored in the loaded model. 
# 'structured_input_signature[1]' is the argument specification, 
# and 'keys()' returns the dictionary keys for the input placeholders.
serving_input = list(
    loaded.signatures["serving_default"].structured_input_signature[1].keys()
)[0]
print("Serving function input:", serving_input)

# Similarly, find the serving function's output name from 'structured_outputs'
serving_output = list(
    loaded.signatures["serving_default"].structured_outputs.keys()
)[0]
print("Serving function output:", serving_output)

# Retrieve the input name for our locally loaded model 
# (the name of the first layer or tensor in the model graph).
input_name = local_model.input.name
print("Model input name:", input_name)

# Retrieve the output name for our locally loaded model
# (the name of the final layer in the model graph).
output_name = local_model.output.name
print("Model output name:", output_name)


In [None]:
# Create explanation parameters for a Vertex AI model deployment.
# Here, we're configuring the explanation method to be "integrated gradients"
# and specifying that it should use 50 steps in the integration process.
parameters = aiplatform.explain.ExplanationParameters(
    {
        "integrated_gradients_attribution": {
            "step_count": 50
        }
    }
)


## Upload the model

Next, upload your model to a `Model` resource using `Model.upload()` method, with the following parameters:

- `display_name`: The human readable name for the `Model` resource.
- `artifact`: The Cloud Storage location of the trained model artifacts.
- `serving_container_image_uri`: The serving container image.
- `sync`: Whether to execute the upload asynchronously or synchronously.
- `explanation_parameters`: Parameters to configure explaining for `Model`'s predictions.
- `explanation_metadata`: Metadata describing the `Model`'s input and output for explanation.

If the `upload()` method is run asynchronously, you can subsequently block until completion with the `wait()` method.

In [None]:
# This cell sets up configuration for a model name and metadata describing how the
# model expects input and output. Then it defines metadata objects for explanation
# inputs and outputs, which Vertex AI can use to apply integrated gradients or other
# explanation methods to the model.

MODEL_NAME = "flower_classifier_v1"

# The input metadata dictionary specifies the model's input tensor name
# and the data modality (in this case, an image).
INPUT_METADATA = {
    "input_tensor_name": CONCRETE_INPUT,
    "modality": "image"
}

# The output metadata dictionary specifies the name of the model's output tensor.
OUTPUT_METADATA = {
    "output_tensor_name": serving_output
}

# ExplanationMetadata.InputMetadata and OutputMetadata help the Vertex AI
# explanation service understand how to pass inputs to the model and interpret
# its outputs for providing explanations.
input_metadata = aiplatform.explain.ExplanationMetadata.InputMetadata(INPUT_METADATA)
output_metadata = aiplatform.explain.ExplanationMetadata.OutputMetadata(OUTPUT_METADATA)

# Create the ExplanationMetadata object with the defined input and output metadata.
metadata = aiplatform.explain.ExplanationMetadata(
    inputs={"image": input_metadata},
    outputs={"class": output_metadata}
)


**NOTE** This can take a few minutes to run. 

In [None]:
# Initialize Vertex AI with the given project and staging bucket.
aiplatform.init(project=PROJECT, staging_bucket=BUCKET)

# Upload a model to Vertex AI using the provided model artifacts and configuration.
#  - display_name: Name to display in the Vertex AI console.
#  - artifact_uri: Location of the SavedModel directory in GCS.
#  - serving_container_image_uri: Docker image to use for model serving.
#  - explanation_parameters & explanation_metadata: Configuration for AI Explanations.
#  - sync=False: Makes the upload asynchronous, allowing the code to continue running while
#    the upload operation proceeds in the background.

model = aiplatform.Model.upload(
    display_name=MODEL_NAME,
    artifact_uri=OUTDIR,
    serving_container_image_uri="us-docker.pkg.dev/vertex-ai/prediction/tf2-cpu.2-11:latest",
    explanation_parameters=parameters,
    explanation_metadata=metadata,
    sync=False,
)

# Wait for the asynchronous upload operation to complete.
model.wait()


## Deploy the model

Next, deploy your model for online prediction. To deploy the model, you invoke the `deploy` method, with the following parameters:

- `deployed_model_display_name`: A human readable name for the deployed model.
- `traffic_split`: Percent of traffic at the endpoint that goes to this model, which is specified as a dictionary of one or more key/value pairs.
If only one model, then specify as { "0": 100 }, where "0" refers to this model being uploaded and 100 means 100% of the traffic.
If there are existing models on the endpoint, for which the traffic will be split, then use model_id to specify as { "0": percent, model_id: percent, ... }, where model_id is the model id of an existing model to the deployed endpoint. The percents must add up to 100.
- `machine_type`: The type of machine to use for training.
- `max_replica_count`: The maximum number of compute instances to scale to. In this tutorial, only one instance is provisioned.

**NOTE** This can take a few minutes.  

In [None]:
# Deploy the uploaded model to an endpoint in Vertex AI. This will create a serving
# endpoint that you can send prediction requests to.
endpoint = model.deploy(
    deployed_model_display_name=MODEL_NAME,  # Name used to identify the deployed model
    traffic_split={"0": 100},               # Allocate 100% of the traffic to this model
    machine_type="n1-standard-4",           # VM configuration for serving
    min_replica_count=1,                    # Minimum number of model server replicas
    max_replica_count=1,                    # Maximum number of model server replicas
)


### Prepare the request content
You are going to send the flower image as compressed JPG image, instead of the raw uncompressed bytes:

- `mpimg.imsave`: Write the uncompressed image to disk as a compressed JPEG image.
- `tf.io.read_file`: Read the compressed JPG images back into memory as raw bytes.
- `base64.b64encode`: Encode the raw bytes into a base 64 encoded string.

In [None]:
# Load the evaluation dataset from the EVAL_DATA_PATH directory
eval_ds = load_dataset(EVAL_DATA_PATH)

# Extract 5 samples from the evaluation dataset into NumPy arrays for easy inspection
x_test, y_test = dataset_to_numpy(eval_ds, 5)

# Pick one image from the eval dataset
test_image = x_test[0]

# Save this image as a JPEG file on the local disk for reference
mpimg.imsave("tmp.jpg", test_image)


Read the JPG image and encode it with base64 to send to the model endpoint. Send the encoded image to the endpoint with `endpoint.explain`. Then you can parse the response for the prediction and explanation. Full documentation on endpoint.explain can be found [here](https://cloud.google.com/vertex-ai/docs/reference/rest/v1beta1/projects.locations.endpoints/explain).

In [None]:
# Read the local JPEG file (tmp.jpg) and convert it into a bytes tensor
bytes_data = tf.io.read_file("tmp.jpg")

# Convert the tensor of bytes into a Python bytes object, then encode it in Base64.
# Decode it again to a string so that it can be JSON-serializable.
b64str = base64.b64encode(bytes_data.numpy()).decode("utf-8")

# According to the Vertex AI prediction/explanation API, instances are passed
# as a list of dictionaries. Each dictionary has the model input name set to
# another dictionary containing the Base64-encoded string with the key "b64".
instances_list = [{serving_input: {"b64": b64str}}]

# Send the instances list to the deployed endpoint and request an explanation.
# This method returns both the prediction results and explanations.
response = endpoint.explain(instances_list)

# Print the full response, including predictions and explanation metadata.
print(response)


### Visualize feature attributions from Integrated Gradients. 
Query the response to get predictions and feature attributions. Use Matplotlib to visualize.

In [None]:
"""
This cell demonstrates how to parse the prediction and explanation
(response) returned by the Vertex AI endpoint, then visualize both
the original image and its explanation mask (feature attributions).

We do the following:
1. Identify the predicted class by finding the maximum confidence.
2. Decode the original image from Base64.
3. Decode the explanation mask from Base64.
4. Plot the original image, the explanation mask, and an overlay of the mask
    on the original image.
"""

import io
from io import BytesIO

import base64  # Needed for decoding the explanation mask and original image
import matplotlib.image as mpimg
import matplotlib.pyplot as plt
import numpy as np

# Classes correspond to: daisy, dandelion, roses, sunflowers, tulips
CLASSES = [
    "daisy",
    "dandelion",
    "roses",
    "sunflowers",
    "tulips",
]

# Loop through each prediction in the endpoint's response
for prediction in response.predictions:
    # Identify the predicted class by picking the index with the highest confidence score
    label_index = np.argmax(prediction)
    class_name = CLASSES[label_index]
    confidence_score = prediction[label_index]

    # Print out the predicted class and confidence
    print(f"Predicted class: {class_name}")
    print(f"Confidence score: {confidence_score}")

    # Decode the original image (already stored as a Base64 string in b64str)
    image_bytes = base64.b64decode(b64str)
    image_file = BytesIO(image_bytes)
    img = mpimg.imread(image_file, format="JPG")  # Load as an RGB array

# Loop through each explanation in the endpoint's response
for explanation in response.explanations:
    # Feature attributions are stored in explanation.attributions[0].feature_attributions
    attributions = dict(explanation.attributions[0].feature_attributions)

    # Identify which class this explanation refers to
    xai_label_index = explanation.attributions[0].output_index[0]
    xai_class_name = CLASSES[xai_label_index]

    # The explanation mask is stored under "b64_jpeg" in the "attributions" dictionary
    xai_b64str = attributions["image"]["b64_jpeg"]

    # Decode the explanation mask into a viewable image
    xai_image_bytes = base64.b64decode(xai_b64str)
    xai_image_file = io.BytesIO(xai_image_bytes)
    xai_img = mpimg.imread(xai_image_file, format="JPG")

# Plot the original image, the feature attribution mask, and an overlay
fig = plt.figure(figsize=(13, 18))

# Plot the original image
fig.add_subplot(1, 3, 1)
plt.title("Input Image")
plt.imshow(img)
plt.axis("off")

# Plot the feature attribution mask
fig.add_subplot(1, 3, 2)
plt.title("Feature Attribution Mask")
plt.imshow(xai_img)
plt.axis("off")

# Plot the overlay of mask on the original image
fig.add_subplot(1, 3, 3)
plt.title("Overlayed Attribution Mask")
plt.imshow(img)
plt.imshow(xai_img, alpha=0.6)  # Semi-transparent overlay
plt.axis("off")

plt.show()


## What's next?

To learn more about AI Explanations, check out the resources here.

* [Vertex AI Explanations documentation](https://cloud.google.com/vertex-ai/docs/explainable-ai)
* [Integrated gradients paper](https://arxiv.org/abs/1703.01365)
* [XRAI paper](https://arxiv.org/abs/1906.02825)