In [25]:
import numpy as np
import matplotlib.pyplot as plt
from datetime import datetime
import yaml
import librosa
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
import IPython.display as ipd

from lava.lib.dl import slayer
from audio_dataloader import DNSAudio
from snr import si_snr
from dnsmos import DNSMOS

In [26]:
from train_nas_fullconv_baseline import collate_fn, stft_splitter, stft_mixer, nop_stats, Network

In [27]:
import wavio

In [28]:
trained_folder = '../20230707_loh_nas_runs'
args = yaml.safe_load(open(trained_folder + '/args.txt', 'rt'))
if 'out_delay' not in args.keys():
    args['out_delay'] = 0
if 'n_fft' not in args.keys():
    args['n_fft'] = 512
device = torch.device('cpu') #torch.device('cuda:0')
root = '/mnt/data4tb/stadtmann/dns_challenge_4/datasets_fullband/' #args['path']
out_delay = args['out_delay']
n_fft = args['n_fft']
win_length = n_fft
hop_length = n_fft // 4
stats = slayer.utils.LearningStats(accuracy_str='SI-SNR', accuracy_unit='dB')

In [29]:
train_set = DNSAudio(root=root + 'training_set/')
validation_set = DNSAudio(root=root + 'validation_set/')

train_loader = DataLoader(train_set,
                          batch_size=32,
                          shuffle=True,
                          collate_fn=collate_fn,
                          num_workers=4,
                          pin_memory=True)
validation_loader = DataLoader(validation_set,
                               batch_size=32,
                               shuffle=True,
                               collate_fn=collate_fn,
                               num_workers=4,
                               pin_memory=True)

In [30]:
kk = 5
cc = 256
dd = 5

net = Network(kk,cc,dd,
                args['threshold'],
                args['tau_grad'],
                args['scale_grad'],
                args['dmax'],
                args['out_delay']).to(device)

In [31]:
debug_sample_no = 0

In [32]:
noisy, clean, noise, metadata = train_set[debug_sample_no]
noisy = torch.unsqueeze(torch.FloatTensor(noisy), dim=0).to(device)
noisy_abs, noisy_arg = stft_splitter(noisy, n_fft)
net(noisy_abs)
net.load_state_dict(torch.load(trained_folder + '/trained_k5c256d5_optfcn_Adamlr_customschedule' + '/network.pt'))

<All keys matched successfully>

In [33]:
denoised_abs = net(noisy_abs)
cleaned = stft_mixer(denoised_abs, noisy_arg, n_fft)

wavio.write("debug_in.wav", noisy.squeeze(0).detach().cpu(),16000,sampwidth=2)
wavio.write("debug_out.wav", cleaned.squeeze(0).detach().cpu(),16000,sampwidth=2)
wavio.write("debug_ref.wav", clean,16000,sampwidth=2)

In [34]:
ipd.Audio(noisy.squeeze(0).detach().cpu(), rate=16000)

In [35]:
ipd.Audio(cleaned.squeeze(0).detach().cpu(), rate=16000)

In [36]:
ipd.Audio(clean, rate=16000)

In [37]:
dnsmos = DNSMOS()
dnsmos_noisy = np.zeros(3)
dnsmos_clean = np.zeros(3)
dnsmos_noise = np.zeros(3)
dnsmos_cleaned  = np.zeros(3)
#train_event_counts = []

t_st = datetime.now()
for i, (noisy, clean, noise) in enumerate(train_loader):
    net.eval()
    with torch.no_grad():
        noisy = noisy.to(device)
        clean = clean.to(device)

        noisy_abs, noisy_arg = stft_splitter(noisy, n_fft)
        clean_abs, clean_arg = stft_splitter(clean, n_fft)

        #denoised_abs, count = net(noisy_abs)
        denoised_abs = net(noisy_abs)
        #train_event_counts.append(count.cpu().data.numpy())
        noisy_arg = slayer.axon.delay(noisy_arg, out_delay)
        clean_abs = slayer.axon.delay(clean_abs, out_delay)
        clean = slayer.axon.delay(clean, win_length * out_delay)

        loss = F.mse_loss(denoised_abs, clean_abs)
        clean_rec = stft_mixer(denoised_abs, noisy_arg, n_fft)
        score = si_snr(clean_rec, clean)

        dnsmos_noisy += np.sum(dnsmos(noisy.cpu().data.numpy()), axis=0)
        dnsmos_clean += np.sum(dnsmos(clean.cpu().data.numpy()), axis=0)
        dnsmos_noise += np.sum(dnsmos(noise.cpu().data.numpy()), axis=0)
        dnsmos_cleaned += np.sum(dnsmos(clean_rec.cpu().data.numpy()), axis=0)

        stats.training.correct_samples += torch.sum(score).item()
        stats.training.loss_sum += loss.item()
        stats.training.num_samples += noisy.shape[0]

        processed = i * train_loader.batch_size
        total = len(train_loader.dataset)
        time_elapsed = (datetime.now() - t_st).total_seconds()
        samples_sec = time_elapsed / (i + 1) / train_loader.batch_size
        header_list = [f'Train: [{processed}/{total} '
                        f'({100.0 * processed / total:.0f}%)]']
        #header_list.append(f'Event rate: {[c.item() for c in count]}')
        print(f'\r{header_list[0]}', end='')

dnsmos_clean /= len(train_loader.dataset)
dnsmos_noisy /= len(train_loader.dataset)
dnsmos_noise /= len(train_loader.dataset)
dnsmos_cleaned /= len(train_loader.dataset)

print()
stats.print(0, i, samples_sec, header=header_list)
print('Avg DNSMOS clean   [ovrl, sig, bak]: ', dnsmos_clean)
print('Avg DNSMOS noisy   [ovrl, sig, bak]: ', dnsmos_noisy)
print('Avg DNSMOS noise   [ovrl, sig, bak]: ', dnsmos_noise)
print('Avg DNSMOS cleaned [ovrl, sig, bak]: ', dnsmos_cleaned)

# mean_events = np.mean(train_event_counts, axis=0)

# neuronops = []
# for block in net.blocks[:-1]:
#     neuronops.append(np.prod(block.neuron.shape))

# synops = []
# for events, block in zip(mean_events, net.blocks[1:]):
#     synops.append(events * np.prod(block.synapse.shape))
# print(f'SynOPS: {synops}')
# print(f'Total SynOPS: {sum(synops)}')
# print(f'Total NeuronOPS: {sum(neuronops)}')
# print(f'Time-step per sample: {noisy_abs.shape[-1]}')

[0A
Epoch    0: i =  1874 ,     213.2955 ms elapsed        
Train loss =     0.03179                          SI-SNR = 13.53153 dB
Avg DNSMOS clean   [ovrl, sig, bak]:  [3.2952963  3.57391434 4.08399461]
Avg DNSMOS noisy   [ovrl, sig, bak]:  [2.43810771 3.17403667 2.69062128]
Avg DNSMOS noise   [ovrl, sig, bak]:  [1.34275018 1.57635708 1.93369491]
Avg DNSMOS cleaned [ovrl, sig, bak]:  [2.81519338 3.1857744  3.99692714]


In [38]:
dnsmos = DNSMOS()
dnsmos_noisy = np.zeros(3)
dnsmos_clean = np.zeros(3)
dnsmos_noise = np.zeros(3)
dnsmos_cleaned  = np.zeros(3)
#train_event_counts = []

t_st = datetime.now()
for i, (noisy, clean, noise) in enumerate(validation_loader):
    net.eval()
    with torch.no_grad():
        noisy = noisy.to(device)
        clean = clean.to(device)

        noisy_abs, noisy_arg = stft_splitter(noisy, n_fft)
        clean_abs, clean_arg = stft_splitter(clean, n_fft)

        #denoised_abs, count = net(noisy_abs)
        denoised_abs = net(noisy_abs)
        #train_event_counts.append(count.cpu().data.numpy())
        noisy_arg = slayer.axon.delay(noisy_arg, out_delay)
        clean_abs = slayer.axon.delay(clean_abs, out_delay)
        clean = slayer.axon.delay(clean, win_length * out_delay)

        loss = F.mse_loss(denoised_abs, clean_abs)
        clean_rec = stft_mixer(denoised_abs, noisy_arg, n_fft)
        score = si_snr(clean_rec, clean)

        dnsmos_noisy += np.sum(dnsmos(noisy.cpu().data.numpy()), axis=0)
        dnsmos_clean += np.sum(dnsmos(clean.cpu().data.numpy()), axis=0)
        dnsmos_noise += np.sum(dnsmos(noise.cpu().data.numpy()), axis=0)
        dnsmos_cleaned += np.sum(dnsmos(clean_rec.cpu().data.numpy()), axis=0)

        stats.training.correct_samples += torch.sum(score).item()
        stats.training.loss_sum += loss.item()
        stats.training.num_samples += noisy.shape[0]

        processed = i * validation_loader.batch_size
        total = len(validation_loader.dataset)
        time_elapsed = (datetime.now() - t_st).total_seconds()
        samples_sec = time_elapsed / (i + 1) / validation_loader.batch_size
        header_list = [f'Train: [{processed}/{total} '
                        f'({100.0 * processed / total:.0f}%)]']
        #header_list.append(f'Event rate: {[c.item() for c in count]}')
        print(f'\r{header_list[0]}', end='')

dnsmos_clean /= len(validation_loader.dataset)
dnsmos_noisy /= len(validation_loader.dataset)
dnsmos_noise /= len(validation_loader.dataset)
dnsmos_cleaned /= len(validation_loader.dataset)

print()
stats.print(0, i, samples_sec, header=header_list)
print('Avg DNSMOS clean   [ovrl, sig, bak]: ', dnsmos_clean)
print('Avg DNSMOS noisy   [ovrl, sig, bak]: ', dnsmos_noisy)
print('Avg DNSMOS noise   [ovrl, sig, bak]: ', dnsmos_noise)
print('Avg DNSMOS cleaned [ovrl, sig, bak]: ', dnsmos_cleaned)

# mean_events = np.mean(train_event_counts, axis=0)

# neuronops = []
# for block in net.blocks[:-1]:
#     neuronops.append(np.prod(block.neuron.shape))

# synops = []
# for events, block in zip(mean_events, net.blocks[1:]):
#     synops.append(events * np.prod(block.synapse.shape))
# print(f'SynOPS: {synops}')
# print(f'Total SynOPS: {sum(synops)}')
# print(f'Total NeuronOPS: {sum(neuronops)}')
# print(f'Time-step per sample: {noisy_abs.shape[-1]}')

[4A
Epoch    0: i =  1874 ,     199.6000 ms elapsed        
Train loss =     0.03176                          SI-SNR = 13.52253 dB
Avg DNSMOS clean   [ovrl, sig, bak]:  [3.29625107 3.5746113  4.08414389]
Avg DNSMOS noisy   [ovrl, sig, bak]:  [2.44817278 3.17793735 2.70845216]
Avg DNSMOS noise   [ovrl, sig, bak]:  [1.34752958 1.58465638 1.95197176]
Avg DNSMOS cleaned [ovrl, sig, bak]:  [2.81851724 3.1888065  3.99714343]
