In [19]:
import torch
import causal_improved_sudormrf_v3 
import soundfile as sf
from os.path import join as pjoin
import time as t
import numpy as np
import tqdm
import matplotlib.pyplot as plt

In [20]:
torch.set_num_threads(8)

In [21]:
mps_device = torch.device('cpu')

In [22]:
def load_sudormrf_causal_cpu(model_path, device):
    # 1: declarem el model (instanciem la classe)
    model = causal_improved_sudormrf_v3.CausalSuDORMRF(
        in_audio_channels=1,
        out_channels=512,
        in_channels=256,
        num_blocks=16,
        upsampling_depth=5,
        enc_kernel_size=21,
        enc_num_basis=512,
        num_sources=1,
        )
    model = torch.nn.DataParallel(model)
    model.load_state_dict(torch.load(model_path, map_location=device))
    model = model.module.to(device)
    model.eval()
    return model

In [23]:
fs = 16000

In [24]:
model = load_sudormrf_causal_cpu('e39_sudo_whamr_16k_enhnoisy_augment.pt', mps_device)

In [25]:
#now load 5s mixture

In [26]:
mixture, fs = sf.read('16k_mixture.wav')

In [27]:
mixture = torch.tensor(mixture, dtype=torch.float32)

In [28]:
mixture = mixture[:16000*5]

In [29]:
audio_length = len(mixture) / fs

In [30]:
audio_length

5.0

In [31]:
len(mixture)

80000

In [32]:
# test how performance degrades when using short chunks

In [33]:
num_chunks = [20] #latència de 250ms

In [34]:
times = []
sisdrs = []
chunk_sizes = []
minitimes = []
for chunks in tqdm.tqdm(num_chunks):
    #try:
    if chunks == 1:
        mix_list = [mixture]
    else:
        # split into simple chunks (rectangular window and no overlap), dropping the last one
        mix_list = torch.split(mixture, len(mixture) // chunks)#[0:-1]
    chunk_sizes.append(len(mix_list[0]))
    #targets_list = torch.split(targets, len(mixture) // chunks, dim=1)[0:-1]
    # compute the audio
    rec_sources = []
    tic = t.time()
    for m in mix_list:
        tic2 = t.time()
        rec_sources.append(model(m.unsqueeze(0).unsqueeze(1)).squeeze())
        tac2 = t.time()
        minitimes.append(tac2 - tic2)
    #rec_sources = torch.cat(rec_sources, dim=1)
    #sisdrs.append(loss(rec_sources.unsqueeze(0), targets[:, :rec_sources.shape[1]].unsqueeze(0)).item())
    tac = t.time()
    times.append(tac - tic)
    '''
    except:
        times.append('nan')
        #sisdrs.append('nan')
        chunk_sizes.append('nan')
    ''';

100%|██████████| 1/1 [00:03<00:00,  3.53s/it]


In [35]:
# necessitem 5.13 segons per processar 5s
times

[3.1876773834228516]

In [36]:
# cada frame de 0.25s necessita...
minitimes

[0.1733717918395996,
 0.15123796463012695,
 0.16485881805419922,
 0.16848468780517578,
 0.16979193687438965,
 0.16264700889587402,
 0.15744423866271973,
 0.15974926948547363,
 0.15121984481811523,
 0.15453529357910156,
 0.1512129306793213,
 0.15890216827392578,
 0.1559598445892334,
 0.16179680824279785,
 0.16660618782043457,
 0.15880703926086426,
 0.1587047576904297,
 0.15050601959228516,
 0.15356063842773438,
 0.15816283226013184]

In [None]:
ms_sizes = [x*1000 / fs for x in chunk_sizes]

In [None]:
results = {'times' : times, 'ms_sizes' : ms_sizes}

In [None]:
len(mixture)/fs

In [None]:
realtime_factor = [x/(len(mixture)/fs) for x in times]

In [None]:
plt.subplot(2,1,1)
plt.plot(ms_sizes, times)
plt.ylabel('Computation time [s]')
plt.xlabel('window size [ms]')
plt.tight_layout()
plt.grid(True)

plt.subplot(2,1,2)
plt.plot(ms_sizes, realtime_factor)
plt.ylabel('realtime factor')
plt.xlabel('window size [ms]')
plt.tight_layout()
plt.grid(True)

In [None]:
# in my desktop, 16000samples windows should be enough. 
# in this laptop we see we'll need at least 250ms windows, which means 4000 samples