In [4]:
#REQUIRED LIBRARIES
!pip install soundfile
# Preprocessing and operations on Audio
import soundfile as sf 
from scipy.io.wavfile import read

# Math operations and 
import math
import functools
import scipy.special as spsp
from scipy.special import exp1

# File handling 
import glob, os

# Data handling 
import numpy as np
from tqdm import tqdm

# Deep learning: Modelling helpers
import tensorflow as tf
from tensorflow.keras import Sequential
from tensorflow.keras.models import Model
from tensorflow.python.ops.signal import window_ops
from tensorflow.keras.layers import Activation, Add, Conv1D, Conv2D, Dense, Dropout,Flatten, LayerNormalization, MaxPooling2D, ReLU, Input, Masking

Collecting soundfile
  Downloading https://files.pythonhosted.org/packages/eb/f2/3cbbbf3b96fb9fa91582c438b574cff3f45b29c772f94c400e2c99ef5db9/SoundFile-0.10.3.post1-py2.py3-none-any.whl
Installing collected packages: soundfile
Successfully installed soundfile-0.10.3.post1


In [0]:
# Functions for audio handling
def save_wav(path, wav, f_s):
    """"
    save_wav: Save the denoised audio to the given path
        path                  - path to save the output audio file 
        f_s                   - sampling freq
        wav                   - audio file extension
        np.squeeze            - Remove single-dimensional entries from the shape of an array
        if block(isinstance)  - function returns True if the specified object is of the specified type, otherwise False 
                                checking if the file is float dtype
        np.asarray            - convert an given input to an array
     """
    wav = np.squeeze(wav)
    if isinstance(wav[0], np.float32):
        wav = np.asarray(np.multiply(wav, 32768.0), dtype=np.int16) 
    sf.write(path, wav, f_s)

def read_wav(path):
    """
    read_wav: Read the audio files from a given path
    """
    wav, f_s = sf.read(path, dtype='int16')
    return wav, f_s

def Batch(fdir):
    """
    Batch: Create a bacth of input audio files
        fdir           - Input path of audio files 
        fname_l        - list of file names.
        wav_l          - list of fetched_audio files
        len_l          - list of audio file lengths
    Returns: 
    audio_files_numpy_format, np array of audio file legnths, list of file_names
    """
    fname_l = [] 
    wav_l = [] 
    fnames = ['*.wav', '*.flac', '*.mp3']

    #get all the supported sound file types from the given path
    for fname in fnames:  
        for fpath in glob.glob(os.path.join(fdir, fname)): 
            (wav, _) = read_wav(fpath) # read each audio file using the soundfile library 
            if np.isnan(wav).any() or np.isinf(wav).any():
                raise ValueError('Error: NaN or Inf value. File path: %s.' % (fdir))
            wav_l.append(wav) #add the wavefile name to the wav_l list
            fname_l.append(os.path.basename(os.path.splitext(fpath)[0])) # append respective path 

    len_l = []
    # get maximum audio length among all the files, so that all the other files are padded with zeroes to have a uniform batch
    maxlen = max(len(wav) for wav in wav_l) 
    wav_np = np.zeros([len(wav_l), maxlen], np.int16) # creating a numpy array of zeroes, with the length of the largest audio file as a dimension 

    for (i, wav) in zip(range(len(wav_l)), wav_l):
      #Overlapping the zeros in the numpy array of zeroes, to create a padded array for the smaller files
        wav_np[i,:len(wav)] = wav 
        len_l.append(len(wav))
    return wav_np, np.array(len_l, np.int32),fname_l

In [0]:
#CUSTOM CLASS FOR DIGITAL SIGNAL PROCESSING 

import numpy as np
import numpy.fft as fft

def stft(x, Nwin, Nfft=None):
    
    Nfft = Nfft or Nwin
    print(x)
    Nwindows = tf.size(x) // Nwin
    x_len =Nwindows * Nwin
    # reshape into array `Nwin` wide, and as tall as possible. This is
    # optimized for C-order (row-major) layouts.
    arr = np.reshape(x[:x_len], (-1, Nwin))
    stft = fft.rfft(arr, Nfft)
    return stft


def istft(stftArr, Nwin):
    print("stftArr_shape:",stftArr.shape)
    arr = fft.irfft(stftArr)
    print("arr_shape:",fft.irfft(stftArr).shape)
    return np.reshape(arr, -1)


class STFT:
    """
    Short-Term Fourier Transform:
        N_d                -  window duration (samples)
        N_s                -  window shift (samples)
        NFFT               -  number of DFT componts [ Discrete fourier transform]
        f_s                   - sampling freq
    """
    def __init__(self, N_d, N_s, NFFT, f_s):
        self.N_d = N_d
        self.N_s = N_s
        self.NFFT = NFFT
        self.f_s = f_s
        self.W = functools.partial(window_ops.hann_window, periodic=True) # A callable that takes a window length and returns a [window_length] Tensor of samples in the provided datatype.
        self.ten = tf.cast(10.0, tf.float32) # Casting the tensor to the float32 type

    def polar_analysis(self, x):
        """
        x                  -  Input numpy array
        tf.signal.stft     -  Computes the Short-time Fourier Transform of signals
        Returns:
        tf.abs             -  computes absolute value of tensor         
        tf.math.angle      -  returns element wise arguments of a complex tensor
      """
#         STFT = tf.signal.stft(x, self.N_d, self.N_s, self.NFFT, window_fn=self.W, pad_end=True)
        STFT=stft(x,self.N_d)
        print("STFT_fun:",STFT)
        return tf.abs(STFT), tf.math.angle(STFT)

    def polar_synthesis(self, STMS, STPS):
        """
        tf.cast                      - Casts a tensor to a new type
        tf.complex                   - A Tensor of type complex64 or complex128 . Raises. TypeError, Real and imag must be correct
        tf.exp                       - performs exponential operation on a tensor
        Returns:
        tf.signal.inverse_stft       - inverse the stft input signals 
        """
        stfs = tf.cast(STMS, tf.complex64)*tf.exp(1j*tf.cast(STPS, tf.complex64))
        print("STFT_dir:",stfs)
        return istft(stfs,self.N_d)
#         return tf.signal.inverse_stft(stfs, self.N_d, self.N_s, self.NFFT, tf.signal.inverse_stft_window_fn(self.N_s, self.W))

class DeepXiInput(STFT):
    def __init__(self, N_d, N_s, NFFT, f_s, mu=None, sigma=None):
        """
        defining mu and sigma 
        """
        super().__init__(N_d, N_s, NFFT, f_s)
        self.mu = mu
        self.sigma = sigma
        
    def observation(self, x):
        x = self.normalise(x)
#         print("x_type:",type(x))
        x_STMS, x_STPS = self.polar_analysis(x)
        
        return x_STMS, x_STPS
    
    def normalise(self, x):
        #normailzation / standardization
        """
        tf.math.truediv  -  Divides x tensor by y elementwise
        """
        return tf.truediv(tf.cast(x, tf.float32), 32768.0)

    def n_frames(self, N):
        """
        tf.math.ceil - Return the ceiling of the input, element-wise
        """
        print("new:",N,self.N_d)
        return tf.cast(tf.math.ceil(tf.truediv(tf.cast(N, tf.float32), tf.cast(self.N_d, tf.float32))), tf.int32)

#     def log_10(self, x):
#         """
#         tf.math.log - Computes natural logarithm of x element-wise
#         """
#         return tf.truediv(tf.math.log(x), tf.math.log(self.ten))

#     def xi(self, s_STMS, d_STMS):
#         return tf.truediv(tf.square(s_STMS), tf.maximum(tf.square(d_STMS), 1e-12))

#     def xi_db(self, s_STMS, d_STMS):
#         return tf.multiply(10.0, self.log_10(tf.maximum(self.xi(s_STMS, d_STMS), 1e-12)))

#     def xi_bar(self, s_STMS, d_STMS):
#         """
#         tf.math.erf -  Computes the Gauss error function of x element-wise
#         """
#         return tf.multiply(0.5, tf.add(1.0, tf.math.erf(tf.truediv(tf.subtract(self.xi_db(s_STMS, d_STMS), self.mu),
#           tf.multiply(self.sigma, tf.sqrt(2.0))))))

    def xi_hat(self, xi_bar_hat):
        """
        scipy.special.erfinv(y)  -  Inverse of the error function erf.
        """ 
        xi_db_hat = np.add(np.multiply(np.multiply(self.sigma, np.sqrt(2.0)),
                                       spsp.erfinv(np.subtract(np.multiply(2.0, xi_bar_hat), 1))), self.mu)
        return np.power(10.0, np.divide(xi_db_hat, 10.0))

In [7]:
from google.colab import drive
drive.mount('/content/gdrive')  #mounting

Go to this URL in a browser: https://accounts.google.com/o/oauth2/auth?client_id=947318989803-6bn6qk8qdgf4n4g3pfee6491hc0brc4i.apps.googleusercontent.com&redirect_uri=urn%3aietf%3awg%3aoauth%3a2.0%3aoob&response_type=code&scope=email%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdocs.test%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive.photos.readonly%20https%3a%2f%2fwww.googleapis.com%2fauth%2fpeopleapi.readonly

Enter your authorization code:
··········
Mounted at /content/gdrive


In [8]:
%cd '/content/gdrive/My Drive/DEEP/DeepXi'  
# set working directory

/content/gdrive/My Drive/DEEP/DeepXi


In [9]:
mu=np.array([[1,2,3,4,5,6],[4,5,6,7,8,9],[7,8,9,10,11,12]]);
sigma=np.array([[1,2,3,4,5,6],[4,5,6,7,8,9],[7,8,9,10,11,12]]);
xi_hat_bar=np.array([[0.64731264,0.6426016 ,0.48500836,0.06257722,0.0639689,0.08598465],[0.42110795,0.5649021,0.57181394,0.42728427,0.4406531 ,0.41797763],[0.39728028,0.40586644,0.427332,0.35131848,0.3496621,0.34612828]])

def xi_hat_test(xi_bar_hat):
    xi_db_hat = np.add(np.multiply(np.multiply(sigma, np.sqrt(2.0)),
                                   spsp.erfinv(np.subtract(np.multiply(2.0, xi_bar_hat), 1))), mu)
    return np.power(10.0, np.divide(xi_db_hat, 10.0))
g=xi_hat_test(xi_hat_bar);
print(g)

[[1.37343314 1.87535849 1.9441233  0.61179062 0.54809738 0.60319535]
 [2.09110599 3.81684824 5.11206688 3.72988351 4.7923567  5.17174472]
 [3.29400934 4.06861991 5.43433334 4.15180285 4.73307978 5.30941263]]


In [0]:
 # MMSE-LSA gain function.
def gfunc(xi, gamma=None):
    """
    MMSE-LSA Gain function
    """
    nu = np.multiply(np.divide(xi, np.add(1, xi)), gamma)
    G = np.multiply(np.divide(xi, np.add(1, xi)), np.exp(np.multiply(0.5, exp1(nu))))
    return G


In [0]:
# Modelling ResNet architecture

class ResNet:
    """
    ResNet: Residual Neural Network
      - Base model for DeepXi architeture
    """
    def __init__(self,inp,n_outp,n_blocks,d_model,d_f,k,max_d_rate,padding,):
        self.d_model = d_model
        self.d_f = d_f
        self.k = k
        self.n_outp = n_outp
        self.padding = padding
        self.first_layer = self.feedforward(inp)
        self.layer_list = [self.first_layer]
        for i in range(n_blocks): self.layer_list.append(self.block(self.layer_list[-1], int(2**(i%(np.log2(max_d_rate)+1)))))
        self.logits = Conv1D(self.n_outp, 1, dilation_rate=1, use_bias=True)(self.layer_list[-1])
        self.outp = Activation('sigmoid')(self.logits)     

    # 1st layer
    def feedforward(self, inp):
        """
        1D convolution layer (temporal convolution)   -    This layer creates a convolution kernel that is convolved with the layer input over a single spatial (or temporal) 
                                                            dimension to produce a tensor of outputs.

        """
        ff = Conv1D(self.d_model, 1, dilation_rate=1, use_bias=False)(inp)
        norm = LayerNormalization(axis=2, epsilon=1e-6)(ff)
        act = ReLU()(norm)
        return act

    # 2nd layer
    def block(self, inp, d_rate):
        """
        2D convolution layer (spatial convolution)
        """
        self.conv_1 = self.unit(inp, self.d_f, 1, 1, False)
        self.conv_2 = self.unit(self.conv_1, self.d_f, self.k, d_rate,
            False)
        self.conv_3 = self.unit(self.conv_2, self.d_model, 1, 1, True)
        residual = Add()([inp, self.conv_3])
        return residual

    # 3rd layer
    def unit(self, inp, n_filt, k, d_rate, use_bias):
        """
        dilation_rate                     : an integer or tuple/list of 2 integers, specifying the dilation rate to use for dilated convolution.
                                          Can be a single integer to specify the same value for all spatial dimensions. 
        use_bias                          : Boolean, whether the layer uses a bias vector.
        Relu                              : Clips value in range of 0 to infinity , so clips all negative value to zero
        Layer normalization layer         :  Normalize the activations of the previous layer for each given example in a batch independently,
                                           rather than across a batch like Batch Normalization. i.e. applies a transformation that maintains the mean activation within each example close to 0
                                             and the activation standard deviation close to 1.
        """
        norm = LayerNormalization(axis=2, epsilon=1e-6)(inp)
        act = ReLU()(norm)
        conv = Conv1D(n_filt, k, padding=self.padding, dilation_rate=d_rate,
            use_bias=use_bias)(act)
        return conv

In [0]:
class DeepXi(DeepXiInput):
    def __init__(self,N_d,N_s,NFFT,f_s,model_path,stat_path,**kwargs):
        super().__init__(N_d, N_s, NFFT, f_s)
        self.n_feat = math.ceil(self.NFFT/2 + 1)
        self.n_outp = self.n_feat
        self.inp = Input(name='inp', shape=[None, self.n_feat], dtype='float32')
        self.mask = Masking(mask_value=0.0)(self.inp)
        
        self.network = ResNet(
            inp=self.mask,
            n_outp=self.n_outp,
            n_blocks=kwargs['n_blocks'],
            d_model=kwargs['d_model'],
            d_f=kwargs['d_f'],
            k=kwargs['k'],
            max_d_rate=kwargs['max_d_rate'],
            padding=kwargs['padding'],
            )
        self.model = Model(inputs=self.inp, outputs=self.network.outp)
        self.model.summary()
        # The Actual program starts from this line
        self.sample_stats(stat_path) # Load sample statistics file to derive mu and sigma values
        self.model.load_weights(model_path) #Load Weights of Saved_model from model_path


    def infer(self,test_x,test_x_len,test_x_base_names,out_path='out/denoised/',n_filters=40,):
        
        out_path = out_path
        
        print("Processing observations...")
        x_STMS_batch, x_STPS_batch, n_frames = self.observation_batch(test_x, test_x_len)

        print("Performing inference...")
        xi_bar_hat_batch = self.model.predict(x_STMS_batch, batch_size=1, verbose=1)# MAX TIME TAKEN

        batch_size = len(test_x_len)
        for i in tqdm(range(batch_size)):
            base_name = test_x_base_names[i]
            x_STMS = x_STMS_batch[i,:n_frames[i],:]
            x_STPS = x_STPS_batch[i,:n_frames[i],:]
#             print("x_STPS:",x_STPS)
#             print("stps_shape:",x_STPS.shape)
            xi_bar_hat = xi_bar_hat_batch[i,:n_frames[i],:]
#             print("xi_bar_hat:",xi_bar_hat)
#             print("mul",np.multiply(2.0,xi_bar_hat),1)
#             print("mul_1",np.multiply(2.0,xi_bar_hat))
            xi_hat = self.xi_hat(xi_bar_hat)
            
            y_STMS = np.multiply(x_STMS, gfunc(xi_hat, xi_hat+1))
#             print("y_STMS:",y_STMS)
#             print("stms_shape:",y_STMS.shape)
            y = self.polar_synthesis(y_STMS, x_STPS) # Stops here in RT
#             print("Y_shape:",y.shape)
            save_wav(out_path+ base_name + '.wav', y, self.f_s)
            
    def sample_stats(self,stats_path='data/'):
        if os.path.exists(stats_path + 'stats.npz'):
            print('Loading sample statistics...')
            with np.load(stats_path + 'stats.npz') as stats:
                self.mu = stats['mu_hat'] 
#                 print(self.mu)
                self.sigma = stats['sigma_hat']
#                 print(self.sigma)
                


    def observation_batch(self, x_batch, x_batch_len):
        """
        batch_size       - getting size of numpy (converted audio)
        max_n_frames     - taking maximum value of array size
        x_STMS_batch     - create numpy of zeros value with following dimension
        n_feat           - is def function, define using math operation
        STMS             - short time magnitude spectrum
        STPS             - short time phase spectrum
        """
        batch_size = len(x_batch)
        max_n_frames = self.n_frames(max(x_batch_len))
        x_STMS_batch = np.zeros([batch_size, max_n_frames, self.n_feat], np.float32)
        x_STPS_batch = np.zeros([batch_size, max_n_frames, self.n_feat], np.float32)
        n_frames_batch = [self.n_frames(i) for i in x_batch_len]
        for i in tqdm(range(batch_size)): # Module for iterating batches
            x_STMS, x_STPS = self.observation(x_batch[i,:x_batch_len[i]])
#             print("n_frame:",n_frames_batch[i])
#             print("x_STMS shape:",x_STMS.shape)
            n_frames_batch[i]=x_STMS.shape[0]
            x_STMS_batch[i,:n_frames_batch[i],:] = x_STMS.numpy()
            x_STPS_batch[i,:n_frames_batch[i],:] = x_STPS.numpy()
        return x_STMS_batch, x_STPS_batch, n_frames_batch

In [13]:
#VARIABLES FOR THE MODEL
d_model  = 256     #block output size
n_blocks  = 40     #no of blocks in the model
d_f = 64           #block bottlekneck size
k =  3             #convolution kernel size
max_d_rate = 16    #max_dilation_rate
padding = "causal" #type of convnet padding
f_s  = 16000     #sampling frequency
T_d  = 32          #window duration
T_s  =  16         #window shift
N_d = int(f_s*T_d*0.001) # window duration (samples).
N_s = int(f_s*T_s*0.001) # window shift (samples).
# N_s=512
NFFT = int(pow(2, np.ceil(np.log2(N_d)))) # number of DFT components.


# PATH VARIABLE
data_path='data/' # Path of the sample_stats file that to be loaded for inference purpose.
test_x_path='set/test_noisy_speech' # Path of the inputs : noisy audio files
out_path='out/denoised/' # Path to which out the output audio file is saved.
model_path='model/saved_model/variables/variables' # Path of the TF Saved_model


test_x, test_x_len, test_x_base_names = Batch(test_x_path) # Fetch the test noisy audio inputs along with its names.


deepxi = DeepXi(N_d=N_d,N_s=N_s,NFFT=NFFT,f_s=f_s,d_model=d_model,n_blocks=n_blocks,d_f=d_f,k=k,max_d_rate =max_d_rate,padding=padding
                ,model_path=model_path,stat_path=data_path) # DeepXi object instantiation

Model: "model"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
inp (InputLayer)                [(None, None, 257)]  0                                            
__________________________________________________________________________________________________
masking (Masking)               (None, None, 257)    0           inp[0][0]                        
__________________________________________________________________________________________________
conv1d (Conv1D)                 (None, None, 256)    65792       masking[0][0]                    
__________________________________________________________________________________________________
layer_normalization (LayerNorma (None, None, 256)    512         conv1d[0][0]                     
______________________________________________________________________________________________

In [14]:
deepxi.infer(test_x=test_x,test_x_len=test_x_len,test_x_base_names=test_x_base_names,out_path= out_path) # perform inference with the above mentioned 'model_variables' and 'path variables' 

100%|██████████| 1/1 [00:00<00:00, 32.29it/s]

Processing observations...
new: 178277 512
new: 178277 512
tf.Tensor([0.         0.         0.         ... 0.00585938 0.00656128 0.00695801], shape=(178277,), dtype=float32)
STFT_fun: [[-2.51617432e-01+0.00000000e+00j -3.71032936e-01-9.25656360e-02j
  -2.18940173e-01-4.50834787e-01j ... -1.65537427e-03+2.27761316e-04j
  -1.75203566e-03+3.37094850e-04j -1.73950195e-03+0.00000000e+00j]
 [ 9.13024902e-01+0.00000000e+00j  7.97804247e-01+5.59360615e-02j
  -7.11486427e-01-1.91519824e-01j ...  1.09433349e-03-1.86153295e-04j
   1.02692287e-03+8.43060155e-05j  1.22070312e-03+0.00000000e+00j]
 [ 8.54492188e-03+0.00000000e+00j -5.86152642e-01-2.34768980e-01j
   7.94574834e-01+1.00447580e+00j ... -2.24173988e-03+1.98118098e-05j
  -2.33020078e-03-1.55497225e-05j -2.50244141e-03+0.00000000e+00j]
 ...
 [ 1.50482178e-01+0.00000000e+00j -3.91519249e-01+3.39205534e-01j
   2.11452270e-01-1.02489362e+00j ...  3.27688613e-04-1.58449246e-04j
   3.63290187e-04-9.81804277e-05j  3.96728516e-04+0.00000000e+00j]






  0%|          | 0/1 [00:00<?, ?it/s]

STFT_dir: tf.Tensor(
[[-1.2238399e-01-1.06991482e-08j -1.3059716e-01-3.25815007e-02j
  -2.5032720e-02-5.15465811e-02j ... -2.9185761e-04+4.01564175e-05j
  -3.2882800e-04+6.32671290e-05j -3.2040564e-04-2.80107500e-11j]
 [ 7.5786912e-01+0.00000000e+00j  3.7666065e-01+2.64086258e-02j
  -1.8284182e-01-4.92178425e-02j ...  6.1007077e-04-1.03777042e-04j
   5.9796002e-04+4.90899838e-05j  7.2527147e-04+0.00000000e+00j]
 [ 7.2749201e-03+0.00000000e+00j -2.4639292e-01-9.86866057e-02j
   8.6310878e-02+1.09111428e-01j ... -1.4821539e-03+1.30986900e-05j
  -1.5793775e-03-1.05393583e-05j -1.5570949e-03-1.36125555e-10j]
 ...
 [ 1.1761992e-01+0.00000000e+00j -1.4519317e-01+1.25792876e-01j
   2.1372659e-02-1.03591710e-01j ...  2.8150564e-04-1.36118120e-04j
   3.1864567e-04-8.61150984e-05j  3.2418635e-04+0.00000000e+00j]
 [ 5.7506555e-01+0.00000000e+00j  1.6593462e-01-8.65320861e-02j
  -8.4698446e-02+5.40084764e-02j ... -6.0378602e-03-7.36743677e-05j
  -6.2574851e-03-1.30932851e-04j -5.9967972e-03-5.2425

100%|██████████| 1/1 [00:00<00:00,  1.64it/s]


In [0]:
STPS=[[2.606935,-2.1490455,-2.4505608,2.4742794,3.1415927 ],
      [2.606935,-2.1490455,-2.4505608,2.4742794,3.1415927 ]]
STMS=[[9.13391076e-03,6.70777448e-03,2.91260912e-05,5.03782830e-05,3.16022124e-05],
      [9.13391076e-03,6.70777448e-03,2.91260912e-05,5.03782830e-05,3.16022124e-05]]

f_s = 16000  #sampling frequency
T_d  = 32          #window duration
T_s  =  16         #window shift
N_d = int(f_s*T_d*0.001) # window duration (samples). (frame size)
N_s = int(f_s*T_s*0.001) # window shift (samples). (hop size)
NFFT =int(pow(2, np.ceil(np.log2(N_d))))  # number of DFT components.
W = functools.partial(window_ops.hamming_window, periodic=False)
print("window_duration:",N_d)
print("window_shift:",N_s)
print("no of DFT components:",NFFT)

    
def polar_synthesis(STMS, STPS):
#     print("gggg::",tf.exp(1j*tf.cast(STPS, tf.complex64)))
    stfts = tf.cast(STMS, tf.complex64)*tf.exp(1j*tf.cast(STPS, tf.complex64))
    print("stfts:",stfts)
    return tf.signal.inverse_stft(stfts, N_d,N_s,NFFT, tf.signal.inverse_stft_window_fn(N_s,W))

g= polar_synthesis(STMS,STPS)
print("\nresult:",g)

In [0]:
x=[1.0,2.0,3.0,4.0,5.0]

f_s = 16000  #sampling frequency
T_d  = 32          #window duration
T_s  =  16         #window shift
N_d = int(f_s*T_d*0.001) # window duration (samples).
# N_s = int(f_s*T_s*0.001) # window shift (samples).
N_s=512
NFFT =int(pow(2, np.ceil(np.log2(N_d))))  # number of DFT components.
W = functools.partial(window_ops.hamming_window, periodic=False)
print("window_duration:",N_d)
print("window_shift:",N_s)
print("no of DFT components:",NFFT)

def polar_analysis(x):

    STFT = tf.signal.stft(x,N_d,N_s,NFFT, window_fn=W, pad_end=True)
    print("STFT_fun:",STFT)
    return tf.abs(STFT), tf.math.angle(STFT)

stms,stps=polar_analysis(x)

window_duration: 512
window_shift: 512
no of DFT components: 512
STFT_fun: tf.Tensor(
[[ 1.20452   +0.j          1.2037327 -0.03945237j  1.2013724 -0.07883725j
   1.1974432 -0.11808729j  1.1919532 -0.15713543j  1.1849126 -0.19591501j
   1.1763349 -0.23435993j  1.1662365 -0.27240473j  1.1546369 -0.3099849j
   1.1415582 -0.34703672j  1.1270252 -0.38349766j  1.111066  -0.41930634j
   1.0937108 -0.45440277j  1.0749928 -0.4887284j   1.0549477 -0.5222262j
   1.0336137 -0.554841j    1.0110312 -0.5865192j   0.9872431 -0.6172093j
   0.9622946 -0.6468617j   0.93623275-0.67542887j  0.90910697-0.7028657j
   0.8809684 -0.7291293j   0.8518701 -0.75417906j  0.8218667 -0.77797675j
   0.7910146 -0.800487j    0.7593714 -0.8216768j   0.7269964 -0.8415158j
   0.6939497 -0.8599764j   0.660293  -0.87703395j  0.6260885 -0.8926664j
   0.5913993 -0.90685457j  0.55628955-0.91958225j  0.52082366-0.9308362j
   0.4850664 -0.9406059j   0.44908312-0.94888395j  0.412939  -0.95566595j
   0.37669963-0.9609502j   0.3404

In [0]:
def polar_synthesis(STMS, STPS):
    stfts = tf.cast(STMS, tf.complex64)*tf.exp(1j*tf.cast(STPS, tf.complex64))
    print("stfts:",stfts)
    return tf.signal.inverse_stft(stfts, N_d,N_s,NFFT, tf.signal.inverse_stft_window_fn(N_s,W))

g= polar_synthesis(stms,stps)
print("\nresult:",g)

In [0]:
def stft(x, Nwin, Nfft=None):
    
    Nfft = Nfft or Nwin
    print(x)
    Nwindows = tf.size(x) // Nwin
    x_len =Nwindows * Nwin
    # reshape into array `Nwin` wide, and as tall as possible. This is
    # optimized for C-order (row-major) layouts.
    arr = np.reshape(x[:x_len], (-1, Nwin))
    stft = fft.rfft(arr, Nfft)
    return stft

Nwin=2
x = tf.constant([0.01010132,0.00436401,-0.01464844,-0.00598145,-0.01898193,-0.02548218,0.01010132,0.00436401,
                 -0.01464844,-0.00598145,-0.01898193,-0.02548218])
# x = tf.random.uniform(shape=(512,),dtype=tf.dtypes.float32)
ss=stft(x,Nwin)
ss

tf.Tensor(
[ 0.01010132  0.00436401 -0.01464844 -0.00598145 -0.01898193 -0.02548218
  0.01010132  0.00436401 -0.01464844 -0.00598145 -0.01898193 -0.02548218], shape=(12,), dtype=float32)


array([[ 0.01446533+0.j,  0.00573731+0.j],
       [-0.02062989+0.j, -0.00866699+0.j],
       [-0.04446411+0.j,  0.00650025+0.j],
       [ 0.01446533+0.j,  0.00573731+0.j],
       [-0.02062989+0.j, -0.00866699+0.j],
       [-0.04446411+0.j,  0.00650025+0.j]])

In [0]:
def istft(stftArr, Nwin):
    print(stftArr.shape)
    arr = fft.irfft(stftArr)[:, :Nwin]
    print(arr.shape)
    return np.reshape(arr, -1)

iss=istft(ss,Nwin)
iss

(6, 2)
(6, 2)


array([ 0.01010132,  0.00436401, -0.01464844, -0.00598145, -0.01898193,
       -0.02548218,  0.01010132,  0.00436401, -0.01464844, -0.00598145,
       -0.01898193, -0.02548218])