In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
cd /content/drive/MyDrive/00-Scholar/05-wav2vec2gpt/

/content/drive/MyDrive/00-Scholar/05-wav2vec2gpt


In [3]:
!pip install datasets wandb transformers rouge_score

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting datasets
  Downloading datasets-2.6.1-py3-none-any.whl (441 kB)
[K     |████████████████████████████████| 441 kB 4.6 MB/s 
[?25hCollecting wandb
  Downloading wandb-0.13.5-py2.py3-none-any.whl (1.9 MB)
[K     |████████████████████████████████| 1.9 MB 82.0 MB/s 
[?25hCollecting transformers
  Downloading transformers-4.24.0-py3-none-any.whl (5.5 MB)
[K     |████████████████████████████████| 5.5 MB 75.9 MB/s 
[?25hCollecting rouge_score
  Downloading rouge_score-0.1.2.tar.gz (17 kB)
Collecting huggingface-hub<1.0.0,>=0.2.0
  Downloading huggingface_hub-0.10.1-py3-none-any.whl (163 kB)
[K     |████████████████████████████████| 163 kB 84.0 MB/s 
Collecting dill<0.3.6
  Downloading dill-0.3.5.1-py2.py3-none-any.whl (95 kB)
[K     |████████████████████████████████| 95 kB 6.2 MB/s 
Collecting responses<0.19
  Downloading responses-0.18.0-py3-none-any.whl (38 kB)
Collecting xxh

In [4]:
import os
import pickle
import random
from tqdm import tqdm
import numpy as np
import torch

from datasets import load_dataset, load_metric
import math
from itertools import groupby

import wandb

os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
# os.environ["WANDB_DISABLED"] = "true"

cache_dir = "./caches/"
model_cache_dir = os.path.join(cache_dir, 'huggingface')
data_cache_dir = os.path.join(cache_dir, 'datasets')
checkpoint_dir = os.path.join(cache_dir, 'checkpoint')

seed = 0
random.seed(0)
np.random.seed(seed)
torch.manual_seed(seed)

import IPython
import seaborn as sns
sns.set(rc = {'figure.figsize':(16,8)})

In [5]:
import sys
sys.path.append('/content/drive/MyDrive/00-Scholar/05-wav2vec2gpt/src')

from wav2vec2GPTwCTC import *
from configuration_wav2vec2gpt import Wav2Vec2GPTConfig

from transformers import Wav2Vec2FeatureExtractor, GPT2Model
from transformers import GPT2Tokenizer, AddedToken
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments

In [6]:
wav2vec_pretrained = "facebook/wav2vec2-base"
# wav2vec_pretrained = "facebook/wav2vec2-base-960h"
gpt_pretrained = "gpt2"

# Should aware that pad_token_id is used to compute CTC loss, 
# so pad_token configuration for both tokenizer and model should be the same
args = {
    
    # 'pad_token': 'Ġ', 'pad_token_id': 220,
    # 'unk_token': 'Ġ', 'unk_token_id': 220,
    'pad_token': "<|endoftext|>", 'pad_token_id': 50256,
    'unk_token': "<|endoftext|>", 'unk_token_id': 50256,
    'bos_token': "<|endoftext|>", 'bos_token_id': 50256,
    'eos_token': "<|endoftext|>", 'eos_token_id': 50256,
    
    
    'add_adapter': True,
    'output_hidden_size': 128,
    'num_adapter_layers': 3,
    'adapter_kernel_size': [4, 3, 3, 4], 
    'adapter_stride':      [2, 2, 1, 1],
    'adapter_padding':     [0, 0, 0, 0],
    'adapter_bias': True,
    
    # 'add_adapter': True,
    # 'output_hidden_size': 256,
    # 'num_adapter_layers': 2,
    # 'adapter_kernel_size': [4, 5,], 
    # 'adapter_stride':      [2, 2,],
    # 'adapter_padding':     [0, 0,],
    # 'adapter_bias': False,
    
    
}


config = Wav2Vec2GPTConfig(**args)

In [7]:
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(wav2vec_pretrained, 
                                                             cache_dir=model_cache_dir,
                                                             **args)

tokenizer = GPT2Tokenizer.from_pretrained(gpt_pretrained,
                                          cache_dir=model_cache_dir,
                                          **args)

In [8]:
# with open('./data/dataset-vctk-16k.pkl', 'rb') as f:
with open('./data/dev-clean.pkl', 'rb') as f:
    dataset = pickle.load(f)

print('entire dataset length: {}'.format(len(dataset['text'])))

entire dataset length: 5736


In [9]:
idx = 11
print('text example: {}'.format(dataset['text'][idx]))
print('normalized text example: {}'.format(dataset['normalized_text'][idx]))
IPython.display.Audio(dataset['audio_array'][idx], rate=16000)

text example: It cried aloud that eternity was very long, and like a great palace without a quiet room.
normalized text example: It cried aloud that eternity was very long, and like a great palace without a quiet room.


In [10]:
for k in dataset.keys():
    if k == 'sample_rate': continue
    del dataset[k][int(len(dataset[k]) * 0.2):]
dataset_size = len(dataset['text'])

max_audio_length = 0
for arr in dataset['audio_array'][:dataset_size]:
    if len(arr) > max_audio_length:
        max_audio_length = len(arr)
print('maximum audio length: {}'.format(max_audio_length))


for idx in tqdm(range(dataset_size)):
    dataset['audio_array'][idx] = feature_extractor(dataset['audio_array'][idx], 
                                                    sampling_rate=dataset['sample_rate'],
                                                    return_tensors="pt",
                                                    padding='max_length',
                                                    max_length=max_audio_length
                                                    ).input_values[0]

maximum audio length: 439360


100%|██████████| 1147/1147 [00:03<00:00, 311.86it/s]


In [11]:
dataset['text'] = tokenizer(dataset['text'][:dataset_size],
                            return_tensors="pt",
                            padding='longest'
                            # padding='max_length',
                            # max_length=max_text_length,
                            )
max_text_length = dataset['text']['attention_mask'][1] # VCTK: 42, dev-clean: 92

print(dataset['text']['attention_mask'].shape)

torch.Size([1147, 92])


In [12]:
split_ratio = (0.8, 0.9)
indices = np.arange(dataset_size)
np.random.shuffle(indices)

train_idx = indices[:int(dataset_size * split_ratio[0])]
val_idx = indices[int(dataset_size * split_ratio[0]):int(dataset_size * split_ratio[1])]
test_idx = indices[int(dataset_size * split_ratio[1]):]

In [13]:
class CustomDataset(torch.utils.data.Dataset):
    def __init__(self, input_values, tokenized_output, indices):
        self.input_values = input_values
        self.tokenized_output = tokenized_output
        self.indices = indices

    def __getitem__(self, idx):
        item = dict()
        item['input_values'] = self.input_values[self.indices[idx]]
        item['labels'] = self.tokenized_output['input_ids'][self.indices[idx]]
        item['output_attention_mask'] = self.tokenized_output['attention_mask'][self.indices[idx]]
        return item

    def __len__(self):
        return len(self.indices)

    

train_dataset = CustomDataset(dataset['audio_array'], dataset['text'], train_idx)
val_dataset = CustomDataset(dataset['audio_array'], dataset['text'], val_idx)
test_dataset = CustomDataset(dataset['audio_array'], dataset['text'], test_idx)

In [14]:
model = Wav2Vec2GPTModel(config=config)

model.wav2vec2 = Wav2Vec2Model2.from_pretrained(wav2vec_pretrained, cache_dir=model_cache_dir, config=config)
model.transformer = GPT2Model.from_pretrained(gpt_pretrained, cache_dir=model_cache_dir, config=config)
model.lm_head.weight = model.transformer.wte.weight

# model.transformer.h = model.transformer.h[-4:]


# for layer in model.wav2vec2.adapter.layers:
#     layer.conv.bias.data = torch.zeros_like(layer.conv.bias.data)
# model.rnn_compressor.bias.data = model.transformer.wte.weight[config.pad_token_id]

# device_map = {
#     0: [0, 1, 2, 3, 4,],
#     1: [5, 6, 7, 8, 9, 10, 11, ],
# }
# model.parallelize(device_map)

print(model)

Some weights of the model checkpoint at facebook/wav2vec2-base were not used when initializing Wav2Vec2Model2: ['wav2vec2.encoder.layers.5.attention.k_proj.weight', 'wav2vec2.encoder.layers.6.layer_norm.bias', 'wav2vec2.encoder.layers.9.attention.v_proj.bias', 'wav2vec2.encoder.layers.5.attention.out_proj.bias', 'wav2vec2.encoder.layers.1.feed_forward.intermediate_dense.weight', 'wav2vec2.encoder.layers.8.final_layer_norm.bias', 'wav2vec2.encoder.layers.11.attention.v_proj.bias', 'wav2vec2.encoder.layers.2.attention.q_proj.bias', 'wav2vec2.encoder.layers.2.attention.v_proj.weight', 'wav2vec2.encoder.layers.10.attention.v_proj.bias', 'wav2vec2.encoder.layers.3.attention.k_proj.weight', 'wav2vec2.encoder.layers.8.feed_forward.intermediate_dense.bias', 'wav2vec2.encoder.layers.11.attention.k_proj.bias', 'wav2vec2.encoder.layers.4.feed_forward.intermediate_dense.weight', 'wav2vec2.encoder.layers.4.final_layer_norm.weight', 'wav2vec2.encoder.layers.1.layer_norm.bias', 'wav2vec2.encoder.laye

Wav2Vec2GPTModel(
  (wav2vec2): Wav2Vec2Model2(
    (feature_extractor): Wav2Vec2FeatureEncoder(
      (conv_layers): ModuleList(
        (0): Wav2Vec2GroupNormConvLayer(
          (conv): Conv1d(1, 512, kernel_size=(10,), stride=(5,), bias=False)
          (activation): GELUActivation()
          (layer_norm): GroupNorm(512, 512, eps=1e-05, affine=True)
        )
        (1): Wav2Vec2NoLayerNormConvLayer(
          (conv): Conv1d(512, 512, kernel_size=(3,), stride=(2,), bias=False)
          (activation): GELUActivation()
        )
        (2): Wav2Vec2NoLayerNormConvLayer(
          (conv): Conv1d(512, 512, kernel_size=(3,), stride=(2,), bias=False)
          (activation): GELUActivation()
        )
        (3): Wav2Vec2NoLayerNormConvLayer(
          (conv): Conv1d(512, 512, kernel_size=(3,), stride=(2,), bias=False)
          (activation): GELUActivation()
        )
        (4): Wav2Vec2NoLayerNormConvLayer(
          (conv): Conv1d(512, 512, kernel_size=(3,), stride=(2,), bias=Fal

In [15]:
# print(model.wav2vec2.feature_extractor.conv_layers[0].conv.weight)
# print()
# print(model.transformer.wpe.weight)
# print()
# print(model.transformer.wte.weight)
# print()
# print(model.lm_head.weight)

In [16]:
model.freeze_feature_extractor()
### model.freeze_feature_projection() # not exists here
### model.freeze_wav2vec_encoder() # not exists here
# model.freeze_wav2vec_adapter()
# model.freeze_rnn_compressor()
model.freeze_gpt_decoder()
model.freeze_lm_head()

In [17]:
# load rouge for validation
rouge = load_metric("rouge")

def compute_metrics(pred):
    labels_ids = pred.label_ids

    pred_ids = pred.predictions[1].argmax(axis=-1)
    pred_ids = [[key for key, _group in groupby(i)] for i in pred_ids]  # only distinct ones

    # all unnecessary tokens are removed
    pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
    label_str = tokenizer.batch_decode(labels_ids, skip_special_tokens=True)

    rouge_output = rouge.compute(predictions=pred_str, references=label_str, rouge_types=["rouge2"])['rouge2'].mid

    print(rouge_output)


    
    return {
        "rouge2_precision": round(rouge_output.precision, 4),
        "rouge2_recall": round(rouge_output.recall, 4),
        "rouge2_fmeasure": round(rouge_output.fmeasure, 4),
    }

  


Downloading builder script:   0%|          | 0.00/2.16k [00:00<?, ?B/s]

In [18]:
batch_size = 4
steps_per_epoch = math.ceil(len(train_dataset) / batch_size)


# set training arguments - these params are not really tuned, feel free to change
training_args = Seq2SeqTrainingArguments(
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    learning_rate=1e-4, 
    weight_decay=0.0, adam_beta1=0.9, adam_beta2=0.999, adam_epsilon=1e-08, max_grad_norm=1.0,
    num_train_epochs=100,
    max_steps=-1,
    lr_scheduler_type="cosine", 
    # warmup_ratio=0.0, 
    
    logging_strategy='steps',
    evaluation_strategy='steps',
    eval_accumulation_steps=1,
    logging_steps=int(steps_per_epoch / 2),
    save_steps=int(steps_per_epoch * 1),
    eval_steps=int(steps_per_epoch / 2),
    warmup_steps=int(steps_per_epoch * 10),

    output_dir=os.path.join(checkpoint_dir, "wav2vec2gpt/unfreeze-adapter-rnn"),
    # save_strategy='steps',
    # save_total_limit=3,
    # overwrite_output_dir=True,

)

In [19]:
# instantiate trainer
trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    compute_metrics=compute_metrics,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
)

In [None]:
wandb.init(project="testing-wav2vec2gpt", entity="yoom618") # 2f747faa5c3ba7aa67c3aa9a68f060bf273d26ba
torch.cuda.empty_cache()

# start training
# model.select_random = True
model.select_random = False
trainer.train()

ERROR:wandb.jupyter:Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.


<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
wandb: Paste an API key from your profile and hit enter, or press ctrl+c to quit: 

In [None]:
wandb.finish()

In [None]:
# model.load_state_dict(torch.load('/data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-adapter-rnn-lm/checkpoint-8814/pytorch_model.bin'))

##### example


BATCH_SIZE = 4
i = 3
device = 'cuda:0'
batch_idx = train_idx[i*BATCH_SIZE:i*BATCH_SIZE+BATCH_SIZE]

audio_feature_batch = list()
for idx in batch_idx:
    audio_feature_batch.append(dataset['audio_array'][idx])
audio_feature_batch = torch.stack(audio_feature_batch)
# print(audio_feature_batch.size())

label_batch = dataset['text']['input_ids'][batch_idx]
attention_batch = dataset['text']['attention_mask'][batch_idx]
# print(label_batch.size())

with torch.no_grad():
    audio_embedding = model(input_values=audio_feature_batch.to(device), 
                            labels=label_batch.to(device),
                            output_attention_mask=attention_batch.to(device),)
print(audio_embedding.logits[0].shape)

pred_ids_0 = torch.argmax(audio_embedding.logits[0], axis=-1)
pred_ids_1 = torch.argmax(audio_embedding.logits[1], axis=-1)

print(pred_ids_0.size(), pred_ids_1.size())
print()

# for idx in range(BATCH_SIZE):
#     print(tokenizer.decode(label_batch[idx]))
#     print(tokenizer.decode(pred_ids[idx]))
#     print()


for idx in range(BATCH_SIZE):
    print(tokenizer.decode([key for key, _group in groupby(label_batch[idx])]))
    
    # print(tokenizer.decode(pred_ids_0[idx]))
    print(tokenizer.decode([key for key, _group in groupby(pred_ids_0[idx])]))
    # print(tokenizer.decode([key for key, _group in groupby(pred_ids_0[idx])]).replace('<|endoftext|>',''))
    
    # print(tokenizer.decode(pred_ids_1[idx]))
    print(tokenizer.decode([key for key, _group in groupby(pred_ids_1[idx])]))
    # print(tokenizer.decode([key for key, _group in groupby(pred_ids_1[idx])]).replace('<|endoftext|>',''))
    print()

In [None]:
idx = 1

with torch.no_grad():
    hidden_states, _ = model.rnn_compressor(model.wav2vec2(audio_feature_batch.to(device)).last_hidden_state)
    # hidden_states = model.rnn_compressor(model.wav2vec2(audio_feature_batch.to(device)).last_hidden_state.transpose(1,2)).transpose(1,2)
    
    sns.heatmap(
        nn.functional.softmax(model.lm_head(hidden_states), dim=2).cpu()[idx,:300, :1200]
    )
    
    
IPython.display.Audio(dataset['audio_array'][batch_idx[idx]], rate=16000)