In [1]:
import torch
import causal_improved_sudormrf_v3 
import soundfile as sf
from os.path import join as pjoin
import time as t
import numpy as np
import tqdm
import matplotlib.pyplot as plt

In [2]:
torch.set_num_threads(1)

In [3]:
mps_device = torch.device('cpu')

In [4]:
def load_sudormrf_causal_cpu(model_path, device):
    # 1: declarem el model (instanciem la classe)
    model = causal_improved_sudormrf_v3.CausalSuDORMRF(
        in_audio_channels=1,
        out_channels=512,
        in_channels=256,
        num_blocks=16,
        upsampling_depth=5,
        enc_kernel_size=21,
        enc_num_basis=512,
        num_sources=1,
        )
    model = torch.nn.DataParallel(model)
    model.load_state_dict(torch.load(model_path, map_location=device))
    model = model.module.to(device)
    model.eval()
    return model

In [5]:
fs = 16000

In [6]:
model = load_sudormrf_causal_cpu('e39_sudo_whamr_16k_enhnoisy_augment.pt', mps_device)

In [7]:
#now load 5s mixture

In [8]:
mixture, fs = sf.read('16k_mixture.wav')

In [9]:
mixture = torch.tensor(mixture, dtype=torch.float32)

In [10]:
mixture = mixture[:16000*5]

In [11]:
audio_length = len(mixture) / fs

In [12]:
audio_length

5.0

In [13]:
len(mixture)

80000

In [14]:
# test how performance degrades when using short chunks

In [15]:
num_chunks = (np.linspace(1, 150, num=20)).astype('int').tolist()

In [None]:
times = []
sisdrs = []
chunk_sizes = []
minitimes = []
for chunks in tqdm.tqdm(num_chunks):
    #try:
    if chunks == 1:
        mix_list = [mixture]
    else:
        # split into simple chunks (rectangular window and no overlap), dropping the last one
        mix_list = torch.split(mixture, len(mixture) // chunks)#[0:-1]
    chunk_sizes.append(len(mix_list[0]))
    #targets_list = torch.split(targets, len(mixture) // chunks, dim=1)[0:-1]
    # compute the audio
    rec_sources = []
    tic = t.time()
    for m in mix_list:
        tic2 = t.time()
        rec_sources.append(model(m.unsqueeze(0).unsqueeze(1)).squeeze())
        tac2 = t.time()
        minitimes.append(tac2 - tic2)
    #rec_sources = torch.cat(rec_sources, dim=1)
    #sisdrs.append(loss(rec_sources.unsqueeze(0), targets[:, :rec_sources.shape[1]].unsqueeze(0)).item())
    tac = t.time()
    times.append(tac - tic)
    '''
    except:
        times.append('nan')
        #sisdrs.append('nan')
        chunk_sizes.append('nan')
    ''';

 40%|█████████████████▌                          | 8/20 [00:49<01:32,  7.71s/it]

In [None]:
# necessitem 4.7 segons per processar 5s

In [None]:
# cada frame de 0.25s necessita... 0.22 en la meva CPU

In [None]:
ms_sizes = [x*1000 / fs for x in chunk_sizes]

In [None]:
results = {'times' : times, 'ms_sizes' : ms_sizes}

In [None]:
len(mixture)/fs

In [None]:
realtime_factor = [x/(len(mixture)/fs) for x in times]

In [None]:
plt.subplot(2,1,1)
plt.plot(ms_sizes, times)
plt.ylabel('Computation time [s]')
plt.xlabel('window size [ms]')
plt.tight_layout()
plt.grid(True)

plt.subplot(2,1,2)
plt.plot(ms_sizes, realtime_factor)
plt.ylabel('realtime factor')
plt.xlabel('window size [ms]')
plt.tight_layout()
plt.grid(True)

In [None]:
# in my desktop, 16000samples windows should be enough. 
# in this laptop we see we'll need at least 250ms windows, which means 4000 samples