In [1]:
import sys
sys.path.append('..')

import numpy as np
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm

In [2]:
from pathlib import Path
import torchaudio
import fairseq
import torch
device = torch.device('cuda')
device

device(type='cuda')

In [3]:
def get_mos_data(split):
    mos_list_file = f'../data/phase1-ood/DATA/sets/{split}_mos_list.txt'
    mos_data = {}
    for line in open(mos_list_file):
        elms = line.rstrip().split(',')
        if len(elms) == 2:
            file_id, mos = elms
            mos = float(mos)
            mos_data[file_id] = mos
        else:
            file_id = elms[0]
            mos_data[file_id] = 0
            
    return mos_data

In [4]:
train_mos_data = get_mos_data('train')
len(train_mos_data)

136

In [5]:
val_mos_data = get_mos_data('val')
len(val_mos_data)

136

In [6]:
unlabeled_mos_data = get_mos_data('unlabeled')
len(unlabeled_mos_data)

540

In [7]:
wav_dir = Path('../data/phase1-ood/DATA/wav/')


In [9]:
# fairseq_base_model = '../fairseq/w2v_large_lv_fsh_swbd_cv.pt'
fairseq_base_model = '../fairseq/xlsr_53_56k.pt'
model, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task([fairseq_base_model])
# model

In [10]:
ssl_model = model[0]

In [11]:
ssl_model.remove_pretraining_modules()
ssl_model.to(device)
ssl_model.eval();

In [12]:
def extract_mean(wavpath):
    with torch.no_grad():
        wav = torchaudio.load(wavpath)[0]
        res = ssl_model(wav.to(device), mask=False, features_only=True)
        return res['x'].squeeze(0).mean(dim=0)


In [13]:
out_dir = Path('../out/utt_data/w2v_xlsr')
import os
os.makedirs(out_dir, exist_ok=True)


In [14]:
val_vecs = []
val_moss = []

for key, mos in tqdm(sorted(val_mos_data.items())):
    wavpath = wav_dir / key
    vec = extract_mean(wavpath)
    outpath = out_dir / (wavpath.stem + '.npy')
    
    vec = vec.detach().cpu().numpy()
    np.save(outpath, vec)
    
    val_vecs.append(vec)
    val_moss.append(mos)

  0%|          | 0/136 [00:00<?, ?it/s]

In [15]:
train_vecs = []
train_moss = []

for key, mos in tqdm(sorted(train_mos_data.items())):
    wavpath = wav_dir / key
    vec = extract_mean(wavpath)
    outpath = out_dir / (wavpath.stem + '.npy')
    
    vec = vec.detach().cpu().numpy()
    np.save(outpath, vec)
    
    train_vecs.append(vec)
    train_moss.append(mos)

  0%|          | 0/136 [00:00<?, ?it/s]

In [16]:
unlabeled_vecs = []
unlabeled_moss = []

for key, mos in tqdm(sorted(unlabeled_mos_data.items())):
    wavpath = wav_dir / key
    vec = extract_mean(wavpath)
    outpath = out_dir / (wavpath.stem + '.npy')
    
    vec = vec.detach().cpu().numpy()
    np.save(outpath, vec)
    
    unlabeled_vecs.append(vec)
    unlabeled_moss.append(mos)

  0%|          | 0/540 [00:00<?, ?it/s]