In [4]:
from glob import glob

import torch
import torchaudio
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor

In [54]:
samples_path = glob("../samples/*.wav")
samples_path = [samples_path[1]]

len(samples_path), samples_path[0]

(1, '../samples/LJ001-0001.wav')

In [5]:
processor = Wav2Vec2Processor.from_pretrained("indonesian-nlp/wav2vec2-large-xlsr-indonesian")
model = Wav2Vec2ForCTC.from_pretrained("indonesian-nlp/wav2vec2-large-xlsr-indonesian")

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [52]:
model.config

Wav2Vec2Config {
  "_name_or_path": "indonesian-nlp/wav2vec2-large-xlsr-indonesian",
  "activation_dropout": 0.055,
  "apply_spec_augment": true,
  "architectures": [
    "Wav2Vec2ForCTC"
  ],
  "attention_dropout": 0.094,
  "bos_token_id": 1,
  "conv_bias": true,
  "conv_dim": [
    512,
    512,
    512,
    512,
    512,
    512,
    512
  ],
  "conv_kernel": [
    10,
    3,
    3,
    3,
    3,
    2,
    2
  ],
  "conv_stride": [
    5,
    2,
    2,
    2,
    2,
    2,
    2
  ],
  "ctc_loss_reduction": "mean",
  "ctc_zero_infinity": true,
  "do_stable_layer_norm": true,
  "eos_token_id": 2,
  "feat_extract_activation": "gelu",
  "feat_extract_dropout": 0.0,
  "feat_extract_norm": "layer",
  "feat_proj_dropout": 0.04,
  "final_dropout": 0.0,
  "gradient_checkpointing": true,
  "hidden_act": "gelu",
  "hidden_dropout": 0.047,
  "hidden_size": 1024,
  "initializer_range": 0.02,
  "intermediate_size": 4096,
  "layer_norm_eps": 1e-05,
  "layerdrop": 0.041,
  "mask_channel_length": 

In [24]:
total_param = sum(p.numel() for p in model.parameters())
fe_param = sum(p.numel() for p in model.wav2vec2.feature_extractor.parameters())
fp_param = sum(p.numel() for p in model.wav2vec2.feature_projection.parameters())
enc_param = sum(p.numel() for p in model.wav2vec2.encoder.parameters())
cls_param = sum(p.numel() for p in model.lm_head.parameters())

print(f"Total Params            : {total_param:,}")
print(f"Feature Encoder Params  : {fe_param:,}")
print(f"Feature Projector Params: {fp_param:,}")
print(f"Encoder Params          : {enc_param:,}")
print(f"Classifier Head Params  : {cls_param:,}")

Total Params            : 315,467,420
Feature Encoder Params  : 4,210,176
Feature Projector Params: 526,336
Encoder Params          : 310,701,184
Classifier Head Params  : 28,700


In [55]:
def read_wave(path):
    speech_array, sampling_rate = torchaudio.load(path)
    resampler = torchaudio.transforms.Resample(sampling_rate, 16_000)
    wave = resampler(speech_array).squeeze().numpy()
    return wave

samples = [read_wave(path) for path in samples_path]
inputs = processor(samples, sampling_rate = 16_000, return_tensors = "pt", padding = True)
inputs.input_values.shape

torch.Size([1, 188624])

In [56]:
with torch.no_grad():
    outputs = model(inputs.input_values, attention_mask = inputs.attention_mask, output_hidden_states = True)
    logits = outputs.logits

In [64]:
logits.shape

torch.Size([1, 589, 28])

In [65]:
pred_ids.shape

torch.Size([1, 589])

In [57]:
pred_ids = torch.argmax(logits, dim = -1)
pred_tokens = processor.batch_decode(pred_ids)
pred_tokens

['percetakan dalam satusatunya pengertian yang menjadi perhatian kita saat ini berbeda dari sebagian besar jika tidak semua seni dan kerajinan yang diwakili dalam pameran']

In [58]:
for i, h in enumerate(outputs.hidden_states):
    print(f"hidden state {i} {h.shape}")

hidden state 0 torch.Size([1, 589, 1024])
hidden state 1 torch.Size([1, 589, 1024])
hidden state 2 torch.Size([1, 589, 1024])
hidden state 3 torch.Size([1, 589, 1024])
hidden state 4 torch.Size([1, 589, 1024])
hidden state 5 torch.Size([1, 589, 1024])
hidden state 6 torch.Size([1, 589, 1024])
hidden state 7 torch.Size([1, 589, 1024])
hidden state 8 torch.Size([1, 589, 1024])
hidden state 9 torch.Size([1, 589, 1024])
hidden state 10 torch.Size([1, 589, 1024])
hidden state 11 torch.Size([1, 589, 1024])
hidden state 12 torch.Size([1, 589, 1024])
hidden state 13 torch.Size([1, 589, 1024])
hidden state 14 torch.Size([1, 589, 1024])
hidden state 15 torch.Size([1, 589, 1024])
hidden state 16 torch.Size([1, 589, 1024])
hidden state 17 torch.Size([1, 589, 1024])
hidden state 18 torch.Size([1, 589, 1024])
hidden state 19 torch.Size([1, 589, 1024])
hidden state 20 torch.Size([1, 589, 1024])
hidden state 21 torch.Size([1, 589, 1024])
hidden state 22 torch.Size([1, 589, 1024])
hidden state 23 torch

In [63]:
model.wav2vec2.feature_extractor

Wav2Vec2FeatureExtractor(
  (conv_layers): ModuleList(
    (0): Wav2Vec2LayerNormConvLayer(
      (conv): Conv1d(1, 512, kernel_size=(10,), stride=(5,))
      (layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
    )
    (1): Wav2Vec2LayerNormConvLayer(
      (conv): Conv1d(512, 512, kernel_size=(3,), stride=(2,))
      (layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
    )
    (2): Wav2Vec2LayerNormConvLayer(
      (conv): Conv1d(512, 512, kernel_size=(3,), stride=(2,))
      (layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
    )
    (3): Wav2Vec2LayerNormConvLayer(
      (conv): Conv1d(512, 512, kernel_size=(3,), stride=(2,))
      (layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
    )
    (4): Wav2Vec2LayerNormConvLayer(
      (conv): Conv1d(512, 512, kernel_size=(3,), stride=(2,))
      (layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
    )
    (5): Wav2Vec2LayerNormConvLayer(
      (conv): 