In [1]:
import IPython.display as ipd

import numpy as np
import torch


# import stft_64_pad_0 as stft
import stft_64 as stft
from audio_processing import griffin_lim

from scipy.io.wavfile import read
import time

In [2]:
hop_length = 4
win_length = 4*hop_length
channels = hop_length*2+1
wav_length = hop_length*7

audio_origin = torch.rand((1,wav_length), dtype=torch.float64)
print(audio_origin.shape)
print(audio_origin[0,:10])

torch.Size([1, 28])
tensor([0.6863, 0.3328, 0.8530, 0.9166, 0.9978, 0.2222, 0.9217, 0.1788, 0.3215,
        0.2862], dtype=torch.float64)


In [3]:
# stft_fn = stft.STFT(filter_length=4, hop_length=1, win_length=4,
#                     window='hann')

stft_fn = stft.STFT(filter_length=win_length, hop_length=hop_length, win_length=win_length,
                    window=None)


def compare(a,b):
    return torch.mean(torch.abs(a-b)), torch.mean((a-b)*(a-b))

def compare_L1(ori,gen):
    return torch.mean(torch.abs(ori-gen)/torch.abs(ori))


def compare_L2(a,b):
    return torch.sum(torch.abs(a-b)), torch.sum((a-b)*(a-b))

In [4]:
magnitude, phase_origin = stft_fn.transform(audio_origin)

forward_basis = stft_fn.forward_basis
print(forward_basis.shape)
start_frame = 0
M_Rc = torch.sum(forward_basis[:channels,0,:] * audio_origin[:,start_frame: start_frame+win_length], dim =1)
M_Ic = torch.sum(forward_basis[channels:,0,:] * audio_origin[:,start_frame: start_frame+win_length], dim =1)
M_c_square = torch.sqrt(M_Rc**2+M_Ic**2)
print(M_c_square.shape)
magnitude    = magnitude[:,:,2:6]
phase_origin = phase_origin[:,:,2:6]
print(magnitude[0,:,start_frame])
print(magnitude.shape)


torch.Size([18, 1, 16])
torch.Size([9])
tensor([10.2010,  0.7522,  1.8549,  0.2704,  0.8865,  1.3528,  0.3495,  1.4804,
         1.6199], dtype=torch.float64)
torch.Size([1, 9, 4])


In [5]:
class hop_6_solver():
    #magnitude shape: [channels=2H+1, 4]
    #[channels, 0] for [1, hop_length*0: hop_length*4]
    #[channels, 1] for [1, hop_length*1: hop_length*5]
    #[channels, 2] for [1, hop_length*2: hop_length*6]
    #[channels, 3] for [1, hop_length*3: hop_length*7]
    #x: [hop_length*7,1]
    
    def __init__(self, hop_length, channels, win_length, stft_fn):
        self.hop_length = hop_length
        self.channels = channels
        self.win_length = win_length
        self.stft_fn = stft_fn
    
    def solve(self, mag, initial_guess, n_iters=50, lambda_JTJ=0):
        #magnitude shape:      [hop_length*2+1, 4]
        #initial_guess shape:  [1, 7*hop_length]
        forward_basis = self.stft_fn.forward_basis
        
        start_time = time.time()
        for n_iter_idx in range(n_iters):
            if (n_iter_idx+1)%10==0:
                print('Iter %d/%d: Used times: %.2f, Loss:%.8f' %(n_iter_idx,n_iters,time.time()-start_time, 
                                                                  torch.sum(torch.abs(targets)))
                     )
            targets = torch.zeros((self.channels*4, 1))
            A_matrix = torch.zeros((self.channels*4, 7*self.hop_length))
            for i in range(4):
                start_hop = self.hop_length*i
                end_hop   = self.hop_length*i + self.win_length
                R_ic = torch.sum(forward_basis[:self.channels,0,:] * initial_guess[start_hop:end_hop, 0].unsqueeze(0), dim =1)
                I_ic = torch.sum(forward_basis[self.channels:,0,:] * initial_guess[start_hop:end_hop, 0].unsqueeze(0), dim =1)
                
                targets[self.channels*i:self.channels*(i+1),0] = (mag[:, i]**2 - (R_ic**2+I_ic**2))/2
                
                A_matrix[self.channels*i:self.channels*(i+1), start_hop:end_hop] = \
                     forward_basis[:self.channels,0,:] * R_ic.unsqueeze(1) + \
                     forward_basis[self.channels:,0,:] * I_ic.unsqueeze(1)
                
            # print('target', targets)
            # print('A_matrix', A_matrix)
            temp = A_matrix.T @ A_matrix
            # print('temp', temp)
            add_ons = (torch.inverse(temp)) @ (A_matrix.T @ targets)
            # print('add_ons', add_ons)
            initial_guess = initial_guess + add_ons
            
        return initial_guess

In [6]:
initial_guess = torch.rand((7*hop_length,1), dtype = torch.float64)
# initial_guess[:,0] = audio_origin[0, :] + initial_guess[:, 0]

newton_method = hop_6_solver(hop_length, channels, win_length, stft_fn)
ans = newton_method.solve(magnitude[0, :,:], initial_guess, n_iters=100)
print('\n')
print('ans   ', ans[:, 0])
print('origin', audio_origin[0, :])
# print('error       ', newton_method.func(ans))
# print('error origin', newton_method.func(audio_origin))
# print('ans part    3', ans[0,3*hop_length:4*hop_length])
# print('origin part 3', audio_origin[0, 3*hop_length:4*hop_length])

Iter 9/100: Used times: 0.00, Loss:15.11869812
Iter 19/100: Used times: 0.01, Loss:13.66700172
Iter 29/100: Used times: 0.01, Loss:11.73900986
Iter 39/100: Used times: 0.02, Loss:12.38022137
Iter 49/100: Used times: 0.02, Loss:11.75629139
Iter 59/100: Used times: 0.02, Loss:12.40001011
Iter 69/100: Used times: 0.03, Loss:11.75977898
Iter 79/100: Used times: 0.03, Loss:12.39877129
Iter 89/100: Used times: 0.03, Loss:11.75954628
Iter 99/100: Used times: 0.04, Loss:12.39823437


ans    tensor([ 0.1445,  0.3448,  1.0901,  0.9040,  0.6132,  1.3468,  0.1708,  1.0450,
         0.3204,  0.4823,  0.4504,  0.8784,  0.7010,  0.5576,  0.7759,  0.3750,
         0.8527,  0.2357, -0.2118,  0.5502, -0.2421,  0.2937,  0.6919,  0.8209,
        -0.1437,  0.3829,  0.3701,  0.7501], dtype=torch.float64)
origin tensor([0.6863, 0.3328, 0.8530, 0.9166, 0.9978, 0.2222, 0.9217, 0.1788, 0.3215,
        0.2862, 0.9611, 0.7934, 0.7195, 0.9254, 0.4497, 0.6352, 0.2595, 0.1618,
        0.8676, 0.4410, 0.0720, 0.1916,

In [11]:

def griffin_lim(mag, stft_fn, n_iters=30):
    
    angles = torch.rand(mag.shape, dtype = torch.float64)
    # angles[:,:] = phase_origin[0, :,:] + 1*torch.rand(mag.shape, dtype = torch.float64)
    # angles[:,:] = phase_origin[0, :,:] 
    inverse_basis = stft_fn.inverse_basis.squeeze(1) * 4
    divident = torch.zeros((1, 7*hop_length), dtype = torch.float64)
    for i in range(4):
        divident[0, hop_length*i : hop_length*i + win_length] = \
                   divident[0, hop_length*i : hop_length*i + win_length]+1
    # print('divident', divident)
    for n_iter in range(n_iters):
        if (n_iter+1)%10==0:
            # f.write('%d/%d:%.4f'%(i,n_iters,compare_L1(magnitude,MAG)))
            a1,a2 = compare(signal, audio_origin)
            print('%d/%d:%.4f, %.4f'%(n_iter,n_iters,a1,a2))
        signal = torch.zeros((1, 7*hop_length), dtype = torch.float64)
        recombine_magnitude_phase = torch.cat(
            [mag*torch.cos(angles), mag*torch.sin(angles)], dim=0)
        
        # print(recombine_magnitude_phase.shape)
        # print(inverse_basis.shape)
        for i in range(4):
            signal[0, hop_length*i : hop_length*i + win_length] = \
                   signal[0, hop_length*i : hop_length*i + win_length] + \
                   (inverse_basis.T @ recombine_magnitude_phase[:, [i]]).T
            # print((inverse_basis.T @ recombine_magnitude_phase[:, [i]]).T[0,:])
            # break
        signal = signal/divident
        # break
        
        _, angles = stft_fn.transform(signal)
        angles = angles[0, :, 2:6]
    return signal

griffin_ans = griffin_lim(magnitude[0,:,:], stft_fn, n_iters=100)
print('\n')
print('ans   ', griffin_ans[0, :])
print('origin', audio_origin[0, :])

9/100:0.2185, 0.0620
19/100:0.1610, 0.0380
29/100:0.1380, 0.0291
39/100:0.1272, 0.0248
49/100:0.1209, 0.0226
59/100:0.1171, 0.0213
69/100:0.1144, 0.0205
79/100:0.1125, 0.0199
89/100:0.1110, 0.0195
99/100:0.1098, 0.0192


ans    tensor([ 0.7227,  0.5825,  0.7562,  0.8082,  1.1968,  0.0474,  0.7270,  0.3265,
         0.3684,  0.3142,  0.9189,  0.9305,  0.7249,  0.7885,  0.4513,  0.5407,
         0.3538,  0.1456,  0.8571,  0.4490, -0.0508,  0.2516,  0.0469,  0.4073,
         0.5124,  0.6857, -0.1067,  0.7347], dtype=torch.float64)
origin tensor([0.6863, 0.3328, 0.8530, 0.9166, 0.9978, 0.2222, 0.9217, 0.1788, 0.3215,
        0.2862, 0.9611, 0.7934, 0.7195, 0.9254, 0.4497, 0.6352, 0.2595, 0.1618,
        0.8676, 0.4410, 0.0720, 0.1916, 0.1095, 0.3379, 0.2544, 0.9017, 0.0227,
        0.4099], dtype=torch.float64)


In [10]:
class hop_7_L_solver():
    #L shape: [channels=2H+1, 4]
    #[channels, 0] for [1, hop_length*0: hop_length*4]
    #[channels, 1] for [1, hop_length*1: hop_length*5]
    #[channels, 2] for [1, hop_length*2: hop_length*6]
    #[channels, 3] for [1, hop_length*3: hop_length*7]
    #x: [1, hop_length*7]
    
    def __init__(self, hop_length, channels, win_length):
        self.hop_length = hop_length
        self.channels = channels
        self.win_length = win_length
        
    def test(self, p, L):
        targets = []
        
        # print(M_c_square.shape)
        for hop_iter in range(4):
            start_frame = hop_iter * self.hop_length
            for c in range(self.channels):
                L_ci = torch.sum(p[0, start_frame: start_frame + self.win_length] * \
                                 torch.roll(p[0, start_frame: start_frame + self.win_length], c)
                                ) - L[c,hop_iter]
                targets.append(L_ci)
        return torch.stack(targets)
    
    def func(self, p):
        targets = []
        
        # print(M_c_square.shape)
        for hop_iter in range(4):
            start_frame = hop_iter * self.hop_length
            for c in range(self.channels):
                L_ci = torch.sum(p[0, start_frame: start_frame + self.win_length] * \
                                 torch.roll(p[0, start_frame: start_frame + self.win_length], c)
                                ) - self.L[c,hop_iter]
                targets.append(L_ci)
        return torch.stack(targets)
    
    def solve(self, L, initial_guess, n_iters=50, lambda_JTJ=1):
        #magnitude shape:      [1, hop_length*2+1, 4]
        #initial_guess shape:  [1, 7*hop_length]
        
        self.L =L
        start_time = time.time()
        for i in range(n_iters):
            print('\rIter %d/%d: Used times: %.2f' %(i,n_iters,time.time()-start_time), end="")
            # check(recon)
            #print('#')
            x = torch.tensor(initial_guess.detach().numpy(), dtype=torch.float64, requires_grad = True) 
            # x = torch.from_numpy(np.zeros(initial_guess.shape))
            
            J = torch.autograd.functional.jacobian(self.func, x)
            # print('J shape', J.shape)
           
            J = J.squeeze(1)
            # print('')
            # print('J',J)
            target = self.func(x).detach().numpy()
            
            
            temp = J.T @ J
            minus = (torch.inverse(temp))@ (J.T @ target)
            minus = minus.numpy()
        
            initial_guess = initial_guess - minus.T
            
        return initial_guess

In [17]:

from solver_test import *
L = magnitude_to_L(magnitude, stft_fn.forward_basis[:2*hop_length+1,:,:])
initial_guess = torch.zeros((1, 7*hop_length), dtype = torch.float64)
initial_guess[:,:] = griffin_ans[:,:]

newton_method = hop_7_L_solver(hop_length, channels, win_length)
ans = newton_method.solve(L[:,:], initial_guess, n_iters=20, lambda_JTJ=0)
print('\n')
print('ans   ', ans[0, :])
print('origin', audio_origin[0, :])
print('error       ', newton_method.func(ans))
print('error origin', newton_method.func(audio_origin))
print('ans part    3', ans[0,3*hop_length:4*hop_length])
print('origin part 3', audio_origin[0, 3*hop_length:4*hop_length])

Iter 19/20: Used times: 1.06

ans    tensor([ 0.7316,  0.3779,  0.8279,  0.8773,  1.0273,  0.1878,  0.8982,  0.2403,
         0.3131,  0.2451,  0.9633,  0.8630,  0.7246,  0.9053,  0.4225,  0.5957,
         0.3054,  0.1843,  0.8467,  0.4204,  0.0016,  0.2584,  0.1662,  0.3159,
         0.4708,  0.5871, -0.1811,  0.7365], dtype=torch.float64)
origin tensor([0.6863, 0.3328, 0.8530, 0.9166, 0.9978, 0.2222, 0.9217, 0.1788, 0.3215,
        0.2862, 0.9611, 0.7934, 0.7195, 0.9254, 0.4497, 0.6352, 0.2595, 0.1618,
        0.8676, 0.4410, 0.0720, 0.1916, 0.1095, 0.3379, 0.2544, 0.9017, 0.0227,
        0.4099], dtype=torch.float64)
error        tensor([ 3.2415e-03,  4.0435e-03, -1.2372e-02, -9.2928e-03,  1.1186e-02,
         9.4537e-03, -2.4953e-03, -7.5473e-04, -1.8573e-03, -2.2764e-04,
         8.1812e-03,  1.0197e-02, -2.0565e-03, -7.8618e-03, -2.4817e-03,
         1.5292e-02, -3.8524e-03, -1.7165e-02, -2.8891e-03, -4.4661e-03,
        -6.3626e-03,  7.9339e-03,  3.8242e-04, -9.6620e-03, -3.0856