In [1]:
# import os
import pandas as pd
import numpy as np

import torch
from torchvision.io import read_image

from helpers.DatasetProcess import dataset_to_df, search_df
import matplotlib as mpl
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import ImageGrid
from pywt import WaveletPacket2D as wp2d
import pywt
import ptwt
import cv2
from skimage.color import rgb2gray

from einops import rearrange, repeat

In [2]:
path = '../../../data/real-vs-fake'
relative_paths = ["/train/real", "/train/fake", "/test/real", "/test/fake", "/valid/real", "/valid/fake"]
paths_classes = ["REAL", "FAKE", "REAL", "FAKE", "REAL", "FAKE"]

try:
    df_all = pd.read_csv(f'{path}/df_all.csv')
    df_train = pd.read_csv(f'{path}/df_train.csv')
    df_val = pd.read_csv(f'{path}/df_val.csv')
    df_test = pd.read_csv(f'{path}/df_test.csv')
    classes_stats = pd.read_csv(f'{path}/classes_stats.csv')
except FileNotFoundError:
    df_all, df_train, df_val, df_test, classes_stats = dataset_to_df(
        path, relative_paths, paths_classes, 0.8, 0.14, 0.06)
    df_all.to_csv(f'{path}/df_all.csv', index=False)
    df_train.to_csv(f'{path}/df_train.csv', index=False)
    df_val.to_csv(f'{path}/df_val.csv', index=False)
    df_test.to_csv(f'{path}/df_test.csv', index=False)
    classes_stats.to_csv(f'{path}/classes_stats.csv')

classes_stats

Unnamed: 0.1,Unnamed: 0,REAL,FAKE,Total
0,Training,56000,56000,112000
1,Validation,9800,9800,19600
2,Testing,4200,4200,8400
3,Row_Total,70000,70000,140000


In [3]:
index=10
img = read_image(df_all.iloc[index, 0]).type(torch.float32).unsqueeze(0)
img.shape


torch.Size([1, 3, 256, 256])

In [4]:
ptwp=ptwt.WaveletPacket2D(img,"haar",mode='boundary',maxlevel=2)
# paths = [node.path for node in wp[0].get_level(level)]
paths_keys = [key for key in ptwp if len(key)==2]
# paths_keys
# 
nodes = [ptwp[path].unsqueeze(1) for path in paths_keys]
# nodes = [ptwp[path] for path in paths_keys]
# nodes
# ptwp['aa'].unsqueeze(1).shape
nodes_tensor=torch.cat(nodes,dim=1)
# nodes_tensor=torch.cat(nodes)
nodes_tensor.shape

  return torch.sparse.mm(selection_eye, convolution_matrix)


torch.Size([1, 16, 3, 64, 64])

In [5]:
####################################################################################
#                    Norm to Plot
####################################################################################
def norm_to_plot(img):

    TensorOrigin = False
    CHWOrigin = False
    # print(img.shape)  # >>>>>>>>>>>>>>>>
    # Make sure the used Image is numpy array with (Height,Width,Channel) shape
    if torch.is_tensor(img):
        TensorOrigin = True
        img = img.numpy()

    if img.shape[0] == 3:
        CHWOrigin = True
        img = rearrange(img, 'c h w -> h w c')

    img_out = (img-img.min())/(img.max()-img.min())
    # print(img_out.shape)  # >>>>>>>>>>>>>>>>

    return img, img_out


####################################################################################
#                    Plot the Image and its Extracted WPT Features
####################################################################################
def plot_img_grid(holder, data, main_title, sub_title=None, Grid2D=True, setticks=None, axes_pad=0.3):

    data_rows = data.shape[0]
    data_cols = data.shape[1]
    if Grid2D:
        nrows = data_rows
        ncols = data_cols
    else:
        nrows = 1
        ncols = data_rows*data_cols

    x_label = range(ncols)
    y_label = range(nrows)

    grid = ImageGrid(holder, 111, (nrows, ncols), axes_pad=axes_pad)

    for i, ax in enumerate(grid):
        r, c = i//data_cols, i % data_cols
        ax.imshow(norm_to_plot(data[r, c])[1])
        if setticks is not None:
            ax.set(xticks=np.arange(0, data[r, c].shape[1]+1, step=setticks),
                   yticks=np.arange(0, data[r, c].shape[0]+1, step=setticks))
        else:
            ax.set(xticks=[], yticks=[])

        if Grid2D:
            if r == nrows - 1:
                ax.set_xlabel(x_label[c], rotation=0, fontsize=10, labelpad=20)
            if c == 0:
                ax.set_ylabel(y_label[r], rotation=0, fontsize=10, labelpad=20)
        else:
            ax.set_xlabel(x_label[i], rotation=0, fontsize=10, labelpad=20)
            if i == 0:
                ax.set_ylabel(y_label[i], rotation=0, fontsize=10, labelpad=20)
        if sub_title is not None:
            ax.set_title(sub_title[i])

    holder.suptitle(main_title)


####################################################################################
#      Decompose the Image using Wavelet Packet Transform(Keep Features Only)
####################################################################################
def wpt_dec(img, wavelet_fun, level):

    paths_keys = [0]*3
    features_rows = 2**level

    # apply wavelet packet transform

    ptwp = ptwt.WaveletPacket2D(img, wavelet_fun, mode='boundary', maxlevel=2)

    # get the paths of the image
    paths_keys = [key for key in ptwp if len(key) == 2]
    
    # Arrange the paths in a 2D matrix shape, useful to visualize the wavelet packet features
    paths_rows = []
    paths_matrix = []
    for i, path in enumerate(paths_keys):
        if (i+1) % features_rows == 0:
            paths_rows.append(path)
            paths_matrix.append(paths_rows)
            paths_rows = []
        else:
            paths_rows.append(path)

    nodes = [ptwp[path].unsqueeze(1) for path in paths_keys]
    nodes_tensor=torch.cat(nodes,dim=1)
    
    ###############################################################################
    wp_fun = wp[0][paths[0]].wavelet.wavefun()
    # # x, y = wp_fun[-1], wp_fun[0]
    wp_name = ptwp.wavelet.name

    return img, nodes_array, paths, features_rows, wp_fun, wp_name, node_shape, nodes_tensor

# def wpt_dec(img, wavelet_fun, level):
#     # This function decomposes the input 2D image into (2**level)**2 2D features
#     # hence if we have only one level the output 2D features will be 4
#     # The 2D features could be arranged as a matrix of (2**level) rows and (2**level) cols
#     # ranging from the most approximate feature at location (0,0) to the most detailed
#     # feature at location ((2**level) - 1,(2**level) - 1)
#     # The original 2D wave_packet_decompose function down samples the input by 2 after each filter
#     # Note: The 2D wave_packet_decompose function deals with a single channel image only.
#     # hence you need to take a single channel out of the 3 channel coloured image
#     #
#     # This function gives the option to output the final features in "stacked" 3D matrix
#     # of a shape:
#     # if numpy  : (no_of_features x feature_height x feature_width x 3 Channels)
#     # if torch  : ((no_of_features x 3 Channels) x feature_height x feature_width)

#     TensorOrigin = False
#     CHWOrigin = False

#     img_one_channel = [0]*3
#     wp = [0]*3
#     paths = [0]*3

#     # Make sure the used Image is numpy array with (Height,Width,Channel) shape.
#     if torch.is_tensor(img):
#         TensorOrigin = True
#         img = img.numpy()

#     if img.shape[0] == 3:
#         CHWOrigin = True
#         img = rearrange(img, 'c h w -> h w c')

#     img_h = img.shape[0]
#     features_rows = 2**level

#     img_one_channel = [img[:, :, i] for i in range(3)]

#     # apply wavelet packet transform
#     wp = [wp2d(data=img_one_channel[i], wavelet=wavelet_fun,
#                mode='symmetric') for i in range(3)]

#     # get the paths of the image
#     paths = [node.path for node in wp[0].get_level(level)]

#     # Arrange the paths in a 2D matrix shape, useful to visualize the wavelet packet features
#     paths_rows = []
#     paths_matrix = []
#     for i, path in enumerate(paths):
#         if (i+1) % features_rows == 0:
#             paths_rows.append(path)
#             paths_matrix.append(paths_rows)
#             paths_rows = []
#         else:
#             paths_rows.append(path)

#     nodes = [[wp[i][path].data for path in paths] for i in range(3)]
#     node_shape = wp[0][paths[0]].data.shape

#     # print("axis", axis) # >>>>>>

#     # --->(16, feature_height, feature_width, 3)
#     nodes_array = rearrange(np.array(nodes), "c f fh fw -> f fh fw c")

#     # --->(16*3, feature_height, feature_width)
#     nodes_tensor = torch.tensor(
#         rearrange(np.array(nodes), "c f fh fw -> f c fh fw"))

#     ###############################################################################
#     wp_fun = wp[0][paths[0]].wavelet.wavefun()
#     # # x, y = wp_fun[-1], wp_fun[0]
#     wp_name = wp[0][paths[0]].wavelet.family_name

#     return img, nodes_array, paths, features_rows, wp_fun, wp_name, node_shape, nodes_tensor


####################################################################################
#                    Plot the Image and its Extracted WPT Features
####################################################################################
def plot_wpt_nodes(image, wavelet_fun, level, setticks1=None, setticks2=None):
    img, nodes, paths, rows, *_ = wpt_dec(image, wavelet_fun, level)
    plt.rcParams['figure.constrained_layout.use'] = True
    # plt.rcParams["figure.autolayout"] = True

    fig = plt.figure(figsize=(15, 15))
    subfigs = fig.subfigures(2, 1, height_ratios=[
                             1, 2], hspace=0.01, squeeze='True')

    axs0 = subfigs[0].subplots(1, 1)
    axs0.imshow(norm_to_plot(img)[1])
    if setticks1 is not None:
        axs0.set(xticks=np.arange(0, img.shape[1]+1, step=setticks1),
                 yticks=np.arange(0, img.shape[0]+1, step=setticks1))
    else:
        axs0.set(xticks=[], yticks=[])
    axs0.set_title("Image")

    grid_text = "Features extracted using Wavelet Packet Transform"

    nodes_grid = np.reshape(nodes[:, :, :, :],
                            (int(np.sqrt(nodes.shape[0])), -1,
                             nodes.shape[1], nodes.shape[2], nodes.shape[3]))

    plot_img_grid(subfigs[1], nodes_grid, grid_text,
                  paths, Grid2D=True, setticks=setticks2)


####################################################################################
#                     Plot the Wavelet Impulse function
####################################################################################
def plot_wpt_fun(image, wavelet_fun, level):
    *_, wp_fun, wp_name, _, _ = wpt_dec(image, wavelet_fun, level)

    fig, axs = plt.subplots(1, 1, figsize=(2, 2), layout='constrained')

    axs.plot(wp_fun[-1], wp_fun[0])
    axs.grid(True)
    axs.set_title(f'{wavelet_fun}')

    # axs[0].plot(wp_fun[-1], wp_fun[0])
    # axs[0].grid(True)
    # axs[0].set_title(f'{wavelet_fun}')

    # w = pywt.Wavelet(wavelet_fun)
    # (phi, psi, x) = w.wavefun(level=level)

    # axs[1].plot(x, phi)
    # axs[1].grid(True)
    # axs[1].set_title(f'{wavelet_fun} phi')

    # axs[2].plot(x, psi)
    # axs[2].grid(True)
    # axs[2].set_title(f'{wavelet_fun} psi')

In [6]:
# image = plt.imread(df_all.iloc[2999, 0])
image = plt.imread(df_all.iloc[2999, 0])
print("input image shape", image.shape)

input image shape (256, 256, 3)


In [7]:
w_family = pywt.families()
wavelet_lst = []
for i, w_fun in enumerate(w_family):
    wavelet_lst.append(pywt.wavelist(w_fun))
    formatted_w_fun = f'{i}) {w_fun}: {", ".join(pywt.wavelist(w_fun))}'
    print(formatted_w_fun)
print(" \n")

0) haar: haar
1) db: db1, db2, db3, db4, db5, db6, db7, db8, db9, db10, db11, db12, db13, db14, db15, db16, db17, db18, db19, db20, db21, db22, db23, db24, db25, db26, db27, db28, db29, db30, db31, db32, db33, db34, db35, db36, db37, db38
2) sym: sym2, sym3, sym4, sym5, sym6, sym7, sym8, sym9, sym10, sym11, sym12, sym13, sym14, sym15, sym16, sym17, sym18, sym19, sym20
3) coif: coif1, coif2, coif3, coif4, coif5, coif6, coif7, coif8, coif9, coif10, coif11, coif12, coif13, coif14, coif15, coif16, coif17
4) bior: bior1.1, bior1.3, bior1.5, bior2.2, bior2.4, bior2.6, bior2.8, bior3.1, bior3.3, bior3.5, bior3.7, bior3.9, bior4.4, bior5.5, bior6.8
5) rbio: rbio1.1, rbio1.3, rbio1.5, rbio2.2, rbio2.4, rbio2.6, rbio2.8, rbio3.1, rbio3.3, rbio3.5, rbio3.7, rbio3.9, rbio4.4, rbio5.5, rbio6.8
6) dmey: dmey
7) gaus: gaus1, gaus2, gaus3, gaus4, gaus5, gaus6, gaus7, gaus8
8) mexh: mexh
9) morl: morl
10) cgau: cgau1, cgau2, cgau3, cgau4, cgau5, cgau6, cgau7, cgau8
11) shan: shan
12) fbsp: fbsp
13) cmo

In [8]:
level = 2

# (5,7) , (1,0) , (2,0)
# i, j = 0, 0  # index of the main wavelet category (haar)
# i, j = 1, 0  # index of the main wavelet category (db1)
# i, j = 1, 11  # index of the main wavelet category (db12)
# i, j = 4, 2  # index of the main wavelet category (bior1.5)
# i, j = 5, 7  # index of the main wavelet category (rbio3.1)
# i, j = 3, 3  # index of the main wavelet category (coif4 )

wavelet_fun = wavelet_lst[i][j]
# wavelet_fun = 'db2'
print("Used Wavelet function:", wavelet_fun, "\n")

NameError: name 'j' is not defined

In [None]:
Img, nodes, paths, rows, wp_fun, wp_name,nodeshape ,nodes_tensor = wpt_dec(
    image, wavelet_fun, level)

In [None]:
plot_wpt_fun(image, wavelet_fun, level)
print(nodeshape)

In [None]:
plot_wpt_nodes(image, wavelet_fun, level, setticks1=32, setticks2=10)

In [None]:
print(nodeshape)

In [None]:
nodes.shape

In [None]:
nodes_tensor.shape