In [None]:
# Load public dataset from University of Tokyo
!wget http://ss-takashi.sakura.ne.jp/corpus/jsut_ver1.1.zip
!unzip jsut_ver1.1.zip

path = 'jsut_ver1.1/basic5000/'
df = pd.read_csv(path + 'transcript_utf8.txt', header = None, delimiter = ":", names=["path", "sentence"], index_col=False)
df["path"] = df["path"].map(lambda x: path + 'wav/' + x + ".wav")
df.head()

jsut_voice_train = Dataset.from_pandas(df)

In [None]:
# Import training dataset
common_voice_train = load_dataset('common_voice', 'ja',split='train+validation')
common_voice_test = load_dataset('common_voice', 'ja', split='test')

# Remove unwanted columns
common_voice_train = common_voice_train.remove_columns(["accent", "age", "client_id", "down_votes", "gender", "locale", "segment", "up_votes"])
common_voice_test = common_voice_test.remove_columns(["accent", "age", "client_id", "down_votes", "gender", "locale", "segment", "up_votes"])

# Concat common voice and public dataset
common_voice_train = datasets.concatenate_datasets([jsut_voice_train, common_voice_train])

In [None]:
wakati = MeCab.Tagger("-Owakati")
neo = neologdn.normalize

# Unwanted token
chars_to_ignore_regex = '[\,\、\。\．\「\」\…\？\・\!\-\;\:\"\“\%\‘\”\�]'

def remove_special_characters(batch):
    batch["sentence"] = neologdn.normalize=(batch["sentence"]).strip()
    batch["sentence"] = re.sub(chars_to_ignore_regex,'', batch["sentence"]).strip()
    return batch

common_voice_train = common_voice_train.map(remove_special_characters)
common_voice_test = common_voice_test.map(remove_special_characters)

In [None]:
# make metric function
wer_metric = load_metric("wer")

def compute_metrics(pred):
    pred_logits = pred.predictions
    pred_ids = np.argmax(pred_logits, axis=-1)

    pred.label_ids[pred.label_ids == -100] = processor.tokenizer.pad_token_id

    pred_str = processor.batch_decode(pred_ids)
    # we do not want to group tokens when computing the metrics
    label_str = processor.batch_decode(pred.label_ids, group_tokens=False)

    wer = wer_metric.compute(predictions=pred_str, references=label_str)

    return {"wer": wer}

In [None]:
def get_polynomial_decay_schedule_with_warmup(
    optimizer, num_warmup_steps, num_training_steps, lr_end=1e-7, power=1.2, last_epoch=-1
):

    lr_init = optimizer.defaults["lr"]
    assert lr_init > lr_end, f"lr_end ({lr_end}) must be be smaller than initial lr ({lr_init})"

    def lr_lambda(current_step: int):
        if current_step < num_warmup_steps:
            return float(current_step) / float(max(1, num_warmup_steps))
        elif current_step > num_training_steps:
            return lr_end / lr_init  # as LambdaLR multiplies by lr_init
        else:
            lr_range = lr_init - lr_end
            decay_steps = num_training_steps - num_warmup_steps
            pct_remaining = 1 - (current_step - num_warmup_steps) / decay_steps
            decay = lr_range * pct_remaining ** power + lr_end
            return decay / lr_init  # as LambdaLR multiplies by lr_init

    return LambdaLR(optimizer, lr_lambda, last_epoch)

# wrap custom learning scheduler with trainer
class PolyTrainer(Trainer):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def create_scheduler(self, num_training_steps: int):
        self.lr_scheduler = get_polynomial_decay_schedule_with_warmup(self.optimizer,
                                                                      num_warmup_steps=self.args.warmup_steps,
                                                                      num_training_steps=num_training_steps)
    def create_optimizer_and_scheduler(self, num_training_steps: int):
        self.create_optimizer()
        self.create_scheduler(num_training_steps)

In [None]:
import torch
import torchaudio
from datasets import load_dataset, load_metric
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
import MeCab
import pykakasi
import re

#config
wakati = MeCab.Tagger("-Owakati")
chars_to_ignore_regex = '[\,\、\。\．\「\」\…\？\・]'

#load model
processor = Wav2Vec2Processor.from_pretrained(save_dir)
test_model = Wav2Vec2ForCTC.from_pretrained(save_dir)
test_model.to("cuda")
resampler = torchaudio.transforms.Resample(48_000, 16_000)

#load testdata
test_dataset = load_dataset("common_voice", "ja", split="test")
wer = load_metric("wer")

# Preprocessing the datasets.
def speech_file_to_array_fn(batch):
    batch["sentence"] = wakati.parse(batch["sentence"]).strip()
    batch["sentence"] = re.sub(chars_to_ignore_regex,'', batch["sentence"]).strip()
    speech_array, sampling_rate = torchaudio.load(batch["path"])
    batch["speech"] = resampler(speech_array).squeeze().numpy()
    return batch

test_dataset = test_dataset.map(speech_file_to_array_fn)

# Preprocessing the datasets.
# We need to read the aduio files as arrays
def evaluate(batch):
    inputs = processor(batch["speech"], sampling_rate=16_000, return_tensors="pt", padding=True)

    with torch.no_grad():
        logits = test_model(inputs.input_values.to("cuda"), attention_mask=inputs.attention_mask.to("cuda")).logits
    pred_ids = torch.argmax(logits, dim=-1)
    batch["pred_strings"] = processor.batch_decode(pred_ids)
    return batch

result = test_dataset.map(evaluate, batched=True, batch_size=8)

print("WER: {:2f}".format(100 * wer.compute(predictions=result["pred_strings"], references=result["sentence"])))

# print some reusults
pick = random.randint(0, len(common_voice_test_transcription)-1)
input_dict = processor(common_voice_test["input_values"][pick], return_tensors="pt", padding=True)
logits = test_model(input_dict.input_values.to("cuda")).logits
pred_ids = torch.argmax(logits, dim=-1)[0]

print("Prediction:")
print(processor.decode(pred_ids).strip())

print("\nLabel:")
print(processor.decode(common_voice_test['labels'][pick]))