In [1]:
import os
import pickle
import random
from tqdm import tqdm
import numpy as np
import torch

from datasets import load_dataset, load_metric
import math
from itertools import groupby

import wandb

os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "1,0"
os.environ["WANDB_DISABLED"] = "true"

cache_dir = "/data4/yoomcache"
model_cache_dir = os.path.join(cache_dir, 'huggingface')
data_cache_dir = os.path.join(cache_dir, 'datasets')
checkpoint_dir = os.path.join(cache_dir, 'checkpoint')

seed = 0
random.seed(0)
np.random.seed(seed)
torch.manual_seed(seed)


<torch._C.Generator at 0x7f86bc0a7bf0>

In [2]:
# %reload_ext autoreload
# %autoreload 2

from transformers import Wav2Vec2FeatureExtractor, Wav2Vec2Model

In [3]:
wav2vec_pretrained = "facebook/wav2vec2-base"

In [4]:
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(wav2vec_pretrained, 
                                                             cache_dir=model_cache_dir,)

In [5]:
with open('/data4/TTS/VCTK-Corpus/dataset-vctk-16k.pkl', 'rb') as f:
    dataset = pickle.load(f)
del dataset['page'], dataset['index'], dataset['audio_path']



# dataset_size = len(dataset['text'])
dataset_size = int(len(dataset['text']) * 0.01)

max_audio_length = 0
for arr in dataset['audio_array']:
    if len(arr) > max_audio_length:
        max_audio_length = len(arr)
print(max_audio_length)


for idx in tqdm(range(dataset_size)):
    dataset['audio_array'][idx] = feature_extractor(dataset['audio_array'][idx], 
                                                    sampling_rate=dataset['sample_rate'],
                                                    return_tensors="pt",
                                                    padding='max_length',
                                                    max_length=max_audio_length
                                                    ).input_values[0]
del dataset['audio_array'][dataset_size:]
print(len(dataset['audio_array']))

308533


100%|████████████████████████████████████████| 440/440 [00:03<00:00, 144.42it/s]

440





In [6]:
model_wav2vec = Wav2Vec2Model.from_pretrained(wav2vec_pretrained,
                                             cache_dir=model_cache_dir)

Some weights of the model checkpoint at facebook/wav2vec2-base were not used when initializing Wav2Vec2Model: ['project_hid.weight', 'quantizer.codevectors', 'project_q.bias', 'project_q.weight', 'quantizer.weight_proj.weight', 'project_hid.bias', 'quantizer.weight_proj.bias']
- This IS expected if you are initializing Wav2Vec2Model from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing Wav2Vec2Model from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [7]:
##### example


BATCH_SIZE = 4
i = 4
device = 'cuda:0'
device = 'cpu'
batch_idx = range(i*BATCH_SIZE, i*BATCH_SIZE+BATCH_SIZE)

audio_feature_batch = list()
for idx in batch_idx:
    audio_feature_batch.append(dataset['audio_array'][idx])
audio_feature_batch = torch.stack(audio_feature_batch)
print(audio_feature_batch.size())



with torch.no_grad():
    audio_embedding = model_wav2vec(input_values=audio_feature_batch.to(device), )
    
print(audio_embedding.last_hidden_state.shape)


torch.Size([4, 308533])
torch.Size([4, 963, 768])
