## Setup

In [None]:
import multiprocessing

import tensorflow
print("TensorFlow version:", tensorflow.__version__)

import tensorflow.compat.v2 as tf
import tensorflow_datasets as tfds

# Pets Classification with TensorFlow

In [None]:
# enable XLA
tf.config.optimizer.set_jit(True)

# enable AMP
#tf.keras.mixed_precision.experimental.set_policy('mixed_float16')

In [None]:
import tensorflow.keras.layers as layers
from tensorflow.keras.applications.resnet50 import ResNet50

def create_model(img_size=(224,224), num_class=2, train_base=True):
    # accept float16 image inputs
    input_layer = layers.Input(shape=(img_size[0],img_size[1],3), dtype=tf.float16)
    base = ResNet50(input_tensor=input_layer,
                    include_top=False,
                    weights="imagenet")
    base.trainable = train_base
    x = base.output
    x = layers.GlobalAveragePooling2D()(x)
    # softmax only accepts float32 - need to manually cast (likely a bug)
    preds = layers.Dense(num_class, activation="softmax", dtype=tf.float32)(x)
    return tf.keras.models.Model(inputs=input_layer, outputs=preds)

In [None]:
(train_dataset, test_dataset), info = tfds.load(name="oxford_iiit_pet:3.*.*",
                                                split=["train", "test"],
                                                shuffle_files=True,
                                                as_supervised=True,
                                                with_info=True)

num_class = info.features["label"].num_classes
num_train = info.splits["train"].num_examples
num_test  = info.splits["test"].num_examples

In [None]:
IMG_SIZE = (224, 224)

@tf.function
def format_train_example(image, label):
    image = tf.cast(image, tf.float32)
    image = (image/127.5) - 1
    image = tf.image.resize(image, IMG_SIZE)
    # perform image augmentation with tf.image
    image = tf.image.flip_left_right(image)
    image = tf.image.random_brightness(image, 0.1)
    # return images as float16
    image = tf.cast(image, tf.float16)
    return image, tf.one_hot(label, num_class)

@tf.function
def format_eval_example(image, label):
    image = tf.cast(image, tf.float32)
    image = (image/127.5) - 1
    image = tf.image.resize(image, IMG_SIZE)
    # return images as float16
    image = tf.cast(image, tf.float16)
    return image, tf.one_hot(label, num_class)

In [None]:
BATCH_SIZE = 80
N_THREADS = multiprocessing.cpu_count()
PREFETCH_COUNT = 8

train_dataset = train_dataset.shuffle(1024)
train_dataset = train_dataset.repeat(-1)
train_dataset = train_dataset.map(format_train_example,
                                  num_parallel_calls=N_THREADS)
train_dataset = train_dataset.batch(BATCH_SIZE)
train_dataset = train_dataset.prefetch(PREFETCH_COUNT)

In [None]:
test_dataset = test_dataset.map(format_eval_example,
                                num_parallel_calls=N_THREADS)
test_dataset = test_dataset.repeat(-1)
test_dataset = test_dataset.batch(BATCH_SIZE)

In [None]:
model = create_model(IMG_SIZE, num_class, train_base=True)
opt = tf.keras.optimizers.Adam()

model.compile(loss="categorical_crossentropy",
              optimizer=opt,
              metrics=["acc"])

#model.summary()

In [None]:
steps_per_epoch = num_train//BATCH_SIZE
steps_test = num_test//BATCH_SIZE

In [None]:
model.fit(train_dataset, steps_per_epoch=steps_per_epoch,
          epochs=3, verbose=1)

In [None]:
from tensorflow.python.eager import profiler
from tensorflow.core.protobuf import trace_events_pb2
from google.protobuf.json_format import MessageToDict

In [None]:
profiler.start()
model.fit(train_dataset, steps_per_epoch=3, epochs=1, verbose=2)
model_profile = profiler.stop()

profile_pb = trace_events_pb2.Trace()
profile_pb.ParseFromString(model_profile)

profile_dict = MessageToDict(profile_pb)

timing_dict = {
    "hmma": 0,
    "hmma_events": [],
    "other_fp16": 0,
    "other_fp16_events": [],
    "sgemm": 0,
    "sgemm_events": [],
    "copy": 0,
    "copy_events": [],
    "xla": 0,
    "xla_events": [],
    "others": 0,
    "others_events": [],
    "total": 0
}

for event in profile_dict["traceEvents"]:
    try:
        device_id = int(event["deviceId"])
        event_timestamp = int(event["timestampPs"])
        if device_id == 1:
            event_name = event["name"].lower()
            event_time = int(event["durationPs"])
            # tensor core (HMMA) events
            if "hmma" in event_name or "884" in event_name and event_name != "fusion_884":
                timing_dict["hmma"] += event_time
                if event_name not in timing_dict["hmma_events"]:
                    timing_dict["hmma_events"].append(event_name)
            # FP16 events
            elif "fp16" in event_name:
                timing_dict["other_fp16"] += event_time
                if event_name not in timing_dict["other_fp16_events"]:
                    timing_dict["other_fp16_events"].append(event_name)
            # FP32 GEMM events
            elif "sgemm" in event_name:
                timing_dict["sgemm"] += event_time
                if event_name not in timing_dict["sgemm_events"]:
                    timing_dict["sgemm_events"].append(event_name)
            # Transfer events
            elif "copy" in event_name or "cpy" in event_name:
                timing_dict["copy"] += event_time
                if event_name not in timing_dict["copy_events"]:
                    timing_dict["copy_events"].append(event_name)
            # XLA fusion ops
            elif "fusion" in event_name:
                timing_dict["xla"] += event_time
                if event_name not in timing_dict["xla_events"]:
                    timing_dict["xla_events"].append(event_name)
            # all other events
            else:
                timing_dict["others"] += event_time
                if event_name not in timing_dict["others_events"]:
                    timing_dict["others_events"].append(event_name)
            timing_dict["total"] += event_time
    except Exception as e:
        pass
    
print("= type (num_type) % time =")
# consider compute time only
total = timing_dict["total"] - timing_dict["copy"] - timing_dict["others"]
print("- hmma (", len(timing_dict["hmma_events"]), ")\t", round(timing_dict["hmma"]/total*100, 1))
print("- fp16 (", len(timing_dict["other_fp16_events"]), ")\t", round(timing_dict["other_fp16"]/total*100, 1))
print("- sgemm (", len(timing_dict["sgemm_events"]), ")\t", round(timing_dict["sgemm"]/total*100, 1))
#print("- copy (", len(timing_dict["copy_events"]), ")\t", round(timing_dict["copy"]/total*100, 1))
print("- xla (", len(timing_dict["xla_events"]), ")\t", round(timing_dict["xla"]/total*100, 1))
#print("- others (", len(timing_dict["others_events"]), ")\t", round(timing_dict["others"]/total*100, 1))
print("Total time:", total)