In [1]:
#!g1.1
import os
import IPython.display as ipd
import torch, torchaudio
from xnot_matcher import XNOTNeighborsVC

from speechbrain.pretrained import EncoderClassifier

import glob
import numpy as np
import json

from tqdm.auto import tqdm
from collections import defaultdict
from sklearn.metrics import roc_curve
from jiwer import cer, wer

device = 'cuda'

In [2]:
#!g1.1

from speechkit import configure_credentials, creds
from speechkit import model_repository
from speechkit.stt import AudioProcessingType

configure_credentials(
    yandex_credentials=creds.YandexCredentials(
        api_key=os.environ['api_key'],
    )
)

model = model_repository.recognition_model()

model.model = 'general:rc'
model.language = 'en-US'
model.audio_processing_type = AudioProcessingType.Full

In [5]:
#!g1.1
classifier = EncoderClassifier.from_hparams(source="speechbrain/spkrec-xvect-voxceleb", savedir="pretrained_models/spkrec-xvect-voxceleb", run_opts={"device": device})

hyperparams.yaml:   0%|          | 0.00/2.04k [00:00<?, ?B/s]

embedding_model.ckpt:   0%|          | 0.00/16.9M [00:00<?, ?B/s]

mean_var_norm_emb.ckpt:   0%|          | 0.00/3.20k [00:00<?, ?B/s]

classifier.ckpt:   0%|          | 0.00/15.9M [00:00<?, ?B/s]

label_encoder.txt:   0%|          | 0.00/129k [00:00<?, ?B/s]

In [6]:
#!g1.1
knn_vc = torch.hub.load('bshall/knn-vc', 'knn_vc', prematched=True, trust_repo=True, pretrained=True, device=device)
xnot_vc = XNOTNeighborsVC(knn_vc.wavlm, knn_vc.hifigan, knn_vc.h, device=device).eval()

Downloading: "https://github.com/bshall/knn-vc/zipball/master" to /tmp/xdg_cache/torch/hub/master.zip
Downloading: "https://github.com/bshall/knn-vc/releases/download/v0.1/prematch_g_02500000.pt" to /tmp/xdg_cache/torch/hub/checkpoints/prematch_g_02500000.pt
100%|██████████| 63.1M/63.1M [00:00<00:00, 118MB/s] 
Downloading: "https://github.com/bshall/knn-vc/releases/download/v0.1/WavLM-Large.pt" to /tmp/xdg_cache/torch/hub/checkpoints/WavLM-Large.pt


Removing weight norm...
[HiFiGAN] Generator loaded with 16,523,393 parameters.


100%|██████████| 1.18G/1.18G [00:09<00:00, 138MB/s] 


WavLM-Large loaded with 315,453,120 parameters.


In [7]:
#!g1.1
libri_folder = 'data/LibriSpeech/test-clean/'
speakers = list(sorted(os.listdir(libri_folder)))

In [8]:
#!g1.1
all_speaker_audios = {}
for speaker in speakers:
    files = glob.glob(f'{libri_folder}/{speaker}/*/*.flac', recursive=True)
    all_speaker_audios[speaker] = files

In [9]:
#!g1.1
chosen = {}
for speaker in speakers:
    files = glob.glob(f'{libri_folder}/{speaker}/*/*.flac', recursive=True)
    chosen[speaker] = np.random.choice(files, 5, replace=False)

In [11]:
#!g1.1
matching_sets = {}

for speaker in tqdm(speakers):
    audios = []
    for filename in chosen[speaker]:
        audio, _ = torchaudio.load(filename)
        audios.append(audio)
    matching_sets[speaker] = knn_vc.get_matching_set(audios).cpu()
    
torch.save(matching_sets, 'data/matchings/matching_sets')

In [None]:
#!g1.1
n_speakers = 10
speakers_list = speakers[:n_speakers]


for src_speaker in tqdm(speakers_list):
    for target_speaker in speakers_list:
        if src_speaker == target_speaker:
            continue
        print(f'Conversing {src_speaker=} to {target_speaker=}')
        for filename in chosen[src_speaker]:
            audio, _ = torchaudio.load(filename)
            query_seq = knn_vc.get_features(audio)
            idx = filename.split('/')[-1].split('.')[0]

            for i, W in enumerate([1.0, 2.0, 4.0]):
                path = f'w-{int(W)}/target-{target_speaker}-idx-{idx}-src-{src_speaker}'
                if os.path.exists(f'data/x_nots/{path}'):
                    continue
                out_wav_xnot, xnot = xnot_vc.match(query_seq, matching_sets[target_speaker], topk=4, algorithm='xnot', W=W, max_steps=200)
                torchaudio.save(f'data/audios/xnot/{path}.wav', out_wav_xnot[None].cpu(), 16000)
                torch.save(
                {
                    'state_dict': xnot.state_dict()
                }, f'data/x_nots/{path}')
            path = f'target-{target_speaker}-idx-{idx}-src-{src_speaker}'
            out_wav_knn, _ = xnot_vc.match(query_seq, matching_sets[target_speaker], topk=4, algorithm='knn')
            torchaudio.save(f'data/audios/knn/{path}.wav', out_wav_knn[None].cpu(), 16000)

In [None]:
#!g1.1
def compute_eer(label, pred, positive_label=1):
    fpr, tpr, threshold = roc_curve(label, pred, pos_label=positive_label)
    fnr = 1 - tpr
    eer_threshold = threshold[np.nanargmin(np.absolute((fnr - fpr)))]
    eer_1 = fpr[np.nanargmin(np.absolute((fnr - fpr)))]
    eer_2 = fnr[np.nanargmin(np.absolute((fnr - fpr)))]
    eer = (eer_1 + eer_2) / 2
    return eer

In [None]:
#!g1.1
sim = torch.nn.CosineSimilarity().to(device)

knn_similarities = []

for target_speaker in tqdm(chosen, total=len(chosen)):
    conversions = glob.glob(f'data/audios/knn/target-{target_speaker}*', recursive=True)
    source_audios = np.random.choice(all_speaker_audios[target_speaker], size=len(conversions))

    for source, conversion in zip(source_audios, conversions):
        source_embedding = classifier.encode_batch(torchaudio.load(source)[0]).squeeze(1)
        converted_embedding = classifier.encode_batch(torchaudio.load(conversion)[0]).squeeze(1)
        knn_similarities.append(sim(source_embedding, converted_embedding).cpu().item())

        
gt_similarities = []

for target_speaker in np.random.choice(speakers, size=len(knn_similarities)):
    src_speaker = target_speaker
    while src_speaker == target_speaker:
        src_speaker = np.random.choice(speakers)

    target_audio = np.random.choice(all_speaker_audios[target_speaker], size=1)
    source_audio = np.random.choice(all_speaker_audios[src_speaker], size=1)

    source_embedding = classifier.encode_batch(torchaudio.load(source_audio)[0]).squeeze(1)
    converted_embedding = classifier.encode_batch(torchaudio.load(target_audio)[0]).squeeze(1)
    gt_similarities.append(sim(source_embedding, converted_embedding).cpu().item())

        
preds_knn = np.array(knn_similarities + gt_similarities)
targets = np.array([1. for _ in knn_similarities] + [0. for _ in gt_similarities])
knn_eer = compute_eer(targets, preds_knn, positive_label=0)

results = []
for W in [1.0, 2.0, 4.0]:
    similarities = []

    for target_speaker in tqdm(chosen, total=len(chosen)):
        conversions = glob.glob(f'data/audios/xnot/w-{int(W)}/target-{target_speaker}*', recursive=True)
        source_audios = np.random.choice(all_speaker_audios[target_speaker], size=len(conversions))

        for source, conversion in zip(source_audios, conversions):
            source_embedding = classifier.encode_batch(torchaudio.load(source)[0]).squeeze(1)
            converted_embedding = classifier.encode_batch(torchaudio.load(conversion)[0]).squeeze(1)
            similarities.append(sim(source_embedding, converted_embedding).cpu().item())
    
    preds_xnot = np.array(similarities + gt_similarities)
    targets = np.array([1. for _ in similarities] + [0. for _ in gt_similarities])

    results.append(compute_eer(targets, preds_xnot, positive_label=0))

In [None]:
#!g1.1
recognitions = defaultdict(dict)

for speaker, filenames in tqdm(chosen.items(), total=len(chosen)):
    for filename in filenames:
        if filename in recognitions['src']:
            continue
        recognition = model.transcribe_file(filename)
        assert len(recognition) == 1
        recognitions['src'][filename] = recognition[0].raw_text

with open('recognitions-src.json', 'w') as f:
    json.dump(recognitions['src'], f, ensure_ascii=False, indent=4)

In [None]:
#!g1.1
knn_wavs = glob.glob(f'data/audios/knn/*.wav', recursive=True)
xnot_all_wavs = glob.glob(f'data/audios/xnot/*/*.wav', recursive=True)

for filename in tqdm(knn_wavs):
    if filename in recognitions['knn']:
        continue
    recognition = model.transcribe_file(filename)
    assert len(recognition) == 1
    recognitions['knn'][filename] = recognition[0].raw_text
    
for filename in tqdm(xnot_all_wavs):
    if filename in recognitions['xnot']:
        continue
    recognition = model.transcribe_file(filename)
    assert len(recognition) == 1
    recognitions['xnot'][filename] = recognition[0].raw_text

In [None]:
#!g1.1
with open('recognitions.json', 'w') as f:
    json.dump(recognitions, f, ensure_ascii=False, indent=4)

In [None]:
#!g1.1
text_files = glob.glob(f'{libri_folder}/**/*trans.txt', recursive=True)

In [None]:
#!g1.1
gt_texts = {}

for file in text_files:
    with open(file) as f:
        for line in f:
            line = line.strip()
            idx, text = line.split(maxsplit=1)
            gt_texts[idx] = text.lower()
            
with open('gt_texts.json', 'w') as f:
    json.dump(gt_texts, f, ensure_ascii=False, indent=4)

In [None]:
#!g1.1
wer_results = {}

knn_wers = []
knn_cers = []

for filename, rec_text in recognitions['knn'].items():
    idx = filename.split('-', maxsplit=3)[-1].rsplit('-', maxsplit=2)[0]
    knn_wers.append(wer(gt_texts[idx], rec_text))
    knn_cers.append(wer(gt_texts[idx], rec_text))

for W in [1.0, 2.0, 4.0]:

    xnot_wavs = glob.glob(f'data/audios/xnot/w-{int(W)}/*.wav', recursive=True)

    xnot_wers = []
    xnot_cers = []

    for filename, rec_text in recognitions['xnot'].items():
        idx = filename.split('-', maxsplit=4)[-1].rsplit('-', maxsplit=2)[0]
        xnot_wers.append(wer(gt_texts[idx], rec_text))
        
    wer_results[int(W)] = (xnot_wers, xnot_cers)

In [None]:
#!g1.1
gt_wers = []
gt_cers = []

for filename, rec_text in recognitions['src'].items():
    idx = filename.split('/')[-1].split('.')[0]
    gt_wers.append(wer(gt_texts[idx], rec_text))
    gt_cers.append(cer(gt_texts[idx], rec_text))

In [None]:
#!g1.1
