In [1]:
import IPython.display as ipd

import numpy as np
import torch

from audio_processing import griffin_lim

from scipy.io.wavfile import read
import time
from reconstruct_functions import *
import random
import os

def set_seed(seed: int = 42) -> None:
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    os.environ["PYTHONHASHSEED"] = str(seed)
    print(f"Random seed set as {seed}")

set_seed(114514)

Random seed set as 114514


In [2]:
hop_length = 4
win_length = 16 ###at least 4*hop_length, even number
full_length = win_length + 4*hop_length

def get_Yw(input_signal, window, hop_length =1, win_length=4):
    # input shape: [1,T]
    # output shape: [(T+N-1)/H, T/2+1]
    T = input_signal.shape[-1]
    H = hop_length
    N = win_length
    Yw_real = torch.zeros(((T+N-1)//H, T//2+1), dtype=torch.float64)
    Yw_imag = torch.zeros(((T+N-1)//H, T//2+1), dtype=torch.float64)
    # pi = torch.acos(torch.zeros(1, dtype=torch.float64)).item() * 2
    for m in range((T+N-1)//H):
        for n in range(T//2+1):
            for t in range(T):
                if m*H-t<0 or m*H-t>=N:
                    continue
                Yw_real[m,n] = Yw_real[m,n]+ input_signal[0,t]*window[m*H-t]*np.cos(-2*np.pi*n*t/T)
                Yw_imag[m,n] = Yw_imag[m,n]+ input_signal[0,t]*window[m*H-t]*np.sin(-2*np.pi*n*t/T)
    return torch.sqrt(Yw_real**2 + Yw_imag**2)

In [3]:


def compute_reconstruction(input_signal, window):
    # shape: [1,:]
    magnitude  = get_Yw(input_signal, window, hop_length, win_length)
    
    ans = reconstruct_from_Yw_with_H_4(magnitude, window, hop_length)
    
    if ans[0,0]*input_signal[0,0]<0:
        ans[:,:] = -ans[:,:]
    print('ans', ans)
    
    return ans


In [4]:
test_time = 5
data_arr = []

for idx in range(test_time):
    print('######', idx)
    audio_origin = torch.rand((1,full_length), dtype=torch.float64)*2-1
    
    window = torch.rand((win_length), dtype=torch.float64)
    print('signal', audio_origin)
    print('window', window)
    ans = compute_reconstruction(audio_origin, window)
        
    data_arr.append(torch.mean(torch.abs(audio_origin[0,:] - ans[0,:])))
    print(torch.mean(torch.abs(audio_origin[0,:] - ans[0,:])))

print(np.mean(data_arr))
print(np.var(data_arr))

###### 0
signal tensor([[ 0.0481,  0.6044,  0.4005,  0.3006,  0.1836, -0.8655,  0.7392, -0.5941,
          0.0399,  0.1515,  0.5950, -0.8755, -0.2879, -0.1142,  0.7590,  0.0166,
         -0.5702, -0.4450, -0.0190, -0.6339,  0.7961,  0.1058, -0.3611, -0.4709,
         -0.2687, -0.1315, -0.4979,  0.7037, -0.3971, -0.1072, -0.9203, -0.4680]],
       dtype=torch.float64)
window tensor([0.3985, 0.4288, 0.2968, 0.2912, 0.6184, 0.2165, 0.6063, 0.7010, 0.2264,
        0.6962, 0.4151, 0.8090, 0.0754, 0.5051, 0.9403, 0.2933],
       dtype=torch.float64)
solutions [0.35373115 0.17596547 0.08309123]


100%|██████████| 6/6 [00:00<00:00, 1304.47it/s]


ans [[ 0.04805358  0.60435092  0.40050719  0.300559    0.18361471 -0.86551264
   0.73915877 -0.59405125  0.03991664  0.15152053  0.59503475 -0.87545772
  -0.28792413 -0.11418192  0.7589735   0.01663292 -0.5702101  -0.44499301
  -0.01900145 -0.63386279  0.79608359  0.10575303 -0.36105869 -0.47088567
  -0.26865222 -0.13154937 -0.4978787   0.7037439  -0.39705772 -0.1071553
  -0.92030933 -0.46798494]]
tensor(1.0637e-11, dtype=torch.float64)
###### 1
signal tensor([[ 0.2323,  0.1480,  0.6167, -0.4830, -0.0279,  0.9758, -0.3849,  0.1666,
         -0.9009, -0.5043,  0.4210,  0.4831, -0.1558, -0.7055, -0.5264, -0.1842,
          0.1742, -0.7626,  0.1707,  0.2971,  0.2528,  0.4621, -0.0382,  0.6119,
          0.8691,  0.1859,  0.8171, -0.6323, -0.4716, -0.9533, -0.4446, -0.5064]],
       dtype=torch.float64)
window tensor([0.7726, 0.1040, 0.1687, 0.3434, 0.1976, 0.0815, 0.7567, 0.2475, 0.5623,
        0.7351, 0.7363, 0.3513, 0.7720, 0.1842, 0.5002, 0.8010],
       dtype=torch.float64)
solutions

100%|██████████| 6/6 [00:00<00:00, 1328.85it/s]

ans [[ 0.23225625  0.14802665  0.61666715 -0.4829868  -0.02792985  0.97577104
  -0.38487804  0.16657397 -0.90094731 -0.50433919  0.42097012  0.48313482
  -0.15583536 -0.70549755 -0.52643042 -0.18417989  0.17421621 -0.7625536
   0.17065796  0.2971199   0.25277608  0.46205909 -0.03818923  0.61186638
   0.86908708  0.18585854  0.81710325 -0.63225933 -0.47155419 -0.95325027
  -0.44459225 -0.50644099]]
tensor(1.6389e-11, dtype=torch.float64)
###### 2
signal tensor([[ 0.3307,  0.0851, -0.6886,  0.5982,  0.5215, -0.6493, -0.5949, -0.1975,
         -0.4121,  0.3747, -0.1560,  0.4211,  0.3695, -0.8292,  0.9244, -0.1416,
         -0.1708, -0.1730,  0.8796,  0.3604, -0.5419,  0.3376,  0.9650, -0.9202,
         -0.2147, -0.5329, -0.3018, -0.4644, -0.0391, -0.4169,  0.6877, -0.3169]],
       dtype=torch.float64)
window tensor([0.7129, 0.5997, 0.7102, 0.6913, 0.8777, 0.4153, 0.4426, 0.4991, 0.1047,
        0.0734, 0.8764, 0.4606, 0.7445, 0.1401, 0.4872, 0.4990],
       dtype=torch.float64)





solutions [1.47134678 0.35523834 0.05885804]


100%|██████████| 6/6 [00:00<00:00, 1349.23it/s]


ans [[ 0.33074495  0.0851445  -0.68859681  0.59823783  0.52154045 -0.64927282
  -0.59488788 -0.19748764 -0.41210933  0.37465445 -0.15598469  0.42107842
   0.36949566 -0.82918551  0.92439093 -0.1416182  -0.17083722 -0.17300594
   0.87961974  0.36036377 -0.54189285  0.33755741  0.96501251 -0.92021382
  -0.21467471 -0.53289109 -0.30176181 -0.46444415 -0.0391089  -0.4168767
   0.6876735  -0.31685244]]
tensor(8.6924e-14, dtype=torch.float64)
###### 3
signal tensor([[-0.1677, -0.7570,  0.2763, -0.4136, -0.9374,  0.4671, -0.4160,  0.1631,
          0.0493, -0.7231, -0.3512,  0.1391,  0.0035,  0.5673,  0.3027,  0.3263,
          0.6122,  0.8561,  0.5883,  0.1585, -0.4953, -0.2611,  0.0256,  0.8444,
          0.5286,  0.3846, -0.5044,  0.6252, -0.1874,  0.6542, -0.9586, -0.7692]],
       dtype=torch.float64)
window tensor([0.5783, 0.2318, 0.0118, 0.4411, 0.8580, 0.4274, 0.4554, 0.1032, 0.5633,
        0.4361, 0.6102, 0.7283, 0.3875, 0.3943, 0.4718, 0.4039],
       dtype=torch.float64)
solutions

100%|██████████| 6/6 [00:00<00:00, 1364.08it/s]


ans [[-0.16767564 -0.7570356   0.27625407 -0.413572   -0.9374164   0.4670854
  -0.41600202  0.16314604  0.04926667 -0.72314992 -0.35116966  0.13906474
   0.00353267  0.5673161   0.3027352   0.32629265  0.61219725  0.85610014
   0.58832306  0.15851556 -0.49532775 -0.26108596  0.0256317   0.84440822
   0.52862697  0.38458766 -0.50450909  0.62524662 -0.1874114   0.65414049
  -0.9582668  -0.76911944]]
tensor(1.5204e-05, dtype=torch.float64)
###### 4
signal tensor([[-0.0043,  0.6346,  0.7073, -0.9660,  0.0171,  0.3421,  0.1836,  0.3658,
         -0.3397,  0.0538,  0.5210,  0.6106, -0.6203, -0.2729, -0.1503,  0.3346,
          0.9652, -0.2245, -0.2097, -0.4393,  0.6255,  0.7256,  0.8333,  0.0933,
         -0.7047, -0.1916, -0.0537, -0.3560,  0.4116,  0.7621,  0.0089,  0.8748]],
       dtype=torch.float64)
window tensor([0.4215, 0.4269, 0.6516, 0.4784, 0.8769, 0.0634, 0.3327, 0.0479, 0.1534,
        0.9541, 0.8723, 0.8969, 0.2035, 0.9535, 0.4062, 0.3093],
       dtype=torch.float64)
solutions

100%|██████████| 6/6 [00:00<00:00, 1364.96it/s]

ans [[-0.00431623  0.6346376   0.70733918 -0.96595425  0.01710323  0.34211519
   0.18357696  0.36580411 -0.33973766  0.05380399  0.52097951  0.61058941
  -0.62027262 -0.27293019 -0.15032804  0.33464946  0.96522553 -0.22448519
  -0.20970326 -0.43934176  0.62554235  0.72555783  0.83329289  0.0933153
  -0.70469592 -0.19156814 -0.05371406 -0.35604546  0.41163348  0.76210314
   0.00890602  0.87478468]]
tensor(8.4506e-07, dtype=torch.float64)
3.2097469814675906e-06
3.6070558777653493e-11



