In [1]:
import IPython.display as ipd

import numpy as np
import torch
import matplotlib.pyplot as plt
import torch.nn.functional as F

import stft_64 as stft
from audio_processing import griffin_lim

from scipy.io.wavfile import read
import time
from newton_method_solver import *

#### Setup hparams

In [2]:
hop_length = 16
win_length = 4*hop_length
channels = win_length//2+1
wav_length = hop_length*7

class hparams_class:
    def __init__(self):
        self.temp=False
hparams = hparams_class()
hparams.sampling_rate = 22050

In [3]:
def load_wav_to_torch(full_path):
    sampling_rate, data = read(full_path)
    data = data.astype(np.float64)
    # data = np.round(data,decimals=6)
    return torch.DoubleTensor(data), sampling_rate

audio_origin, sampling_rate = load_wav_to_torch('demo.wav')
print(sampling_rate)
audio_origin = audio_origin.unsqueeze(0)
ipd.Audio(audio_origin[0].data.cpu().numpy(), rate=hparams.sampling_rate)

22050


In [4]:

stft_fn = stft.STFT(filter_length=win_length, hop_length=hop_length, win_length=win_length,
                    window=None)

def compare(a,b):
    return torch.mean(torch.abs(a-b)), torch.mean((a-b)*(a-b))

def compare_L1(ori,gen):
    return torch.mean(torch.abs(ori-gen)/torch.abs(ori))


def compare_L2(a,b):
    return torch.sum(torch.abs(a-b)), torch.sum((a-b)*(a-b))

In [5]:


magnitude, phase_origin = stft_fn.transform(audio_origin)
print(magnitude.shape)

reconstruction = stft_fn.inverse(magnitude, phase_origin)
print(compare(reconstruction, audio_origin))
ipd.Audio(reconstruction[0].data.cpu().numpy(), rate=hparams.sampling_rate)

torch.Size([1, 33, 6161])
(tensor(7.3618e-07, dtype=torch.float64), tensor(4.9818e-09, dtype=torch.float64))


# get x from Spectrogram

In [6]:


input_data = audio_origin.view(audio_origin.size(0), 1, audio_origin.size(1))
input_data = F.pad(
            input_data.unsqueeze(1),
            (int(stft_fn.filter_length / 2), int(stft_fn.filter_length / 2), 0, 0),
            mode='reflect')
input_data = input_data.squeeze(1)
input_data = input_data.double().numpy()
print(input_data.shape)

(1, 1, 98624)


In [21]:
INI_START = 5
STEP_APPLY_7HOP = 100

method_7hop = hop_7_solver(stft_fn.forward_basis, hop_length, channels, win_length)
solution = get_audio_from_spectrogram(stft_fn)

generate = np.zeros(input_data.shape)
generate = generate.astype(np.float64)
generate[0,0,:INI_START*hop_length + win_length - hop_length] = input_data[0,0,:INI_START*hop_length+ win_length - hop_length]

def loss_l1(a,b):
    return np.sum(np.abs(a-b))
    
INTERVALs = magnitude.shape[-1]
for i in range(INI_START, INTERVALs):
    print('\r%d/%d'%(i,INTERVALs), end="")
    
    f = open("log_nocumulative.txt", "a")
    
    START = hop_length*i
    result = solution.calculate_part_audio(magnitude[:,:,i], generate[0,:,START:START+win_length - hop_length])
    # result = np.round(result, decimals=6)
    generate[0,0,START+win_length - hop_length:START+win_length] = result[:,0]
    result_fix_error = solution.fix_error_newton(magnitude[:,:,i], generate[0,:,START:START+win_length], n_iters=50)
    generate[0,0,START+win_length - hop_length : START+win_length] = result_fix_error[:]
    if i%STEP_APPLY_7HOP == 0:
        ans = method_7hop.solve(magnitude[:,:,i-5:i+3], torch.from_numpy(generate[0,:,hop_length*(i-3):hop_length*(i+4)]))
        generate[0,:,hop_length*(i-3):hop_length*(i+4)] = ans[:,:]
        
    ground_truth = input_data[0,0,START+win_length - hop_length:START+win_length]
    
    f.write('#####:'+ str(i-INI_START) + '\n')
    f.write('total error      :' + str(loss_l1(generate[0,0,START+win_length - hop_length:START+win_length], ground_truth)) + '\n')
    f.write('element max error:' + str(np.max(np.abs(generate[0,0,START+win_length - hop_length:START+win_length]-ground_truth))) + '\n')
    
    f.write(str(generate[0,0,START+win_length - hop_length:START+win_length - hop_length+20]) + '\n')
    f.write(str(ground_truth[:20]) + '\n')
    
    sums_real = torch.sum(stft_fn.forward_basis[:channels,0,0:win_length] * generate[:,0,START:START+win_length], dim=1)
    sums_img  = torch.sum(stft_fn.forward_basis[channels:,0,0:win_length] * generate[:,0,START:START+win_length], dim=1)
    recon = torch.sqrt(sums_real**2 + sums_img**2)
    recon = recon.numpy()
    mag_ground = magnitude[0,:,i]
    mag_ground = mag_ground.numpy()
    # print(recon[:20])
    # print(mag_ground[:20])
    f.write('MAG diff recon:  '+str(loss_l1(recon, mag_ground)) + '\n')
    
    rounded_input = input_data[:,0,START:START+win_length]
    sums_real = torch.sum(stft_fn.forward_basis[:channels,0,0:win_length] * rounded_input, dim=1)
    sums_img  = torch.sum(stft_fn.forward_basis[channels:,0,0:win_length] * rounded_input, dim=1)
    recon = torch.sqrt(sums_real**2 + sums_img**2)
    recon = recon.numpy()
    mag_ground = magnitude[0,:,i]
    mag_ground = mag_ground.numpy()
    
    f.write('MAG diff ground:  '+str(loss_l1(recon, mag_ground)) + '\n')
    f.close()
    

6160/61610: Used times: 9.722

In [22]:

ipd.Audio(generate[0,0,:], rate=hparams.sampling_rate)