In [1]:
import os
import json
import torch
import torchaudio
import numpy as np
import pandas as pd

from tqdm import tqdm
from glob import glob
from scipy.spatial import distance
from textblob import TextBlob

from datasets import Dataset
from transformers import AutoProcessor, Wav2Vec2ForCTC

  from .autonotebook import tqdm as notebook_tqdm


In [35]:
os.environ["CUDA_VISIBLE_DEVICES"]= "1"

In [3]:
test_txt_path_lst = sorted(glob('/home/kyoungmin_temp/laboratory/kor2kor/dataset/aihub_older_jeju/test_circum_01/*.json'))
len(test_txt_path_lst)

4498

In [16]:
test_info = {'path': [], 'dialect': [], 'standard': []}
test_info['path'] = list(map(lambda x: x.replace('test_circum_01', 'test_speech_circum_01').replace('json', 'wav'), test_txt_path_lst))

for sample_path in tqdm(test_txt_path_lst):
    with open(sample_path) as f:
        sample_json = json.load(f)
    
    dialect_txt = ' '.join(list(x['dialect'] for x in sample_json['transcription']['segments']))
    standard = ' '.join(list(x['dialect'] if x['standard'] is None else x['standard'] for x in sample_json['transcription']['segments']))
    test_info['dialect'].append(dialect_txt)
    test_info['standard'].append(standard)

100%|███████████████████████████████████████████████████████████████| 4498/4498 [00:09<00:00, 455.04it/s]


In [17]:
test_ds = Dataset.from_dict(test_info)
test_ds

Dataset({
    features: ['path', 'dialect', 'standard'],
    num_rows: 4498
})

In [18]:
import librosa
from pyctcdecode import build_ctcdecoder
from transformers import (
    AutoConfig,
    AutoFeatureExtractor,
    AutoModelForCTC,
    AutoTokenizer,
    Wav2Vec2ProcessorWithLM,
)
from transformers.pipelines import AutomaticSpeechRecognitionPipeline

# 모델과 토크나이저, 예측을 위한 각 모듈들을 불러옵니다.
model = AutoModelForCTC.from_pretrained("42MARU/ko-spelling-wav2vec2-conformer-del-1s", cache_dir='/home/kyoungmin_temp/HF_CACHE')
feature_extractor = AutoFeatureExtractor.from_pretrained("42MARU/ko-spelling-wav2vec2-conformer-del-1s", cache_dir='/home/kyoungmin_temp/HF_CACHE')
tokenizer = AutoTokenizer.from_pretrained("42MARU/ko-spelling-wav2vec2-conformer-del-1s", cache_dir='/home/kyoungmin_temp/HF_CACHE')
beamsearch_decoder = build_ctcdecoder(
    labels=list(tokenizer.encoder.keys()),
    kenlm_model_path=None,
)

Found entries of length > 1 in alphabet. This is unusual unless style is BPE, but the alphabet was not recognized as BPE type. Is this correct?


Ignored unknown kwarg option normalize
Ignored unknown kwarg option normalize
Ignored unknown kwarg option normalize
Ignored unknown kwarg option normalize


In [19]:
processor = Wav2Vec2ProcessorWithLM(
    feature_extractor=feature_extractor, tokenizer=tokenizer, decoder=beamsearch_decoder
)

# 실제 예측을 위한 파이프라인에 정의된 모듈들을 삽입.
asr_pipeline = AutomaticSpeechRecognitionPipeline(
    model=model,
    tokenizer=processor.tokenizer,
    feature_extractor=processor.feature_extractor,
    decoder=processor.decoder,
    device=-1,
)

In [32]:
input_txt = []
for sample in tqdm(test_ds):
    raw_data, _ = librosa.load(sample['path'], sr=16000)
    kwargs = {"decoder_kwargs": {"beam_width": 100}}
    pred = asr_pipeline(inputs=raw_data, **kwargs)["text"]
    input_txt.append(pred)

100%|████████████████████████████████████████████████████████████████| 4498/4498 [19:48<00:00,  3.79it/s]

dialect





In [33]:
len(input_txt)

4498

In [50]:
import torch
from transformers import pipeline

MODEL_NAME = 'KoBART_base_v2-trial2'
device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
pipe = pipeline(
    "translation", model=f"{MODEL_NAME}", max_length=40
)

You passed along `num_labels=3` with an incompatible id to label map: {'0': 'NEGATIVE', '1': 'POSITIVE'}. The number of labels wil be overwritten to 2.
You passed along `num_labels=3` with an incompatible id to label map: {'0': 'NEGATIVE', '1': 'POSITIVE'}. The number of labels wil be overwritten to 2.


In [45]:
test_txt_path_lst = sorted(glob('/home/kyoungmin_temp/laboratory/kor2kor/dataset/aihub_older_jeju/test_circum_01/*.json'))
len(test_txt_path_lst)

4498

In [53]:
def output_processing(result_txt):
    empty_space = result_txt.strip(' ').replace('\n', '').split(' ')
    try:
        empty_space = empty_space[:empty_space.index('')]
    except:
        pass
    
    if(len(empty_space) >= 2):
        while empty_space[-1] == empty_space[-2]:
            empty_space.pop()
            if len(empty_space) == 2:
                break

    if(len(empty_space) >= 4):
        while empty_space[-2:] == empty_space[-4:-2]:
            empty_space.pop()
            empty_space.pop()
            if len(empty_space) == 4:
                break

        if len(empty_space) == 2:
            pass
        else:
            word_set1 = set(''.join(empty_space[-2:]))
            word_set2 = set(''.join(empty_space[-4:-2]))
            total_set = set(''.join(empty_space[-4:])) 
                
            while (word_set1 == total_set) or (word_set2 == total_set):
                empty_space.pop()
                empty_space.pop()
                
                word_set1 = set(''.join(empty_space[-2:]))
                word_set2 = set(''.join(empty_space[-4:-2]))
                total_set = set(''.join(empty_space[-4:]))
    
                if len(empty_space) == 2:
                    break

    if(len(empty_space) >= 6):
        while empty_space[-3:] == empty_space[-6:-3]:
            empty_space.pop()
            empty_space.pop()
            empty_space.pop()

        if len(empty_space) == 3:
            pass
        else:
            word_set1 = set(''.join(empty_space[-3:]))
            word_set2 = set(''.join(empty_space[-6:-3]))
            total_set = set(''.join(empty_space[-6:]))
            
            while (word_set1 == total_set) or (word_set2 == total_set):
                empty_space.pop()
                empty_space.pop()
                empty_space.pop()
                
                word_set1 = set(''.join(empty_space[-3:]))
                word_set2 = set(''.join(empty_space[-6:-3]))
                total_set = set(''.join(empty_space[-6:]))
    
                if len(empty_space) == 3:
                    break
            
    return ' '.join(empty_space)

In [57]:
import nltk.translate.bleu_score as bleu
import nlptutti as metrics

In [None]:
bleu_result = {'path': [], 'bleu_score': [], 'dialect': [], 'standard': [], 'predict': [], 'cer_score': []}

for sample_path, conformer_output in tqdm(zip(test_txt_path_lst, input_txt)):
    with open(sample_path) as f:
        sample_json = json.load(f)
    
    dialect_txt = conformer_output
    ground_truth = ' '.join(list(x['dialect'] if x['standard'] is None else x['standard'] for x in sample_json['transcription']['segments']))
    model_result = pipe(dialect_txt, num_return_sequences=1, pad_token_id=0)[0]['translation_text']
    post_process_txt = output_processing(model_result)

    reference = [ground_truth.split()]
    model_output = post_process_txt.split()
    bleu_score = bleu.sentence_bleu(reference, model_output)
    cer_score = metrics.get_cer(ground_truth, model_output)['cer']

    bleu_result['path'].append(os.path.basename(sample_path))
    bleu_result['bleu_score'].append(bleu_score)
    bleu_result['dialect'].append(dialect_txt)
    bleu_result['standard'].append(ground_truth)
    bleu_result['predict'].append(post_process_txt)
    bleu_result['cer_score'].append(cer_score)

8it [00:08,  1.04s/it]