In [1]:
import os
import pickle
import random
from tqdm import tqdm
import numpy as np
import torch

from datasets import load_dataset, load_metric
import math
from itertools import groupby

import wandb

# os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
# os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"
# device = 'cuda:0'
device = 'cpu'

cache_dir = "/data4/yoomcache"
model_cache_dir = os.path.join(cache_dir, 'huggingface')
data_cache_dir = os.path.join(cache_dir, 'datasets')
checkpoint_dir = os.path.join(cache_dir, 'checkpoint')

seed = 0
random.seed(0)
np.random.seed(seed)
torch.manual_seed(seed)

wandb.init(project="testing-wav2vec2gpt", entity="yoom-private")

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33myoom-private[0m (use `wandb login --relogin` to force relogin)


In [2]:
# %reload_ext autoreload
# %autoreload 2
from wav2vec2GPTwCTC import *
from configuration_wav2vec2gpt import Wav2Vec2GPTConfig

from transformers import Wav2Vec2FeatureExtractor
from transformers import GPT2Tokenizer, AddedToken
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments

In [3]:
wav2vec_pretrained = "facebook/wav2vec2-base"
gpt_pretrained = "gpt2"

# Should aware that pad_token_id is used to compute CTC loss, 
# so pad_token configuration for both tokenizer and model should be the same
args = {
#     'pad_token': 'Ġ', 'pad_token_id': 220,
#     'unk_token': 'Ġ', 'unk_token_id': 220,
    'pad_token': "<|endoftext|>", 'pad_token_id': 50256,
    'unk_token': "<|endoftext|>", 'unk_token_id': 50256,
    'bos_token': "<|endoftext|>", 'bos_token_id': 50256,
    'eos_token': "<|endoftext|>", 'eos_token_id': 50256,
    
    'n_positions': 64, # VCTK: 42, 
    
    'add_adapter': True,
    'adapter_kernel_size': 6, 
    'adapter_stride': 2,
    'num_adapter_layers': 3,
}


config = Wav2Vec2GPTConfig(**args)

In [4]:
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(wav2vec_pretrained, 
                                                             cache_dir=model_cache_dir,
                                                             **args)

tokenizer = GPT2Tokenizer.from_pretrained(gpt_pretrained,
                                          cache_dir=model_cache_dir,
                                          **args)

In [5]:
# if not os.path.exists('/data4/TTS/VCTK-Corpus/dataset-vctk-16k(preprocessed).pkl'):
#     with open('/data4/TTS/VCTK-Corpus/dataset-vctk-16k.pkl', 'rb') as f:
#         dataset = pickle.load(f)
#     del dataset['page'], dataset['index'], dataset['audio_path']


#     max_audio_length = 0
#     for arr in dataset['audio_array']:
#         if len(arr) > max_audio_length:
#             max_audio_length = len(arr)
#     print(max_audio_length)


#     for idx in tqdm(range(len(dataset['audio_array']))):
#         dataset['audio_array'][idx] = feature_extractor(dataset['audio_array'][idx], 
#                                                         sampling_rate=dataset['sample_rate'],
#                                                         return_tensors="pt",
#                                                         padding='max_length',
#                                                         max_length=max_audio_length
#                                                         ).input_values[0]
#     dataset['audio_array'] = torch.stack(dataset['audio_array'])
#     print(dataset['audio_array'].shape)


#     with open('/data4/TTS/VCTK-Corpus/dataset-vctk-16k(preprocessed).pkl', 'wb') as f:
#         pickle.dump(dataset, f)
        
        
# else:
#     with open('/data4/TTS/VCTK-Corpus/dataset-vctk-16k(preprocessed).pkl', 'rb') as f:
#         dataset = pickle.load(f)
#     print(dataset['audio_array'].shape)

In [6]:
with open('/data4/TTS/VCTK-Corpus/dataset-vctk-16k.pkl', 'rb') as f:
    dataset = pickle.load(f)
del dataset['page'], dataset['index'], dataset['audio_path']


max_audio_length = 0
for arr in dataset['audio_array']:
    if len(arr) > max_audio_length:
        max_audio_length = len(arr)
print(max_audio_length)


for idx in tqdm(range(len(dataset['audio_array']))):
    dataset['audio_array'][idx] = feature_extractor(dataset['audio_array'][idx], 
                                                    sampling_rate=dataset['sample_rate'],
                                                    return_tensors="pt",
                                                    padding='max_length',
                                                    max_length=max_audio_length
                                                    ).input_values[0]
print(len(dataset['audio_array']))

308533


100%|████████████████████████████████████| 44070/44070 [01:35<00:00, 460.61it/s]

44070





In [7]:
dataset['text'] = tokenizer(dataset['text'],
                            return_tensors="pt",
                            # padding='longest', # VCTK: 42,
                            padding='max_length',
                            max_length=args['n_positions']
                            )
print(dataset['text']['attention_mask'].shape)

torch.Size([44070, 64])


In [8]:
split_ratio = (0.8, 0.9)
dataset_size = dataset['text']['attention_mask'].shape[0]
indices = np.arange(dataset_size)
np.random.shuffle(indices)

train_idx = indices[:int(dataset_size * split_ratio[0])]
val_idx = indices[int(dataset_size * split_ratio[0]):int(dataset_size * split_ratio[1])]
test_idx = indices[int(dataset_size * split_ratio[1]):]

In [9]:
class CustomDataset(torch.utils.data.Dataset):
    def __init__(self, input_values, tokenized_output, indices):
        self.input_values = input_values
        self.tokenized_output = tokenized_output
        self.indices = indices

    def __getitem__(self, idx):
        item = dict()
        item['input_values'] = self.input_values[self.indices[idx]]
        item['labels'] = self.tokenized_output['input_ids'][self.indices[idx]]
        item['output_attention_mask'] = self.tokenized_output['attention_mask'][self.indices[idx]]
        return item

    def __len__(self):
        return len(self.indices)

    

train_dataset = CustomDataset(dataset['audio_array'], dataset['text'], train_idx)
val_dataset = CustomDataset(dataset['audio_array'], dataset['text'], val_idx)
test_dataset = CustomDataset(dataset['audio_array'], dataset['text'], test_idx)

In [10]:
model = Wav2Vec2GPTModel(config=config)

model.wav2vec2.from_pretrained(wav2vec_pretrained, cache_dir=model_cache_dir)
model.transformer.from_pretrained(gpt_pretrained, cache_dir=model_cache_dir)


# device_map = {
#     0: [0, 1, 2, 3, 4,],
#     2: [5, 6, 7, 8, 9, 10, 11, ],
# }
# model.gpt2lm.parallelize(device_map)


model.freeze_feature_extractor()
model.freeze_feature_projection()
# model.freeze_wav2vec_encoder() # not exists here
model.unfreeze_wav2vec_adapter()
model.unfreeze_rnn_compressor()
model.freeze_gpt_decoder()
model.unfreeze_lm_head()

Some weights of the model checkpoint at facebook/wav2vec2-base were not used when initializing Wav2Vec2Model2: ['project_q.weight', 'wav2vec2.encoder.layers.11.attention.out_proj.weight', 'wav2vec2.encoder.layer_norm.weight', 'wav2vec2.encoder.layers.9.attention.v_proj.weight', 'wav2vec2.encoder.pos_conv_embed.conv.weight_g', 'wav2vec2.encoder.layers.1.feed_forward.output_dense.bias', 'wav2vec2.encoder.layers.3.attention.out_proj.bias', 'wav2vec2.encoder.layers.8.layer_norm.bias', 'wav2vec2.encoder.layers.4.attention.out_proj.bias', 'wav2vec2.encoder.layers.11.layer_norm.bias', 'wav2vec2.encoder.layers.0.attention.k_proj.weight', 'wav2vec2.encoder.layers.4.layer_norm.weight', 'wav2vec2.encoder.layers.6.feed_forward.intermediate_dense.bias', 'wav2vec2.encoder.layers.11.final_layer_norm.weight', 'wav2vec2.encoder.layers.7.feed_forward.output_dense.bias', 'wav2vec2.encoder.layers.8.feed_forward.output_dense.bias', 'wav2vec2.encoder.layers.8.final_layer_norm.bias', 'quantizer.weight_proj.w

In [14]:
# # load rouge for validation
# rouge = load_metric("rouge")

# def compute_metrics(pred):
#     labels_ids = pred.label_ids
#     pred_ids = pred.predictions

#     # all unnecessary tokens are removed
#     pred_str = decoder_tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
#     labels_ids[labels_ids == -100] = decoder_tokenizer.eos_token_id
#     label_str = decoder_tokenizer.batch_decode(labels_ids, skip_special_tokens=True)

#     rouge_output = rouge.compute(predictions=pred_str, references=label_str, rouge_types=["rouge2"])["rouge2"].mid

#     return {
#         "rouge2_precision": round(rouge_output.precision, 4),
#         "rouge2_recall": round(rouge_output.recall, 4),
#         "rouge2_fmeasure": round(rouge_output.fmeasure, 4),
#     }

In [15]:
batch_size = 24
steps_per_epoch = math.ceil(len(train_dataset) / batch_size / 2)


# set training arguments - these params are not really tuned, feel free to change
training_args = Seq2SeqTrainingArguments(
#     predict_with_generate=True,
    output_dir=os.path.join(checkpoint_dir, "wav2vec2gpt/unfreeze-adapter-rnn-lm"),
    # do_train=True,
    # do_eval=False,
    # do_predict=True,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    learning_rate=1e-4, 
    weight_decay=0.0, adam_beta1=0.9, adam_beta2=0.999, adam_epsilon=1e-08, max_grad_norm=1.0,
    num_train_epochs=100,
    max_steps=-1,
    # lr_scheduler_type='cosine', 
    # warmup_ratio=0.0, 
    
    logging_strategy='steps',
    save_strategy='steps',
    evaluation_strategy='steps',
    logging_steps=1 * steps_per_epoch,
    save_steps=1 * steps_per_epoch,
    eval_steps=1 * steps_per_epoch,
    warmup_steps=10 * steps_per_epoch,
    save_total_limit=10,
    overwrite_output_dir=True,
)

In [16]:
# instantiate trainer
trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
#     compute_metrics=compute_metrics,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
)


# start training
trainer.train()

***** Running training *****
  Num examples = 35256
  Num Epochs = 100
  Instantaneous batch size per device = 24
  Total train batch size (w. parallel, distributed & accumulation) = 48
  Gradient Accumulation steps = 1
  Total optimization steps = 73500
Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"
  result = _VF.gru(input, hx, self._flat_weights, self.bias, self.num_layers,


Step,Training Loss,Validation Loss
735,119.4855,25.785419
1470,22.9161,19.734936
2205,19.3008,19.108522
2940,18.7088,18.710142
3675,18.3754,18.470869
4410,18.0862,17.991512
5145,17.8112,17.808498
5880,17.5475,17.551544
6615,17.2051,17.552288
7350,16.893,17.225233


***** Running Evaluation *****
  Num examples = 4407
  Batch size = 48
Saving model checkpoint to /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-adapter-rnn-lm/checkpoint-735
Configuration saved in /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-adapter-rnn-lm/checkpoint-735/config.json
Model weights saved in /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-adapter-rnn-lm/checkpoint-735/pytorch_model.bin
Deleting older checkpoint [/data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-adapter-rnn-lm/checkpoint-96] due to args.save_total_limit
***** Running Evaluation *****
  Num examples = 4407
  Batch size = 48
Saving model checkpoint to /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-adapter-rnn-lm/checkpoint-1470
Configuration saved in /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-adapter-rnn-lm/checkpoint-1470/config.json
Model weights saved in /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-adapter-rnn-lm/checkpoint-1470/pytorch_model.bin
Deleting older checkpoint [/data4/yoomca

***** Running Evaluation *****
  Num examples = 4407
  Batch size = 48
Saving model checkpoint to /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-adapter-rnn-lm/checkpoint-8085
Configuration saved in /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-adapter-rnn-lm/checkpoint-8085/config.json
Model weights saved in /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-adapter-rnn-lm/checkpoint-8085/pytorch_model.bin
Deleting older checkpoint [/data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-adapter-rnn-lm/checkpoint-735] due to args.save_total_limit
***** Running Evaluation *****
  Num examples = 4407
  Batch size = 48
Saving model checkpoint to /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-adapter-rnn-lm/checkpoint-8820
Configuration saved in /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-adapter-rnn-lm/checkpoint-8820/config.json
Model weights saved in /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-adapter-rnn-lm/checkpoint-8820/pytorch_model.bin
Deleting older checkpoint [/data4/yo

***** Running Evaluation *****
  Num examples = 4407
  Batch size = 48
Saving model checkpoint to /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-adapter-rnn-lm/checkpoint-15435
Configuration saved in /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-adapter-rnn-lm/checkpoint-15435/config.json
Model weights saved in /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-adapter-rnn-lm/checkpoint-15435/pytorch_model.bin
Deleting older checkpoint [/data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-adapter-rnn-lm/checkpoint-8085] due to args.save_total_limit
***** Running Evaluation *****
  Num examples = 4407
  Batch size = 48
Saving model checkpoint to /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-adapter-rnn-lm/checkpoint-16170
Configuration saved in /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-adapter-rnn-lm/checkpoint-16170/config.json
Model weights saved in /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-adapter-rnn-lm/checkpoint-16170/pytorch_model.bin
Deleting older checkpoint [/d

***** Running Evaluation *****
  Num examples = 4407
  Batch size = 48
Saving model checkpoint to /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-adapter-rnn-lm/checkpoint-22785
Configuration saved in /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-adapter-rnn-lm/checkpoint-22785/config.json
Model weights saved in /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-adapter-rnn-lm/checkpoint-22785/pytorch_model.bin
Deleting older checkpoint [/data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-adapter-rnn-lm/checkpoint-15435] due to args.save_total_limit
***** Running Evaluation *****
  Num examples = 4407
  Batch size = 48
Saving model checkpoint to /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-adapter-rnn-lm/checkpoint-23520
Configuration saved in /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-adapter-rnn-lm/checkpoint-23520/config.json
Model weights saved in /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-adapter-rnn-lm/checkpoint-23520/pytorch_model.bin
Deleting older checkpoint [/

Model weights saved in /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-adapter-rnn-lm/checkpoint-29400/pytorch_model.bin
Deleting older checkpoint [/data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-adapter-rnn-lm/checkpoint-22050] due to args.save_total_limit
***** Running Evaluation *****
  Num examples = 4407
  Batch size = 48
Saving model checkpoint to /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-adapter-rnn-lm/checkpoint-30135
Configuration saved in /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-adapter-rnn-lm/checkpoint-30135/config.json
Model weights saved in /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-adapter-rnn-lm/checkpoint-30135/pytorch_model.bin
Deleting older checkpoint [/data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-adapter-rnn-lm/checkpoint-22785] due to args.save_total_limit
***** Running Evaluation *****
  Num examples = 4407
  Batch size = 48
Saving model checkpoint to /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-adapter-rnn-lm/checkpoint-30870
Config

wandb: Network error (ReadTimeout), entering retry loop.
***** Running Evaluation *****
  Num examples = 4407
  Batch size = 48
Saving model checkpoint to /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-adapter-rnn-lm/checkpoint-36750
Configuration saved in /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-adapter-rnn-lm/checkpoint-36750/config.json
Model weights saved in /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-adapter-rnn-lm/checkpoint-36750/pytorch_model.bin
Deleting older checkpoint [/data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-adapter-rnn-lm/checkpoint-29400] due to args.save_total_limit
***** Running Evaluation *****
  Num examples = 4407
  Batch size = 48
Saving model checkpoint to /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-adapter-rnn-lm/checkpoint-37485
Configuration saved in /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-adapter-rnn-lm/checkpoint-37485/config.json
Model weights saved in /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-adapter-rnn-lm/checkp

***** Running Evaluation *****
  Num examples = 4407
  Batch size = 48
Saving model checkpoint to /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-adapter-rnn-lm/checkpoint-44100
Configuration saved in /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-adapter-rnn-lm/checkpoint-44100/config.json
Model weights saved in /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-adapter-rnn-lm/checkpoint-44100/pytorch_model.bin
Deleting older checkpoint [/data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-adapter-rnn-lm/checkpoint-36750] due to args.save_total_limit
***** Running Evaluation *****
  Num examples = 4407
  Batch size = 48
Saving model checkpoint to /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-adapter-rnn-lm/checkpoint-44835
Configuration saved in /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-adapter-rnn-lm/checkpoint-44835/config.json
Model weights saved in /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-adapter-rnn-lm/checkpoint-44835/pytorch_model.bin
Deleting older checkpoint [/

Model weights saved in /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-adapter-rnn-lm/checkpoint-50715/pytorch_model.bin
Deleting older checkpoint [/data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-adapter-rnn-lm/checkpoint-43365] due to args.save_total_limit
***** Running Evaluation *****
  Num examples = 4407
  Batch size = 48
Saving model checkpoint to /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-adapter-rnn-lm/checkpoint-51450
Configuration saved in /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-adapter-rnn-lm/checkpoint-51450/config.json
Model weights saved in /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-adapter-rnn-lm/checkpoint-51450/pytorch_model.bin
Deleting older checkpoint [/data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-adapter-rnn-lm/checkpoint-44100] due to args.save_total_limit
***** Running Evaluation *****
  Num examples = 4407
  Batch size = 48
Saving model checkpoint to /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-adapter-rnn-lm/checkpoint-52185
Config

***** Running Evaluation *****
  Num examples = 4407
  Batch size = 48
Saving model checkpoint to /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-adapter-rnn-lm/checkpoint-58065
Configuration saved in /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-adapter-rnn-lm/checkpoint-58065/config.json
Model weights saved in /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-adapter-rnn-lm/checkpoint-58065/pytorch_model.bin
Deleting older checkpoint [/data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-adapter-rnn-lm/checkpoint-50715] due to args.save_total_limit
***** Running Evaluation *****
  Num examples = 4407
  Batch size = 48
Saving model checkpoint to /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-adapter-rnn-lm/checkpoint-58800
Configuration saved in /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-adapter-rnn-lm/checkpoint-58800/config.json
Model weights saved in /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-adapter-rnn-lm/checkpoint-58800/pytorch_model.bin
Deleting older checkpoint [/

***** Running Evaluation *****
  Num examples = 4407
  Batch size = 48
Saving model checkpoint to /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-adapter-rnn-lm/checkpoint-65415
Configuration saved in /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-adapter-rnn-lm/checkpoint-65415/config.json
Model weights saved in /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-adapter-rnn-lm/checkpoint-65415/pytorch_model.bin
Deleting older checkpoint [/data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-adapter-rnn-lm/checkpoint-58065] due to args.save_total_limit
wandb: Network error (ReadTimeout), entering retry loop.
***** Running Evaluation *****
  Num examples = 4407
  Batch size = 48
Saving model checkpoint to /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-adapter-rnn-lm/checkpoint-66150
Configuration saved in /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-adapter-rnn-lm/checkpoint-66150/config.json
Model weights saved in /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-adapter-rnn-lm/checkp

Model weights saved in /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-adapter-rnn-lm/checkpoint-72030/pytorch_model.bin
Deleting older checkpoint [/data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-adapter-rnn-lm/checkpoint-64680] due to args.save_total_limit
***** Running Evaluation *****
  Num examples = 4407
  Batch size = 48
Saving model checkpoint to /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-adapter-rnn-lm/checkpoint-72765
Configuration saved in /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-adapter-rnn-lm/checkpoint-72765/config.json
Model weights saved in /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-adapter-rnn-lm/checkpoint-72765/pytorch_model.bin
Deleting older checkpoint [/data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-adapter-rnn-lm/checkpoint-65415] due to args.save_total_limit
***** Running Evaluation *****
  Num examples = 4407
  Batch size = 48
Saving model checkpoint to /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-adapter-rnn-lm/checkpoint-73500
Config

TrainOutput(global_step=73500, training_loss=8.568285219361181, metrics={'train_runtime': 95106.6696, 'train_samples_per_second': 37.07, 'train_steps_per_second': 0.773, 'total_flos': 7.469035469007829e+20, 'train_loss': 8.568285219361181, 'epoch': 100.0})

In [11]:
wandb.finish()




VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

In [11]:
!ls /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-adapter-rnn-lm/

checkpoint-66885  checkpoint-69090  checkpoint-71295  checkpoint-73500
checkpoint-67620  checkpoint-69825  checkpoint-72030
checkpoint-68355  checkpoint-70560  checkpoint-72765


In [12]:
##### example

model.load_state_dict(torch.load(
    '/data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-adapter-rnn-lm/checkpoint-66885/pytorch_model.bin'))


BATCH_SIZE = 16
i = 3
batch_idx = test_idx[i*BATCH_SIZE:i*BATCH_SIZE+BATCH_SIZE]

audio_feature_batch = list()
for idx in batch_idx:
    audio_feature_batch.append(dataset['audio_array'][idx])
audio_feature_batch = torch.stack(audio_feature_batch)
print(audio_feature_batch.size())

label_batch = dataset['text']['input_ids'][batch_idx]
attention_batch = dataset['text']['attention_mask'][batch_idx]

print(label_batch.size())

with torch.no_grad():
    audio_embedding = model(input_values=audio_feature_batch.to(device), 
                            labels=label_batch.to(device),
                            output_attention_mask=attention_batch.to(device),)
print(audio_embedding.logits.shape)

pred_ids = torch.argmax(audio_embedding.logits, axis=-1)
print(pred_ids.size())
print()

for idx in range(BATCH_SIZE):
    print(tokenizer.decode(label_batch[idx]).replace('<|endoftext|>',''))
    print(tokenizer.decode([key for key, _group in groupby(pred_ids[idx])]))
    print(tokenizer.decode([key for key, _group in groupby(pred_ids[idx])]).replace('<|endoftext|>',''))
    print()

torch.Size([16, 308533])
torch.Size([16, 64])
torch.Size([16, 64, 50257])
torch.Size([16, 64])

This is no reflection on Rangers.
The<|endoftext|> reflection<|endoftext|> interest<|endoftext|>.<|endoftext|>
The reflection interest.

Today, we begin to answer that question.
The<|endoftext|>, will details are<|endoftext|> the suggestion<|endoftext|>.
The, will details are the suggestion.

We think a lot of Allan McGregor.
It<|endoftext|> think<|endoftext|> other not away<|endoftext|> me<|endoftext|>.
It think other not away me.

This gives a financial incentive to switch.
He<|endoftext|> her from incentive<|endoftext|> switch<|endoftext|>.<|endoftext|>
He her from incentive switch.

Sounds like The Sixth Sense?
I<|endoftext|> like the Sixth Sense<|endoftext|>.<|endoftext|>
I like the Sixth Sense.

They had four children together.
It<|endoftext|> support for<|endoftext|> to in<|endoftext|> matter<|endoftext|>.
It support for to in matter.

It was clear.
It<|endoftext|> clear<|endoftext|>.

In [13]:
##### example

model.load_state_dict(torch.load(
    '/data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-adapter-rnn-lm/checkpoint-72765/pytorch_model.bin'))


with torch.no_grad():
    audio_embedding = model(input_values=audio_feature_batch.to(device), 
                            labels=label_batch.to(device),
                            output_attention_mask=attention_batch.to(device),)

for idx in range(BATCH_SIZE):
    print(tokenizer.decode(label_batch[idx]).replace('<|endoftext|>',''))
    print(tokenizer.decode([key for key, _group in groupby(pred_ids[idx])]))
    print(tokenizer.decode([key for key, _group in groupby(pred_ids[idx])]).replace('<|endoftext|>',''))
    print()

This is no reflection on Rangers.
The<|endoftext|> reflection<|endoftext|> interest<|endoftext|>.<|endoftext|>
The reflection interest.

Today, we begin to answer that question.
The<|endoftext|>, will details are<|endoftext|> the suggestion<|endoftext|>.
The, will details are the suggestion.

We think a lot of Allan McGregor.
It<|endoftext|> think<|endoftext|> other not away<|endoftext|> me<|endoftext|>.
It think other not away me.

This gives a financial incentive to switch.
He<|endoftext|> her from incentive<|endoftext|> switch<|endoftext|>.<|endoftext|>
He her from incentive switch.

Sounds like The Sixth Sense?
I<|endoftext|> like the Sixth Sense<|endoftext|>.<|endoftext|>
I like the Sixth Sense.

They had four children together.
It<|endoftext|> support for<|endoftext|> to in<|endoftext|> matter<|endoftext|>.
It support for to in matter.

It was clear.
It<|endoftext|> clear<|endoftext|>.<|endoftext|>
It clear.

He is a sort of a mystery figure.
The<|endoftext|>'m, he would<|endofte

In [None]:
# import IPython

# IPython.display.Audio(dataset[4]['audio']['path'])

In [None]:
model.wav2vec2