<a href="https://colab.research.google.com/github/tokusumi/keras-flops/blob/master/notebooks/flops_calculation_tfkeras.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# FLOPS calculation with tf.keras

Calculate FLOPS about operations used at inference of tf.keras.Sequential or tf.keras.Model instanse.

In [1]:
import tensorflow as tf
from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2_as_graph

from tensorflow.keras import Model, Input
from tensorflow.keras.layers import Dense, Flatten

## main function

In [2]:
def get_flops(model, batch_size=None):
    if batch_size is None:
        batch_size = 1

    real_model = tf.function(model).get_concrete_function(tf.TensorSpec([batch_size] + model.inputs[0].shape[1:], model.inputs[0].dtype))
    frozen_func, graph_def = convert_variables_to_constants_v2_as_graph(real_model)

    run_meta = tf.compat.v1.RunMetadata()
    opts = tf.compat.v1.profiler.ProfileOptionBuilder.float_operation()
    flops = tf.compat.v1.profiler.profile(graph=frozen_func.graph,
                                            run_meta=run_meta, cmd='op', options=opts)
    return flops.total_float_ops

## test with simple architecture

In [3]:
def build_base_model():
    inp = Input((32, 32, 3))
    x = Flatten()(inp)
    out = Dense(10)(x)
    model = Model(inp, out)
    return model

In [4]:
def main(batch_size):
    model = build_base_model()
    model.summary()

    flops = get_flops(model, batch_size)
    print(f"FLOPS: {flops}")
    return model

In [5]:
m = main(1)
# FLOPS = 32 * 32 * 3 (= 3072) * 10 * 2 + 10 = 61450

Model: "functional_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_1 (InputLayer)         [(None, 32, 32, 3)]       0         
_________________________________________________________________
flatten (Flatten)            (None, 3072)              0         
_________________________________________________________________
dense (Dense)                (None, 10)                30730     
Total params: 30,730
Trainable params: 30,730
Non-trainable params: 0
_________________________________________________________________
Instructions for updating:
Use `tf.compat.v1.graph_util.tensor_shape_from_node_def_name`
FLOPS: 61450


## demo with LeNet and CIFAR10

In [6]:
import numpy as np
import tensorflow as tf

import tensorflow_datasets as tfds
from absl import app, flags
from easydict import EasyDict
from tensorflow.keras import Model, Input
from tensorflow.keras.layers import Dense, Flatten, Conv2D, MaxPooling2D, Dropout

In [7]:
def build_model(batch_size=1):
    inp = Input((32, 32, 3))
    x = Conv2D(32, kernel_size=(3, 3),
                     activation='relu')(inp)
    x = Conv2D(64, (3, 3), activation='relu')(x)
    x = MaxPooling2D(pool_size=(2, 2))(x)
    x = Dropout(0.25)(x)
    x = Flatten()(x)
    x = Dense(128, activation="relu")(x)
    x = Dropout(0.5)(x)
    out = Dense(10, activation="softmax")(x)
    model = Model(inp, out)
    return model

In [8]:
def load_cifar10(batch=128):
    """Load CIFAR10 training and test data."""

    def convert_types(data):
        image = data["image"]
        label = data["label"]
        image = tf.cast(image, tf.float32)
        image /= 255
        return image, label

    dataset, info = tfds.load('cifar10', with_info=True)
    # return dataset, info
    mnist_train, mnist_test = dataset['train'], dataset['test']
    mnist_train = mnist_train.map(convert_types).shuffle(10000).batch(batch)
    mnist_test = mnist_test.map(convert_types).batch(batch)
    return EasyDict(train=mnist_train, test=mnist_test)

In [9]:
def main(batch_size, pred_batch):
    # Load training and test data
    data = load_cifar10(batch_size)
    #Load CNN Model
    model = build_model()
    model.summary()

    # Calculae FLOPS
    flops = get_flops(model, pred_batch)
    print(f"FLOPS: {flops / 10**9:.03} G")

    # train
    model.compile(
        optimizer="adam",
        loss=tf.keras.losses.SparseCategoricalCrossentropy()
    )
    model.fit(
        data.train,
        epochs=5
    )
    return model

In [10]:
# GFLOPS of 2nd Conv2D ～ 28 * 28 * 3 * 3 * 32 * 64 * 2 / 10 ** 9
# ～ 0.028901376

In [11]:
main(batch_size=128, pred_batch=1)

Model: "functional_3"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_2 (InputLayer)         [(None, 32, 32, 3)]       0         
_________________________________________________________________
conv2d (Conv2D)              (None, 30, 30, 32)        896       
_________________________________________________________________
conv2d_1 (Conv2D)            (None, 28, 28, 64)        18496     
_________________________________________________________________
max_pooling2d (MaxPooling2D) (None, 14, 14, 64)        0         
_________________________________________________________________
dropout (Dropout)            (None, 14, 14, 64)        0         
_________________________________________________________________
flatten_1 (Flatten)          (None, 12544)             0         
_________________________________________________________________
dense_1 (Dense)              (None, 128)              

<tensorflow.python.keras.engine.functional.Functional at 0x7fc9cf7e2240>