In [1]:
import IPython.display as ipd

import numpy as np
import torch


# import stft_64_pad_0 as stft
import stft_64 as stft
from audio_processing import griffin_lim

from scipy.io.wavfile import read
import time
from solver_test import *

In [2]:
hop_length = 4
win_length = 4*hop_length
channels = hop_length*2+1
wav_length = hop_length*7

audio_origin = torch.rand((1,wav_length), dtype=torch.float64)
print(audio_origin.shape)
print(audio_origin[0,:10])

torch.Size([1, 28])
tensor([0.0598, 0.6319, 0.2921, 0.7176, 0.9370, 0.0388, 0.5276, 0.4439, 0.0714,
        0.8513], dtype=torch.float64)


In [3]:
# stft_fn = stft.STFT(filter_length=4, hop_length=1, win_length=4,
#                     window='hann')

stft_fn = stft.STFT(filter_length=win_length, hop_length=hop_length, win_length=win_length,
                    window=None)


def compare(a,b):
    return torch.mean(torch.abs(a-b)), torch.mean((a-b)*(a-b))

def compare_L1(ori,gen):
    return torch.mean(torch.abs(ori-gen)/torch.abs(ori))


def compare_L2(a,b):
    return torch.sum(torch.abs(a-b)), torch.sum((a-b)*(a-b))

In [4]:
magnitude, phase_origin = stft_fn.transform(audio_origin)

forward_basis = stft_fn.forward_basis
print(forward_basis.shape)
start_frame = 0
M_Rc = torch.sum(forward_basis[:channels,0,:] * audio_origin[:,start_frame: start_frame+win_length], dim =1)
M_Ic = torch.sum(forward_basis[channels:,0,:] * audio_origin[:,start_frame: start_frame+win_length], dim =1)
M_c_square = torch.sqrt(M_Rc**2+M_Ic**2)
print(M_c_square.shape)
magnitude    = magnitude[:,:,2:6]
phase_origin = phase_origin[:,:,2:6]
print(magnitude[0,:,start_frame])
print(magnitude.shape)


torch.Size([18, 1, 16])
torch.Size([9])
tensor([7.7303, 0.3088, 1.3460, 0.5514, 0.2974, 1.5522, 2.3430, 1.2050, 0.8025],
       dtype=torch.float64)
torch.Size([1, 9, 4])


In [5]:
L = magnitude_to_L(magnitude, stft_fn.forward_basis[:2*hop_length+1,:,:])
print(L.shape)
print(L[:,0])
# print(audio_origin[0,:4*hop_length])
# print(stft_fn.forward_basis[:2*hop_length+1,:,:2*hop_length+1])
for i in range(2*hop_length+1):
    print(torch.sum(audio_origin[0,:4*hop_length] * torch.roll(audio_origin[0,:4*hop_length],i)))

torch.Size([9, 4])
tensor([5.2315, 3.1121, 3.6610, 4.1980, 2.8735, 3.8415, 3.8671, 3.6269, 4.1663],
       dtype=torch.float64)
tensor(5.2315, dtype=torch.float64)
tensor(3.1121, dtype=torch.float64)
tensor(3.6610, dtype=torch.float64)
tensor(4.1980, dtype=torch.float64)
tensor(2.8735, dtype=torch.float64)
tensor(3.8415, dtype=torch.float64)
tensor(3.8671, dtype=torch.float64)
tensor(3.6269, dtype=torch.float64)
tensor(4.1663, dtype=torch.float64)


In [48]:
class hop_7_L_solver():
    #L shape: [channels=2H+1, 4]
    #[channels, 0] for [1, hop_length*0: hop_length*4]
    #[channels, 1] for [1, hop_length*1: hop_length*5]
    #[channels, 2] for [1, hop_length*2: hop_length*6]
    #[channels, 3] for [1, hop_length*3: hop_length*7]
    #x: [1, hop_length*7]
    
    def __init__(self, hop_length, channels, win_length):
        self.hop_length = hop_length
        self.channels = channels
        self.win_length = win_length
        
    def test(self, p, L):
        targets = []
        
        # print(M_c_square.shape)
        for hop_iter in range(4):
            start_frame = hop_iter * self.hop_length
            for c in range(self.channels):
                L_ci = torch.sum(p[0, start_frame: start_frame + self.win_length] * \
                                 torch.roll(p[0, start_frame: start_frame + self.win_length], c)
                                ) - L[c,hop_iter]
                targets.append(L_ci)
        return torch.stack(targets)
    
    def func(self, p):
        targets = []
        
        # print(M_c_square.shape)
        for hop_iter in range(4):
            start_frame = hop_iter * self.hop_length
            for c in range(self.channels):
                L_ci = torch.sum(p[0, start_frame: start_frame + self.win_length] * \
                                 torch.roll(p[0, start_frame: start_frame + self.win_length], c)
                                ) - self.L[c,hop_iter]
                targets.append(L_ci)
        return torch.stack(targets)
    
    def solve(self, L, initial_guess, n_iters=50, lambda_JTJ=1):
        #magnitude shape:      [1, hop_length*2+1, 4]
        #initial_guess shape:  [1, 7*hop_length]
        
        self.L =L
        start_time = time.time()
        for i in range(n_iters):
            print('\rIter %d/%d: Used times: %.2f' %(i,n_iters,time.time()-start_time), end="")
            # check(recon)
            #print('#')
            x = torch.tensor(initial_guess.detach().numpy(), dtype=torch.float64, requires_grad = True) 
            # x = torch.from_numpy(np.zeros(initial_guess.shape))
            # x[:,:] = initial_guess[:,:]
            # x = torch.DoubleTensor(x)
            
            J = torch.autograd.functional.jacobian(self.func, x)
            # print('J shape', J.shape)
           
            J = J.squeeze(1)
            # print('')
            # print('J',J)
            target = self.func(x).detach().numpy()
            
            # Q, R = np.linalg.qr(J, mode='reduced')
            # Qb = np.matmul(Q.T, target)
            # minus = np.linalg.solve(R,Qb)
            # print(J, target)
            
            temp = J.T @ J
            # minus = (torch.inverse(temp + lambda_JTJ * torch.diag(torch.diag(temp, 0))))@ (J.T @ target)
            minus = (torch.inverse(temp))@ (J.T @ target)
            # print('res', minus.T)
            # print('check', torch.matmul(J, minus)-target)
            # print('error', target)
            minus = minus.numpy()
            # print(J.shape)
            # print(x.shape)
            # print(self.func(x).shape)
            # minus = overdetermined_linear_system_solver(J.numpy(), self.func(x).numpy())
            # minus = torch.from_numpy(minus)
        
            initial_guess = initial_guess - minus.T
            # if torch.sum(torch.abs(norm_vector/1000-magnitude[0,:]))<1e-10:
            #     break
            
        return initial_guess

In [53]:
initial_guess = torch.rand((1, 7*hop_length))

newton_method = hop_7_L_solver(hop_length, channels, win_length)
ans = newton_method.solve(L[:,:], initial_guess, n_iters=20, lambda_JTJ=0)
print('\n')
print('ans   ', ans[0, :])
print('origin', audio_origin[0, :])
print('error       ', newton_method.func(ans))
print('error origin', newton_method.func(audio_origin))
print('ans part    3', ans[0,3*hop_length:4*hop_length])
print('origin part 3', audio_origin[0, 3*hop_length:4*hop_length])

Iter 19/20: Used times: 1.05

ans    tensor([ 0.1703,  0.6513,  0.7265,  0.3488,  0.2655,  0.1177,  0.4774,  0.8396,
         0.3125,  0.6608,  0.1239,  0.7228,  0.2570,  0.0412,  0.2094, -0.0626,
         0.7813,  0.6448,  0.9157,  0.7395,  0.1801,  0.2871,  0.5408,  0.3513,
         0.2685,  0.6434,  0.6922,  0.7710], dtype=torch.float64)
origin tensor([0.1292, 0.2521, 0.5072, 0.7652, 0.6710, 0.0801, 0.1273, 0.5830, 0.4827,
        0.8574, 0.1320, 0.5975, 0.0721, 0.1411, 0.4255, 0.0425, 0.5664, 0.3740,
        0.9140, 0.9769, 0.3372, 0.2533, 0.2719, 0.2596, 0.4828, 0.8983, 0.7136,
        0.5337], dtype=torch.float64)
error        tensor([ 0.0213,  0.0155,  0.0128,  0.0023, -0.0285, -0.0499, -0.0090,  0.0189,
         0.0098,  0.0027, -0.0438, -0.0270,  0.0131,  0.0142,  0.0221,  0.0359,
         0.0038,  0.0033,  0.0098,  0.0446,  0.0080, -0.0013, -0.0114, -0.0037,
        -0.0004, -0.0130, -0.0352,  0.0058, -0.0296,  0.0002, -0.0027,  0.0014,
        -0.0035,  0.0024,  0.0050,  0.0