In [1]:
import os
import pickle
import random
from tqdm import tqdm
import numpy as np
import torch

from datasets import load_dataset, load_metric
import math
from itertools import groupby

import wandb

os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"

cache_dir = "/data4/yoomcache"
model_cache_dir = os.path.join(cache_dir, 'huggingface')
data_cache_dir = os.path.join(cache_dir, 'datasets')
checkpoint_dir = os.path.join(cache_dir, 'checkpoint')

seed = 0
random.seed(0)
np.random.seed(seed)
torch.manual_seed(seed)

wandb.init(project="testing-wav2vec2gpt", entity="yoom-private")

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33myoom-private[0m (use `wandb login --relogin` to force relogin)


In [2]:
# %reload_ext autoreload
# %autoreload 2
from wav2vec2GPTwCTC import *
from configuration_wav2vec2gpt import Wav2Vec2GPTConfig

from transformers import Wav2Vec2FeatureExtractor
from transformers import GPT2Tokenizer, AddedToken
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments

In [3]:
wav2vec_pretrained = "facebook/wav2vec2-base"
gpt_pretrained = "gpt2"

# Should aware that pad_token_id is used to compute CTC loss, 
# so pad_token configuration for both tokenizer and model should be the same
args = {
#     'pad_token': 'Ġ', 'pad_token_id': 220,
#     'unk_token': 'Ġ', 'unk_token_id': 220,
    'pad_token': "<|endoftext|>", 'pad_token_id': 50256,
    'unk_token': "<|endoftext|>", 'unk_token_id': 50256,
    'bos_token': "<|endoftext|>", 'bos_token_id': 50256,
    'eos_token': "<|endoftext|>", 'eos_token_id': 50256,
    
    'n_positions': 128, # VCTK: 42, 
    
    'add_adapter': True,
    'adapter_kernel_size': 6, 
    'adapter_stride': 2,
    'num_adapter_layers': 3,
}


config = Wav2Vec2GPTConfig(**args)

In [4]:
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(wav2vec_pretrained, 
                                                             cache_dir=model_cache_dir,
                                                             **args)

tokenizer = GPT2Tokenizer.from_pretrained(gpt_pretrained,
                                          cache_dir=model_cache_dir,
                                          **args)

In [5]:
if not os.path.exists('/data4/TTS/VCTK-Corpus/dataset-vctk-16k(preprocessed).pkl'):
    with open('/data4/TTS/VCTK-Corpus/dataset-vctk-16k.pkl', 'rb') as f:
        dataset = pickle.load(f)
    del dataset['page'], dataset['index'], dataset['audio_path']


    max_audio_length = 0
    for arr in dataset['audio_array']:
        if len(arr) > max_audio_length:
            max_audio_length = len(arr)
    print(max_audio_length)


    for idx in tqdm(range(len(dataset['audio_array']))):
        dataset['audio_array'][idx] = feature_extractor(dataset['audio_array'][idx], 
                                                        sampling_rate=dataset['sample_rate'],
                                                        return_tensors="pt",
                                                        padding='max_length',
                                                        max_length=max_audio_length
                                                        ).input_values[0]
    dataset['audio_array'] = torch.stack(dataset['audio_array'])
    print(dataset['audio_array'].shape)


    with open('/data4/TTS/VCTK-Corpus/dataset-vctk-16k(preprocessed).pkl', 'wb') as f:
        pickle.dump(dataset, f)
        
        
else:
    with open('/data4/TTS/VCTK-Corpus/dataset-vctk-16k(preprocessed).pkl', 'rb') as f:
        dataset = pickle.load(f)
    print(dataset['audio_array'].shape)

torch.Size([44070, 308533])


In [6]:
dataset['text'] = tokenizer(dataset['text'],
                            return_tensors="pt",
                            # padding='longest', # VCTK: 42,
                            padding='max_length',
                            max_length=args['n_positions']
                            )
print(dataset['text']['attention_mask'].shape)

torch.Size([44070, 128])


In [7]:
split_ratio = (0.8, 0.9)
dataset_size = dataset['text']['attention_mask'].shape[0]
indices = np.arange(dataset_size)
np.random.shuffle(indices)

train_idx = indices[:int(dataset_size * split_ratio[0])]
val_idx = indices[int(dataset_size * split_ratio[0]):int(dataset_size * split_ratio[1])]
test_idx = indices[int(dataset_size * split_ratio[1]):]

In [8]:
class CustomDataset(torch.utils.data.Dataset):
    def __init__(self, input_values, tokenized_output, indices):
        self.input_values = input_values
        self.tokenized_output = tokenized_output
        self.indices = indices

    def __getitem__(self, idx):
        item = dict()
        item['input_values'] = self.input_values[self.indices[idx]]
        item['labels'] = self.tokenized_output['input_ids'][self.indices[idx]]
        item['output_attention_mask'] = self.tokenized_output['attention_mask'][self.indices[idx]]
        return item

    def __len__(self):
        return len(self.indices)

    

train_dataset = CustomDataset(dataset['audio_array'], dataset['text'], train_idx)
val_dataset = CustomDataset(dataset['audio_array'], dataset['text'], val_idx)
test_dataset = CustomDataset(dataset['audio_array'], dataset['text'], test_idx)

In [9]:
model = Wav2Vec2GPTModel(config=config)

model.wav2vec2.from_pretrained(wav2vec_pretrained, cache_dir=model_cache_dir)
model.transformer.from_pretrained(gpt_pretrained, cache_dir=model_cache_dir)


# device_map = {
#     0: [0, 1, 2, 3, 4,],
#     2: [5, 6, 7, 8, 9, 10, 11, ],
# }
# model.gpt2lm.parallelize(device_map)


model.freeze_feature_extractor()
model.freeze_feature_projection()
# model.freeze_wav2vec_encoder() # not exists here
model.unfreeze_wav2vec_adapter()
model.unfreeze_rnn_compressor()
model.freeze_gpt_decoder()
model.unfreeze_lm_head()

Some weights of the model checkpoint at facebook/wav2vec2-base were not used when initializing Wav2Vec2Model2: ['wav2vec2.encoder.layers.9.feed_forward.intermediate_dense.bias', 'wav2vec2.encoder.layers.9.attention.q_proj.weight', 'wav2vec2.encoder.layers.7.attention.out_proj.weight', 'wav2vec2.encoder.layers.0.final_layer_norm.weight', 'wav2vec2.encoder.layers.6.feed_forward.output_dense.bias', 'wav2vec2.encoder.layers.8.attention.out_proj.weight', 'wav2vec2.encoder.layers.11.layer_norm.bias', 'wav2vec2.encoder.layers.11.layer_norm.weight', 'wav2vec2.encoder.layers.0.feed_forward.output_dense.weight', 'wav2vec2.encoder.layers.7.attention.k_proj.bias', 'wav2vec2.encoder.layers.0.layer_norm.weight', 'wav2vec2.encoder.layers.4.attention.q_proj.bias', 'wav2vec2.encoder.layers.4.final_layer_norm.bias', 'wav2vec2.encoder.layers.6.attention.q_proj.weight', 'wav2vec2.encoder.layers.9.feed_forward.output_dense.bias', 'wav2vec2.encoder.layers.6.attention.k_proj.bias', 'wav2vec2.encoder.layers.6

In [10]:
# # load rouge for validation
# rouge = load_metric("rouge")

# def compute_metrics(pred):
#     labels_ids = pred.label_ids
#     pred_ids = pred.predictions

#     # all unnecessary tokens are removed
#     pred_str = decoder_tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
#     labels_ids[labels_ids == -100] = decoder_tokenizer.eos_token_id
#     label_str = decoder_tokenizer.batch_decode(labels_ids, skip_special_tokens=True)

#     rouge_output = rouge.compute(predictions=pred_str, references=label_str, rouge_types=["rouge2"])["rouge2"].mid

#     return {
#         "rouge2_precision": round(rouge_output.precision, 4),
#         "rouge2_recall": round(rouge_output.recall, 4),
#         "rouge2_fmeasure": round(rouge_output.fmeasure, 4),
#     }

In [11]:
batch_size = 24
steps_per_epoch = math.ceil(len(train_dataset) / batch_size)


# set training arguments - these params are not really tuned, feel free to change
training_args = Seq2SeqTrainingArguments(
#     predict_with_generate=True,
    output_dir=os.path.join(checkpoint_dir, "wav2vec2gpt/unfreeze-adapter-rnn-lm"),
    # do_train=True,
    # do_eval=False,
    # do_predict=True,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    learning_rate=1e-4, 
    weight_decay=0.0, adam_beta1=0.9, adam_beta2=0.999, adam_epsilon=1e-08, max_grad_norm=1.0,
    num_train_epochs=100,
    max_steps=-1,
    lr_scheduler_type='cosine', 
    # warmup_ratio=0.0, 
    
    logging_strategy='steps',
    save_strategy='steps',
    evaluation_strategy='steps',
    logging_steps=1 * steps_per_epoch,
    save_steps=2 * steps_per_epoch,
    eval_steps=1 * steps_per_epoch,
    warmup_steps=10 * steps_per_epoch,
    save_total_limit=10,
    overwrite_output_dir=True,
)

In [None]:
# instantiate trainer
trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
#     compute_metrics=compute_metrics,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
)


# start training
trainer.train()

***** Running training *****
  Num examples = 35256
  Num Epochs = 100
  Instantaneous batch size per device = 24
  Total train batch size (w. parallel, distributed & accumulation) = 48
  Gradient Accumulation steps = 1
  Total optimization steps = 73500
Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"
  result = _VF.gru(input, hx, self._flat_weights, self.bias, self.num_layers,


Step,Training Loss,Validation Loss
1469,3.3272,0.48041
2938,0.4639,0.459413


***** Running Evaluation *****
  Num examples = 4407
  Batch size = 48
***** Running Evaluation *****
  Num examples = 4407
  Batch size = 48
Saving model checkpoint to /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-adapter-rnn-lm/checkpoint-2938
Configuration saved in /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-adapter-rnn-lm/checkpoint-2938/config.json
Model weights saved in /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-adapter-rnn-lm/checkpoint-2938/pytorch_model.bin
Deleting older checkpoint [/data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-adapter-rnn-lm/checkpoint-82] due to args.save_total_limit
  result = _VF.gru(input, hx, self._flat_weights, self.bias, self.num_layers,


In [None]:
wandb.finish()

In [None]:
##### example


BATCH_SIZE = 16
i = 3
device = 'cuda:0'
batch_idx = test_idx[i*BATCH_SIZE:i*BATCH_SIZE+BATCH_SIZE]

audio_feature_batch = dataset['audio_array'][batch_idx]
print(audio_feature_batch.size())

label_batch = dataset['text']['input_ids'][batch_idx]
attention_batch = dataset['text']['attention_mask'][batch_idx]

print(label_batch.size())

with torch.no_grad():
    audio_embedding = model(input_values=audio_feature_batch.to(device), 
                            labels=label_batch.to(device),
                            output_attention_mask=attention_batch.to(device),)
print(audio_embedding.logits.shape)

pred_ids = torch.argmax(audio_embedding.logits, axis=-1)
print(pred_ids.size())
print()

for idx in range(BATCH_SIZE):
    print(tokenizer.decode(label_batch[idx]).replace('<|endoftext|>',''))
    print(tokenizer.decode([key for key, _group in groupby(pred_ids[idx])]))
    print(tokenizer.decode([key for key, _group in groupby(pred_ids[idx])]).replace('<|endoftext|>',''))
    print()

In [None]:
# import IPython

# IPython.display.Audio(dataset[4]['audio']['path'])

In [None]:
model.wav2vec2