# Speech Enhancement using U-Net Spiking Neural Network

This notebook provides a guideline for evaluating the U-Net based SNN model for the speech enhancement task.


In [1]:
import os
from pathlib import Path
import yaml
import numpy as np
import torch
import torchaudio
from collections import OrderedDict
from torch.utils.data import DataLoader
import IPython.display as ipd

In [2]:
from src.model.SurrogateGradient import SuperSpike, SigmoidDerivative, ATan, PiecewiseLinear, SpikeFunc
from src.model.SpikingModel import UNetSNN
from src.data.constants import DataDirectories, AudioParameters
from src.data.DatasetManager import DatasetManager
from src.stft.constants import StftParameters

## 1. Read the training hyperparameters

In [3]:
experiment_filename = 'SpeechEnhancement_Train_UNetSNN_2023_05_26-03_26_01_PM_37652140'
experiment_files_dir = os.path.join(Path(__file__).parent,
                                    DataDirectories.experiments_dirname, 
                                    experiment_filename)

In [4]:
params = {}
params_dir = os.path.join(experiment_files_dir, 'params.json')
if os.path.exists(params_dir):
    params = yaml.safe_load(open(params_dir, 'rt'))['hyperparameters']

In [5]:
dtype = torch.float32
device = params['device']

In [6]:
for key, value in params.items():
    if value == 'True':
        params[key] = True
    elif value == 'False':
        params[key] = False

In [7]:
if params['spike_fn'] == 'SuperSpike':
    spike_fn = SuperSpike
elif params['spike_fn'] == 'SigmoidDerivative':
    spike_fn = SigmoidDerivative
elif params['spike_fn'] == 'ATan':
    spike_fn = ATan
elif params['spike_fn'] == 'PiecewiseLinear':
    spike_fn = PiecewiseLinear
elif params['spike_fn'] == 'SpikeFunc':
    spike_fn = SpikeFunc
    
spike_fn.spiking_mode = params['spiking_mode']
if params['surrogate_scale'] is not None:
    spike_fn.surrogate_scale = params['surrogate_scale']
    
params['spike_fn'] = spike_fn

In [8]:
params['truncated_bptt_ratio'] = int(params['truncated_bptt_ratio'])

## 2. Instantiate dataset

In [9]:
debug_flag = False

In [10]:
data_files_dir = os.path.join(DataDirectories.project_dir, 
                              f'{params["task_name"]}_{DataDirectories.data_dirname}')

In [11]:
dataset_manager_test = DatasetManager(data_files_dir=data_files_dir,
                                      data_load=DataDirectories.data_load_test,
                                      experiment_files_dir=experiment_files_dir,
                                      plots_dir=os.path.join(experiment_files_dir, 'plots'), 
                                      dtype=dtype,
                                      representation_name=params['representation_name'],
                                      representation_dir_name=params['representation_dir_name'],
                                      transform_name=params['transform_name'],
                                      debug_flag=debug_flag)

In [12]:
batch_size = 8
dataloader_manager_test = DataLoader(dataset_manager_test, batch_size=batch_size,
                                     shuffle=False, num_workers=0,
                                     pin_memory=True, drop_last=False,
                                     sampler=None, prefetch_factor=2)

In [13]:
# tensor_noisyspeech_transform_info_dir = os.path.join(data_files_dir, 
#                                                      params['representation_dir_name'], 
#                                                      DataDirectories.transform_info_dirname, 
#                                                      f'{DataDirectories.noisyspeech_dirname}_{DataDirectories.transform_info_filename}.pt')

# metadata = torch.load(tensor_noisyspeech_transform_info_dir, map_location=None)

## 3. Instantiate network

In [14]:
net = UNetSNN(input_dim=params['input_dim'], hidden_channels_list=params.get('hidden_channels_list'), output_dim=params['output_dim'], 
              kernel_size=params.get('kernel_size'), stride=params.get('stride'), padding=params.get('padding'), dilation=params.get('dilation'), 
              bias=params['bias'], padding_mode=params['padding_mode'], pooling_flag=params['pooling_flag'], pooling_type=params['pooling_type'],
              use_same_layer=params['use_same_layer'], nb_steps=params['nb_steps_bin'], truncated_bptt_ratio=params['truncated_bptt_ratio'], 
              spike_fn=params['spike_fn'], neuron_model=params['neuron_model'], neuron_parameters=params['neuron_parameters'], 
              weight_init=params['weight_init'], upsample_mode=params['upsample_mode'], scale_flag=params['scale_flag'], 
              scale_factor=params['scale_factor'], bn_flag=params['bn_flag'], dropout_flag=params['dropout_flag'], dropout_p=params.get('dropout_p'),
              device=device, dtype=dtype, skip_connection_type=params['skip_connection_type'], 
              use_intermediate_output=params['use_intermediate_output']).to(device)

## 4. Load pretrained network

In [15]:
model_file_dir = os.path.join(experiment_files_dir, f'{params["task_name"]}_{params["model_name"]}_InpDim={params["input_dim"]}.pt')
checkpoint = torch.load(model_file_dir)

In [16]:
state_dict = checkpoint['model_state_dict']
new_state_dict = OrderedDict()
for k, v in state_dict.items():
    k =  k[len('module.'):]
    new_state_dict[k] = v

In [17]:
net.load_state_dict(new_state_dict)

<All keys matched successfully>

## 5. Define STFT parameters

In [18]:
n_fft = StftParameters.n_fft
win_length = StftParameters.win_length
hop_length = StftParameters.hop_length
power = StftParameters.power
normalized = StftParameters.normalized
center = StftParameters.center

In [19]:
def stft_splitter(x, compute_stft: bool = False):
    if compute_stft:
        x = torchaudio.transforms.Spectrogram(n_fft=n_fft, win_length=win_length,
                                              hop_length=hop_length, power=power,
                                              normalized=normalized, center=center)(x.cpu())
    return torch.abs(x)[:, :, :-1, :], torch.angle(x)[:, :, :-1, :]

In [20]:
def stft_mixer(x_abs, x_arg):
    x = torch.complex(x_abs * torch.cos(x_arg), x_abs * torch.sin(x_arg))
    x = torch.cat((x, x[:, :, -1:, :]), 2)
    return torchaudio.transforms.InverseSpectrogram(n_fft=n_fft, win_length=win_length,
                                                    hop_length=hop_length, normalized=normalized,
                                                    center=center)(x.cpu())

## 6. Run pretrained network

In [21]:
noisy_dir = dataset_manager_test.noisyspeech_dir
clean_dir = dataset_manager_test.cleanspeech_dir
audio_filename = sorted(os.listdir(dataset_manager_test.noisyspeech_dir))

In [22]:
# noisy_dir_ = os.path.join(Path(__file__).parent, 'examples', 'rec.wav')
index_ = 5
noisy_dir_ = os.path.join(noisy_dir, audio_filename[index_])

In [23]:
noisy = dataset_manager_test.load_audio(noisy_dir_, update_info=False)
noisy = torch.unsqueeze(noisy, dim=0)

In [24]:
noisy_abs, noisy_arg = stft_splitter(noisy, True)
noisy_abs_ = dataset_manager_test.transform_manager(noisy_abs.cpu()).to(device)
net.init_state(noisy_abs_.shape[0])
net.init_rec()
cleaned_abs_, _ = net(noisy_abs_)
cleaned_abs = dataset_manager_test.transform_manager(cleaned_abs_.cpu(), mode='inverse_transform').to(device)

In [25]:
noisy = stft_mixer(noisy_abs, noisy_arg)
clean_rec = stft_mixer(cleaned_abs, noisy_arg.to(device)).detach()

In [26]:
noisy = noisy.view(noisy.shape[0], noisy.shape[-1])
clean_rec = clean_rec.view(clean_rec.shape[0], clean_rec.shape[-1])

In [27]:
ipd.Audio(noisy[0].cpu(), rate=AudioParameters.sample_rate)

In [28]:
ipd.Audio(clean_rec[0].cpu(), rate=AudioParameters.sample_rate)

## 7. Compute firing rates

In [29]:
nb_data = batch_size * len(dataloader_manager_test)

In [30]:
net.save_mem = True

In [31]:
nb_output = len(net.snn_layers)
firing_rates = np.zeros(nb_output)

for i, (noisy, clean, index) in enumerate(dataloader_manager_test):
    if i % 10 == 0:
        print(f'itration : {i}/{dataloader_manager_test.__len__()}')
    net.eval()
    with torch.no_grad():
        noisy = noisy.to(device)
        clean = clean.to(device)
        noisy_abs, noisy_arg = stft_splitter(noisy)
        clean_abs, clean_arg = stft_splitter(clean)
        noisy_abs_ = dataset_manager_test.transform_manager(noisy_abs.cpu()).to(device)
        net.init_state(noisy_abs.shape[0])
        net.init_rec()
        cleaned_abs_, _ = net(noisy_abs_)
        cleaned_abs = dataset_manager_test.transform_manager(cleaned_abs_.cpu(), mode='inverse_transform').to(device)
        # Compute firing rates
        for layer_idx in range(nb_output):
            spk_rec_layer = torch.from_numpy(net.spk_rec[layer_idx])
            firing_rates[layer_idx] += spk_rec_layer.count_nonzero()/spk_rec_layer.numel()
            
firing_rates /= dataloader_manager_test.__len__()

itration : 0/103
itration : 10/103
itration : 20/103
itration : 30/103
itration : 40/103
itration : 50/103
itration : 60/103
itration : 70/103
itration : 80/103
itration : 90/103
itration : 100/103


In [32]:
noisy = stft_mixer(noisy_abs, noisy_arg)
clean = stft_mixer(clean_abs, clean_arg)
clean_rec = stft_mixer(cleaned_abs, noisy_arg)

In [33]:
noisy = noisy.view(noisy.shape[0], noisy.shape[-1])
clean = clean.view(clean.shape[0], clean.shape[-1])
clean_rec = clean_rec.view(clean_rec.shape[0], clean_rec.shape[-1])

In [34]:
index_ = -1

In [35]:
ipd.Audio(noisy[index_].cpu(), rate=AudioParameters.sample_rate)

In [36]:
ipd.Audio(clean_rec[index_].cpu(), rate=AudioParameters.sample_rate)

In [37]:
# ipd.Audio(clean[index_].cpu(), rate=AudioParameters.sample_rate)

## 8. Compute ANN/SNN energy ratio

In [38]:
binary_firing_rates = np.zeros(nb_output - 1)
binary_firing_rates += firing_rates[1:]

In [39]:
binary_firing_rates_mean = np.mean(binary_firing_rates)
# print(f'binary_firing_rates_mean \t = \t {binary_firing_rates_mean}')
print(f'binary_firing_rates_mean (%) \t = \t {binary_firing_rates_mean*100:.2f}')

binary_firing_rates_mean (%) 	 = 	 10.88


In [40]:
out_rec = net.spk_rec[1:]
out_rec.append(net.spk_rec[0])

In [41]:
print('i \t k_w \t k_h \t c_in \t w_out \t h_out \t c_out \t #OP_ANN_i \t firing_rates_i \t #OP_SNN_i')
print('-----------------------------------------------------------------------------------------------------------------------')
OP_ANN = np.zeros(nb_output)
OP_SNN = np.zeros(nb_output)
for i, snn_layer in enumerate(net.snn_layers):
    k_w = snn_layer.kernel_size[0]
    k_h = snn_layer.kernel_size[1]
    c_in = snn_layer.input_channels
    w_out = out_rec[i][0].shape[1]
    h_out = out_rec[i][0].shape[2]
    c_out = snn_layer.output_channels
    OP_ANN_i = k_w * k_h * c_in * h_out * w_out * c_out
    firing_rates_i = round(firing_rates[i], 5)
    OP_SNN_i = firing_rates_i * OP_ANN_i
    OP_ANN[i] += OP_ANN_i
    OP_SNN[i] += OP_SNN_i
    if i == 0:
        print(f'{i+1} \t {k_w} \t {k_h} \t {c_in} \t {w_out} \t {h_out} \t {c_out} \t {OP_ANN_i} \t {firing_rates_i} \t\t\t {OP_SNN_i}')
    else:
        print(f'{i+1} \t {k_w} \t {k_h} \t {c_in} \t {w_out} \t {h_out} \t {c_out} \t {OP_ANN_i} \t {firing_rates_i} \t\t {OP_SNN_i}')

i 	 k_w 	 k_h 	 c_in 	 w_out 	 h_out 	 c_out 	 #OP_ANN_i 	 firing_rates_i 	 #OP_SNN_i
-----------------------------------------------------------------------------------------------------------------------
1 	 7 	 5 	 1 	 128 	 251 	 64 	 71966720 	 1.0 			 71966720.0
2 	 7 	 5 	 64 	 64 	 251 	 128 	 4605870080 	 0.09513 		 438156420.71040004
3 	 7 	 5 	 128 	 32 	 251 	 256 	 9211740160 	 0.13698 		 1261824167.1167998
4 	 5 	 5 	 256 	 16 	 251 	 512 	 13159628800 	 0.05145 		 677062901.76
5 	 5 	 5 	 512 	 8 	 126 	 512 	 6606028800 	 0.05727 		 378327269.376
6 	 3 	 3 	 512 	 4 	 63 	 512 	 594542592 	 0.12722 		 75637708.55424
7 	 3 	 3 	 512 	 2 	 32 	 512 	 150994944 	 0.12439 		 18782261.08416
8 	 3 	 3 	 512 	 1 	 16 	 512 	 37748736 	 0.21137 		 7978950.32832
9 	 3 	 3 	 512 	 2 	 32 	 512 	 150994944 	 0.03387 		 5114198.75328
10 	 3 	 3 	 1024 	 4 	 63 	 512 	 1189085184 	 0.14475 		 172120080.38399997
11 	 3 	 3 	 1024 	 8 	 126 	 512 	 4756340736 	 0.13975 		 664698617.85

In [42]:
sum_OP_ANN = np.sum(OP_ANN)
sum_OP_SNN_layer_1 = np.dot(firing_rates[0], OP_ANN[0])
sum_OP_SNN_binary = np.dot(firing_rates[1:], OP_ANN[1:])
# print(f'sum_OP_ANN \t\t = \t {sum_OP_ANN}')
# print(f'sum_OP_SNN_layer_1 \t = \t {sum_OP_SNN_layer_1}')
# print(f'sum_OP_SNN_binary \t = \t {sum_OP_SNN_binary}')
# print('--------------------------------------------------------')
ANN_SNN_energy = (sum_OP_ANN * 4.6) / ((sum_OP_SNN_layer_1*4.6) + (sum_OP_SNN_binary*0.9))
print(f'ANN_SNN_energy \t\t = \t {ANN_SNN_energy:.5f}')

ANN_SNN_energy 		 = 	 53.97236
