In [1]:
import IPython.display as ipd

import numpy as np
import torch

from audio_processing import griffin_lim

from scipy.io.wavfile import read
import time
from reconstruct_functions import *
import random
import os

def set_seed(seed: int = 42) -> None:
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    os.environ["PYTHONHASHSEED"] = str(seed)
    print(f"Random seed set as {seed}")

set_seed(114514)

Random seed set as 114514


In [2]:
hop_length = 1
win_length = 4 ###at least 4*hop_length, even number
full_length = win_length + 4*hop_length

def get_Yw(input_signal, window, hop_length =1, win_length=4):
    # input shape: [1,T]
    # output shape: [(T+N-1)/H, T/2+1]
    T = input_signal.shape[-1]
    H = hop_length
    N = win_length
    Yw_real = torch.zeros(((T+N-1)//H, T//2+1), dtype=torch.float64)
    Yw_imag = torch.zeros(((T+N-1)//H, T//2+1), dtype=torch.float64)
    # pi = torch.acos(torch.zeros(1, dtype=torch.float64)).item() * 2
    for m in range((T+N-1)//H):
        for n in range(T//2+1):
            for t in range(T):
                if m*H-t<0 or m*H-t>=N:
                    continue
                Yw_real[m,n] = Yw_real[m,n]+ input_signal[0,t]*window[m*H-t]*np.cos(-2*np.pi*n*t/T)
                Yw_imag[m,n] = Yw_imag[m,n]+ input_signal[0,t]*window[m*H-t]*np.sin(-2*np.pi*n*t/T)
    return torch.sqrt(Yw_real**2 + Yw_imag**2)

In [3]:


def compute_reconstruction(input_signal, window):
    # shape: [1,:]
    magnitude  = get_Yw(input_signal, window, hop_length, win_length)
    
    ans = reconstruct_from_Yw_with_H_1(magnitude, window, hop_length)
    
    if ans[0,0]*input_signal[0,0]<0:
        ans[:,:] = -ans[:,:]
    print('ans', ans)
    
    return ans


In [4]:
test_time = 5
data_arr = []

for idx in range(test_time):
    audio_origin = torch.rand((1,full_length), dtype=torch.float64)*2-1
    
    window = torch.rand((win_length), dtype=torch.float64)
    print('signal', audio_origin)
    print('window', window)
    ans = compute_reconstruction(audio_origin, window)
        
    data_arr.append(torch.mean(torch.abs(audio_origin[0,:] - ans[0,:])))
    print(torch.mean(torch.abs(audio_origin[0,:] - ans[0,:])))

print(np.mean(data_arr))
print(np.var(data_arr))

signal tensor([[ 0.0481,  0.6044,  0.4005,  0.3006,  0.1836, -0.8655,  0.7392, -0.5941]],
       dtype=torch.float64)
window tensor([0.5200, 0.5758, 0.7975, 0.0623], dtype=torch.float64)


100%|██████████| 6/6 [00:00<00:00, 4935.44it/s]


ans [[ 0.04805358  0.60435092  0.40050719  0.300559    0.18361471 -0.86551264
   0.73915877 -0.59405125]]
tensor(6.7741e-16, dtype=torch.float64)
signal tensor([[-0.2879, -0.1142,  0.7590,  0.0166, -0.5702, -0.4450, -0.0190, -0.6339]],
       dtype=torch.float64)
window tensor([0.8980, 0.5529, 0.3195, 0.2646], dtype=torch.float64)


100%|██████████| 6/6 [00:00<00:00, 5212.47it/s]


ans [[-0.28792413 -0.11418192  0.7589735   0.01663292 -0.5702101  -0.44499301
  -0.01900145 -0.63386279]]
tensor(2.3597e-15, dtype=torch.float64)
signal tensor([[-0.2687, -0.1315, -0.4979,  0.7037, -0.3971, -0.1072, -0.9203, -0.4680]],
       dtype=torch.float64)
window tensor([0.3985, 0.4288, 0.2968, 0.2912], dtype=torch.float64)


100%|██████████| 6/6 [00:00<00:00, 3994.58it/s]


ans [[-0.26865222 -0.13154937 -0.4978787   0.7037439  -0.39705772 -0.1071553
  -0.92030933 -0.46798494]]
tensor(3.9378e-16, dtype=torch.float64)
signal tensor([[ 0.2368, -0.5671,  0.2125,  0.4021, -0.5473,  0.3924, -0.1698,  0.6180]],
       dtype=torch.float64)
window tensor([0.0754, 0.5051, 0.9403, 0.2933], dtype=torch.float64)


100%|██████████| 6/6 [00:00<00:00, 5470.83it/s]


ans [[ 0.23683362 -0.56706514  0.21251349  0.40205853 -0.54726437  0.39241571
  -0.16983447  0.61804083]]
tensor(5.6670e-11, dtype=torch.float64)
signal tensor([[ 0.2323,  0.1480,  0.6167, -0.4830, -0.0279,  0.9758, -0.3849,  0.1666]],
       dtype=torch.float64)
window tensor([0.0495, 0.2478, 0.7105, 0.7416], dtype=torch.float64)


100%|██████████| 6/6 [00:00<00:00, 5729.92it/s]

ans [[ 0.23225625  0.14802665  0.61666715 -0.4829868  -0.02792985  0.97577104
  -0.38487804  0.16657397]]
tensor(1.0732e-11, dtype=torch.float64)
1.348107936047227e-11
4.835896387642082e-22



