In [1]:
import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "1"

cache_dir = "/data4/yoomcache"
model_cache_dir = os.path.join(cache_dir, 'huggingface')
data_cache_dir = os.path.join(cache_dir, 'datasets')
checkpoint_dir = os.path.join(cache_dir, 'checkpoint')

import torch
from datasets import load_dataset, load_metric
import math
from itertools import groupby

import wandb
wandb.init(project="testing-wav2vec2gpt", entity="yoom-private")

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33myoom-private[0m (use `wandb login --relogin` to force relogin)


In [2]:
# %reload_ext autoreload
# %autoreload 2
from wav2vec2GPTwCTC import *
from configuration_wav2vec2gpt import Wav2Vec2GPTConfig

from transformers import Wav2Vec2FeatureExtractor
from transformers import GPT2Tokenizer, AddedToken
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments

In [3]:
wav2vec_pretrained = "facebook/wav2vec2-base"
gpt_pretrained = "gpt2"

# Should aware that pad_token_id is used to compute CTC loss, 
# so pad_token configuration for both tokenizer and model should be the same
args = {
    'pad_token': 'Ġ', 'pad_token_id': 220,
    'unk_token': 'Ġ', 'unk_token_id': 220,
    # 'pad_token': "<|endoftext|>", 'pad_token_id': 50256,
    # 'unk_token': "<|endoftext|>", 'unk_token_id': 50256,
    'bos_token': "<|endoftext|>", 'bos_token_id': 50256,
    'eos_token': "<|endoftext|>", 'eos_token_id': 50256,
    
    'n_positions': 128,
    
    'add_adapter': True,
    'adapter_kernel_size': 3, 
    'adapter_stride': 2,
    'num_adapter_layers': 3,
}


config = Wav2Vec2GPTConfig(**args)

In [4]:
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(wav2vec_pretrained, 
                                                             cache_dir=model_cache_dir,
                                                             **args)

tokenizer = GPT2Tokenizer.from_pretrained(gpt_pretrained,
                                          cache_dir=model_cache_dir,
                                          **args)

In [5]:
model = Wav2Vec2GPTModel(config=config)

model.wav2vec2.from_pretrained(wav2vec_pretrained, cache_dir=model_cache_dir)
model.transformer.from_pretrained(gpt_pretrained, cache_dir=model_cache_dir)


# device_map = {
#     0: [0, 1, 2, 3, 4,],
#     2: [5, 6, 7, 8, 9, 10, 11, ],
# }
# model.gpt2lm.parallelize(device_map)


model.freeze_feature_extractor()
model.freeze_feature_projection()
# model.freeze_wav2vec_encoder() # not exists here
model.unfreeze_wav2vec_adapter()
model.unfreeze_rnn_compressor()
model.freeze_gpt_decoder()
model.unfreeze_lm_head()

Some weights of the model checkpoint at facebook/wav2vec2-base were not used when initializing Wav2Vec2Model2: ['wav2vec2.encoder.layers.10.attention.k_proj.bias', 'wav2vec2.encoder.layers.5.attention.k_proj.bias', 'wav2vec2.encoder.layers.1.feed_forward.intermediate_dense.weight', 'wav2vec2.encoder.layers.1.feed_forward.output_dense.weight', 'wav2vec2.encoder.layers.10.final_layer_norm.weight', 'wav2vec2.encoder.layers.0.final_layer_norm.weight', 'wav2vec2.encoder.layers.7.final_layer_norm.weight', 'wav2vec2.encoder.layers.8.attention.k_proj.bias', 'wav2vec2.encoder.layers.11.attention.v_proj.weight', 'project_hid.weight', 'wav2vec2.encoder.layers.9.attention.out_proj.weight', 'wav2vec2.encoder.layers.6.attention.q_proj.weight', 'wav2vec2.encoder.layers.8.feed_forward.intermediate_dense.weight', 'wav2vec2.encoder.layer_norm.weight', 'wav2vec2.encoder.layers.10.final_layer_norm.bias', 'wav2vec2.encoder.layers.9.final_layer_norm.weight', 'wav2vec2.encoder.layers.11.final_layer_norm.bias

In [6]:
dataset = load_dataset("hf-internal-testing/librispeech_asr_demo", "clean", 
                       split="validation", 
                       cache_dir=data_cache_dir
                      )

dataset = dataset.sort("id")
sampling_rate = dataset.features["audio"].sampling_rate
audio_inputs = [d["audio"]["array"] for d in dataset]

print(dataset, sampling_rate, len(audio_inputs))

Reusing dataset librispeech_asr (/data4/yoomcache/datasets/hf-internal-testing___librispeech_asr/clean/2.1.0/d3bc4c2bc2078fcde3ad0f0f635862e4c0fef78ba94c4a34c4c250a097af240b)
Loading cached sorted indices for dataset at /data4/yoomcache/datasets/hf-internal-testing___librispeech_asr/clean/2.1.0/d3bc4c2bc2078fcde3ad0f0f635862e4c0fef78ba94c4a34c4c250a097af240b/cache-2f7c0cbee6ef3aa1.arrow


Dataset({
    features: ['file', 'audio', 'text', 'speaker_id', 'chapter_id', 'id'],
    num_rows: 73
}) 16000 73


In [7]:
# text_inputs = dataset["text"]
from example.librispeech_asr_demo import text_inputs

In [8]:
class CustomDataset(torch.utils.data.Dataset):
    def __init__(self, input_values, tokenized_output):
        self.input_values = input_values
        self.tokenized_output = tokenized_output

    def __getitem__(self, idx):
        item = dict()
        item['input_values'] = self.input_values['input_values'][idx]
        item['labels'] = self.tokenized_output['input_ids'][idx]
        item['output_attention_mask'] = self.tokenized_output['attention_mask'][idx]
        return item

    def __len__(self):
        return len(self.input_values['input_values'])

    
input_values = feature_extractor(audio_inputs, 
                                      sampling_rate=sampling_rate,
                                      return_tensors="pt",
                                      padding='longest',
                                     )

tokenized_output = tokenizer(text_inputs,
                             return_tensors="pt",
                             padding='longest',
                         )

train_dataset = CustomDataset(input_values, tokenized_output)
# val_dataset = CustomDataset(input_values, tokenized_output)
# test_dataset = CustomDataset(input_values, tokenized_output)

In [9]:
# # load rouge for validation
# rouge = load_metric("rouge")

# def compute_metrics(pred):
#     labels_ids = pred.label_ids
#     pred_ids = pred.predictions

#     # all unnecessary tokens are removed
#     pred_str = decoder_tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
#     labels_ids[labels_ids == -100] = decoder_tokenizer.eos_token_id
#     label_str = decoder_tokenizer.batch_decode(labels_ids, skip_special_tokens=True)

#     rouge_output = rouge.compute(predictions=pred_str, references=label_str, rouge_types=["rouge2"])["rouge2"].mid

#     return {
#         "rouge2_precision": round(rouge_output.precision, 4),
#         "rouge2_recall": round(rouge_output.recall, 4),
#         "rouge2_fmeasure": round(rouge_output.fmeasure, 4),
#     }

In [10]:
batch_size = 3
steps_per_epoch = math.ceil(len(train_dataset) / batch_size)


# set training arguments - these params are not really tuned, feel free to change
training_args = Seq2SeqTrainingArguments(
#     predict_with_generate=True,
    output_dir=os.path.join(checkpoint_dir, "wav2vec2gpt/unfreeze-rnn"),
    # do_train=True,
    # do_eval=False,
    # do_predict=True,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size * 5,
    learning_rate=1e-4, 
    weight_decay=0.0, adam_beta1=0.9, adam_beta2=0.999, adam_epsilon=1e-08, max_grad_norm=1.0,
    num_train_epochs=200,
    max_steps=-1,
    # lr_scheduler_type='linear', warmup_ratio=0.0, 
    
    logging_strategy='steps',
    save_strategy='steps',
    evaluation_strategy='steps',
    logging_steps=1 * steps_per_epoch,
    save_steps=2 * steps_per_epoch,
    eval_steps=1 * steps_per_epoch,
    warmup_steps=10 * steps_per_epoch,
    save_total_limit=10,
    overwrite_output_dir=True,
)

In [None]:
# instantiate trainer
trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
#     compute_metrics=compute_metrics,
    train_dataset=train_dataset,
    eval_dataset=train_dataset,
)


# start training
trainer.train()

***** Running training *****
  Num examples = 73
  Num Epochs = 200
  Instantaneous batch size per device = 3
  Total train batch size (w. parallel, distributed & accumulation) = 3
  Gradient Accumulation steps = 1
  Total optimization steps = 5000
Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"


Step,Training Loss,Validation Loss
25,3778.2278,18605.064453
50,3702.2463,17605.548828
75,3548.1747,15933.613281
100,3340.7556,14062.597656
125,2993.7131,11564.162109
150,2479.2923,8303.015625
175,1736.3438,4069.257812
200,800.6461,2689.722168
225,499.2102,2068.931885
250,517.0004,1997.456543


***** Running Evaluation *****
  Num examples = 73
  Batch size = 15
***** Running Evaluation *****
  Num examples = 73
  Batch size = 15
Saving model checkpoint to /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-rnn/checkpoint-50
Configuration saved in /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-rnn/checkpoint-50/config.json
Model weights saved in /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-rnn/checkpoint-50/pytorch_model.bin
Deleting older checkpoint [/data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-rnn/checkpoint-4550] due to args.save_total_limit
***** Running Evaluation *****
  Num examples = 73
  Batch size = 15
***** Running Evaluation *****
  Num examples = 73
  Batch size = 15
Saving model checkpoint to /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-rnn/checkpoint-100
Configuration saved in /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-rnn/checkpoint-100/config.json
Model weights saved in /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-rnn/checkpoint-100/pytor

Configuration saved in /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-rnn/checkpoint-750/config.json
Model weights saved in /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-rnn/checkpoint-750/pytorch_model.bin
Deleting older checkpoint [/data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-rnn/checkpoint-250] due to args.save_total_limit
***** Running Evaluation *****
  Num examples = 73
  Batch size = 15
***** Running Evaluation *****
  Num examples = 73
  Batch size = 15
Saving model checkpoint to /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-rnn/checkpoint-800
Configuration saved in /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-rnn/checkpoint-800/config.json
Model weights saved in /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-rnn/checkpoint-800/pytorch_model.bin
Deleting older checkpoint [/data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-rnn/checkpoint-300] due to args.save_total_limit
***** Running Evaluation *****
  Num examples = 73
  Batch size = 15
***** Running Evaluation

In [None]:
wandb.finish()

In [None]:
# example


BATCH_SIZE = 8
i = 3
device = 'cuda:0'


audio_batch = audio_inputs[i*BATCH_SIZE:i*BATCH_SIZE+BATCH_SIZE]
audio_feature_batch = feature_extractor(audio_batch, 
                                      sampling_rate=sampling_rate,
                                      return_tensors="pt",
                                      padding='longest',
                                     ).input_values
print(audio_feature_batch.size())


text_batch = text_inputs[i*BATCH_SIZE:i*BATCH_SIZE+BATCH_SIZE]

text_tokens_batch = tokenizer(text_batch, 
                              return_tensors="pt",
                              padding='max_length',
                              max_length=train_dataset.tokenized_output['input_ids'].shape[1]
                             )
print(text_tokens_batch['attention_mask'].size())

with torch.no_grad():
    audio_embedding = model(input_values=audio_feature_batch.to(device), 
                            labels=text_tokens_batch['input_ids'].to(device),
                            output_attention_mask=text_tokens_batch['attention_mask'].to(device),)
print(audio_embedding.logits.shape)

pred_ids = torch.argmax(audio_embedding.logits, axis=-1)
print(pred_ids.size())
print()

for idx in range(BATCH_SIZE):
    print(text_batch[idx])
    print(tokenizer.decode([key for key, _group in groupby(pred_ids[idx])]))
    print()

In [None]:
import IPython

IPython.display.Audio(dataset[listen_idx-6]['audio']['path'])

In [None]:
tokenizer.pad_token_id = 220
