# **Colab install lib (optional)**

In [None]:
!pip install accelerate -U

In [9]:
!pip install transformers soundfile datasets jiwer gdown

# **Download dataset (optional)**

In [3]:
!mkdir ./dataset
import gdown
def drive_download(idx, output):
    url = 'https://drive.google.com/uc?id=' + idx
    gdown.download(url, output, quiet=False)
drive_download("1ZBL3h6bHMmd8MIUNXqg72PucUkC9ZSWJ", "./dataset/train_data.zip")
drive_download("1ZepptsTrVSjQEx-dpBBmQ2b7xYFLn_64", "./dataset/public_test.zip")
file_id = '1K_07kix1OgBGO2FNPh-Lxqr1yLbtqFYt'
output_file = './dataset/train.jsonl'
!gdown https://drive.google.com/uc?id=$file_id -O $output_file



Downloading...
From: https://drive.google.com/uc?id=1ZBL3h6bHMmd8MIUNXqg72PucUkC9ZSWJ
To: /content/dataset/train_data.zip
100%|██████████| 733M/733M [00:03<00:00, 215MB/s]
Downloading...
From: https://drive.google.com/uc?id=1ZepptsTrVSjQEx-dpBBmQ2b7xYFLn_64
To: /content/dataset/public_test.zip
100%|██████████| 131M/131M [00:00<00:00, 142MB/s]


Downloading...
From: https://drive.google.com/uc?id=1K_07kix1OgBGO2FNPh-Lxqr1yLbtqFYt
To: /content/dataset/train.jsonl
  0% 0.00/3.30M [00:00<?, ?B/s]100% 3.30M/3.30M [00:00<00:00, 205MB/s]


In [None]:
!unzip ./dataset/public_test.zip -d ./dataset/test
!unzip ./dataset/train_data.zip -d ./dataset/train

# **Train**

In [11]:
import torch
import json
import utils
import numpy as np
from functools import partial
from datasets import load_metric
from dataset import Wav2VecDataset
from torch.utils.data import DataLoader
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor

In [12]:
def custom_collate(processor, batch):
    processed_batch = [
        processor(i["input_values"], text=i["label"], sampling_rate=16000) for i in batch
    ]
    input_features = [{"input_values": i.input_values[0]} for i in processed_batch]
    input_features = processor.pad(input_features, padding=True, return_tensors="pt")
    if batch[0]["label"] is None:
        return input_features
    label_features = [{"input_ids": i.labels} for i in processed_batch]
    with processor.as_target_processor():
        label_features = processor.pad(
            label_features,
            padding=True,
            return_tensors="pt"
        )
    label_features = label_features["input_ids"].masked_fill(label_features.attention_mask.ne(1), -100)
    input_features["labels"] = label_features
    return input_features

In [16]:
def train_test_split(root_path, notation_file, test_size=0.3)   :
    notations = utils.load_annotation(notation_file)
    dataset = Wav2VecDataset(root_path, [i["file"] for i in notations], notations)
    N = len(dataset)
    print(f"Len dataset: {N}")
    train_size = int(N * (1-test_size))
    train_set, valid_set = torch.utils.data.random_split(dataset, [train_size, N-train_size])
    return train_set, valid_set

In [18]:
train_ds, valid_ds = train_test_split("./dataset/train/Train/", "./dataset/train.jsonl", test_size=0.2)
len(train_ds), len(valid_ds)

Len dataset: 7490


(5992, 1498)

In [None]:
# nguyenvulebinh/wav2vec2-base-vietnamese-250h
processor = Wav2Vec2Processor.from_pretrained("nguyenvulebinh/wav2vec2-base-vietnamese-250h")
model = Wav2Vec2ForCTC.from_pretrained(
    "nguyenvulebinh/wav2vec2-base-vietnamese-250h",
    ctc_loss_reduction="mean",
    pad_token_id=processor.tokenizer.pad_token_id
)
model.freeze_feature_encoder()
model

In [None]:
import numpy as np
wer_metric = load_metric("wer")
def compute_metrics(pred):
    pred_logits = pred.predictions
    pred_ids = np.argmax(pred_logits, axis=-1)
    pred.label_ids[pred.label_ids == -100] = processor.tokenizer.pad_token_id
    pred_str = processor.batch_decode(pred_ids)
    label_str = processor.batch_decode(pred.label_ids, group_tokens=False)
    wer = wer_metric.compute(predictions=pred_str, references=label_str)
    return {"wer": wer}

In [21]:
!mkdir ./checkpoint

In [None]:
from transformers import TrainingArguments, Trainer

training_args = TrainingArguments(
    output_dir="./checkpoint/wav2vec_v1.0",
    group_by_length=True,
    per_device_train_batch_size=16,
    evaluation_strategy="epoch",
    num_train_epochs=10,
    fp16=True,
    gradient_checkpointing=True,
    logging_strategy="epoch",
    learning_rate=1e-4,
    weight_decay=0.005,
    warmup_steps=100,
    save_total_limit=1,
)

In [None]:
trainer = Trainer(
    model=model,
    data_collator=partial(custom_collate, processor),
    args=training_args,
    compute_metrics=compute_metrics,
    train_dataset=train_ds,
    eval_dataset=valid_ds,
    tokenizer=processor.feature_extractor,
)

In [None]:
trainer.train()

# **Test**

In [23]:
import os
test_ds = Wav2VecDataset("./dataset/test/public_test", os.listdir("./dataset/test/public_test"))
len(test_ds)

1299

In [24]:
test_ds[2]

{'input_values': array([0.        , 0.        , 0.        , ..., 0.00018311, 0.00030518,
        0.00012207]),
 'sample_rate': 16000,
 'label': None,
 'file': 'MdZLyoyXTyox1kXsuoPhP1g.wav'}

In [25]:
def test_collate(processor, batch):
    processed_batch = [
        processor(i["input_values"], sampling_rate=i["sample_rate"]) for i in batch
    ]
    input_features = [{"input_values": i.input_values[0]} for i in processed_batch]
    input_features = processor.pad(input_features, padding=True, return_tensors="pt")
    input_features["id"] = [i["file"] for i in batch]
    return input_features

In [26]:
test_loader = DataLoader(test_ds, shuffle=False, batch_size=16, collate_fn=partial(test_collate, processor))
len(test_loader)

82

In [None]:
def wav2vec_inference(model, test_loader, processor, device=None):
    if device is None:
        device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model.eval()
    model = model.to(device)
    pred_sentences = {}
    for idx, batch in enumerate(test_loader, 1):
        X_test =  batch["input_values"].to(device)
        X_test = X_test.half()
        file_test = batch["id"]
        with torch.set_grad_enabled(False):
            with torch.autocast("cuda", dtype=torch.float16, enabled=True):
                logits = model(input_values=X_test).logits
        logits = torch.argmax(logits, dim=-1)
        transcriptions = processor.batch_decode(logits, skip_special_tokens=True)
        for file_id, trans in zip(file_test, transcriptions):
            pred_sentences[file_id] = trans
        # pred_sentences += transcriptions
        print("\r", end="")
        print(f"\r {idx} / {len(test_loader)}", end = "" if idx != len(test_loader) else "\n")
    return pred_sentences

In [None]:
pred_sens = wav2vec_inference(model, test_loader, processor, torch.device("cuda:0"))

 82 / 82


In [None]:
len(pred_sens.items())

1299

In [None]:
with open("./dataset/test_sentences.json", "w", encoding="utf-8") as f:
    json.dump(pred_sens, f)
    f.close()

In [None]:
json.load(open("./test_sentences.json", "r", encoding="utf-8"))

{'TGZD2rv26WxpkV9LRAxwmVb.wav': 'tăng ở cầu thang a mức tăng là 41',
 'QgEpeOnXqmu2gclgZxvGaVo.wav': 'tắt camera của huy đi cho tôi',
 'OswACb2vqKb3OjMJkqveMn2.wav': 'kiểm tra cho mình vào lúc 15 giờ 16 phút nhé',
 'grFpJzV3KfLoCwTY8Wh2o9K.wav': 'đóng hộ anh cái cửa cuốn số 19',
 'bwFuSjwUxFmpjSJTE7X9M2N.wav': 'bật cho mình cái laptop với mình cần làm việc',
 'm0HfaJWGmgixojj38qoEFxC.wav': 'giảm cho mình đi khoảng 45 nhé',
 '32NTIZYh7psVhOSFCeQ2W7M.wav': 'tối nhỉ em kiểm tra cho anh cái đèn ốp trần nhá',
 'xiGv2T0iECUFKxGOA11UxUe.wav': 'sẽ có khách tới nhà ở trong vòng 5 tiếng 39 phút đấy nhé',
 'cSgsEultnRfuhUTbvKyP31Z.wav': 'tôi không muốn dùng cái tiền sảnh nữa không cần thư giãn đâu',
 'bYyk93Vxd5UUS1uDbGjqPFT.wav': 'bạn ơi bật cho mình cái bóng hắt',
 'VFp67Vea96jY863powsqpPj.wav': 'giảm cho anh cái bóng làm việc ở trong phòng thu với',
 'LKx7t6NAyensF2DZF4oiVD3.wav': 'kiểm tra lúc 1 giờ 2 phút em nhé',
 'XeSFaqh4yKUg18SSn9dCLLk.wav': 'bạn ơi tăng cái bóng thả lên 36 giúp mình với

# **Just store config.json, preprocessor_config.json, pytorch_model.bin**

In [None]:
!mkdir ./checkpoint_wav2vec_v1.0
!cp -r ./checkpoint/wav2vec_v1.0/checkpoint-3500/config.json ./checkpoint_wav2vec_v1.0
!cp -r ./checkpoint/wav2vec_v1.0/preprocessor_config.json ./checkpoint_wav2vec_v1.0
!cp -r ./checkpoint/wav2vec_v1.0/pytorch_model.bin ./checkpoint_wav2vec_v1.0