In [1]:
import math 
import numpy as np

# necessary functions
from scipy.fftpack import fft,fftfreq,ifft
# Importing Scipy 
import scipy as sp
import pywt
import scipy.fft as F
from sklearn.cluster import KMeans
import torch

In [4]:
def PrepareWavelets(K, length=20, seed = 1):
    motherwavelets = []
    for family in pywt.families():
        for mother in pywt.wavelist(family):
            motherwavelets.append(mother)
    
    X = np.zeros([1,length])
    PSI = np.zeros([1,length])
    for mw_temp in motherwavelets:
        if mw_temp.startswith('gaus') or mw_temp.startswith('mexh') or mw_temp.startswith('morl') or mw_temp.startswith('cmor') or mw_temp.startswith('fbsp') or mw_temp.startswith('shan') or mw_temp.startswith('cgau'):
            pass
        else:
            param = pywt.Wavelet(mw_temp).wavefun(level=7)
            psi, x = param[1], param[-1]

            # normalization
            psi_sum = np.sum(psi)
            if np.abs(psi_sum) > 1:
                psi = psi / np.abs(psi_sum)
            x = x / max(x)

            # down sampling
            idx_ds = np.round(np.linspace(0, x.shape[0]-1, length)).astype(int)
            x = x[idx_ds]
            psi = psi[idx_ds]

            X = np.vstack((X, x.reshape(1,-1)))
            PSI = np.vstack((PSI, psi.reshape(1,-1)))

    X = X[1:,:]
    PSI = PSI[1:,:]

    # clustering
    FRE = np.zeros([1,length])
    for i in range(PSI.shape[0]):
        FRE = np.vstack((FRE, np.real(F.fft(PSI[i,:])).reshape(1,-1)))
    FRE = FRE[1:,:]

    PSI_extended = np.hstack((PSI, FRE))
    kmeans = KMeans(n_clusters=K).fit(PSI_extended)
    label = kmeans.labels_

    SelectedWavelet = np.zeros([1,length])
    for k in range(K):
        wavesidx = np.where(label==k)[0][0]
        SelectedWavelet = np.vstack((SelectedWavelet, PSI[wavesidx,:]))            

    return torch.tensor(SelectedWavelet[1:,:])

In [5]:
for k in range(3,30):

#     torch.manual_seed(seed)
#     random.seed(seed)
#     np.random.seed(seed)
    SelectedWavelet1 = PrepareWavelets(K=k, length=128, seed=1)
    

    SelectedWavelet2 = PrepareWavelets(K=k, length=128, seed=1)
    

    print((SelectedWavelet1==SelectedWavelet2).sum())
    print(SelectedWavelet1.shape[0]*SelectedWavelet1.shape[1])
    print("--------------")

tensor(384)
384
--------------
tensor(260)
512
--------------
tensor(640)
640
--------------
tensor(138)
768
--------------
tensor(392)
896
--------------
tensor(268)
1024
--------------
tensor(270)
1152
--------------
tensor(398)
1280
--------------
tensor(274)
1408
--------------
tensor(276)
1536
--------------
tensor(530)
1664
--------------
tensor(154)
1792
--------------
tensor(408)
1920
--------------
tensor(158)
2048
--------------
tensor(462)
2176
--------------
tensor(162)
2304
--------------
tensor(38)
2432
--------------
tensor(292)
2560
--------------
tensor(456)
2688
--------------
tensor(422)
2816
--------------
tensor(424)
2944
--------------
tensor(386)
3072
--------------
tensor(302)
3200
--------------
tensor(386)
3328
--------------
tensor(244)
3456
--------------
tensor(582)
3584
--------------
tensor(617)
3712
--------------


In [24]:
seed = 1
torch.manual_seed(seed)
random.seed(seed)
np.random.seed(seed)

In [25]:
SelectedWavelet2 = PrepareWavelets(K=20, length=128)

In [26]:
(SelectedWavelet1==SelectedWavelet2).sum()

tensor(2560)