In [1]:
import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"

cache_dir = "/data4/yoomcache"
model_cache_dir = os.path.join(cache_dir, 'huggingface')
data_cache_dir = os.path.join(cache_dir, 'datasets')
checkpoint_dir = os.path.join(cache_dir, 'checkpoint')

import torch
from datasets import load_dataset, load_metric
import math
from itertools import groupby

import wandb
wandb.init(project="testing-wav2vec2gpt", entity="yoom-private")

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33myoom-private[0m (use `wandb login --relogin` to force relogin)


In [2]:
# %reload_ext autoreload
# %autoreload 2
from wav2vec2GPTwCTC import *
from configuration_wav2vec2gpt import Wav2Vec2GPTConfig

from transformers import Wav2Vec2FeatureExtractor
from transformers import GPT2Tokenizer, AddedToken
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments

In [3]:
wav2vec_pretrained = "facebook/wav2vec2-base"
gpt_pretrained = "gpt2"

# Should aware that pad_token_id is used to compute CTC loss, 
# so pad_token configuration for both tokenizer and model should be the same
args = {
#     'pad_token': 'Ġ', 'pad_token_id': 220,
#     'unk_token': 'Ġ', 'unk_token_id': 220,
    'pad_token': "<|endoftext|>", 'pad_token_id': 50256,
    'unk_token': "<|endoftext|>", 'unk_token_id': 50256,
    'bos_token': "<|endoftext|>", 'bos_token_id': 50256,
    'eos_token': "<|endoftext|>", 'eos_token_id': 50256,
    
    'n_positions': 512,
    
    'add_adapter': True,
    'adapter_kernel_size': 6, 
    'adapter_stride': 2,
    'num_adapter_layers': 3,
}


config = Wav2Vec2GPTConfig(**args)

In [4]:
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(wav2vec_pretrained, 
                                                             cache_dir=model_cache_dir,
                                                             **args)

tokenizer = GPT2Tokenizer.from_pretrained(gpt_pretrained,
                                          cache_dir=model_cache_dir,
                                          **args)

In [5]:
model = Wav2Vec2GPTModel(config=config)

model.wav2vec2.from_pretrained(wav2vec_pretrained, cache_dir=model_cache_dir)
model.transformer.from_pretrained(gpt_pretrained, cache_dir=model_cache_dir)


# device_map = {
#     0: [0, 1, 2, 3, 4,],
#     2: [5, 6, 7, 8, 9, 10, 11, ],
# }
# model.gpt2lm.parallelize(device_map)


model.freeze_feature_extractor()
model.freeze_feature_projection()
# model.freeze_wav2vec_encoder() # not exists here
model.unfreeze_wav2vec_adapter()
model.unfreeze_rnn_compressor()
model.freeze_gpt_decoder()
model.unfreeze_lm_head()

Some weights of the model checkpoint at facebook/wav2vec2-base were not used when initializing Wav2Vec2Model2: ['wav2vec2.encoder.layers.10.attention.k_proj.bias', 'quantizer.codevectors', 'wav2vec2.encoder.layers.3.attention.k_proj.bias', 'wav2vec2.encoder.layers.8.final_layer_norm.bias', 'wav2vec2.encoder.layers.5.attention.q_proj.weight', 'wav2vec2.encoder.layers.8.attention.k_proj.weight', 'wav2vec2.encoder.layers.11.feed_forward.intermediate_dense.weight', 'wav2vec2.encoder.layers.3.attention.q_proj.bias', 'wav2vec2.encoder.layers.2.layer_norm.weight', 'wav2vec2.encoder.layers.3.attention.v_proj.bias', 'wav2vec2.encoder.layers.1.layer_norm.bias', 'wav2vec2.encoder.layers.5.feed_forward.output_dense.weight', 'wav2vec2.encoder.layers.5.attention.k_proj.weight', 'wav2vec2.encoder.layers.10.layer_norm.bias', 'wav2vec2.encoder.layers.8.feed_forward.intermediate_dense.weight', 'wav2vec2.encoder.layers.3.feed_forward.intermediate_dense.bias', 'wav2vec2.encoder.pos_conv_embed.conv.bias', 

In [6]:
dataset = load_dataset("hf-internal-testing/librispeech_asr_demo", "clean", 
                       split="validation", 
                       cache_dir=data_cache_dir
                      )

dataset = dataset.sort("id")
sampling_rate = dataset.features["audio"].sampling_rate
audio_inputs = [d["audio"]["array"] for d in dataset]

print(dataset, sampling_rate)

Reusing dataset librispeech_asr (/data4/yoomcache/datasets/hf-internal-testing___librispeech_asr/clean/2.1.0/d3bc4c2bc2078fcde3ad0f0f635862e4c0fef78ba94c4a34c4c250a097af240b)
Loading cached sorted indices for dataset at /data4/yoomcache/datasets/hf-internal-testing___librispeech_asr/clean/2.1.0/d3bc4c2bc2078fcde3ad0f0f635862e4c0fef78ba94c4a34c4c250a097af240b/cache-2f7c0cbee6ef3aa1.arrow


Dataset({
    features: ['file', 'audio', 'text', 'speaker_id', 'chapter_id', 'id'],
    num_rows: 73
}) 16000


In [7]:
# text_inputs = dataset["text"]
from example.librispeech_asr_demo import text_inputs

In [8]:
class CustomDataset(torch.utils.data.Dataset):
    def __init__(self, input_values, tokenized_output):
        self.input_values = input_values
        self.tokenized_output = tokenized_output

    def __getitem__(self, idx):
        item = dict()
        item['input_values'] = self.input_values['input_values'][idx]
        item['labels'] = self.tokenized_output['input_ids'][idx]
        item['output_attention_mask'] = self.tokenized_output['attention_mask'][idx]
        return item

    def __len__(self):
        return len(self.input_values['input_values'])

    
input_values = feature_extractor(audio_inputs, 
                                      sampling_rate=sampling_rate,
                                      return_tensors="pt",
                                      padding='longest',
                                     )

tokenized_output = tokenizer(text_inputs,
                             return_tensors="pt",
                             # padding='longest',
                             padding='max_length',
                             max_length=args['n_positions']
                         )

train_dataset = CustomDataset(input_values, tokenized_output)
# val_dataset = CustomDataset(input_values, tokenized_output)
# test_dataset = CustomDataset(input_values, tokenized_output)

In [9]:
# # load rouge for validation
# rouge = load_metric("rouge")

# def compute_metrics(pred):
#     labels_ids = pred.label_ids
#     pred_ids = pred.predictions

#     # all unnecessary tokens are removed
#     pred_str = decoder_tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
#     labels_ids[labels_ids == -100] = decoder_tokenizer.eos_token_id
#     label_str = decoder_tokenizer.batch_decode(labels_ids, skip_special_tokens=True)

#     rouge_output = rouge.compute(predictions=pred_str, references=label_str, rouge_types=["rouge2"])["rouge2"].mid

#     return {
#         "rouge2_precision": round(rouge_output.precision, 4),
#         "rouge2_recall": round(rouge_output.recall, 4),
#         "rouge2_fmeasure": round(rouge_output.fmeasure, 4),
#     }

In [10]:
batch_size = 4
steps_per_epoch = math.ceil(len(train_dataset) / batch_size)


# set training arguments - these params are not really tuned, feel free to change
training_args = Seq2SeqTrainingArguments(
#     predict_with_generate=True,
    output_dir=os.path.join(checkpoint_dir, "wav2vec2gpt/unfreeze-rnn"),
    # do_train=True,
    # do_eval=False,
    # do_predict=True,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size * 5,
    learning_rate=5e-5, 
    weight_decay=0.0, adam_beta1=0.9, adam_beta2=0.999, adam_epsilon=1e-08, max_grad_norm=1.0,
    num_train_epochs=1000,
    max_steps=-1,
    lr_scheduler_type='cosine', 
    # warmup_ratio=0.0, 
    
    logging_strategy='steps',
    save_strategy='steps',
    evaluation_strategy='steps',
    logging_steps=1 * steps_per_epoch,
    save_steps=2 * steps_per_epoch,
    eval_steps=1 * steps_per_epoch,
    warmup_steps=100 * steps_per_epoch,
    save_total_limit=10,
    overwrite_output_dir=True,
)

In [11]:
# instantiate trainer
trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
#     compute_metrics=compute_metrics,
    train_dataset=train_dataset,
    eval_dataset=train_dataset,
)


# start training
trainer.train()

***** Running training *****
  Num examples = 73
  Num Epochs = 1000
  Instantaneous batch size per device = 4
  Total train batch size (w. parallel, distributed & accumulation) = 8
  Gradient Accumulation steps = 1
  Total optimization steps = 10000
Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"
  result = _VF.gru(input, hx, self._flat_weights, self.bias, self.num_layers,


Step,Training Loss,Validation Loss
19,39.0973,184.115067
38,37.0443,178.938858
57,36.5862,169.55423
76,35.1982,157.169357
95,33.017,144.311462
114,31.3683,133.494247
133,28.7641,123.64872
152,29.9879,116.307739
171,27.9577,109.201309
190,25.6724,101.60701


***** Running Evaluation *****
  Num examples = 73
  Batch size = 40
***** Running Evaluation *****
  Num examples = 73
  Batch size = 40
Saving model checkpoint to /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-rnn/checkpoint-38
Configuration saved in /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-rnn/checkpoint-38/config.json
Model weights saved in /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-rnn/checkpoint-38/pytorch_model.bin
Deleting older checkpoint [/data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-rnn/checkpoint-9652] due to args.save_total_limit
  result = _VF.gru(input, hx, self._flat_weights, self.bias, self.num_layers,
***** Running Evaluation *****
  Num examples = 73
  Batch size = 40
***** Running Evaluation *****
  Num examples = 73
  Batch size = 40
Saving model checkpoint to /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-rnn/checkpoint-76
Configuration saved in /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-rnn/checkpoint-76/config.json
Model weights saved 

***** Running Evaluation *****
  Num examples = 73
  Batch size = 40
***** Running Evaluation *****
  Num examples = 73
  Batch size = 40
Saving model checkpoint to /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-rnn/checkpoint-266
Configuration saved in /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-rnn/checkpoint-266/config.json
Model weights saved in /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-rnn/checkpoint-266/pytorch_model.bin
Deleting older checkpoint [/data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-rnn/checkpoint-9880] due to args.save_total_limit
  result = _VF.gru(input, hx, self._flat_weights, self.bias, self.num_layers,
***** Running Evaluation *****
  Num examples = 73
  Batch size = 40
***** Running Evaluation *****
  Num examples = 73
  Batch size = 40
Saving model checkpoint to /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-rnn/checkpoint-304
Configuration saved in /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-rnn/checkpoint-304/config.json
Model weights s

***** Running Evaluation *****
  Num examples = 73
  Batch size = 40
***** Running Evaluation *****
  Num examples = 73
  Batch size = 40
Saving model checkpoint to /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-rnn/checkpoint-494
Configuration saved in /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-rnn/checkpoint-494/config.json
Model weights saved in /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-rnn/checkpoint-494/pytorch_model.bin
Deleting older checkpoint [/data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-rnn/checkpoint-114] due to args.save_total_limit
  result = _VF.gru(input, hx, self._flat_weights, self.bias, self.num_layers,
***** Running Evaluation *****
  Num examples = 73
  Batch size = 40
***** Running Evaluation *****
  Num examples = 73
  Batch size = 40
Saving model checkpoint to /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-rnn/checkpoint-532
Configuration saved in /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-rnn/checkpoint-532/config.json
Model weights sa

***** Running Evaluation *****
  Num examples = 73
  Batch size = 40
***** Running Evaluation *****
  Num examples = 73
  Batch size = 40
Saving model checkpoint to /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-rnn/checkpoint-722
Configuration saved in /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-rnn/checkpoint-722/config.json
Model weights saved in /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-rnn/checkpoint-722/pytorch_model.bin
Deleting older checkpoint [/data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-rnn/checkpoint-342] due to args.save_total_limit
  result = _VF.gru(input, hx, self._flat_weights, self.bias, self.num_layers,
***** Running Evaluation *****
  Num examples = 73
  Batch size = 40
***** Running Evaluation *****
  Num examples = 73
  Batch size = 40
Saving model checkpoint to /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-rnn/checkpoint-760
Configuration saved in /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-rnn/checkpoint-760/config.json
Model weights sa

***** Running Evaluation *****
  Num examples = 73
  Batch size = 40
***** Running Evaluation *****
  Num examples = 73
  Batch size = 40
Saving model checkpoint to /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-rnn/checkpoint-950
Configuration saved in /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-rnn/checkpoint-950/config.json
Model weights saved in /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-rnn/checkpoint-950/pytorch_model.bin
Deleting older checkpoint [/data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-rnn/checkpoint-570] due to args.save_total_limit
  result = _VF.gru(input, hx, self._flat_weights, self.bias, self.num_layers,
***** Running Evaluation *****
  Num examples = 73
  Batch size = 40
***** Running Evaluation *****
  Num examples = 73
  Batch size = 40
Saving model checkpoint to /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-rnn/checkpoint-988
Configuration saved in /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-rnn/checkpoint-988/config.json
Model weights sa

***** Running Evaluation *****
  Num examples = 73
  Batch size = 40
***** Running Evaluation *****
  Num examples = 73
  Batch size = 40
Saving model checkpoint to /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-rnn/checkpoint-1178
Configuration saved in /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-rnn/checkpoint-1178/config.json
Model weights saved in /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-rnn/checkpoint-1178/pytorch_model.bin
Deleting older checkpoint [/data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-rnn/checkpoint-798] due to args.save_total_limit
  result = _VF.gru(input, hx, self._flat_weights, self.bias, self.num_layers,
***** Running Evaluation *****
  Num examples = 73
  Batch size = 40
***** Running Evaluation *****
  Num examples = 73
  Batch size = 40
Saving model checkpoint to /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-rnn/checkpoint-1216
Configuration saved in /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-rnn/checkpoint-1216/config.json
Model weigh

***** Running Evaluation *****
  Num examples = 73
  Batch size = 40
***** Running Evaluation *****
  Num examples = 73
  Batch size = 40
Saving model checkpoint to /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-rnn/checkpoint-1406
Configuration saved in /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-rnn/checkpoint-1406/config.json
Model weights saved in /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-rnn/checkpoint-1406/pytorch_model.bin
Deleting older checkpoint [/data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-rnn/checkpoint-1026] due to args.save_total_limit
  result = _VF.gru(input, hx, self._flat_weights, self.bias, self.num_layers,
***** Running Evaluation *****
  Num examples = 73
  Batch size = 40
***** Running Evaluation *****
  Num examples = 73
  Batch size = 40
Saving model checkpoint to /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-rnn/checkpoint-1444
Configuration saved in /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-rnn/checkpoint-1444/config.json
Model weig

***** Running Evaluation *****
  Num examples = 73
  Batch size = 40
***** Running Evaluation *****
  Num examples = 73
  Batch size = 40
Saving model checkpoint to /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-rnn/checkpoint-1634
Configuration saved in /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-rnn/checkpoint-1634/config.json
Model weights saved in /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-rnn/checkpoint-1634/pytorch_model.bin
Deleting older checkpoint [/data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-rnn/checkpoint-1254] due to args.save_total_limit
  result = _VF.gru(input, hx, self._flat_weights, self.bias, self.num_layers,
***** Running Evaluation *****
  Num examples = 73
  Batch size = 40
***** Running Evaluation *****
  Num examples = 73
  Batch size = 40
Saving model checkpoint to /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-rnn/checkpoint-1672
Configuration saved in /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-rnn/checkpoint-1672/config.json
Model weig

***** Running Evaluation *****
  Num examples = 73
  Batch size = 40
***** Running Evaluation *****
  Num examples = 73
  Batch size = 40
Saving model checkpoint to /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-rnn/checkpoint-1862
Configuration saved in /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-rnn/checkpoint-1862/config.json
Model weights saved in /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-rnn/checkpoint-1862/pytorch_model.bin
Deleting older checkpoint [/data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-rnn/checkpoint-1482] due to args.save_total_limit
***** Running Evaluation *****
  Num examples = 73
  Batch size = 40
***** Running Evaluation *****
  Num examples = 73
  Batch size = 40
Saving model checkpoint to /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-rnn/checkpoint-1900
Configuration saved in /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-rnn/checkpoint-1900/config.json
Model weights saved in /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-rnn/checkpoint-1

Deleting older checkpoint [/data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-rnn/checkpoint-1824] due to args.save_total_limit
***** Running Evaluation *****
  Num examples = 73
  Batch size = 40
***** Running Evaluation *****
  Num examples = 73
  Batch size = 40
Saving model checkpoint to /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-rnn/checkpoint-2242
Configuration saved in /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-rnn/checkpoint-2242/config.json
Model weights saved in /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-rnn/checkpoint-2242/pytorch_model.bin
Deleting older checkpoint [/data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-rnn/checkpoint-1862] due to args.save_total_limit
***** Running Evaluation *****
  Num examples = 73
  Batch size = 40
***** Running Evaluation *****
  Num examples = 73
  Batch size = 40
Saving model checkpoint to /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-rnn/checkpoint-2280
Configuration saved in /data4/yoomcache/checkpoint/wav2vec2gpt/unf

***** Running Evaluation *****
  Num examples = 73
  Batch size = 40
***** Running Evaluation *****
  Num examples = 73
  Batch size = 40
Saving model checkpoint to /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-rnn/checkpoint-2584
Configuration saved in /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-rnn/checkpoint-2584/config.json
Model weights saved in /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-rnn/checkpoint-2584/pytorch_model.bin
Deleting older checkpoint [/data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-rnn/checkpoint-2204] due to args.save_total_limit
***** Running Evaluation *****
  Num examples = 73
  Batch size = 40
***** Running Evaluation *****
  Num examples = 73
  Batch size = 40
Saving model checkpoint to /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-rnn/checkpoint-2622
Configuration saved in /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-rnn/checkpoint-2622/config.json
Model weights saved in /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-rnn/checkpoint-2

Deleting older checkpoint [/data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-rnn/checkpoint-2546] due to args.save_total_limit
***** Running Evaluation *****
  Num examples = 73
  Batch size = 40
***** Running Evaluation *****
  Num examples = 73
  Batch size = 40
Saving model checkpoint to /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-rnn/checkpoint-2964
Configuration saved in /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-rnn/checkpoint-2964/config.json
Model weights saved in /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-rnn/checkpoint-2964/pytorch_model.bin
Deleting older checkpoint [/data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-rnn/checkpoint-2584] due to args.save_total_limit
***** Running Evaluation *****
  Num examples = 73
  Batch size = 40
***** Running Evaluation *****
  Num examples = 73
  Batch size = 40
Saving model checkpoint to /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-rnn/checkpoint-3002
Configuration saved in /data4/yoomcache/checkpoint/wav2vec2gpt/unf

***** Running Evaluation *****
  Num examples = 73
  Batch size = 40
***** Running Evaluation *****
  Num examples = 73
  Batch size = 40
Saving model checkpoint to /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-rnn/checkpoint-3306
Configuration saved in /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-rnn/checkpoint-3306/config.json
Model weights saved in /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-rnn/checkpoint-3306/pytorch_model.bin
Deleting older checkpoint [/data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-rnn/checkpoint-2926] due to args.save_total_limit
***** Running Evaluation *****
  Num examples = 73
  Batch size = 40
***** Running Evaluation *****
  Num examples = 73
  Batch size = 40
Saving model checkpoint to /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-rnn/checkpoint-3344
Configuration saved in /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-rnn/checkpoint-3344/config.json
Model weights saved in /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-rnn/checkpoint-3

Deleting older checkpoint [/data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-rnn/checkpoint-3268] due to args.save_total_limit
***** Running Evaluation *****
  Num examples = 73
  Batch size = 40
***** Running Evaluation *****
  Num examples = 73
  Batch size = 40
Saving model checkpoint to /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-rnn/checkpoint-3686
Configuration saved in /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-rnn/checkpoint-3686/config.json
Model weights saved in /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-rnn/checkpoint-3686/pytorch_model.bin
Deleting older checkpoint [/data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-rnn/checkpoint-3306] due to args.save_total_limit
***** Running Evaluation *****
  Num examples = 73
  Batch size = 40
***** Running Evaluation *****
  Num examples = 73
  Batch size = 40
Saving model checkpoint to /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-rnn/checkpoint-3724
Configuration saved in /data4/yoomcache/checkpoint/wav2vec2gpt/unf

***** Running Evaluation *****
  Num examples = 73
  Batch size = 40
***** Running Evaluation *****
  Num examples = 73
  Batch size = 40
Saving model checkpoint to /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-rnn/checkpoint-4028
Configuration saved in /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-rnn/checkpoint-4028/config.json
Model weights saved in /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-rnn/checkpoint-4028/pytorch_model.bin
Deleting older checkpoint [/data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-rnn/checkpoint-3648] due to args.save_total_limit
***** Running Evaluation *****
  Num examples = 73
  Batch size = 40
***** Running Evaluation *****
  Num examples = 73
  Batch size = 40
Saving model checkpoint to /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-rnn/checkpoint-4066
Configuration saved in /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-rnn/checkpoint-4066/config.json
Model weights saved in /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-rnn/checkpoint-4

Deleting older checkpoint [/data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-rnn/checkpoint-3990] due to args.save_total_limit
***** Running Evaluation *****
  Num examples = 73
  Batch size = 40
***** Running Evaluation *****
  Num examples = 73
  Batch size = 40
Saving model checkpoint to /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-rnn/checkpoint-4408
Configuration saved in /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-rnn/checkpoint-4408/config.json
Model weights saved in /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-rnn/checkpoint-4408/pytorch_model.bin
Deleting older checkpoint [/data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-rnn/checkpoint-4028] due to args.save_total_limit
***** Running Evaluation *****
  Num examples = 73
  Batch size = 40
***** Running Evaluation *****
  Num examples = 73
  Batch size = 40
Saving model checkpoint to /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-rnn/checkpoint-4446
Configuration saved in /data4/yoomcache/checkpoint/wav2vec2gpt/unf

***** Running Evaluation *****
  Num examples = 73
  Batch size = 40
***** Running Evaluation *****
  Num examples = 73
  Batch size = 40
Saving model checkpoint to /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-rnn/checkpoint-4750
Configuration saved in /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-rnn/checkpoint-4750/config.json
Model weights saved in /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-rnn/checkpoint-4750/pytorch_model.bin
Deleting older checkpoint [/data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-rnn/checkpoint-4370] due to args.save_total_limit
***** Running Evaluation *****
  Num examples = 73
  Batch size = 40
***** Running Evaluation *****
  Num examples = 73
  Batch size = 40
Saving model checkpoint to /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-rnn/checkpoint-4788
Configuration saved in /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-rnn/checkpoint-4788/config.json
Model weights saved in /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-rnn/checkpoint-4

Deleting older checkpoint [/data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-rnn/checkpoint-4712] due to args.save_total_limit
***** Running Evaluation *****
  Num examples = 73
  Batch size = 40
***** Running Evaluation *****
  Num examples = 73
  Batch size = 40
Saving model checkpoint to /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-rnn/checkpoint-5130
Configuration saved in /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-rnn/checkpoint-5130/config.json
Model weights saved in /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-rnn/checkpoint-5130/pytorch_model.bin
Deleting older checkpoint [/data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-rnn/checkpoint-4750] due to args.save_total_limit
***** Running Evaluation *****
  Num examples = 73
  Batch size = 40
***** Running Evaluation *****
  Num examples = 73
  Batch size = 40
Saving model checkpoint to /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-rnn/checkpoint-5168
Configuration saved in /data4/yoomcache/checkpoint/wav2vec2gpt/unf

***** Running Evaluation *****
  Num examples = 73
  Batch size = 40
***** Running Evaluation *****
  Num examples = 73
  Batch size = 40
Saving model checkpoint to /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-rnn/checkpoint-5472
Configuration saved in /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-rnn/checkpoint-5472/config.json
Model weights saved in /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-rnn/checkpoint-5472/pytorch_model.bin
Deleting older checkpoint [/data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-rnn/checkpoint-5092] due to args.save_total_limit
***** Running Evaluation *****
  Num examples = 73
  Batch size = 40
***** Running Evaluation *****
  Num examples = 73
  Batch size = 40
Saving model checkpoint to /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-rnn/checkpoint-5510
Configuration saved in /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-rnn/checkpoint-5510/config.json
Model weights saved in /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-rnn/checkpoint-5

Deleting older checkpoint [/data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-rnn/checkpoint-5434] due to args.save_total_limit
***** Running Evaluation *****
  Num examples = 73
  Batch size = 40
***** Running Evaluation *****
  Num examples = 73
  Batch size = 40
Saving model checkpoint to /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-rnn/checkpoint-5852
Configuration saved in /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-rnn/checkpoint-5852/config.json
Model weights saved in /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-rnn/checkpoint-5852/pytorch_model.bin
Deleting older checkpoint [/data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-rnn/checkpoint-5472] due to args.save_total_limit
***** Running Evaluation *****
  Num examples = 73
  Batch size = 40
***** Running Evaluation *****
  Num examples = 73
  Batch size = 40
Saving model checkpoint to /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-rnn/checkpoint-5890
Configuration saved in /data4/yoomcache/checkpoint/wav2vec2gpt/unf

***** Running Evaluation *****
  Num examples = 73
  Batch size = 40
***** Running Evaluation *****
  Num examples = 73
  Batch size = 40
Saving model checkpoint to /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-rnn/checkpoint-6194
Configuration saved in /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-rnn/checkpoint-6194/config.json
Model weights saved in /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-rnn/checkpoint-6194/pytorch_model.bin
Deleting older checkpoint [/data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-rnn/checkpoint-5814] due to args.save_total_limit
***** Running Evaluation *****
  Num examples = 73
  Batch size = 40
***** Running Evaluation *****
  Num examples = 73
  Batch size = 40
Saving model checkpoint to /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-rnn/checkpoint-6232
Configuration saved in /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-rnn/checkpoint-6232/config.json
Model weights saved in /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-rnn/checkpoint-6

Deleting older checkpoint [/data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-rnn/checkpoint-6156] due to args.save_total_limit
***** Running Evaluation *****
  Num examples = 73
  Batch size = 40
***** Running Evaluation *****
  Num examples = 73
  Batch size = 40
Saving model checkpoint to /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-rnn/checkpoint-6574
Configuration saved in /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-rnn/checkpoint-6574/config.json
Model weights saved in /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-rnn/checkpoint-6574/pytorch_model.bin
Deleting older checkpoint [/data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-rnn/checkpoint-6194] due to args.save_total_limit
***** Running Evaluation *****
  Num examples = 73
  Batch size = 40
***** Running Evaluation *****
  Num examples = 73
  Batch size = 40
Saving model checkpoint to /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-rnn/checkpoint-6612
Configuration saved in /data4/yoomcache/checkpoint/wav2vec2gpt/unf

***** Running Evaluation *****
  Num examples = 73
  Batch size = 40
***** Running Evaluation *****
  Num examples = 73
  Batch size = 40
Saving model checkpoint to /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-rnn/checkpoint-6916
Configuration saved in /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-rnn/checkpoint-6916/config.json
Model weights saved in /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-rnn/checkpoint-6916/pytorch_model.bin
Deleting older checkpoint [/data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-rnn/checkpoint-6536] due to args.save_total_limit
***** Running Evaluation *****
  Num examples = 73
  Batch size = 40
***** Running Evaluation *****
  Num examples = 73
  Batch size = 40
Saving model checkpoint to /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-rnn/checkpoint-6954
Configuration saved in /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-rnn/checkpoint-6954/config.json
Model weights saved in /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-rnn/checkpoint-6

Deleting older checkpoint [/data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-rnn/checkpoint-6878] due to args.save_total_limit
***** Running Evaluation *****
  Num examples = 73
  Batch size = 40
***** Running Evaluation *****
  Num examples = 73
  Batch size = 40
Saving model checkpoint to /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-rnn/checkpoint-7296
Configuration saved in /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-rnn/checkpoint-7296/config.json
Model weights saved in /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-rnn/checkpoint-7296/pytorch_model.bin
Deleting older checkpoint [/data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-rnn/checkpoint-6916] due to args.save_total_limit
***** Running Evaluation *****
  Num examples = 73
  Batch size = 40
***** Running Evaluation *****
  Num examples = 73
  Batch size = 40
Saving model checkpoint to /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-rnn/checkpoint-7334
Configuration saved in /data4/yoomcache/checkpoint/wav2vec2gpt/unf

***** Running Evaluation *****
  Num examples = 73
  Batch size = 40
***** Running Evaluation *****
  Num examples = 73
  Batch size = 40
Saving model checkpoint to /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-rnn/checkpoint-7638
Configuration saved in /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-rnn/checkpoint-7638/config.json
Model weights saved in /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-rnn/checkpoint-7638/pytorch_model.bin
Deleting older checkpoint [/data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-rnn/checkpoint-7258] due to args.save_total_limit
***** Running Evaluation *****
  Num examples = 73
  Batch size = 40
***** Running Evaluation *****
  Num examples = 73
  Batch size = 40
Saving model checkpoint to /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-rnn/checkpoint-7676
Configuration saved in /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-rnn/checkpoint-7676/config.json
Model weights saved in /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-rnn/checkpoint-7

Deleting older checkpoint [/data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-rnn/checkpoint-7600] due to args.save_total_limit
***** Running Evaluation *****
  Num examples = 73
  Batch size = 40
***** Running Evaluation *****
  Num examples = 73
  Batch size = 40
Saving model checkpoint to /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-rnn/checkpoint-8018
Configuration saved in /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-rnn/checkpoint-8018/config.json
Model weights saved in /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-rnn/checkpoint-8018/pytorch_model.bin
Deleting older checkpoint [/data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-rnn/checkpoint-7638] due to args.save_total_limit
***** Running Evaluation *****
  Num examples = 73
  Batch size = 40
***** Running Evaluation *****
  Num examples = 73
  Batch size = 40
Saving model checkpoint to /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-rnn/checkpoint-8056
Configuration saved in /data4/yoomcache/checkpoint/wav2vec2gpt/unf

***** Running Evaluation *****
  Num examples = 73
  Batch size = 40
***** Running Evaluation *****
  Num examples = 73
  Batch size = 40
Saving model checkpoint to /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-rnn/checkpoint-8360
Configuration saved in /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-rnn/checkpoint-8360/config.json
Model weights saved in /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-rnn/checkpoint-8360/pytorch_model.bin
Deleting older checkpoint [/data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-rnn/checkpoint-7980] due to args.save_total_limit
***** Running Evaluation *****
  Num examples = 73
  Batch size = 40
***** Running Evaluation *****
  Num examples = 73
  Batch size = 40
Saving model checkpoint to /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-rnn/checkpoint-8398
Configuration saved in /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-rnn/checkpoint-8398/config.json
Model weights saved in /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-rnn/checkpoint-8

Deleting older checkpoint [/data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-rnn/checkpoint-8322] due to args.save_total_limit
***** Running Evaluation *****
  Num examples = 73
  Batch size = 40
***** Running Evaluation *****
  Num examples = 73
  Batch size = 40
Saving model checkpoint to /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-rnn/checkpoint-8740
Configuration saved in /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-rnn/checkpoint-8740/config.json
Model weights saved in /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-rnn/checkpoint-8740/pytorch_model.bin
Deleting older checkpoint [/data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-rnn/checkpoint-8360] due to args.save_total_limit
***** Running Evaluation *****
  Num examples = 73
  Batch size = 40
***** Running Evaluation *****
  Num examples = 73
  Batch size = 40
Saving model checkpoint to /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-rnn/checkpoint-8778
Configuration saved in /data4/yoomcache/checkpoint/wav2vec2gpt/unf

***** Running Evaluation *****
  Num examples = 73
  Batch size = 40
***** Running Evaluation *****
  Num examples = 73
  Batch size = 40
Saving model checkpoint to /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-rnn/checkpoint-9082
Configuration saved in /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-rnn/checkpoint-9082/config.json
Model weights saved in /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-rnn/checkpoint-9082/pytorch_model.bin
Deleting older checkpoint [/data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-rnn/checkpoint-8702] due to args.save_total_limit
***** Running Evaluation *****
  Num examples = 73
  Batch size = 40
***** Running Evaluation *****
  Num examples = 73
  Batch size = 40
Saving model checkpoint to /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-rnn/checkpoint-9120
Configuration saved in /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-rnn/checkpoint-9120/config.json
Model weights saved in /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-rnn/checkpoint-9

Deleting older checkpoint [/data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-rnn/checkpoint-9044] due to args.save_total_limit
***** Running Evaluation *****
  Num examples = 73
  Batch size = 40
***** Running Evaluation *****
  Num examples = 73
  Batch size = 40
Saving model checkpoint to /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-rnn/checkpoint-9462
Configuration saved in /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-rnn/checkpoint-9462/config.json
Model weights saved in /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-rnn/checkpoint-9462/pytorch_model.bin
Deleting older checkpoint [/data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-rnn/checkpoint-9082] due to args.save_total_limit
***** Running Evaluation *****
  Num examples = 73
  Batch size = 40
***** Running Evaluation *****
  Num examples = 73
  Batch size = 40
Saving model checkpoint to /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-rnn/checkpoint-9500
Configuration saved in /data4/yoomcache/checkpoint/wav2vec2gpt/unf

***** Running Evaluation *****
  Num examples = 73
  Batch size = 40
***** Running Evaluation *****
  Num examples = 73
  Batch size = 40
Saving model checkpoint to /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-rnn/checkpoint-9804
Configuration saved in /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-rnn/checkpoint-9804/config.json
Model weights saved in /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-rnn/checkpoint-9804/pytorch_model.bin
Deleting older checkpoint [/data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-rnn/checkpoint-9424] due to args.save_total_limit
***** Running Evaluation *****
  Num examples = 73
  Batch size = 40
***** Running Evaluation *****
  Num examples = 73
  Batch size = 40
Saving model checkpoint to /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-rnn/checkpoint-9842
Configuration saved in /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-rnn/checkpoint-9842/config.json
Model weights saved in /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-rnn/checkpoint-9

TrainOutput(global_step=10000, training_loss=1.1349248749673366, metrics={'train_runtime': 11175.5503, 'train_samples_per_second': 6.532, 'train_steps_per_second': 0.895, 'total_flos': 2.357870970107523e+19, 'train_loss': 1.1349248749673366, 'epoch': 1000.0})

In [12]:
wandb.finish()




VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
eval/loss,█▃▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
eval/runtime,▃▂▃▂▆▄▁▃▅▅▅▅▅▅▅▅█▄▅▃▃▅▂▂▂▃▂▅▄▅▅▅▅▄▄▅▃▄▅▂
eval/samples_per_second,▆▇▆▇▃▅█▆▃▄▄▄▄▄▃▄▁▅▃▆▅▄▆▆▆▆▆▃▄▄▄▃▄▄▄▃▅▄▃▆
eval/steps_per_second,▆▇▆▇▃▅█▆▃▄▄▄▄▄▃▄▁▅▃▆▅▄▆▆▆▅▆▃▄▄▄▃▄▄▄▃▅▄▃▆
train/epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███
train/global_step,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███
train/learning_rate,▂▂▃▄▅▆▇██████▇▇▇▇▇▆▆▆▅▅▅▄▄▄▃▃▃▂▂▂▂▂▁▁▁▁▁
train/loss,█▄▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train/total_flos,▁
train/train_loss,▁

0,1
eval/loss,0.34623
eval/runtime,3.8404
eval/samples_per_second,19.009
eval/steps_per_second,0.521
train/epoch,1000.0
train/global_step,10000.0
train/learning_rate,0.0
train/loss,0.0797
train/total_flos,2.357870970107523e+19
train/train_loss,1.13492


In [13]:
##### example


BATCH_SIZE = 8
i = 3
device = 'cuda:0'


audio_batch = audio_inputs[i*BATCH_SIZE:i*BATCH_SIZE+BATCH_SIZE]
audio_feature_batch = feature_extractor(audio_batch, 
                                      sampling_rate=sampling_rate,
                                      return_tensors="pt",
                                      padding='longest',
                                     ).input_values
print(audio_feature_batch.size())


text_batch = text_inputs[i*BATCH_SIZE:i*BATCH_SIZE+BATCH_SIZE]

text_tokens_batch = tokenizer(text_batch, 
                              return_tensors="pt",
                              padding='max_length',
                              max_length=train_dataset.tokenized_output['input_ids'].shape[1]
                             )
print(text_tokens_batch['attention_mask'].size())

with torch.no_grad():
    audio_embedding = model(input_values=audio_feature_batch.to(device), 
                            labels=text_tokens_batch['input_ids'].to(device),
                            output_attention_mask=text_tokens_batch['attention_mask'].to(device),)
print(audio_embedding.logits.shape)

pred_ids = torch.argmax(audio_embedding.logits, axis=-1)
print(pred_ids.size())
print()

for idx in range(BATCH_SIZE):
    print(text_batch[idx])
    print(tokenizer.decode([key for key, _group in groupby(pred_ids[idx])]))
    print()

torch.Size([8, 143920])
torch.Size([8, 512])
torch.Size([8, 512, 50257])
torch.Size([8, 512])

"He doesn't work at all."
"He doesn't doesn't<|endoftext|> work all."<|endoftext|>

In fact, there is nothing he can do in these dominions as well as our nomes, whose numbers are so great that it worries us to keep them all busy.
In fact<|endoftext|>, is there is there he<|endoftext|> can do<|endoftext|> in these domin these domin theseions these as well as well as well as, whose,<|endoftext|> numbers so are<|endoftext|> it<|endoftext|> it great<|endoftext|> it<|endoftext|> great that great that it that great that great

"Not exactly," returned Kaliko.
"Not exactly<|endoftext|>," returned Kal returned Kal returned Kal returned Kaliko<|endoftext|>

"Where is my brother now?"
"Where<|endoftext|> is<|endoftext|> my brother my brother my brother now?"<|endoftext|>

inquired Shaggy. "In the Metal Forest."
inquiredinquiredinquired<|endoftext|>ag Shag<|endoftext|>gy.<|endoftext|> "In " the Metal the

In [14]:
# import IPython

# IPython.display.Audio(dataset[4]['audio']['path'])

In [15]:
model.wav2vec2

Wav2Vec2Model2(
  (feature_extractor): Wav2Vec2FeatureEncoder(
    (conv_layers): ModuleList(
      (0): Wav2Vec2GroupNormConvLayer(
        (conv): Conv1d(1, 512, kernel_size=(10,), stride=(5,), bias=False)
        (activation): GELUActivation()
        (layer_norm): GroupNorm(512, 512, eps=1e-05, affine=True)
      )
      (1): Wav2Vec2NoLayerNormConvLayer(
        (conv): Conv1d(512, 512, kernel_size=(3,), stride=(2,), bias=False)
        (activation): GELUActivation()
      )
      (2): Wav2Vec2NoLayerNormConvLayer(
        (conv): Conv1d(512, 512, kernel_size=(3,), stride=(2,), bias=False)
        (activation): GELUActivation()
      )
      (3): Wav2Vec2NoLayerNormConvLayer(
        (conv): Conv1d(512, 512, kernel_size=(3,), stride=(2,), bias=False)
        (activation): GELUActivation()
      )
      (4): Wav2Vec2NoLayerNormConvLayer(
        (conv): Conv1d(512, 512, kernel_size=(3,), stride=(2,), bias=False)
        (activation): GELUActivation()
      )
      (5): Wav2Vec2NoLa