In [1]:
import IPython.display as ipd

import numpy as np
import torch


# import stft_64_pad_0 as stft
import stft_64 as stft
from audio_processing import griffin_lim

from scipy.io.wavfile import read
import time
from solver_test import *

In [2]:
hop_length = 4
win_length = 4*hop_length
channels = hop_length*2+1
wav_length = hop_length*8

audio_origin = torch.rand((1,wav_length), dtype=torch.float64)
print(audio_origin.shape)
print(audio_origin[0,:10])

torch.Size([1, 32])
tensor([0.4242, 0.8498, 0.0606, 0.7570, 0.3827, 0.2695, 0.6314, 0.9837, 0.4788,
        0.7080], dtype=torch.float64)


In [3]:
# stft_fn = stft.STFT(filter_length=4, hop_length=1, win_length=4,
#                     window='hann')

stft_fn = stft.STFT(filter_length=win_length, hop_length=hop_length, win_length=win_length,
                    window=None)


def compare(a,b):
    return torch.mean(torch.abs(a-b)), torch.mean((a-b)*(a-b))

def compare_L1(ori,gen):
    return torch.mean(torch.abs(ori-gen)/torch.abs(ori))


def compare_L2(a,b):
    return torch.sum(torch.abs(a-b)), torch.sum((a-b)*(a-b))

In [4]:
magnitude, phase_origin = stft_fn.transform(audio_origin)

forward_basis = stft_fn.forward_basis
print(forward_basis.shape)
start_frame = 0
M_Rc = torch.sum(forward_basis[:channels,0,:] * audio_origin[:,start_frame: start_frame+win_length], dim =1)
M_Ic = torch.sum(forward_basis[channels:,0,:] * audio_origin[:,start_frame: start_frame+win_length], dim =1)
M_c_square = torch.sqrt(M_Rc**2+M_Ic**2)
print(M_c_square.shape)
print(magnitude.shape)
magnitude    = magnitude[:,:,2:7]
phase_origin = phase_origin[:,:,2:7]
print(magnitude[0,:,start_frame])
print(magnitude.shape)


torch.Size([18, 1, 16])
torch.Size([9])
torch.Size([1, 9, 9])
tensor([7.6871, 0.9362, 1.6868, 0.1750, 1.0575, 0.7200, 1.4461, 0.9352, 1.9566],
       dtype=torch.float64)
torch.Size([1, 9, 5])


In [5]:
L = magnitude_to_L(magnitude, stft_fn.forward_basis[:2*hop_length+1,:,:])

TEST_STEP = 2
print(L.shape)
print(L[:,TEST_STEP])
# print(audio_origin[0,:4*hop_length])
# print(stft_fn.forward_basis[:2*hop_length+1,:,:2*hop_length+1])
for i in range(2*hop_length+1):
    print(torch.sum(audio_origin[0,TEST_STEP * hop_length:(TEST_STEP+4) *hop_length] * \
                    torch.roll(audio_origin[0,TEST_STEP * hop_length:(TEST_STEP+4) *hop_length],i)))

torch.Size([9, 5])
tensor([4.2918, 2.7256, 2.8842, 2.6879, 2.9531, 2.7288, 3.3911, 3.1586, 2.5391],
       dtype=torch.float64)
tensor(4.2918, dtype=torch.float64)
tensor(2.7256, dtype=torch.float64)
tensor(2.8842, dtype=torch.float64)
tensor(2.6879, dtype=torch.float64)
tensor(2.9531, dtype=torch.float64)
tensor(2.7288, dtype=torch.float64)
tensor(3.3911, dtype=torch.float64)
tensor(3.1586, dtype=torch.float64)
tensor(2.5391, dtype=torch.float64)


In [6]:
array_part_sin = [channels + i for i in range(1, channels-1)]
array_part_cos = [i for i in range(channels)]
stft_forward_basis = stft_fn.forward_basis[array_part_cos + array_part_sin,0,:]
stft_inv = torch.linalg.inv(stft_forward_basis)

recombine_magnitude_phase = torch.cat(
            [magnitude[0,:,0]*torch.cos(phase_origin[0,:channels,0]), 
             magnitude[0,1:-1,0]*torch.sin(phase_origin[0,1:-1,0])], dim =0)
print(recombine_magnitude_phase.shape)
temp = stft_inv @ recombine_magnitude_phase 
print(temp.shape)
print(temp)
print(audio_origin)

torch.Size([16])
torch.Size([16])
tensor([0.4242, 0.8498, 0.0606, 0.7570, 0.3827, 0.2695, 0.6314, 0.9837, 0.4788,
        0.7080, 0.3457, 0.2514, 0.3277, 0.0867, 0.2142, 0.9158],
       dtype=torch.float64)
tensor([[0.4242, 0.8498, 0.0606, 0.7570, 0.3827, 0.2695, 0.6314, 0.9837, 0.4788,
         0.7080, 0.3457, 0.2514, 0.3277, 0.0867, 0.2142, 0.9158, 0.9147, 0.0239,
         0.6456, 0.0585, 0.6067, 0.3957, 0.7460, 0.2007, 0.7511, 0.5549, 0.0854,
         0.6587, 0.9055, 0.1797, 0.3534, 0.0264]], dtype=torch.float64)


### method 1

In [8]:
class hop_8_solver():
    #magnitude shape: [channels=2H+1, 5]
    #[channels, 0] for [1, hop_length*0: hop_length*4]
    #[channels, 1] for [1, hop_length*1: hop_length*5]
    #[channels, 2] for [1, hop_length*2: hop_length*6]
    #[channels, 3] for [1, hop_length*3: hop_length*7]
    #[channels, 4] for [1, hop_length*4: hop_length*8]
    #x: [channels*2, 1], guessed phase for [channels, 0] and [channels, 4]
    
    def __init__(self, hop_length, channels, win_length, stft_forward_basis):
        self.hop_length = hop_length
        self.channels = channels
        self.win_length = win_length
        self.array_part_cos = [i for i in range(channels)]
        self.array_part_sin = [channels + i for i in range(1, channels-1)]
        self.stft_forward_basis = stft_forward_basis[self.array_part_cos + self.array_part_sin,0,:]
        self.stft_inv = torch.linalg.inv(self.stft_forward_basis)
        
    def test(self, p, L):
        targets = []
        return targets
    
    def recon(self, p):
        recombine_magnitude_phase1 = torch.cat(
                [
                    self.mag[0,[0]]   * self.cos0_0,
                    self.mag[1:-1,0]*torch.cos(p[ :channels-2,0]), 
                    self.mag[-1,[0]]   * self.cos0_channel,
                    self.mag[1:-1,0]*torch.sin(p[:channels-2,0])
                ],
                dim =0)
        part_1_to_4 = self.stft_inv @ recombine_magnitude_phase1
            
        recombine_magnitude_phase2 = torch.cat(
                [
                    self.mag[0,[4]]   * self.cos4_0,
                    self.mag[1:-1,4]*torch.cos(p[channels-2:,0]), 
                    self.mag[-1,[4]]   * self.cos4_channel,
                    self.mag[1:-1,4]*torch.sin(p[channels-2:,0])
                ],
                dim =0)
        part_5_to_8 = self.stft_inv @ recombine_magnitude_phase2
            
        recon = torch.cat((part_1_to_4,part_5_to_8))
        return recon
    
    def error(self, recon):
        # print('recon shape', recon.shape)
        targets = []
        for idx in range(1,4):
            start_idx = self.hop_length * idx
            end_idx = self.hop_length * idx + self.win_length
            for i in range(2*hop_length+1):
                targets.append(
                    torch.sum(recon[start_idx : end_idx] * \
                          torch.roll(recon[start_idx : end_idx],i)
                             ) - self.L[i, idx]
                     )
        return torch.stack(targets)
    
    def func(self, p):
        targets = []
        
        recombine_magnitude_phase1 = torch.cat(
                [
                    self.mag[0,[0]]   * self.cos0_0,
                    self.mag[1:-1,0]*torch.cos(p[ :channels-2,0]), 
                    self.mag[-1,[0]]   * self.cos0_channel,
                    self.mag[1:-1,0]*torch.sin(p[:channels-2,0])
                ],
                dim =0)
        part_1_to_4 = self.stft_inv @ recombine_magnitude_phase1
            
        recombine_magnitude_phase2 = torch.cat(
                [
                    self.mag[0,[4]]   * self.cos4_0,
                    self.mag[1:-1,4]*torch.cos(p[channels-2:,0]), 
                    self.mag[-1,[4]]   * self.cos4_channel,
                    self.mag[1:-1,4]*torch.sin(p[channels-2:,0])
                ],
                dim =0)
        part_5_to_8 = self.stft_inv @ recombine_magnitude_phase2
            
        recon = torch.cat((part_1_to_4,part_5_to_8))
        # print('recon', recon)
        
        for idx in range(1,4):
            start_idx = self.hop_length * idx
            end_idx = self.hop_length * idx + self.win_length
            for i in range(2*hop_length+1):
                targets.append(
                    torch.sum(recon[start_idx : end_idx] * \
                          torch.roll(recon[start_idx : end_idx],i)
                             ) - self.L[i, idx]
                     )
        
        return torch.stack(targets)
    
    def solve(self, cos0_0, cos0_channel, cos4_0, cos4_channel, mag, L, initial_guess, n_iters=50, lambda_JTJ=1):
        #magnitude shape:      [hop_length*2+1, 5]
        #initial_guess shape:  [(channels-2)*2, 1]
        
        self.L = L
        start_time = time.time()
        self.mag = mag
        self.cos0_0, self.cos0_channel, self.cos4_0, self.cos4_channel = cos0_0, cos0_channel, cos4_0, cos4_channel
        
        
        for i in range(n_iters):
            if (i+1)%10==0:
                print('\rIter %d/%d: Used times: %.2f' %(i,n_iters,time.time()-start_time), end="")
                print('target', target)
                print('avg target', np.sum(np.abs(target)))
                
            x = torch.tensor(initial_guess, dtype=torch.float64, requires_grad = True) 
            # print(recon)
            
            J = torch.autograd.functional.jacobian(self.func, x)
            # print('J shape', J.shape)
           
            J = J.squeeze(2)
            J = J.detach()
            # print('')
            # print('J shape', J.shape)
            # print('J', J)
            target = self.func(x).detach().numpy()
            
            #1 method
            # Q, R = np.linalg.qr(J, mode='reduced')
            # Qb = np.matmul(Q.T, target)
            # # print('target', target)
            # # print('avg target', np.sum(np.abs(target)))
            # # print(R)
            # minus = np.linalg.solve(R,Qb)
            # # print('minus.shape', minus.shape)
            # initial_guess = initial_guess - np.expand_dims(minus, 1)
            # # print(J, target)
            
            #2 method
            temp = J.T @ J
            minus = (torch.inverse(temp + lambda_JTJ * torch.diag(torch.diag(temp, 0))))@ (J.T @ target)
            minus = minus.numpy()
            
            # print(minus.shape)
            # print(initial_guess.shape)
            initial_guess = initial_guess - np.expand_dims(minus, 1)
            # print(initial_guess.shape)
        
        # print((self.recon(initial_guess)).shape)
        return self.recon(initial_guess).unsqueeze(0)
    
print(audio_origin[0, :])
newton_method = hop_8_solver(hop_length, channels, win_length, stft_forward_basis = stft_fn.forward_basis)
initial_guess = torch.cat([phase_origin[0,1:-1,[0]], phase_origin[0,1:-1,[4]]], dim =0)
ans = newton_method.solve(torch.cos(phase_origin[0,0,0]),
                          torch.cos(phase_origin[0,-1,0]),
                          torch.cos(phase_origin[0,0,4]),
                          torch.cos(phase_origin[0,-1,4]),
                          magnitude[0, :,:], L, initial_guess, n_iters=0, lambda_JTJ=0)

print('recon', newton_method.recon(initial_guess))
print('error', newton_method.func(initial_guess))

tensor([0.4242, 0.8498, 0.0606, 0.7570, 0.3827, 0.2695, 0.6314, 0.9837, 0.4788,
        0.7080, 0.3457, 0.2514, 0.3277, 0.0867, 0.2142, 0.9158, 0.9147, 0.0239,
        0.6456, 0.0585, 0.6067, 0.3957, 0.7460, 0.2007, 0.7511, 0.5549, 0.0854,
        0.6587, 0.9055, 0.1797, 0.3534, 0.0264], dtype=torch.float64)
recon tensor([0.4242, 0.8498, 0.0606, 0.7570, 0.3827, 0.2695, 0.6314, 0.9837, 0.4788,
        0.7080, 0.3457, 0.2514, 0.3277, 0.0867, 0.2142, 0.9158, 0.9147, 0.0239,
        0.6456, 0.0585, 0.6067, 0.3957, 0.7460, 0.2007, 0.7511, 0.5549, 0.0854,
        0.6587, 0.9055, 0.1797, 0.3534, 0.0264], dtype=torch.float64)
error tensor([ 5.3291e-15, -4.4409e-16, -1.7764e-15, -3.9968e-15,  4.4409e-16,
        -4.4409e-16,  2.6645e-15, -2.6645e-15,  4.4409e-16,  1.7764e-15,
        -2.2204e-15, -4.4409e-15, -5.3291e-15, -2.2204e-15, -2.2204e-15,
        -8.8818e-16, -3.9968e-15, -1.7764e-15,  3.5527e-15, -2.2204e-15,
        -3.5527e-15, -4.8850e-15, -1.3323e-15, -1.3323e-15, -4.4409e-16,
   

In [9]:
initial_guess = torch.rand(((channels-2)*2, 1)) * torch.pi
# initial_guess = torch.cat([phase_origin[0,1:-1,[0]], phase_origin[0,1:-1,[4]]], dim =0) + \
#                 0.*torch.rand(((channels-2)*2, 1))

newton_method = hop_8_solver(hop_length, channels, win_length, stft_forward_basis = stft_fn.forward_basis)
ans = newton_method.solve(torch.cos(phase_origin[0,0,0]),
                          torch.cos(phase_origin[0,-1,0]),
                          torch.cos(phase_origin[0,0,4]),
                          torch.cos(phase_origin[0,-1,4]),
                          magnitude[0, :,:], L, initial_guess, n_iters=20, lambda_JTJ=0)

print('\n')
print('ans   ', ans[0, :])
print('origin', audio_origin[0, :])
print('error       ', newton_method.error(ans[0, :]))
print('error origin', newton_method.error(audio_origin[0, :]))
# print('ans part    3', ans[0,3*hop_length:4*hop_length])
# print('origin part 3', audio_origin[0, 3*hop_length:4*hop_length])



Iter 9/20: Used times: 0.25target [-0.2340149   0.14791865  0.40136501  0.10178723  0.34207428  0.24944167
  0.15713371  0.109893   -0.10629829 -0.1608375  -0.10541091 -0.21131065
  0.36886166  0.05440102  0.12159605 -0.26331697  0.06307748 -0.03842696
 -0.15358998  0.00086287  0.08756503  0.23036176  0.09400955  0.15018598
  0.24233606 -0.24811426 -0.0147288 ]
avg target 4.4589202417342255
Iter 19/20: Used times: 0.53target [-3.01217562e-01  3.37444821e-02  3.07865178e-01 -2.67707890e-02
  2.22298677e-01  1.33819427e-01  9.23933599e-02  5.99921857e-02
 -1.82788437e-01 -4.25323582e-02 -8.82001635e-02 -1.45868098e-01
  4.27429895e-01  7.85418333e-02  1.45449153e-01 -2.07373580e-01
  1.18932859e-01  3.81615264e-04 -1.22599054e-01  1.22818850e-02
  9.30313587e-02  2.34604129e-01  1.01526639e-01  1.46809925e-01
  2.72387845e-01 -2.41504459e-01 -1.95644977e-02]
avg target 3.8599094461818306


ans    tensor([ 0.2870,  0.6828,  0.2467,  0.8308,  1.0250,  0.4258,  0.5264,  0.4046,
        -0.0

In [11]:

def griffin_lim(mag, stft_fn, n_iters=30):
    
    angles = torch.rand(mag.shape, dtype = torch.float64)
    # angles[:,:] = phase_origin[0, :,:] + 1*torch.rand(mag.shape, dtype = torch.float64)
    # angles[:,:] = phase_origin[0, :,:] 
    inverse_basis = stft_fn.inverse_basis.squeeze(1) * 4
    divident = torch.zeros((1, 8*hop_length), dtype = torch.float64)
    for i in range(5):
        divident[0, hop_length*i : hop_length*i + win_length] = \
                   divident[0, hop_length*i : hop_length*i + win_length]+1
    # print('divident', divident)
    for n_iter in range(n_iters):
        if (n_iter+1)%10==0:
            # f.write('%d/%d:%.4f'%(i,n_iters,compare_L1(magnitude,MAG)))
            a1,a2 = compare(signal, audio_origin)
            print('%d/%d:%.4f, %.4f'%(n_iter,n_iters,a1,a2))
        signal = torch.zeros((1, 8*hop_length), dtype = torch.float64)
        recombine_magnitude_phase = torch.cat(
            [mag*torch.cos(angles), mag*torch.sin(angles)], dim=0)
        
        # print(recombine_magnitude_phase.shape)
        # print(inverse_basis.shape)
        for i in range(5):
            signal[0, hop_length*i : hop_length*i + win_length] = \
                   signal[0, hop_length*i : hop_length*i + win_length] + \
                   (inverse_basis.T @ recombine_magnitude_phase[:, [i]]).T
            # print((inverse_basis.T @ recombine_magnitude_phase[:, [i]]).T[0,:])
            # break
        signal = signal/divident
        # break
        
        _, angles = stft_fn.transform(signal)
        angles = angles[0, :, 2:7]
    return signal, angles

griffin_ans, griffin_angles = griffin_lim(magnitude[0,:,:], stft_fn, n_iters=100)
print('\n')
print('ans   ', griffin_ans[0, :])
print('origin', audio_origin[0, :])

9/100:0.3538, 0.1940
19/100:0.3674, 0.2051
29/100:0.3683, 0.2065
39/100:0.3689, 0.2054
49/100:0.3694, 0.2041
59/100:0.3690, 0.2024
69/100:0.3675, 0.2003
79/100:0.3650, 0.1981
89/100:0.3650, 0.1960
99/100:0.3652, 0.1948


ans    tensor([ 0.5975,  0.4171,  0.8631,  0.5848,  0.6656,  0.6040,  0.0889,  0.0320,
         0.5384,  0.1427,  0.6574,  0.2104,  1.1109,  0.5060,  0.1009,  0.4747,
         0.3029,  0.2906,  0.6853,  0.8538,  0.2529,  0.1195,  0.5875,  0.1184,
         0.7928,  0.0836,  0.5187,  0.3026,  0.8740, -0.0480,  0.9668,  0.5253],
       dtype=torch.float64)
origin tensor([0.4242, 0.8498, 0.0606, 0.7570, 0.3827, 0.2695, 0.6314, 0.9837, 0.4788,
        0.7080, 0.3457, 0.2514, 0.3277, 0.0867, 0.2142, 0.9158, 0.9147, 0.0239,
        0.6456, 0.0585, 0.6067, 0.3957, 0.7460, 0.2007, 0.7511, 0.5549, 0.0854,
        0.6587, 0.9055, 0.1797, 0.3534, 0.0264], dtype=torch.float64)


In [19]:
# initial_guess = torch.rand(((channels-2)*2, 1)) * torch.pi
initial_guess = torch.cat([griffin_angles[1:-1,[0]], griffin_angles[1:-1,[4]]], dim =0)

newton_method = hop_8_solver(hop_length, channels, win_length, stft_forward_basis = stft_fn.forward_basis)
ans = newton_method.solve(torch.cos(phase_origin[0,0,0]),
                          torch.cos(phase_origin[0,-1,0]),
                          torch.cos(phase_origin[0,0,4]),
                          torch.cos(phase_origin[0,-1,4]),
                          magnitude[0, :,:], L, initial_guess, n_iters=100, lambda_JTJ=0)

print('\n')
print('ans   ', ans[0, :])
print('origin', audio_origin[0, :])
print('error       ', newton_method.error(ans[0, :]))
print('error origin', newton_method.error(audio_origin[0, :]))
# print('ans part    3', ans[0,3*hop_length:4*hop_length])
# print('origin part 3', audio_origin[0, 3*hop_length:4*hop_length])



Iter 9/100: Used times: 0.25target [-0.39660411  0.28907039  0.32774638  0.47761056  0.24082735  0.43773888
  0.31839581 -0.40602678  0.30534204  0.01181641  0.70180364  0.12273286
  0.75068731  0.49042385  0.34353426 -0.18746268  0.22531115  0.53040692
 -0.14485491  0.09350456  0.30612155  0.32163734  0.22621337  0.32714608
  0.01377993 -0.0079995   0.07440662]
avg target 8.079205227425991
Iter 19/100: Used times: 0.53target [-0.03564264  0.15041255  0.54479453  0.3436594   0.5689903   0.40001172
  0.27150355  0.13440493 -0.10724524  0.27475873  0.19727713  0.12366838
  0.53686761  0.17438751  0.2947527  -0.13092242  0.10490154  0.08266359
  0.25492212  0.29830605  0.26853639  0.56239983  0.55394963  0.44531478
  0.13710545  0.09822899  0.42759543]
avg target 7.5232231205955475
Iter 29/100: Used times: 0.81target [-0.20596562  0.19014217  0.30590786  0.29000405  0.42600108  0.16847578
  0.39341625 -0.08614728  0.10732678 -0.62888002 -0.29898718 -0.21948391
  0.16124648 -0.48706904 -0.