```python
"""Demo DWT Perfect reconstruction filter banks

TFDWT: Fast Discrete Wavelet Transform TensorFlow Layers.
Copyright (C) 2025 Kishore Kumar Tarafdar

This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.

This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
GNU General Public License for more details.

You should have received a copy of the GNU General Public License
along with this program.  If not, see <https://www.gnu.org/licenses/>.
"""
```

In [7]:
!python --version

Python 3.12.9


GPU availability?

In [1]:
import tensorflow as tf
print(f"TensorFlow version {tf.__version__}")
print("CUDA Version:", tf.sysconfig.get_build_info()['cuda_version'])
print("Num GPUs Available: ", len(tf.config.list_physical_devices('GPU')))
gpus = tf.config.list_physical_devices('GPU')
len(gpus)

2025-04-05 16:54:51.064401: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-04-05 16:54:51.074067: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1743852291.085279  331772 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1743852291.088664  331772 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1743852291.097351  331772 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking 

TensorFlow version 2.19.0
CUDA Version: 12.5.1
Num GPUs Available:  1


1

In [2]:
import TFDWT
TFDWT.__version__

'0.0.2'

        Compute FLOPS

In [3]:
# import tensorflow as tf
from tensorflow.python.profiler import model_analyzer, option_builder

# model = tf.keras.applications.Xception(
#     weights='imagenet',
#     input_shape=(150, 150, 3),
#     include_top=False
# ) 

def compute_FLOPS(_):
    input_signature = [
        tf.TensorSpec(
            shape=(1, *params.shape[1:]), 
            dtype=params.dtype, 
            name=params.name
        ) for params in _.inputs
    ]

    # _ =  DWT2D(wave='haar')
    # _ = SpatialAttention()
    # _ = IDWT2D(wave='haar')
    # _ = CBAM(reduction_ratio=2)

    forward_graph = tf.function(_, input_signature).get_concrete_function().graph
    options = option_builder.ProfileOptionBuilder.float_operation()
    graph_info = model_analyzer.profile(forward_graph, options=options)
    # print(graph_info)
    flops = graph_info.total_float_ops // 2
    return flops # 1925897756 


# inp = (256, 1)
# layer = DWT1D(wave='haar')
def FLOP_of_layer(inp, layer):
    inp = tf.keras.layers.Input(shape=inp)
    out = layer(inp)
    # _x = APSDown()(_inputs)
    # _ = conv2D(3, _inputs)
    # _x = DWTselfAttention(wave='haar', level=4)(_inputs)
    _ = tf.keras.Model(inputs=inp, outputs=out)
    return compute_FLOPS(_)

# inp = (256, 1)
# layer = DWT1D(wave='haar')
# inp = (128, 2)
# layer = IDWT1D(wave='haar')
# inp = (256,256,1)
# layer = DWT2D(wave='haar')
# inp = (128,128, 4)
# layer = IDWT2D(wave='haar')
# inp = (256, 256, 256, 1)
# layer = DWT3D(wave='haar')
# inp = (128, 128, 128, 8)
# layer = IDWT3D(wave='haar')

# flops = FLOP_of_layer(inp, layer)
# print('FLOPS > ',flops)

    DWT 1D level-1 perfect reconstruction filter bank

In [4]:
from TFDWT.DWTIDWT1Dv1 import DWT1D, IDWT1D
x = tf.keras.Input(shape=(64, 3))
subbands = DWT1D(wave='db3')(x)
xhat = IDWT1D(wave='db3')(subbands)
model = tf.keras.Model(inputs=x, outputs=xhat)
model.summary()
FLOP_of_layer((64, 3), model)

I0000 00:00:1743852292.411626  331772 gpu_device.cc:2019] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 6258 MB memory:  -> device: 0, name: NVIDIA GeForce RTX 3070 Ti Laptop GPU, pci bus id: 0000:01:00.0, compute capability: 8.6


Instructions for updating:
This API was designed for TensorFlow v1. See https://www.tensorflow.org/guide/migrate for instructions on how to migrate your code to TensorFlow v2.

-max_depth                  10000
-min_bytes                  0
-min_peak_bytes             0
-min_residual_bytes         0
-min_output_bytes           0
-min_micros                 0
-min_accelerator_micros     0
-min_cpu_micros             0
-min_params                 0
-min_float_ops              1
-min_occurrence             0
-step                       -1
-order_by                   float_ops
-account_type_regexes       .*
-start_name_regexes         .*
-trim_name_regexes          
-show_name_regexes          .*
-hide_name_regexes          
-account_displayed_op_only  true
-select                     float_ops
-output                     stdout:


Doc:
scope: The nodes in the model graph are organized by their names, which is hierarchical like filesystem.
flops: Number of float operations. Note: Please re

12288

le:
node name | # float_ops
_TFProfRoot (--/24.58k flops)
  functional_1_1/functional_1/dwt1d_1/matmul (8.19k/8.19k flops)
  functional_1_1/functional_1/dwt1d_1/matmul_1 (8.19k/8.19k flops)
  functional_1_1/functional_1/dwt1d_1/matmul_2 (8.19k/8.19k flops)



    DWT 2D level-1 perfect reconstruction filter bank

In [5]:
from TFDWT.DWTIDWT2Dv1 import DWT2D, IDWT2D
x = tf.keras.Input(shape=(64, 64, 3))
subbands = DWT2D(wave='db3')(x)
xhat = IDWT2D(wave='db3')(subbands)
model = tf.keras.Model(inputs=x, outputs=xhat)
model.summary()
FLOP_of_layer((64, 64, 3), model)


-max_depth                  10000
-min_bytes                  0
-min_peak_bytes             0
-min_residual_bytes         0
-min_output_bytes           0
-min_micros                 0
-min_accelerator_micros     0
-min_cpu_micros             0
-min_params                 0
-min_float_ops              1
-min_occurrence             0
-step                       -1
-order_by                   float_ops
-account_type_regexes       .*
-start_name_regexes         .*
-trim_name_regexes          
-show_name_regexes          .*
-hide_name_regexes          
-account_displayed_op_only  true
-select                     float_ops
-output                     stdout:


Doc:
scope: The nodes in the model graph are organized by their names, which is hierarchical like filesystem.
flops: Number of float operations. Note: Please read the implementation for the math behind it.

Profile:
node name | # float_ops
_TFProfRoot (--/0 flops)



0

    DWT 3D level-1 perfect reconstruction filter bank

In [6]:
from TFDWT.DWTIDWT3Dv1 import DWT3D, IDWT3D
x = tf.keras.Input(shape=(64, 64, 64, 3))
subbands = DWT3D(wave='db3')(x)
xhat = IDWT3D(wave='db3')(subbands)
model = tf.keras.Model(inputs=x, outputs=xhat)
model.summary()
FLOP_of_layer((64, 64, 64, 3), model)


-max_depth                  10000
-min_bytes                  0
-min_peak_bytes             0
-min_residual_bytes         0
-min_output_bytes           0
-min_micros                 0
-min_accelerator_micros     0
-min_cpu_micros             0
-min_params                 0
-min_float_ops              1
-min_occurrence             0
-step                       -1
-order_by                   float_ops
-account_type_regexes       .*
-start_name_regexes         .*
-trim_name_regexes          
-show_name_regexes          .*
-hide_name_regexes          
-account_displayed_op_only  true
-select                     float_ops
-output                     stdout:



0


Doc:
scope: The nodes in the model graph are organized by their names, which is hierarchical like filesystem.
flops: Number of float operations. Note: Please read the implementation for the math behind it.

Profile:
node name | # float_ops
_TFProfRoot (--/0 flops)

