# Flops calculation

## Import libraries

In [8]:
import os
import sys

import tensorflow as tf
from tensorflow.python.framework import ops
from tensorflow.python.keras.utils import tf_utils
from tensorflow.python.profiler import model_analyzer
from tensorflow.python.profiler import option_builder

sys.path.append(os.path.abspath(os.path.join('..')))
from Models import CNNResNetSWAttention

## Functions

In [9]:
def get_flops(model):
    # Batch size of 1 for profiling
    input_shape = (1, 60, 36, 1)
    
    # Fummy inputs for forward pass
    input_right = tf.random.normal(input_shape)
    input_left = tf.random.normal(input_shape)
    
    @tf.function
    def forward_pass(input_right, input_left):
        return model([input_right, input_left])
    
    # Profile the model to calculate FLOPs
    graph_info = model_analyzer.profile(
        forward_pass.get_concrete_function(input_right, input_left).graph,
        options=option_builder.ProfileOptionBuilder.float_operation()
    )
    
    return graph_info.total_float_ops

## Create model and calculate flops

In [10]:
model = CNNResNetSWAttention.create_resnet18_sw__attention(input_shape=(60, 36, 1),
                                                            bn=False,
                                                            first_dense_units=512,
                                                            fc_layer_units=[2048, 1024],
                                                            debug=False)

flops = get_flops(model)
# Assuming MACs = FLOPs / 2
macs = flops / 2 
print(f"MACs: {macs / 1e+9:,} G")
print(f"FLOPs: {flops / 1e+9:,} G")
print(f"Params: {model.count_params() / 1e+6:,} M")

MACs: 4.895054337 G
FLOPs: 9.790108674 G
Params: 28.659722 M
