In [1]:
import os
import json
import torch
import librosa
import numpy as np
from tqdm import tqdm
from discrete_speech_metrics import SpeechBERTScore

DATASETS_DIR = "/datasets/sbm/"
SPLITS = ["inputs", "references"]

LAYERS_TO_ANALYZE = list(range(1, 25))

dataset_names = os.listdir(DATASETS_DIR)

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


In [3]:
def bert_score(v_generated, v_reference):
    """
    Args:
        v_generated (torch.Tensor): Generated feature tensor (T, D).
        v_reference (torch.Tensor): Reference feature tensor (T, D).
    Returns:
        float: Precision.
        float: Recall.
        float: F1 score.
    """
    # Calculate cosine similarity
    sim_matrix = torch.matmul(v_generated, v_reference.T) / (torch.norm(v_generated, dim=1, keepdim=True) * torch.norm(v_reference, dim=1).unsqueeze(0))

    # Calculate precision and recall
    precision = torch.max(sim_matrix, dim=1)[0].mean().item()
    recall = torch.max(sim_matrix, dim=0)[0].mean().item()

    # Calculate F1 score
    f1_score = 2 * precision * recall / (precision + recall)

    return precision, recall, f1_score

class SpeechBERTScoreWithRef(SpeechBERTScore):

    def __init__(self, sr=16000, model_type="hubert-base", layer=None, use_gpu=True):
        super().__init__(sr=sr, model_type=model_type, layer=layer, use_gpu=use_gpu)
        self.ref = None
        
    def set_ref(self, gt_wav):
        """
        Args:
            gt_wav (np.ndarray): Ground truth waveform (T,).
        """
        if self.ref is not None:
            del self.ref
        
        gt_wav = torch.from_numpy(gt_wav).unsqueeze(0).to(self.device).float()
        if self.sr != 16000:
            gt_wav = self.resampler(gt_wav)
        
        v_ref = self.process_feats(gt_wav)
        self.ref = v_ref.squeeze(0)

    def score_with_ref(self, gen_wav):
        """
        Args:
            gen_wav (np.ndarray): Generated waveform (T,).
        Returns:
            float: Precision.
            float: Recall.
            float: F1 score.
        """
        gen_wav = torch.from_numpy(gen_wav).unsqueeze(0).to(self.device).float()

        if self.sr != 16000:
            gen_wav = self.resampler(gen_wav)
        
        v_gen = self.process_feats(gen_wav)
        precision, recall, f1_score = bert_score(v_gen.squeeze(0), self.ref)

        return precision, recall, f1_score

class SpeechBERTScoreWithRefLayerAnalysis(SpeechBERTScore):

    def __init__(self, sr=16000, model_type="hubert-base", layers=[8,], use_gpu=True):
        super().__init__(sr=sr, model_type=model_type, layer=None, use_gpu=use_gpu)
        self.ref_feats = None
        self.layers = layers
    
    def process_all_feats(self, audio):
        feats_hiddens = self.model(audio, output_hidden_states=True).hidden_states
        return feats_hiddens

    def set_ref(self, gt_wav):
        """
        Args:
            gt_wav (np.ndarray): Ground truth waveform (T,).
        """
        if self.ref_feats is not None:
            del self.ref_feats
        
        gt_wav = torch.from_numpy(gt_wav).unsqueeze(0).to(self.device).float()
        if self.sr != 16000:
            gt_wav = self.resampler(gt_wav)
        
        self.ref_feats = self.process_all_feats(gt_wav)
        
    def score_with_ref_by_layer(self, gen_wav):
        """
        Args:
            gen_wav (np.ndarray): Generated waveform (T,).
        Returns:
            results (dict): Dictionary of prec, rec, f1 scores across selected layers
            i.e.:
                {0: (prec, rec, f1_score), 1: (prec, rec, f1_score), ...}
        """
        gen_wav = torch.from_numpy(gen_wav).unsqueeze(0).to(self.device).float()

        if self.sr != 16000:
            gen_wav = self.resampler(gen_wav)
        
        gen_feats = self.process_all_feats(gen_wav)
        
        results = {}
        for layer_num in self.layers:
            v_gen = gen_feats[layer_num].squeeze(0)
            v_ref = self.ref_feats[layer_num].squeeze(0)
            
            precision, recall, f1_score = bert_score(v_gen, v_ref)
            results[layer_num] = (precision, recall, f1_score)
        
        return results

In [4]:
datasets = {}
totals = {
    "inputs": 0,
    "references": 0,
}

# ['atc', 'cv-id', 'cv-vi', 'maritime', 'nsc-room']
for dataset_name in dataset_names:
    
    datasets[dataset_name] = {}
    for split in SPLITS:

        loaded_audio = []
        
        split_dir = os.path.join(DATASETS_DIR, dataset_name, split)
        for filename in os.listdir(split_dir):
            filepath = os.path.join(split_dir, filename)
            arr, sr = librosa.load(filepath, sr=16000, mono=True)
            loaded_audio.append(arr)
            totals[split] += 1
            
        datasets[dataset_name][split] = loaded_audio

In [5]:
metrics = SpeechBERTScoreWithRefLayerAnalysis(
    sr=16000,
    model_type="wavlm-large",
    layers=LAYERS_TO_ANALYZE,
    use_gpu=True)

In [6]:
results = {}

with tqdm(total=totals["inputs"] * totals["references"]) as pbar:
    for ref_dataset_name in dataset_names:
        results[ref_dataset_name] = {}
        for ref_audio in datasets[ref_dataset_name]["references"]:
            metrics.set_ref(ref_audio)
            
            for input_dataset_name in dataset_names:
                if input_dataset_name not in results[ref_dataset_name]:
                    results[ref_dataset_name][input_dataset_name] = []
                
                for input_audio in datasets[input_dataset_name]["inputs"]:
                    scores = metrics.score_with_ref_by_layer(input_audio)
                    results[ref_dataset_name][input_dataset_name].append(scores)
                    pbar.update(1)

100% 25000/25000 [2:16:50<00:00,  3.04it/s]  


In [7]:
with open("results_all.json", mode="w") as fw:
    json.dump(results, fw)    