In [None]:
!pip install accelerate -U
!pip install transformers soundfile datasets jiwer gdown pyctcdecode kenlm

In [19]:
# !mkdir ./dataset
import gdown
def drive_download(idx, output):
    url = 'https://drive.google.com/uc?id=' + idx
    gdown.download(url, output, quiet=False)
# drive_download("1ZBL3h6bHMmd8MIUNXqg72PucUkC9ZSWJ", "./dataset/train_data.zip")
drive_download("1ZepptsTrVSjQEx-dpBBmQ2b7xYFLn_64", "./dataset/public_test.zip")
# drive_download("1K_07kix1OgBGO2FNPh-Lxqr1yLbtqFYt", "./dataset/train.jsonl")

Downloading...
From (uriginal): https://drive.google.com/uc?id=1ZepptsTrVSjQEx-dpBBmQ2b7xYFLn_64
From (redirected): https://drive.google.com/uc?id=1ZepptsTrVSjQEx-dpBBmQ2b7xYFLn_64&confirm=t&uuid=82e7cbd1-d4dc-4ccc-8d1d-499caa4c4b29
To: /workspace/dataset/public_test.zip
100%|██████████| 131M/131M [00:03<00:00, 35.5MB/s] 


In [2]:
drive_download("1VHhkv0FhU6mPieMb-aFLvJskz6A6oyLh", "./wav2vec_large_st_2.zip")
drive_download("186Tv-dPED5QiIJy4sRvlsNsvYxLpXfWX", "./vn_base_vocab.json")
drive_download("1K-oNwBu2svshAkmifU9wISKPMvKgeKy4", "./train_20230909.jsonl")

Downloading...
From (uriginal): https://drive.google.com/uc?id=1VHhkv0FhU6mPieMb-aFLvJskz6A6oyLh
From (redirected): https://drive.google.com/uc?id=1VHhkv0FhU6mPieMb-aFLvJskz6A6oyLh&confirm=t&uuid=399d86ef-600f-4439-8f46-3ff8c148a8f4
To: /workspace/wav2vec_large_st_2.zip
100%|██████████| 1.17G/1.17G [00:28<00:00, 40.7MB/s]
Downloading...
From: https://drive.google.com/uc?id=186Tv-dPED5QiIJy4sRvlsNsvYxLpXfWX
To: /workspace/vn_base_vocab.json
100%|██████████| 1.35k/1.35k [00:00<00:00, 3.73MB/s]
Downloading...
From: https://drive.google.com/uc?id=1K-oNwBu2svshAkmifU9wISKPMvKgeKy4
To: /workspace/train_20230909.jsonl
100%|██████████| 3.30M/3.30M [00:00<00:00, 12.0MB/s]


In [None]:
!unzip ./dataset/public_test.zip -d ./dataset/test
!unzip ./dataset/train_data.zip -d ./dataset/train

In [3]:
import torch, json, utils, os
import numpy as np
from functools import partial
from datasets import load_metric
from dataset import Wav2VecDataset
from torch.utils.data import DataLoader
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor, TrainingArguments, Trainer

In [4]:
def custom_collate(processor, batch):
    processed_batch = [
        processor(i["input_values"], text=i["label"], sampling_rate=16000) for i in batch
    ]
    input_features = [{"input_values": i.input_values[0]} for i in processed_batch]
    input_features = processor.feature_extractor.pad(input_features, padding=True, return_tensors="pt")
    if batch[0]["label"] is None:
        return input_features
    label_features = [{"input_ids": i.labels} for i in processed_batch]
    label_features = processor.tokenizer.pad(
        label_features,
        padding=True,
        return_tensors="pt"
    )
    label_features = label_features["input_ids"].masked_fill(label_features.attention_mask.ne(1), -100)
    input_features["labels"] = label_features
    return input_features

In [6]:
def train_test_split(root_path, notation_file, test_size=0.3):
    notations = utils.load_annotation(notation_file)
    dataset = Wav2VecDataset(root_path, [i["file"] for i in notations], notations)
    N = len(dataset)
    print(f"Len dataset: {N}")
    if test_size == 0:
        return dataset, dataset
    train_size = int(N * (1-test_size))
    train_set, valid_set = torch.utils.data.random_split(dataset, [train_size, N-train_size])
    return train_set, valid_set

In [8]:
train_ds, valid_ds = train_test_split("./dataset/train/Train/", "./train_20230909.jsonl", test_size=0)
len(train_ds), len(valid_ds)

Len dataset: 7490


(7490, 7490)

In [9]:
from transformers import Wav2Vec2CTCTokenizer, Wav2Vec2FeatureExtractor
tokenizer = Wav2Vec2CTCTokenizer("./vn_base_vocab.json")
extractor = Wav2Vec2FeatureExtractor.from_pretrained("nguyenvulebinh/wav2vec2-large-vi-vlsp2020")
processor = Wav2Vec2Processor(tokenizer=tokenizer, feature_extractor=extractor)

Downloading (…)rocessor_config.json:   0%|          | 0.00/263 [00:00<?, ?B/s]

In [11]:
model = Wav2Vec2ForCTC.from_pretrained(
    "./wav2vec_large_st_2/checkpoint-5616",
    attention_dropout=0.1,
    hidden_dropout=0.1,
    feat_proj_dropout=0.0,
    layerdrop=0.025,
    mask_time_prob=0.3,
    ctc_zero_infinity=True,
    ctc_loss_reduction="mean",
    pad_token_id=processor.tokenizer.pad_token_id,
    vocab_size=len(processor.tokenizer)
)
model.freeze_feature_encoder()
model.gradient_checkpointing_enable()
# model.config.vocab_size = 111
# 768 - base model
# model.lm_head = torch.nn.Linear(in_features=1024, out_features=111, bias=True)

In [12]:
wer_metric = load_metric("wer")
def compute_metrics(pred):
    pred_logits = pred.predictions
    pred_ids = np.argmax(pred_logits, axis=-1)
    pred.label_ids[pred.label_ids == -100] = processor.tokenizer.pad_token_id
    pred_str = processor.batch_decode(pred_ids)
    label_str = processor.batch_decode(pred.label_ids, group_tokens=False)
    wer = wer_metric.compute(predictions=pred_str, references=label_str)
    return {"wer": wer}

  wer_metric = load_metric("wer")


Downloading builder script:   0%|          | 0.00/1.90k [00:00<?, ?B/s]

In [15]:
batch_size = 32
steps = len(train_ds) // batch_size
steps

234

In [16]:
training_args = TrainingArguments(
    output_dir="./wav2vec_large_st_3",
    save_total_limit=1,
    
    group_by_length=True,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    evaluation_strategy="steps",
    # fp16=True,
    gradient_checkpointing=True,
    learning_rate=7e-5,
    # warmup_steps=500,
    
    save_steps=steps,
    eval_steps=steps,
    logging_steps=steps,
    max_steps=steps*25,
    
    load_best_model_at_end=True,
    metric_for_best_model="wer",
    greater_is_better=False
)
trainer = Trainer(
    model=model,
    data_collator=partial(custom_collate, processor),
    args=training_args,
    compute_metrics=compute_metrics,
    train_dataset=train_ds,
    eval_dataset=valid_ds,
    tokenizer=processor.feature_extractor,
)
trainer.train()

Step,Training Loss,Validation Loss,Wer
234,0.2067,0.053868,0.104931
468,0.2102,0.054371,0.104425
702,0.2009,0.04934,0.105873
936,0.1965,0.048345,0.104189
1170,0.1951,0.048484,0.10307
1404,0.1787,0.044628,0.10234
1638,0.1731,0.04418,0.102634
1872,0.1593,0.043067,0.105355
2106,0.1616,0.041671,0.101681
2340,0.1603,0.039239,0.102599


KeyboardInterrupt: 

In [17]:
model = Wav2Vec2ForCTC.from_pretrained(
    "./wav2vec_large_st_3/checkpoint-3744",
    attention_dropout=0.1,
    hidden_dropout=0.1,
    feat_proj_dropout=0.0,
    layerdrop=0,
    mask_time_prob=0.25,
    ctc_zero_infinity=True,
    ctc_loss_reduction="mean",
    pad_token_id=processor.tokenizer.pad_token_id,
    vocab_size=len(processor.tokenizer)
)
model.freeze_feature_encoder()
model.gradient_checkpointing_enable()
# model.config.vocab_size = 111
# 768 - base model
# model.lm_head = torch.nn.Linear(in_features=1024, out_features=111, bias=True)

In [18]:
training_args = TrainingArguments(
    output_dir="./wav2vec_large_st_4",
    save_total_limit=1,
    
    group_by_length=True,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    evaluation_strategy="steps",
    # fp16=True,
    gradient_checkpointing=True,
    learning_rate=5e-5,
    # warmup_steps=500,
    weight_decay=0.01,
    save_steps=steps,
    eval_steps=steps,
    logging_steps=steps,
    max_steps=steps*25,
    
    load_best_model_at_end=True,
    metric_for_best_model="wer",
    greater_is_better=False
)
trainer = Trainer(
    model=model,
    data_collator=partial(custom_collate, processor),
    args=training_args,
    compute_metrics=compute_metrics,
    train_dataset=train_ds,
    eval_dataset=valid_ds,
    tokenizer=processor.feature_extractor,
)
trainer.train()

Step,Training Loss,Validation Loss,Wer
234,0.0879,0.03548,0.100373
468,0.0985,0.033936,0.099266
702,0.0915,0.033657,0.099855
936,0.0933,0.032799,0.098783
1170,0.0967,0.033864,0.099325
1404,0.0879,0.032423,0.099207
1638,0.0833,0.030122,0.098171
1872,0.0745,0.030185,0.098642
2106,0.0798,0.028683,0.097641
2340,0.0822,0.028028,0.09783


KeyboardInterrupt: 

In [20]:
import os
test_ds = Wav2VecDataset("./dataset/test/public_test", os.listdir("./dataset/test/public_test"))
len(test_ds)

1299

In [21]:
def test_collate(processor, batch):
    processed_batch = [
        processor(i["input_values"], sampling_rate=i["sample_rate"]) for i in batch
    ]
    input_features = [{"input_values": i.input_values[0]} for i in processed_batch]
    input_features = processor.pad(input_features, padding=True, return_tensors="pt")
    input_features["id"] = [i["file"] for i in batch]
    return input_features

In [22]:
test_loader = DataLoader(test_ds, shuffle=False, batch_size=1, collate_fn=partial(test_collate, processor))
len(test_loader)

1299

In [23]:
drive_download("1-sKIn6-MMt1S5wbNuK2etmJZYztTlS8b", "./4gram_correct.arpa")

Downloading...
From: https://drive.google.com/uc?id=1-sKIn6-MMt1S5wbNuK2etmJZYztTlS8b
To: /workspace/4gram_correct.arpa
100%|██████████| 2.78M/2.78M [00:00<00:00, 22.1MB/s]


In [24]:
from pyctcdecode import build_ctcdecoder
from transformers import Wav2Vec2ProcessorWithLM

vocab_dict = processor.tokenizer.get_vocab()
sorted_vocab_dict = {k.lower(): v for k, v in sorted(vocab_dict.items(), key=lambda item: item[1])}
decoder = build_ctcdecoder(
    labels=list(sorted_vocab_dict.keys()),
    kenlm_model_path="./4gram_correct.arpa",
)
processor_with_lm = Wav2Vec2ProcessorWithLM(
    feature_extractor=extractor,
    tokenizer=tokenizer,
    decoder=decoder
)

Loading the LM will be faster if you build a binary file.
Reading /workspace/4gram_correct.arpa
----5---10---15---20---25---30---35---40---45---50---55---60---65---70---75---80---85---90---95--100
****************************************************************************************************
Only 898 unigrams passed as vocabulary. Is this small or artificial data?


In [25]:
def wav2vec_lm_inference(model, test_loader, processor_with_lm, device=None):
    if device is None:
        device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model.eval()
    model = model.to(device)
    pred_sentences = {}
    for idx, batch in enumerate(test_loader, 1):
        X_test =  batch["input_values"].to(device)
        file_test = batch["id"]
        with torch.set_grad_enabled(False):
            logits = model(input_values=X_test).logits
        # logits = torch.argmax(logits, dim=-1)
        logits = logits.cpu().clone().detach().numpy()
        # transcriptions = processor.batch_decode(logits, skip_special_tokens=True)
        transcriptions = processor_with_lm.batch_decode(logits, beam_width=64).text
        for file_id, trans in zip(file_test, transcriptions):
            pred_sentences[file_id] = trans
        # pred_sentences += transcriptions
        print("\r", end="")
        print(f"\r {idx} / {len(test_loader)}", end = "" if idx != len(test_loader) else "\n")
    return pred_sentences

In [26]:
pred_sens = wav2vec_lm_inference(model, test_loader, processor_with_lm, torch.device("cuda:0"))

 1299 / 1299


In [28]:
with open("./w2v_v3_test_sentences.json", "w", encoding="utf-8") as f:
    json.dump(pred_sens, f, ensure_ascii=False)
    f.close()

In [27]:
pred_sens

{'gkr2nW4Zxwv9ay6iR1od5jP.wav': 'bạn ơi tăng giúp mình cái đèn hắt ở mức 8%',
 'WoGuEH1SVfdNDpDGXxcTbHJ.wav': 'em ơi tắt giúp anh cái đèn treo tường với',
 'SDU8HKVUyOmIpvwzWcMgJW7.wav': 'tầm 6 giờ 8 phút thì mở cho mình cái máy sưởi nha bạn',
 'QaiOJwzIYKRxVrLDQHxODfn.wav': 'em ơi giúp anh đóng cái lò nướng nhé',
 'vda3D3tnIOiwHHelu3W7PJM.wav': 'chị ơi bật cho em cái đèn bếp',
 '7xuHw67gjTqZkCqun94ts1A.wav': 'bật cho tôi cái bóng sợi đốt ở đầu hè',
 'DGm677nnv8ThrLnoGJdgGrV.wav': 'giảm xuống giúp anh 58% em nhé hãy để nó trong vòng 5 tiếng 31 phút',
 'd9bNQAZe5phqi3U9sFwQKW1.wav': 'mở cho tôi cái mành ở bên trong phòng học với',
 'YqMTIJJJIqGmym4oXXjkEet.wav': 'này em ơi đóng cho anh cái vòi nước số 5 nhá',
 '0DjYlBtXtacDRsbwi6xogBS.wav': 'mở hộ mình cái máy lạnh với',
 'mhp0355TKqlzJHmjmEjg8yh.wav': 'đi đóng giúp anh cái màn cuốn',
 'LpkMj8SFJENNcQLYqC6S79p.wav': 'mình sử dụng phòng vệ sinh để đi tắm xong rồi nhé không cần phải làm gì nữa đâu',
 'j3ppLDJGGWJp5rIsxSeB5Wq.wav': 'bạn bậ