## Setup

In [None]:
import tensorflow.compat.v2 as tf
import tensorflow_datasets

# Sequence Classification with BERT in TF 2.0

In [None]:
# enable XLA
tf.config.optimizer.set_jit(True)

# enable AMP via tf.config
tf.config.optimizer.set_experimental_options({"auto_mixed_precision": True})

## Load BERT Tokenizer

In [None]:
from transformers import BertTokenizer

tokenizer = BertTokenizer.from_pretrained("bert-base-cased")

## Input Pipeline

### Load Dataset

In [None]:
data, info = tensorflow_datasets.load("glue/mrpc", with_info=True)

train_examples = info.splits["train"].num_examples
valid_examples = info.splits["validation"].num_examples

## Build Input Pipeline

In [None]:
from transformers import glue_convert_examples_to_features

BATCH_SIZE = 40

# Prepare dataset for GLUE as a tf.data.Dataset instance
train_dataset = glue_convert_examples_to_features(data["train"], tokenizer, 128, "mrpc")
train_dataset = train_dataset.repeat(-1).shuffle(512).batch(BATCH_SIZE).prefetch(8)

## Build BERT Model

### Load Pre-trained BERT Model

In [None]:
from transformers import TFBertForSequenceClassification

model = TFBertForSequenceClassification.from_pretrained("bert-base-cased")

In [None]:
opt = tf.keras.optimizers.Adam(learning_rate=3e-5)
# do loss scaling for optimizer
opt = tf.keras.mixed_precision.experimental.LossScaleOptimizer(opt, "dynamic")

loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
acc = tf.keras.metrics.SparseCategoricalAccuracy("accuracy")
model.compile(optimizer=opt,
              loss=loss,
              metrics=[acc])

## Train BERT Model

In [None]:
history = model.fit(train_dataset, epochs=2, steps_per_epoch=train_examples//BATCH_SIZE)

In [None]:
from tensorflow.python.eager import profiler
from tensorflow.core.protobuf import trace_events_pb2
from google.protobuf.json_format import MessageToDict

In [None]:
profiler.start()
model.fit(train_dataset, steps_per_epoch=3, epochs=1, verbose=2)
model_profile = profiler.stop()

profile_pb = trace_events_pb2.Trace()
profile_pb.ParseFromString(model_profile)

profile_dict = MessageToDict(profile_pb)

timing_dict = {
    "hmma": 0,
    "hmma_events": [],
    "other_fp16": 0,
    "other_fp16_events": [],
    "sgemm": 0,
    "sgemm_events": [],
    "copy": 0,
    "copy_events": [],
    "xla": 0,
    "xla_events": [],
    "others": 0,
    "others_events": [],
    "total": 0
}

for event in profile_dict["traceEvents"]:
    try:
        device_id = int(event["deviceId"])
        event_timestamp = int(event["timestampPs"])
        if device_id == 1:
            event_name = event["name"].lower()
            event_time = int(event["durationPs"])
            # tensor core (HMMA) events
            if "hmma" in event_name or "884" in event_name and event_name != "fusion_884":
                timing_dict["hmma"] += event_time
                if event_name not in timing_dict["hmma_events"]:
                    timing_dict["hmma_events"].append(event_name)
            # FP16 events
            elif "fp16" in event_name:
                timing_dict["other_fp16"] += event_time
                if event_name not in timing_dict["other_fp16_events"]:
                    timing_dict["other_fp16_events"].append(event_name)
            # FP32 GEMM events
            elif "sgemm" in event_name:
                timing_dict["sgemm"] += event_time
                if event_name not in timing_dict["sgemm_events"]:
                    timing_dict["sgemm_events"].append(event_name)
            # Transfer events
            elif "copy" in event_name or "cpy" in event_name:
                timing_dict["copy"] += event_time
                if event_name not in timing_dict["copy_events"]:
                    timing_dict["copy_events"].append(event_name)
            # XLA fusion ops
            elif "fusion" in event_name:
                timing_dict["xla"] += event_time
                if event_name not in timing_dict["xla_events"]:
                    timing_dict["xla_events"].append(event_name)
            # all other events
            else:
                timing_dict["others"] += event_time
                if event_name not in timing_dict["others_events"]:
                    timing_dict["others_events"].append(event_name)
            timing_dict["total"] += event_time
    except Exception as e:
        pass
    
print("= type (num_type) % time =")
# consider compute time only
total = timing_dict["total"] - timing_dict["copy"] - timing_dict["others"]
print("- hmma (", len(timing_dict["hmma_events"]), ")\t", round(timing_dict["hmma"]/total*100, 1))
print("- fp16 (", len(timing_dict["other_fp16_events"]), ")\t", round(timing_dict["other_fp16"]/total*100, 1))
print("- sgemm (", len(timing_dict["sgemm_events"]), ")\t", round(timing_dict["sgemm"]/total*100, 1))
#print("- copy (", len(timing_dict["copy_events"]), ")\t", round(timing_dict["copy"]/total*100, 1))
print("- xla (", len(timing_dict["xla_events"]), ")\t", round(timing_dict["xla"]/total*100, 1))
#print("- others (", len(timing_dict["others_events"]), ")\t", round(timing_dict["others"]/total*100, 1))
print("Total time:", total)