In [1]:
import os
import pickle
import random
from tqdm import tqdm
import numpy as np
import torch

from datasets import load_dataset, load_metric
import math
from itertools import groupby

import wandb

os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "1,0"
# os.environ["WANDB_DISABLED"] = "true"

cache_dir = "/data4/yoomcache"
model_cache_dir = os.path.join(cache_dir, 'huggingface')
data_cache_dir = os.path.join(cache_dir, 'datasets')
checkpoint_dir = os.path.join(cache_dir, 'checkpoint')

seed = 0
random.seed(0)
np.random.seed(seed)
torch.manual_seed(seed)

wandb.init(project="testing-wav2vec2gpt", entity="yoom-private")

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33myoom-private[0m (use `wandb login --relogin` to force relogin)


In [2]:
# %reload_ext autoreload
# %autoreload 2
from wav2vec2GPTwCTC import *
from configuration_wav2vec2gpt import Wav2Vec2GPTConfig

from transformers import Wav2Vec2FeatureExtractor, GPT2Model
from transformers import GPT2Tokenizer, AddedToken
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments

In [3]:
wav2vec_pretrained = "facebook/wav2vec2-base"
gpt_pretrained = "gpt2"

# Should aware that pad_token_id is used to compute CTC loss, 
# so pad_token configuration for both tokenizer and model should be the same
args = {
    'pad_token': 'Ġ', 'pad_token_id': 220,
    'unk_token': 'Ġ', 'unk_token_id': 220,
#     'pad_token': "<|endoftext|>", 'pad_token_id': 50256,
#     'unk_token': "<|endoftext|>", 'unk_token_id': 50256,
    'bos_token': "<|endoftext|>", 'bos_token_id': 50256,
    'eos_token': "<|endoftext|>", 'eos_token_id': 50256,
    
    'n_positions': 64, # VCTK: 42, 
    
    'add_adapter': True,
    'output_hidden_size': 128,
    'num_adapter_layers': 3,
    'adapter_kernel_size': [4, 4, 4, 0], 
    'adapter_stride':      [2, 2, 1, 1],
    'adapter_padding':     [2, 2, 0, 0],
    'adapter_bias': False
    
    
}


config = Wav2Vec2GPTConfig(**args)

In [4]:
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(wav2vec_pretrained, 
                                                             cache_dir=model_cache_dir,
                                                             **args)

tokenizer = GPT2Tokenizer.from_pretrained(gpt_pretrained,
                                          cache_dir=model_cache_dir,
                                          **args)

In [5]:
with open('/data4/TTS/VCTK-Corpus/dataset-vctk-16k.pkl', 'rb') as f:
    dataset = pickle.load(f)
del dataset['page'], dataset['index']
# del dataset['audio_path']



# dataset_size = len(dataset['text'])
dataset_size = int(len(dataset['text']) * 0.02)

max_audio_length = 0
for arr in dataset['audio_array']:
    if len(arr) > max_audio_length:
        max_audio_length = len(arr)
print(max_audio_length)


for idx in tqdm(range(dataset_size)):
    dataset['audio_array'][idx] = feature_extractor(dataset['audio_array'][idx], 
                                                    sampling_rate=dataset['sample_rate'],
                                                    return_tensors="pt",
                                                    padding='max_length',
                                                    max_length=max_audio_length
                                                    ).input_values[0]
del dataset['audio_array'][dataset_size:]
print(len(dataset['audio_array']))

308533


100%|████████████████████████████████████████| 881/881 [00:01<00:00, 492.63it/s]

881





In [6]:
dataset['text'] = tokenizer(dataset['text'][:dataset_size],
                            return_tensors="pt",
                            # padding='longest', # VCTK: 42,
                            padding='max_length',
                            max_length=args['n_positions']
                            )
print(dataset['text']['attention_mask'].shape)

torch.Size([881, 64])


In [7]:
split_ratio = (0.8, 0.9)
indices = np.arange(dataset_size)
np.random.shuffle(indices)

train_idx = indices[:int(dataset_size * split_ratio[0])]
val_idx = indices[int(dataset_size * split_ratio[0]):int(dataset_size * split_ratio[1])]
test_idx = indices[int(dataset_size * split_ratio[1]):]

In [8]:
class CustomDataset(torch.utils.data.Dataset):
    def __init__(self, input_values, tokenized_output, indices):
        self.input_values = input_values
        self.tokenized_output = tokenized_output
        self.indices = indices

    def __getitem__(self, idx):
        item = dict()
        item['input_values'] = self.input_values[self.indices[idx]]
        item['labels'] = self.tokenized_output['input_ids'][self.indices[idx]]
        item['output_attention_mask'] = self.tokenized_output['attention_mask'][self.indices[idx]]
        return item

    def __len__(self):
        return len(self.indices)

    

train_dataset = CustomDataset(dataset['audio_array'], dataset['text'], train_idx)
val_dataset = CustomDataset(dataset['audio_array'], dataset['text'], val_idx)
test_dataset = CustomDataset(dataset['audio_array'], dataset['text'], test_idx)

In [9]:
model = Wav2Vec2GPTModel(config=config)

model.wav2vec2.from_pretrained(wav2vec_pretrained, cache_dir=model_cache_dir)
# model_gpt = GPT2Model.from_pretrained(gpt_pretrained, cache_dir=model_cache_dir)
# model.wte.weight = model_gpt.wte.weight # wte only
# del model_gpt
model.transformer.from_pretrained(gpt_pretrained, cache_dir=model_cache_dir)

print(model)


# device_map = {
#     0: [0, 1, 2, 3, 4,],
#     1: [5, 6, 7, 8, 9, 10, 11, ],
# }
# model.parallelize(device_map)

Some weights of the model checkpoint at facebook/wav2vec2-base were not used when initializing Wav2Vec2Model2: ['wav2vec2.encoder.layers.6.final_layer_norm.weight', 'wav2vec2.encoder.layers.4.feed_forward.intermediate_dense.weight', 'wav2vec2.encoder.layers.9.attention.v_proj.bias', 'wav2vec2.encoder.layers.8.attention.v_proj.weight', 'wav2vec2.encoder.layers.6.feed_forward.intermediate_dense.weight', 'wav2vec2.encoder.layers.6.final_layer_norm.bias', 'wav2vec2.encoder.layers.1.final_layer_norm.bias', 'wav2vec2.encoder.layers.7.attention.k_proj.bias', 'wav2vec2.encoder.layers.2.feed_forward.intermediate_dense.weight', 'wav2vec2.encoder.layers.1.layer_norm.bias', 'wav2vec2.encoder.layers.1.attention.q_proj.bias', 'wav2vec2.encoder.layers.11.attention.k_proj.weight', 'wav2vec2.encoder.layers.1.feed_forward.intermediate_dense.bias', 'wav2vec2.encoder.layers.3.layer_norm.bias', 'wav2vec2.encoder.layers.4.attention.k_proj.weight', 'wav2vec2.encoder.layers.5.feed_forward.intermediate_dense.b

Wav2Vec2GPTModel(
  (wav2vec2): Wav2Vec2Model2(
    (feature_extractor): Wav2Vec2FeatureEncoder(
      (conv_layers): ModuleList(
        (0): Wav2Vec2GroupNormConvLayer(
          (conv): Conv1d(1, 512, kernel_size=(10,), stride=(5,), bias=False)
          (activation): GELUActivation()
          (layer_norm): GroupNorm(512, 512, eps=1e-05, affine=True)
        )
        (1): Wav2Vec2NoLayerNormConvLayer(
          (conv): Conv1d(512, 512, kernel_size=(3,), stride=(2,), bias=False)
          (activation): GELUActivation()
        )
        (2): Wav2Vec2NoLayerNormConvLayer(
          (conv): Conv1d(512, 512, kernel_size=(3,), stride=(2,), bias=False)
          (activation): GELUActivation()
        )
        (3): Wav2Vec2NoLayerNormConvLayer(
          (conv): Conv1d(512, 512, kernel_size=(3,), stride=(2,), bias=False)
          (activation): GELUActivation()
        )
        (4): Wav2Vec2NoLayerNormConvLayer(
          (conv): Conv1d(512, 512, kernel_size=(3,), stride=(2,), bias=Fal

In [10]:
model.freeze_feature_extractor()
# model.freeze_feature_projection() # not exists here
# model.freeze_wav2vec_encoder() # not exists here
# model.freeze_wav2vec_adapter()
# model.freeze_rnn_compressor()
model.freeze_gpt_decoder()
model.freeze_lm_head()

In [11]:
# load rouge for validation
rouge = load_metric("rouge")

def compute_metrics(pred):
    labels_ids = pred.label_ids
    pred_ids = pred.predictions.argmax(axis=-1)
    del pred

    # all unnecessary tokens are removed
    pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
    label_str = tokenizer.batch_decode(labels_ids, skip_special_tokens=True)

    rouge_output = rouge.compute(predictions=pred_str, references=label_str, rouge_types=["rouge2"])["rouge2"].mid
    
    return {
        "rouge2_precision": round(rouge_output.precision, 4),
        "rouge2_recall": round(rouge_output.recall, 4),
        "rouge2_fmeasure": round(rouge_output.fmeasure, 4),
    }

In [12]:
batch_size = 24
steps_per_epoch = math.ceil(len(train_dataset) / batch_size / 2)


# set training arguments - these params are not really tuned, feel free to change
training_args = Seq2SeqTrainingArguments(
#     predict_with_generate=True,
    output_dir=os.path.join(checkpoint_dir, "wav2vec2gpt/unfreeze-adapter-rnn"),
    # do_train=True,
    # do_eval=False,
    # do_predict=True,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    learning_rate=1e-4, 
    weight_decay=0.0, adam_beta1=0.9, adam_beta2=0.999, adam_epsilon=1e-08, max_grad_norm=1.0,
    num_train_epochs=100,
    max_steps=-1,
    # lr_scheduler_type="cosine", 
    # warmup_ratio=0.0, 
    
    logging_strategy='steps',
    save_strategy='steps',
    evaluation_strategy='steps',
    logging_steps=int(steps_per_epoch / 2),
    save_steps=int(steps_per_epoch * 1),
    eval_steps=int(steps_per_epoch / 2),
    warmup_steps=int(steps_per_epoch * 10),
    save_total_limit=10,
    overwrite_output_dir=True,
)

In [None]:
# instantiate trainer
trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    compute_metrics=compute_metrics,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
)


# start training
for i in range(3):
    trainer.train()

***** Running training *****
  Num examples = 704
  Num Epochs = 100
  Instantaneous batch size per device = 24
  Total train batch size (w. parallel, distributed & accumulation) = 48
  Gradient Accumulation steps = 1
  Total optimization steps = 1500
Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"
  result = _VF.rnn_tanh(input, hx, self._flat_weights, self.bias, self.num_layers,


Step,Training Loss,Validation Loss,Rouge2 Precision,Rouge2 Recall,Rouge2 Fmeasure
7,14.3817,13.863242,0.0,0.0,0.0
14,14.077,13.231899,0.0,0.0,0.0
21,13.2288,12.196973,0.0,0.0,0.0
28,12.4153,10.902843,0.0,0.0,0.0
35,11.0729,9.568058,0.0,0.0,0.0
42,10.0146,8.328327,0.0,0.0,0.0
49,8.7654,7.159576,0.0,0.0,0.0
56,7.8158,6.083826,0.0,0.0,0.0
63,6.767,5.088922,0.0,0.0,0.0
70,5.9071,4.180136,0.0,0.0,0.0


***** Running Evaluation *****
  Num examples = 88
  Batch size = 48
***** Running Evaluation *****
  Num examples = 88
  Batch size = 48
Saving model checkpoint to /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-adapter-rnn/checkpoint-15
Configuration saved in /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-adapter-rnn/checkpoint-15/config.json
Model weights saved in /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-adapter-rnn/checkpoint-15/pytorch_model.bin
Deleting older checkpoint [/data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-adapter-rnn/checkpoint-330] due to args.save_total_limit
  result = _VF.rnn_tanh(input, hx, self._flat_weights, self.bias, self.num_layers,
***** Running Evaluation *****
  Num examples = 88
  Batch size = 48
***** Running Evaluation *****
  Num examples = 88
  Batch size = 48
Saving model checkpoint to /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-adapter-rnn/checkpoint-30
Configuration saved in /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-adapte

Deleting older checkpoint [/data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-adapter-rnn/checkpoint-450] due to args.save_total_limit
***** Running Evaluation *****
  Num examples = 88
  Batch size = 48
***** Running Evaluation *****
  Num examples = 88
  Batch size = 48
Saving model checkpoint to /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-adapter-rnn/checkpoint-150
Configuration saved in /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-adapter-rnn/checkpoint-150/config.json
Model weights saved in /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-adapter-rnn/checkpoint-150/pytorch_model.bin
Deleting older checkpoint [/data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-adapter-rnn/checkpoint-465] due to args.save_total_limit
***** Running Evaluation *****
  Num examples = 88
  Batch size = 48
***** Running Evaluation *****
  Num examples = 88
  Batch size = 48
Saving model checkpoint to /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-adapter-rnn/checkpoint-165
Configuration saved in /

***** Running Evaluation *****
  Num examples = 88
  Batch size = 48
***** Running Evaluation *****
  Num examples = 88
  Batch size = 48
Saving model checkpoint to /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-adapter-rnn/checkpoint-285
Configuration saved in /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-adapter-rnn/checkpoint-285/config.json
Model weights saved in /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-adapter-rnn/checkpoint-285/pytorch_model.bin
Deleting older checkpoint [/data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-adapter-rnn/checkpoint-135] due to args.save_total_limit
***** Running Evaluation *****
  Num examples = 88
  Batch size = 48
***** Running Evaluation *****
  Num examples = 88
  Batch size = 48
Saving model checkpoint to /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-adapter-rnn/checkpoint-300
Configuration saved in /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-adapter-rnn/checkpoint-300/config.json
Model weights saved in /data4/yoomcache/checkp

***** Running Evaluation *****
  Num examples = 88
  Batch size = 48
***** Running Evaluation *****
  Num examples = 88
  Batch size = 48
***** Running Evaluation *****
  Num examples = 88
  Batch size = 48
Saving model checkpoint to /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-adapter-rnn/checkpoint-420
Configuration saved in /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-adapter-rnn/checkpoint-420/config.json
Model weights saved in /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-adapter-rnn/checkpoint-420/pytorch_model.bin
Deleting older checkpoint [/data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-adapter-rnn/checkpoint-270] due to args.save_total_limit
***** Running Evaluation *****
  Num examples = 88
  Batch size = 48
***** Running Evaluation *****
  Num examples = 88
  Batch size = 48
Saving model checkpoint to /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-adapter-rnn/checkpoint-435
Configuration saved in /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-adapter-rnn/chec

***** Running Evaluation *****
  Num examples = 88
  Batch size = 48
***** Running Evaluation *****
  Num examples = 88
  Batch size = 48
Saving model checkpoint to /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-adapter-rnn/checkpoint-555
Configuration saved in /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-adapter-rnn/checkpoint-555/config.json
Model weights saved in /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-adapter-rnn/checkpoint-555/pytorch_model.bin
Deleting older checkpoint [/data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-adapter-rnn/checkpoint-405] due to args.save_total_limit
***** Running Evaluation *****
  Num examples = 88
  Batch size = 48
***** Running Evaluation *****
  Num examples = 88
  Batch size = 48
Saving model checkpoint to /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-adapter-rnn/checkpoint-570
Configuration saved in /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-adapter-rnn/checkpoint-570/config.json
Model weights saved in /data4/yoomcache/checkp

  Batch size = 48
***** Running Evaluation *****
  Num examples = 88
  Batch size = 48
Saving model checkpoint to /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-adapter-rnn/checkpoint-690
Configuration saved in /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-adapter-rnn/checkpoint-690/config.json
Model weights saved in /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-adapter-rnn/checkpoint-690/pytorch_model.bin
Deleting older checkpoint [/data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-adapter-rnn/checkpoint-540] due to args.save_total_limit
***** Running Evaluation *****
  Num examples = 88
  Batch size = 48
***** Running Evaluation *****
  Num examples = 88
  Batch size = 48
Saving model checkpoint to /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-adapter-rnn/checkpoint-705
Configuration saved in /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-adapter-rnn/checkpoint-705/config.json
Model weights saved in /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-adapter-rnn/checkpoint-70

  Batch size = 48
Saving model checkpoint to /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-adapter-rnn/checkpoint-825
Configuration saved in /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-adapter-rnn/checkpoint-825/config.json
Model weights saved in /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-adapter-rnn/checkpoint-825/pytorch_model.bin
Deleting older checkpoint [/data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-adapter-rnn/checkpoint-675] due to args.save_total_limit
***** Running Evaluation *****
  Num examples = 88
  Batch size = 48
***** Running Evaluation *****
  Num examples = 88
  Batch size = 48
***** Running Evaluation *****
  Num examples = 88
  Batch size = 48
Saving model checkpoint to /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-adapter-rnn/checkpoint-840
Configuration saved in /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-adapter-rnn/checkpoint-840/config.json
Model weights saved in /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-adapter-rnn/checkpoint-84

  Batch size = 48
Saving model checkpoint to /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-adapter-rnn/checkpoint-960
Configuration saved in /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-adapter-rnn/checkpoint-960/config.json
Model weights saved in /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-adapter-rnn/checkpoint-960/pytorch_model.bin
Deleting older checkpoint [/data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-adapter-rnn/checkpoint-810] due to args.save_total_limit
***** Running Evaluation *****
  Num examples = 88
  Batch size = 48
***** Running Evaluation *****
  Num examples = 88
  Batch size = 48
Saving model checkpoint to /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-adapter-rnn/checkpoint-975
Configuration saved in /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-adapter-rnn/checkpoint-975/config.json
Model weights saved in /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-adapter-rnn/checkpoint-975/pytorch_model.bin
Deleting older checkpoint [/data4/yoomcache/check

Configuration saved in /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-adapter-rnn/checkpoint-1095/config.json
Model weights saved in /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-adapter-rnn/checkpoint-1095/pytorch_model.bin
Deleting older checkpoint [/data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-adapter-rnn/checkpoint-945] due to args.save_total_limit
***** Running Evaluation *****
  Num examples = 88
  Batch size = 48
***** Running Evaluation *****
  Num examples = 88
  Batch size = 48
Saving model checkpoint to /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-adapter-rnn/checkpoint-1110
Configuration saved in /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-adapter-rnn/checkpoint-1110/config.json
Model weights saved in /data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-adapter-rnn/checkpoint-1110/pytorch_model.bin
Deleting older checkpoint [/data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-adapter-rnn/checkpoint-960] due to args.save_total_limit
***** Running Evaluation *****
 

In [None]:
wandb.finish()

In [None]:
# model.load_state_dict(torch.load('/data4/yoomcache/checkpoint/wav2vec2gpt/unfreeze-adapter-rnn-lm/checkpoint-8814/pytorch_model.bin'))

In [None]:
##### example


BATCH_SIZE = 4
i = 4
device = 'cuda:0'
batch_idx = train_idx[i*BATCH_SIZE:i*BATCH_SIZE+BATCH_SIZE]

audio_feature_batch = list()
for idx in batch_idx:
    audio_feature_batch.append(dataset['audio_array'][idx])
audio_feature_batch = torch.stack(audio_feature_batch)
print(audio_feature_batch.size())

label_batch = dataset['text']['input_ids'][batch_idx]
attention_batch = dataset['text']['attention_mask'][batch_idx]

print(label_batch.size())

with torch.no_grad():
    audio_embedding = model(input_values=audio_feature_batch.to(device), 
                            labels=label_batch.to(device),
                            output_attention_mask=attention_batch.to(device),)
print(audio_embedding.logits.shape)

pred_ids = torch.argmax(audio_embedding.logits, axis=-1)
print(pred_ids.size())
print()

for idx in range(BATCH_SIZE):
    print(tokenizer.decode(label_batch[idx]))
    print(tokenizer.decode(pred_ids[idx]))
    print()


# for idx in range(BATCH_SIZE):
#     print(tokenizer.decode([key for key, _group in groupby(label_batch[idx])]))
#     print(tokenizer.decode([key for key, _group in groupby(pred_ids[idx])]))
#     print(tokenizer.decode([key for key, _group in groupby(pred_ids[idx])]).replace('<|endoftext|>',''))
#     print()

In [None]:
# with torch.no_grad():
#     print(model.wav2vec2(input_values=audio_feature_batch.to(device),)[0].shape)

In [None]:
# !pip install seaborn
import seaborn as sns

In [None]:
with torch.no_grad():
#     print(model.wav2vec2(audio_feature_batch)[0].shape)
#     print(model.wav2vec2(audio_feature_batch.to(device))[0])
    sns.heatmap(model.wav2vec2(audio_feature_batch.to(device))[0].cpu()[0])

In [None]:
import IPython

IPython.display.Audio(dataset['audio_path'][4])

In [None]:
model.transformer

In [None]:
model.transformer.h[0].attn.c_proj.weight