In [5]:
from deep_learning_features_audio import *
from deep_learning_dict_api import AudioAnalysisAPI
import pandas as pd
from pathlib import Path
from IPython.display import display, HTML
from torchmetrics import ScaleInvariantSignalNoiseRatio, ScaleInvariantSignalDistortionRatio, SignalNoiseRatio, SignalDistortionRatio, PermutationInvariantTraining
from torchmetrics.audio.pesq import PerceptualEvaluationSpeechQuality
from torchmetrics.functional.audio import signal_distortion_ratio
from torchmetrics.audio.stoi import ShortTimeObjectiveIntelligibility
from datetime import datetime
from deep_learning_dict_datasets import Datasets
import numpy as np
import json


In [6]:
# When running this tutorial in Google Colab, install the required packages
# with the following.
# !pip install torchaudio librosa boto3

import torch
import torchaudio
import torchaudio.functional as TAF
import torchaudio.transforms as T

print(torch.__version__)
print(torchaudio.__version__)

1.11.0+cu102
0.11.0+cu102


In [7]:
def speech_separation_evaluate_metric_with_model_on_libri3mix(model, dataset, metrics, n_test, mix_type="test_mix_clean_file"):
    task = "Speech Separation"
    total_si_snr = torch.zeros(0)
    total_si_sdr = torch.zeros(0)
    total_snr = torch.zeros(0)
    total_sdr = torch.zeros(0)
    total_nb_pesq = torch.zeros(0)
    total_wb_pesq = torch.zeros(0)
    total_pit = torch.zeros(0)
    total_stoi= torch.zeros(0)
    metrics_error = 0 
    experiments = {
    }

    result = {}
    # os.walk(dataset_path)
    print("===================================================>  Dataset: {}".format(dataset))
    # print(Datasets[task])
    print("Datasets[task][dataset] : {}".format(Datasets[task][dataset][mix_type]))
    test_table = pd.read_table(Datasets[task][dataset][mix_type], sep=",")
    # display(test_table)
    # cols = test_table.iloc[:,1:4]
    print("len(test_table): {}".format(len(test_table)))
    # display(cols)
    for i, row in enumerate(test_table.iterrows()):
        preds = torch.zeros(0) 
        target = torch.zeros(0)
        print("====> Benchmarking: {}/{}   tot {}".format(i+1, n_test, test_table.shape[0]))
        # print(i, row[1]['mixture_path'], , , row[1]['noise_path'])

        model_channels = AudioAnalysisAPI[model]['channels']
        dataset_channels = Datasets[task][dataset]['channels']
        if model_channels != dataset_channels:
            return {"error": "Model and Dataset have differents channels number"}


        mixture_path = row[1]['mixture_path']
        source_1_path = row[1]['source_1_path']
        source_2_path = row[1]['source_2_path']
        source_3_path = row[1]['source_3_path']

        #Taking model sample rate
        model_sample_rate = AudioAnalysisAPI[model]['sample_rate']
        dataset_sample_rate = Datasets[task][dataset]['sample_rate']
        print("model_sample_rate:{} - dataset_sample_rate:{}".format(model_sample_rate, dataset_sample_rate))
        # Writing on experiment original paths
        experiments[str(i)]= {
            "mixture_path":mixture_path,
            "source_1_path":source_1_path,
            "source_2_path":source_2_path,
            "source_3_path":source_3_path,
        }
        # Loading targets audio on tensors
        mixture_waveform, mixture_sample_rate = torchaudio.load(mixture_path)
        source_1_waveform, source_1_sample_rate = torchaudio.load(source_1_path)
        source_2_waveform, source_2_sample_rate = torchaudio.load(source_2_path)
        source_3_waveform, source_3_sample_rate = torchaudio.load(source_3_path)
        

        if model_sample_rate != dataset_sample_rate:
            mixture_waveform = TAF.resample(mixture_waveform, mixture_sample_rate, model_sample_rate)
            source_1_waveform = TAF.resample(source_1_waveform, source_1_sample_rate, model_sample_rate)
            source_2_waveform = TAF.resample(source_2_waveform, source_2_sample_rate, model_sample_rate)
            source_3_waveform = TAF.resample(source_3_waveform, source_3_sample_rate, model_sample_rate)
            mixture_path = os.path.join(os.path.split(mixture_path)[0] ,   "resampled_khz" + str(dataset_sample_rate) +   os.path.split(mixture_path)[1])
            source_1_path = os.path.join(os.path.split(source_1_path)[0] , "resampled_khz" + str(dataset_sample_rate) + os.path.split(source_1_path)[1])
            source_2_path = os.path.join(os.path.split(source_2_path)[0] , "resampled_khz" + str(dataset_sample_rate) + os.path.split(source_2_path)[1])
            source_3_path = os.path.join(os.path.split(source_3_path)[0] , "resampled_khz" + str(dataset_sample_rate) + os.path.split(source_3_path)[1])
            print("revisited mixture_path:{}".format(mixture_path))
            print("revisited source_1_path:{}".format(source_1_path))
            print("revisited source_2_path:{}".format(source_2_path))
            print("revisited source_3_path:{}".format(source_3_path))
            torchaudio.save(mixture_path, mixture_waveform, model_sample_rate)
            torchaudio.save(source_1_path, source_1_waveform, model_sample_rate)
            torchaudio.save(source_2_path, source_2_waveform, model_sample_rate)
            torchaudio.save(source_3_path, source_3_waveform, model_sample_rate)

            if mix_type == "test_mix_both_file":
                noise_path = row[1]['noise_path']
                noise_waveform, noise_sample_rate = torchaudio.load(noise_path)
                noise_waveform = TAF.resample(noise_waveform, noise_sample_rate, model_sample_rate)
                source_1_waveform = source_1_waveform + noise_waveform
                source_2_waveform = source_2_waveform + noise_waveform
                source_3_waveform = source_3_waveform + noise_waveform
        else:
            if mix_type == "test_mix_both_file":
                noise_path = row[1]['noise_path']
                noise_waveform, noise_sample_rate = torchaudio.load(noise_path)
                source_1_waveform = source_1_waveform + noise_waveform
                source_2_waveform = source_2_waveform + noise_waveform
                source_3_waveform = source_3_waveform + noise_waveform

        # Separate audio files with choosen model
        source_1_path_prediction, source_2_path_prediction, source_3_path_prediction = AudioAnalysisAPI[model]['function'](audiofile_path=mixture_path)
        
        # Loading predictions audio on tensors
        source_1_prediction_waveform, source_1_prediction_sample_rate = torchaudio.load(source_1_path_prediction)
        source_2_prediction_waveform, source_2_prediction_sample_rate = torchaudio.load(source_2_path_prediction)
        source_3_prediction_waveform, source_3_prediction_sample_rate = torchaudio.load(source_3_path_prediction)
        
            
        # Concatenating predictions into torch tensor 
        preds = torch.cat((preds, source_1_prediction_waveform), 0)
        preds = torch.cat((preds, source_2_prediction_waveform), 0)
        preds = torch.cat((preds, source_3_prediction_waveform), 0)
        # Concatenating targets into torch tensor 
        target = torch.cat((target, source_1_waveform), 0)
        target = torch.cat((target, source_2_waveform), 0)
        target = torch.cat((target, source_3_waveform), 0)

        
       
        # print("mixture_path : {}".format(mixture_path))
        # print("mixture_sample_rate:{}".format(mixture_sample_rate))
        # print("mixture_waveform.shape:{}".format(mixture_waveform.shape))
        # print("mixture_waveform:{}".format(mixture_waveform))
        # print()
        print("source_1_path : {}".format(source_1_path))
        print("source_1_sample_rate:{}".format(source_1_sample_rate))
        print("source_1_waveform.shape:{}".format(source_1_waveform.shape))
        print("source_1_waveform:{}".format(target[0]))
        print()
        print("source_1_path_prediction : {}".format(source_1_path_prediction))
        print("source_1_prediction_sample_rate: {}".format(source_1_prediction_sample_rate))
        print("source_1_prediction_waveform.shape : {}".format(source_1_prediction_waveform.shape))
        print("source_1_prediction_waveform: {}".format(preds[0]))
        print()
        print("source_2_path : {}".format(source_2_path))
        print("source_2_sample_rate:{}".format(source_2_sample_rate))
        print("source_2_waveform.shape:{}".format(source_2_waveform.shape))
        print("source_2_waveform:{}".format(target[1]))
        print()
        print("source_2_path_prediction : {}".format(source_2_path_prediction))
        print("source_2_prediction_sample_rate: {}".format(source_2_prediction_sample_rate))
        print("source_2_prediction_waveform.shape : {}".format(source_2_prediction_waveform.shape))
        print("source_2_prediction_waveform: {}".format(preds[1]))
        print()
        print("source_3_path : {}".format(source_3_path))
        print("source_3_sample_rate:{}".format(source_3_sample_rate))
        print("source_3_waveform.shape:{}".format(source_3_waveform.shape))
        print("source_3_waveform:{}".format(target[2]))
        print()
        print("source_3_path_prediction : {}".format(source_3_path_prediction))
        print("source_3_prediction_sample_rate: {}".format(source_3_prediction_sample_rate))
        print("source_3_prediction_waveform.shape : {}".format(source_3_prediction_waveform.shape))
        print("source_3_prediction_waveform: {}".format(preds[2]))

        print("preds.shape {}\ntarget.shape:{}".format(preds.shape, target.shape))
        print("preds: {},\ntarget:{}".format(preds, target))
        
        experiments[str(i)]["source_1_path_prediction"] = source_1_path_prediction
        experiments[str(i)]["source_2_path_prediction"] = source_2_path_prediction
        experiments[str(i)]["source_3_path_prediction"] = source_3_path_prediction

        
        try:
            for metric in metrics:
            #print("Calculating metric:", metric)
                if metric == "si-snr":
                    si_snr = ScaleInvariantSignalNoiseRatio()
                    si_snr_result = si_snr(preds, target)
                    si_snr_result = torch.reshape(si_snr_result, (1, 1))
                    total_si_snr = torch.cat((total_si_snr, si_snr_result))
                    experiments[str(i)]["si-snr"] = str(si_snr_result)

                if metric == "si-sdr":
                    si_sdr = ScaleInvariantSignalDistortionRatio()
                    si_sdr_result = si_sdr(preds, target)
                    si_sdr_result = torch.reshape(si_sdr_result, (1, 1))
                    total_si_sdr = torch.cat((total_si_sdr, si_sdr_result)) 
                    experiments[str(i)]["si-sdr"] = str(si_sdr_result)                

                if metric == "snr":
                    snr = SignalNoiseRatio()
                    snr_result = snr(preds, target)
                    snr_result = torch.reshape(snr_result, (1, 1))
                    total_snr = torch.cat((total_snr, snr_result))
                    experiments[str(i)]["snr"] = str(snr_result)
                    
                if metric == "sdr":
                    sdr = SignalDistortionRatio()
                    sdr_result = sdr(preds, target)
                    sdr_result = torch.reshape(sdr_result, (1, 1))
                    total_sdr = torch.cat((total_sdr, sdr_result))
                    experiments[str(i)]["sdr"] = str(sdr_result)

                if metric == "pesq":
                    nb_pesq = PerceptualEvaluationSpeechQuality(model_sample_rate, 'nb')
                    nb_pesq_result = nb_pesq(preds, target)
                    nb_pesq_result = torch.reshape(nb_pesq_result, (1, 1))
                    total_nb_pesq = torch.cat((total_nb_pesq, nb_pesq_result))
                    experiments[str(i)]["pesq"] = str(nb_pesq_result)

                    if model_sample_rate > 8000: 
                        wb_pesq = PerceptualEvaluationSpeechQuality(model_sample_rate, 'wb')
                        wb_pesq_result = wb_pesq(preds, target)
                        wb_pesq_result = torch.reshape(wb_pesq_result, (1, 1))
                        total_wb_pesq = torch.cat((total_wb_pesq, wb_pesq_result))
                        experiments[str(i)]["wb-pesq"] = str(wb_pesq_result)

                if metric == "pit":
                    pit = PermutationInvariantTraining(signal_distortion_ratio, 'max')
                    pit_result = pit(preds, target)
                    pit_result = torch.reshape(pit_result, (1, 1))
                    total_pit = torch.cat((total_pit, pit_result))
                    experiments[str(i)]["pit"] = str(pit_result)

                if metric == "stoi":
                    stoi_src = ShortTimeObjectiveIntelligibility(model_sample_rate, False)
                    stoi_result = stoi_src(preds, target)
                    stoi_result = torch.reshape(stoi_result, (1, 1))
                    total_stoi = torch.cat((total_stoi, stoi_result))
                    experiments[str(i)]["stoi"] = str(stoi_result)
        except Exception as e:
            print("=====> ERROR: {}".format(e))
            metrics_error += 1



        if i == n_test - 1: 
            break

    
    total_si_snr = torch.sum(total_si_snr)/n_test
    total_si_sdr = torch.sum(total_si_sdr)/n_test
    total_snr = torch.sum(total_snr)/n_test
    total_sdr = torch.sum(total_sdr)/n_test
    total_pit = torch.sum(total_pit)/n_test
    total_wb_pesq = torch.sum(total_wb_pesq)/n_test
    total_nb_pesq = torch.sum(total_nb_pesq)/n_test
    total_stoi = torch.sum(total_stoi)/n_test

    print("============================================================================================")
    print("total_si_snr:{}".format(total_si_snr)) 
    print("total_si_sdr:{}".format(total_si_sdr)) 
    print("total_snr:{}".format(total_snr)) 
    print("total_sdr:{}".format(total_sdr)) 
    print("total_pit:{}".format(total_pit)) 
    print("total_wb_pesq:{}".format(total_wb_pesq))  
    print("total_nb_pesq:{}".format(total_nb_pesq))  
    print("total_stoi:{}".format(total_stoi))  
    # Creating the content of result file
    result ={
        "model": model, 
        "dataset": dataset, 
        "n_test": n_test,
        "metrics_error": metrics_error,
        "model_sample_rate":model_sample_rate,
        "dataset_sample_rate":dataset_sample_rate,
        "n_test_done": n_test - metrics_error,
        "total_si_snr":str(total_si_snr),
        "total_si_sdr":str(total_si_sdr),
        "total_snr":str(total_snr),
        "total_sdr":str(total_sdr),
        "total_pit":str(total_pit),
        "total_wb_pesq":str(total_wb_pesq),
        "total_nb_pesq":str(total_nb_pesq),
        "total_stoi":str(total_stoi),
        "experiments": experiments,
    }

    # Creating filename
    dateTimeObj = datetime.now()
    print(dateTimeObj)
    timestampStr = dateTimeObj.strftime("%d_%b_%Y__%H_%M_%S_%f")
    result_filename = timestampStr + "_evaluate_" + model.split("/")[-1] + "_" + dataset + "_" + mix_type + "_" + str(n_test)+".json"
    print('result_filename : ', result_filename)
    with open(result_filename, 'w') as f:
        json.dump(result, f)
        
    return result, n_test

speech_separation_3_channels_models = ['/api/audioseparation/speech_separation_sepformer_wsj03mix']
speech_separation_dataset = ["Libri3Mix8kMin", "Libri3Mix8kMax", "Libri3Mix16kMin", "Libri3Mix16kMax"]
metrics = ["si-snr", "si-sdr", "sdr", "snr", "pesq", "stoi"]
n_test = 3

# type = "test_mix_clean_file" #"test_mix_both_file" "test_mix_single_file"
type = "test_mix_both_file" 
# type = "test_mix_single_file"

# speech_separation_evaluate_metric_with_model_on_librimix_2_channels(
#     speech_separation_2_channels_models[0], 
#     speech_separation_dataset[0],
#     [metrics[0], metrics[1], metrics[2], metrics[3], metrics[4], metrics[6]], 
#     n_test=n_test )
speech_separation_evaluate_metric_with_model_on_libri3mix(speech_separation_3_channels_models[0], speech_separation_dataset[0], metrics, n_test=n_test, mix_type=type )
speech_separation_evaluate_metric_with_model_on_libri3mix(speech_separation_3_channels_models[0], speech_separation_dataset[1], metrics, n_test=n_test, mix_type=type )
speech_separation_evaluate_metric_with_model_on_libri3mix(speech_separation_3_channels_models[0], speech_separation_dataset[2], metrics, n_test=n_test, mix_type=type )
speech_separation_evaluate_metric_with_model_on_libri3mix(speech_separation_3_channels_models[0], speech_separation_dataset[3], metrics, n_test=n_test, mix_type=type )




Datasets[task][dataset] : /storage/data_8T/datasets/audio/LibriMix/Libri3Mix/wav16k/max/metadata/mixture_test_mix_both.csv
len(test_table): 3000
====> Benchmarking: 1/3   tot 3000
model_sample_rate:8000 - dataset_sample_rate:16000
revisited mixture_path:/storage/data_8T/datasets/audio/LibriMix/Libri3Mix/wav16k/max/test/mix_both/resampled_khz160004077-13754-0001_5142-33396-0065_5683-32866-0012.wav
revisited source_1_path:/storage/data_8T/datasets/audio/LibriMix/Libri3Mix/wav16k/max/test/s1/resampled_khz160004077-13754-0001_5142-33396-0065_5683-32866-0012.wav
revisited source_2_path:/storage/data_8T/datasets/audio/LibriMix/Libri3Mix/wav16k/max/test/s2/resampled_khz160004077-13754-0001_5142-33396-0065_5683-32866-0012.wav
revisited source_3_path:/storage/data_8T/datasets/audio/LibriMix/Libri3Mix/wav16k/max/test/s3/resampled_khz160004077-13754-0001_5142-33396-0065_5683-32866-0012.wav
source_1_path : /storage/data_8T/datasets/audio/LibriMix/Libri3Mix/wav16k/max/test/s1/resampled_khz160004077

                not been set for this class (ScaleInvariantSignalNoiseRatio). The property determines if `update` by
                default needs access to the full metric state. If this is not the case, significant speedups can be
                achieved and we recommend setting this to `False`.
                We provide an checking function
                `from torchmetrics.utilities import check_forward_full_state_property`
                that can be used to check if the `full_state_update=True` (old and potential slower behaviour,
                default for now) or if `full_state_update=False` can be used safely.
                
                not been set for this class (ScaleInvariantSignalDistortionRatio). The property determines if `update` by
                default needs access to the full metric state. If this is not the case, significant speedups can be
                achieved and we recommend setting this to `False`.
                We provide an checking function

====> Benchmarking: 2/3   tot 3000
model_sample_rate:8000 - dataset_sample_rate:16000
revisited mixture_path:/storage/data_8T/datasets/audio/LibriMix/Libri3Mix/wav16k/max/test/mix_both/resampled_khz160004507-16021-0025_1188-133604-0025_4992-23283-0016.wav
revisited source_1_path:/storage/data_8T/datasets/audio/LibriMix/Libri3Mix/wav16k/max/test/s1/resampled_khz160004507-16021-0025_1188-133604-0025_4992-23283-0016.wav
revisited source_2_path:/storage/data_8T/datasets/audio/LibriMix/Libri3Mix/wav16k/max/test/s2/resampled_khz160004507-16021-0025_1188-133604-0025_4992-23283-0016.wav
revisited source_3_path:/storage/data_8T/datasets/audio/LibriMix/Libri3Mix/wav16k/max/test/s3/resampled_khz160004507-16021-0025_1188-133604-0025_4992-23283-0016.wav
source_1_path : /storage/data_8T/datasets/audio/LibriMix/Libri3Mix/wav16k/max/test/s1/resampled_khz160004507-16021-0025_1188-133604-0025_4992-23283-0016.wav
source_1_sample_rate:16000
source_1_waveform.shape:torch.Size([1, 73720])
source_1_waveform:

({'model': '/api/audioseparation/speech_separation_sepformer_wsj03mix',
  'dataset': 'Libri3Mix16kMax',
  'n_test': 3,
  'metrics_error': 0,
  'model_sample_rate': 8000,
  'dataset_sample_rate': 16000,
  'n_test_done': 3,
  'total_si_snr': 'tensor(-6.2532)',
  'total_si_sdr': 'tensor(-6.2826)',
  'total_snr': 'tensor(-7.2865)',
  'total_sdr': 'tensor(-2.7099)',
  'total_pit': 'tensor(0.)',
  'total_wb_pesq': 'tensor(0.)',
  'total_nb_pesq': 'tensor(1.7595)',
  'total_stoi': 'tensor(0.4512)',
  'experiments': {'0': {'mixture_path': '/storage/data_8T/datasets/audio/LibriMix/Libri3Mix/wav16k/max/test/mix_both/4077-13754-0001_5142-33396-0065_5683-32866-0012.wav',
    'source_1_path': '/storage/data_8T/datasets/audio/LibriMix/Libri3Mix/wav16k/max/test/s1/4077-13754-0001_5142-33396-0065_5683-32866-0012.wav',
    'source_2_path': '/storage/data_8T/datasets/audio/LibriMix/Libri3Mix/wav16k/max/test/s2/4077-13754-0001_5142-33396-0065_5683-32866-0012.wav',
    'source_3_path': '/storage/data_8T/d

In [8]:

# Python program to explain os.path.split() method 
    
# importing os module 
import os
  
# path
path = '/home/User/Desktop/file.txt'
  
# Split the path in 
# head and tail pair
head_tail = os.path.split(path)
  
# print head and tail
# of the specified path
print("Head of '% s:'" % path, head_tail[0])
print("Tail of '% s:'" % path, head_tail[1], "\n")
  
  
# path
path = '/home/User/Desktop/'
  
# Split the path in 
# head and tail pair
head_tail = os.path.split(path)
  
# print head and tail
# of the specified path
print("Head of '% s:'" % path, head_tail[0])
print("Tail of '% s:'" % path, head_tail[1], "\n")
  
# path
path = 'file.txt'
  
# Split the path in 
# head and tail pair
head_tail = os.path.split(path)
  
# print head and tail
# of the specified path
print("Head of '% s:'" % path, head_tail[0])
print("Tail of '% s:'" % path, head_tail[1])

Head of '/home/User/Desktop/file.txt:' /home/User/Desktop
Tail of '/home/User/Desktop/file.txt:' file.txt 

Head of '/home/User/Desktop/:' /home/User/Desktop
Tail of '/home/User/Desktop/:'  

Head of 'file.txt:' 
Tail of 'file.txt:' file.txt
