In [None]:
import librosa
import soundfile as sf
from matplotlib import pyplot as plt
import japanize_matplotlib
import numpy as np
from IPython.display import Audio

# conda activate pyopenjtalk_julius

# set path
import sys

sys.path.append("/home/takeshun256/PausePrediction")

# import standard library
from pprint import pprint
from pathlib import Path
import yaml
import pandas as pd

# import pyopenjtalk
# from pyopenjtalk import run_frontend, g2p
import jaconv
import re
from tqdm import tqdm

# import own library
from config import DATA_DIR, DATA_TAKESHUN256_DIR, SRC_DIR, DATA_IN_ROOT_DIR

# from src.analyze_jmac.text_preprocessing import (
#     AudiobookScriptPreprocessor as Preprocessor,
# )
from src.analyze_jmac.mecab import mecab_wakati_generator, mecab_detailed_generator

# define path
corpus_name = "jmac"
exp_name = "03_VAD_Adjusted"

exp_dir = Path(DATA_TAKESHUN256_DIR) / corpus_name / exp_name
yaml_file_path = Path(DATA_IN_ROOT_DIR) / corpus_name / "text_audio_dict_new.yaml"

assert exp_dir.exists()
assert yaml_file_path.exists()

In [None]:
# 音声波形抽出
# wav_path -> wav, sr
def extract_waveform(audio_file_path, sr=24000):
    waveform, sample_rate = librosa.load(audio_file_path, sr=sr, mono=True)
    return waveform, sample_rate


# wav -> db変換
def convert_db(waveform):
    # TODO: これは何が違うのか？どちらが適切か？
    # db = librosa.power_to_db(waveform)
    db = librosa.amplitude_to_db(waveform, ref=np.max)
    return db


# 連続区間抽出
# db -> bool_list
def run_length_encoding(arr, min_run_length=3):
    diff = np.diff(arr)  # 隣接要素の差分を計算
    run_starts = np.where(diff != 0)[0] + 1  # 差分が0でないインデックスを取得し、連続する範囲の開始位置を得る
    run_lengths = np.diff(np.concatenate(([0], run_starts, [len(arr)])))  # 連続する範囲の長さを計算
    result = np.repeat(run_lengths >= min_run_length, run_lengths)  # 連続する範囲をTrueに変換
    return result

def run_length_encoding_range(arr, min_run_length=3):
    diff = np.diff(arr)
    run_starts = np.where(diff != 0)[0] + 1
    run_lengths = np.diff(np.concatenate(([0], run_starts, [len(arr)])))
    # 連続する範囲の開始、終了インデックスと長さを計算
    runs = np.concatenate(([0], run_starts))
    ranges_with_length = [(start, start + length, length) for start, length in zip(runs, run_lengths) 
                        if length >= min_run_length and arr[start]]
    return ranges_with_length

# Pause区間抽出
# db, db_threshold, time_threshold, sr -> pause_bool_list
# 閾値を超えたらpauseとみなす
def detect_pause_position(
    db_sequence, db_threshold=-50, time_threshold=50 / 1000, sample_rate=24000
):
    """dbと音声長の閾値からpauseの位置を判定する。

    Args:
        db_sequence (np.array): 音声波形をdbに変換した配列
        db_threshold (float): 無音区間とするdbの閾値
        time_threshold (float): 無音区間が連続した時にpauseとみなす時間の閾値

    Returns:
        pause_positions (list): pauseの位置のリスト
    """
    under_db_threshold = db_sequence < db_threshold

    # 連続区間を抽出
    sample_threshold = int(time_threshold * sample_rate)
    is_continuous = run_length_encoding(under_db_threshold, sample_threshold)

    # pauseの位置を抽出
    pause_positions = under_db_threshold & is_continuous

    return pause_positions


# pause区間付きの波形の可視化
def plot_db_with_pause(db, sr, db_threshold, time_threshold, xlim=None):
    fig, ax = plt.subplots(figsize=(20, 5))
    x = np.arange(len(db)) / sr
    ax.plot(x, db, label="db")

    # dbの閾値を引く
    ax.axhline(
        y=db_threshold,
        color="r",
        linestyle="-",
        linewidth=2,
        alpha=0.7,
        label="db_threshold",
    )

    # pauseの領域を塗りつぶす
    pause_position = detect_pause_position(db, db_threshold, time_threshold, sr)
    plt.fill_between(x, -80, 0, where=pause_position, facecolor="b", alpha=0.5)

    ax.set_xlim(xlim)
    ax.legend()
    plt.show()


# 波形の可視化
def plot_wavform(waveform, sr, xlim=None):
    fig, ax = plt.subplots(figsize=(20, 5))
    x = np.arange(len(waveform)) / sr
    ax.plot(x, waveform, label="waveform")
    ax.set_xlim(xlim)
    ax.legend()
    plt.show()


# 音声再生ボタン生成
def play_button(waveform, sr):
    display(Audio(waveform, rate=sr, autoplay=True))


# アライメントの抽出
# lab_path -> df_lab
def read_lab(lab_path):
    """labファイルを読み込む"""
    # labファイルがない場合
    if not Path(lab_path).exists():
        print(f"{lab_path} does not exist.")
        return None

    # labファイルがある場合
    df_lab = []
    with open(lab_path, "r") as f:
        for phoneme_idx, line in enumerate(f):
            if line == "":
                continue
            start, end, phoneme = line.split()
            duration = float(end) - float(start)
            df_lab.append(
                {
                    "start": float(start),
                    "end": float(end),
                    "phoneme": phoneme,
                    "phoneme_idx": phoneme_idx,
                    "duration": duration,
                }
            )
    df_lab = pd.DataFrame(df_lab)
    return df_lab


# アライメントの可視化
def plot_phoneme_alignment(df, xlim=None):
    """Labファイルから音素のアライメントをプロットする

    Args:
        lab_path (_type_): Labファイルのパス
    """
    # df = read_lab(lab_path)
    # display(df[-10:])

    # 描画
    fig, ax = plt.subplots(figsize=(20, 2))
    for start, end, label in df.values:
        ax.axvline(start, color="gray", linestyle="--")
        ax.text((start + end) / 2, 0.5, label, ha="center", va="bottom", fontsize=8)
    # ax.set_yticks([])
    ax.set_xlim(xlim)
    ax.set_xlabel("Time (seconds)")
    fig.tight_layout()
    plt.legend()
    plt.show()


# 並べて可視化する。
def plot_all(
    df_temp, wav_path, sample_rate=24000, db_threshold=-50, time_threshold=50 / 1000
):
    """wavファイル、db、アライメントを並べて可視化する"""
    wav, sr = extract_waveform(wav_path, sr=sample_rate)
    db = convert_db(wav)
    xlim = (0, len(wav) / sr)

    print(wav_path)
    print("wav.shape:", wav.shape)
    print("seconds:", len(wav) / sr)

    fig, ax = plt.subplots(
        3, 1, figsize=(20, 10), gridspec_kw={"height_ratios": [4, 4, 2]}
    )
    # print("spk_id:", spk_id)
    # print("wav_id:", wav_id)
    # print("xlim:", xlim)
    # print("transcript:", transcript)
    # print("start ploting...")

    # 波形の可視化
    x = np.arange(len(wav)) / sr
    ax[0].plot(x, wav, label="waveform")
    ax[0].set_xlim(xlim)
    ax[0].legend()

    # dbの可視化
    x = np.arange(len(db)) / sr
    ax[1].plot(x, db, label="db")
    # dbの閾値を引く
    ax[1].axhline(
        y=db_threshold,
        color="r",
        linestyle="-",
        linewidth=2,
        alpha=0.7,
        label="db_threshold",
    )
    # pauseの領域を塗りつぶす
    pause_position = detect_pause_position(db, db_threshold, time_threshold, sr)
    ax[1].fill_between(x, -80, 0, where=pause_position, facecolor="b", alpha=0.5)
    ax[1].set_xlim(xlim)
    ax[1].legend()

    # アライメントの可視化
    # 描画
    for start, end, label in df_temp.values:
        ax[2].axvline(end, color="gray", linestyle="--")
        ax[2].axvline(start, color="gray", linestyle="--")
        ax[2].text((start + end) / 2, 0.5, label, ha="center", va="bottom", fontsize=8)
    # ax.set_yticks([])
    ax[2].set_xlim(xlim)
    ax[2].set_xlabel("Time (seconds)")
    # ax[2].tight_layout()
    ax[2].legend()

    plt.show()

    play_button(wav, sr)


# 並べて可視化する。
def plot_all2(
    df_temp, wav_path, sample_rate=24000, db_threshold=-50, time_threshold=50 / 1000
):
    """wavファイル、db、アライメントを並べて可視化する"""
    wav, sr = extract_waveform(wav_path, sr=sample_rate)
    db = convert_db(wav)
    xlim = (0, len(wav) / sr)

    print(wav_path)
    print("wav.shape:", wav.shape)
    print("seconds:", len(wav) / sr)

    fig, ax = plt.subplots(
        2, 1, figsize=(20, 10), gridspec_kw={"height_ratios": [4, 4]}
    )
    # print("spk_id:", spk_id)
    # print("wav_id:", wav_id)
    # print("xlim:", xlim)
    # print("transcript:", transcript)
    # print("start ploting...")

    # 波形の可視化
    x = np.arange(len(wav)) / sr
    ax[0].plot(x, wav, label="waveform")
    ax[0].set_xlim(xlim)
    ax[0].legend()

    # dbの可視化
    x = np.arange(len(db)) / sr
    ax[1].plot(x, db, label="db")
    # dbの閾値を引く
    ax[1].axhline(
        y=db_threshold,
        color="r",
        linestyle="-",
        linewidth=2,
        alpha=0.7,
        label="db_threshold",
    )
    # pauseの領域を塗りつぶす
    pause_position = detect_pause_position(db, db_threshold, time_threshold, sr)
    ax[1].fill_between(x, -80, 0, where=pause_position, facecolor="b", alpha=0.5)
    ax[1].set_xlim(xlim)
    ax[1].legend()

    # アライメントの可視化
    # 描画
    for start, end, label in df_temp.values:
        ax[0].axvline(end, color="gray", linestyle="--")
        ax[0].axvline(start, color="gray", linestyle="--")
        ax[0].text((start + end) / 2, 0.5, label, ha="center", va="bottom", fontsize=12)
    # ax.set_yticks([])
    ax[0].set_xlim(xlim)
    ax[0].set_xlabel("Time (seconds)")
    # ax[2].tight_layout()
    ax[0].legend()

    plt.show()

    play_button(wav, sr)


def get_pause_ranges(
    db_sequence, db_threshold=-50, time_threshold=0.05, sample_rate=24000
):
    pause_position = detect_pause_position(
        db_sequence, db_threshold, time_threshold, sample_rate
    )

    # def run_length_encoding_range(arr, min_run_length=3):
    #     """
    #     Run-Length Encoding (RLE)を実行して連続している部分をTrueとしたブール配列を返す関数

    #     Parameters:
    #         arr (numpy.ndarray): 連続している部分を判定したい1次元のNumPy配列
    #         min_run_length (int): 連続していると判定する最小の長さ（デフォルトは3）

    #     Returns:
    #         numpy.ndarray: 連続している部分がTrueとなったブール配列
    #         list: 連続している部分の始点と終点のリスト [(start1, end1), (start2, end2), ...]
    #     """
    #     diff = np.diff(arr)  # 隣接要素の差分を計算
    #     run_starts = np.where(diff != 0)[0] + 1  # 差分が0でないインデックスを取得し、連続する範囲の開始位置を得る

    #     starts = np.concatenate(([0], run_starts))
    #     ends = np.concatenate((run_starts, [len(arr)]))
    #     lengths = ends - starts
    #     ranges = list(zip(starts, ends, lengths))

    #     # min_run_length以下の範囲を削除, Trueが連続しているもののみを取り出す
    #     ranges = [r for r in ranges if (r[2] >= min_run_length and arr[r[0]])]

    #     return ranges


    sample_threshold = int(time_threshold * sample_rate)
    pause_ranges = run_length_encoding_range(pause_position, sample_threshold)
    return pause_ranges


def classfy_pause(
    db_sequence, lab_path, sample_rate=24000, db_threshold=-50, time_threshold=0.05
):
    """ポーズを分類する

    Args:
        df_jvs (_type_): _description_
    """
    # db_threshold = -50
    # time_threshold = 0.05
    # sample_rate = 24000

    # db_sequence = df_jvs.iloc[0]['db_sequence']
    pause_position = detect_pause_position(
        db_sequence, db_threshold, time_threshold, sample_rate
    )

    # def run_length_encoding_range(arr, min_run_length=3):
    #     """
    #     Run-Length Encoding (RLE)を実行して連続している部分をTrueとしたブール配列を返す関数

    #     Parameters:
    #         arr (numpy.ndarray): 連続している部分を判定したい1次元のNumPy配列
    #         min_run_length (int): 連続していると判定する最小の長さ（デフォルトは3）

    #     Returns:
    #         numpy.ndarray: 連続している部分がTrueとなったブール配列
    #         list: 連続している部分の始点と終点のリスト [(start1, end1), (start2, end2), ...]
    #     """
    #     diff = np.diff(arr)  # 隣接要素の差分を計算
    #     run_starts = np.where(diff != 0)[0] + 1  # 差分が0でないインデックスを取得し、連続する範囲の開始位置を得る

    #     starts = np.concatenate(([0], run_starts))
    #     ends = np.concatenate((run_starts, [len(arr)]))
    #     lengths = ends - starts
    #     ranges = list(zip(starts, ends, lengths))

    #     # min_run_length以下の範囲を削除, Trueが連続しているもののみを取り出す
    #     ranges = [r for r in ranges if (r[2] >= min_run_length and arr[r[0]])]

    #     return ranges

    sample_threshold = int(time_threshold * sample_rate)
    pause_ranges = run_length_encoding_range(pause_position, sample_threshold)

    # print(pause_ranges)

    # df_lab = read_lab(df_jvs.iloc[0]['lab_path'])
    df_lab = read_lab(lab_path)

    ans = []
    for pause_range in pause_ranges:
        # df_labのstartもしくは、endが、start, endの範囲内にあるかどうか
        pause_start = pause_range[0]
        pause_end = pause_range[1]
        phoneme_start = df_lab["start"].values * sample_rate
        phoneme_end = df_lab["end"].values * sample_rate
        is_start_include = (pause_start <= phoneme_start) & (phoneme_start <= pause_end)
        is_end_include = (pause_start <= phoneme_end) & (phoneme_end <= pause_end)

        include_phonemes = df_lab[is_start_include | is_end_include]["phoneme"].values
        print(include_phonemes)
        if "silE" in include_phonemes:
            pause_type = "silE"
        elif "silB" in include_phonemes:
            pause_type = "silB"
        elif "sil" in include_phonemes:
            pause_type = "sil"
        elif "pau" in include_phonemes:
            pause_type = "pau"
        elif "sp" in include_phonemes:
            pause_type = "sp"
        else:
            pause_type = "RP"

        ans.append([pause_range[0], pause_range[1], pause_range[2], pause_type])
    return ans

def extract_pause_ranges_from_wavpath(
    wav_path, sample_rate=24000, db_threshold=-50, time_threshold=50 / 1000
):
    """wavファイルからpauseの位置を抽出する"""

    wav, sr = extract_waveform(wav_path, sr=sample_rate)
    db = convert_db(wav)

    pause_position = detect_pause_position(
        db, db_threshold, time_threshold, sample_rate
    )

    sample_threshold = int(time_threshold * sample_rate)
    pause_ranges = run_length_encoding_range(pause_position, sample_threshold)
    return pause_ranges


# ----setting----
# 閾値の設定
db_threshold = -30
# time_threshold = 50 / 1000 # 50ms
time_threshold = 200 / 1000  # 200ms
sample_rate = 24000
# ---------------
wav_p = "/data2/takeshun256/jmac_split_and_added_lab/audiobook_25/audiobook_25_109.wav"
wav, sr = extract_waveform(wav_p, sr=sample_rate)
db = convert_db(wav)
plot_wavform(wav, sr)
plot_db_with_pause(db, sr, db_threshold, time_threshold)
pause_ranges = get_pause_ranges(db, db_threshold, time_threshold, sample_rate)
print(pause_ranges)
pause_ranges_by_wav = extract_pause_ranges_from_wavpath(wav_p, sample_rate, db_threshold, time_threshold)
print(pause_ranges_by_wav)

In [None]:
morp_phons_yaml_path = exp_dir / "text_audio_dict_new_with_morp_phons_and_lab.yaml"

with open(morp_phons_yaml_path, "r") as f:
    morp_phons_yaml_data = yaml.safe_load(f)

In [None]:
morp_phons_yaml_data["audiobook_25"]

In [None]:
morp_phons_yaml_data["audiobook_0"]["000"]["morp_lab"]

In [None]:
# ----setting----
# 閾値の設定
# db_threshold = -50
db_threshold = -30
# time_threshold = 50 / 1000 # 50ms
time_threshold = 80 / 1000  # 200ms
# time_threshold = 100 / 1000  # 200ms
# time_threshold = 200 / 1000  # 200ms
sample_rate = 24000
# ---------------
wav_p = "/data2/takeshun256/jmac_split_and_added_lab/audiobook_0/audiobook_0_000.wav"
wav, sr = extract_waveform(wav_p, sr=sample_rate)
db = convert_db(wav)
# plot_wavform(wav, sr)
# plot_db_with_pause(db, sr, db_threshold, time_threshold)
pause_ranges = get_pause_ranges(db, db_threshold, time_threshold, sample_rate)
print(pause_ranges)

# ml = wav_p.replace(".wav", ".labm")
# with open(ml, "r") as f:
#     ml_data = f.readlines()
# ml_data = [line.strip() for line in ml_data]
# pprint(ml_data)
ml_data = morp_phons_yaml_data["audiobook_0"]["000"]["morp_lab"]
print(ml_data)
df_temp = pd.DataFrame(
    [ml.split(" ") for ml in ml_data], columns=["start", "end", "phoneme"]
)
df_temp["start"] = df_temp["start"].astype(float)
df_temp["end"] = df_temp["end"].astype(float)

plot_all2(
    df_temp,
    wav_p,
    sample_rate=sample_rate,
    db_threshold=db_threshold,
    time_threshold=time_threshold,
)

In [None]:
# ----setting----
# 閾値の設定
# db_threshold = -50
db_threshold = -30
# time_threshold = 50 / 1000 # 50ms
time_threshold = 80 / 1000  # 200ms
# time_threshold = 100 / 1000  # 200ms
# time_threshold = 200 / 1000  # 200ms
sample_rate = 24000
# ---------------

new_morp_phons_yaml_data_path = (
    exp_dir / "text_audio_dict_new_with_morp_phons_and_lab_with_pause_fix_runencode_80ms.yaml"
)

new_morp_phons_yaml_data = {}
for audiobook_name, info in tqdm(morp_phons_yaml_data.items()):
    audiobook_dict = {}
    print(f"[INFO] audiobook_name: {audiobook_name}")
    for chapter_name, chapter_info in info.items():
        wav_path = (
            Path(DATA_TAKESHUN256_DIR)
            / "jmac_split_and_added_lab"
            / audiobook_name
            / f"{audiobook_name}_{chapter_name}.wav"
        )

        if not wav_path.exists():
            print(f"[INFO] {wav_path} does not exist.")
            continue

        wav, sr = extract_waveform(wav_path, sr=sample_rate)
        try:
            db = convert_db(wav)
        except ValueError:
            print(f"[INFO] {wav_path} is not for convert_db.")
            continue
        pause_ranges = get_pause_ranges(db, db_threshold, time_threshold, sample_rate)

        chapter_info["wav_path"] = str(wav_path)
        pause_ranges_str = [
            " ".join([str(pause_range[0]), str(pause_range[1]), str(pause_range[2])])
            for pause_range in pause_ranges
        ]

        chapter_info["pause_ranges_str"] = pause_ranges_str
        audiobook_dict[chapter_name] = chapter_info

    new_morp_phons_yaml_data[audiobook_name] = audiobook_dict
    print(f"[INFO] {audiobook_name} is done.")


# with open(new_morp_phons_yaml_data_path, "w") as f:
#     yaml.dump(new_morp_phons_yaml_data, f, allow_unicode=True)

# pickle
import pickle

with open(new_morp_phons_yaml_data_path, "wb") as f:
    pickle.dump(new_morp_phons_yaml_data, f)

print(f"[INFO] {new_morp_phons_yaml_data_path} is saved.")

In [None]:
new_morp_phons_yaml_data["audiobook_25"]

In [None]:
# ファイルが大きすぎるので、小さくする
new_morp_phons_yaml_data_small_path = (
    exp_dir / "text_audio_dict_new_with_morp_phons_and_lab_with_pause_small_fix_runencode_80ms.yaml"
)

new_morp_phons_yaml_data_small = {}
for audiobook_name, info in new_morp_phons_yaml_data.items():
    audiobook_dict = {}
    print(f"[INFO] audiobook_name: {audiobook_name}")
    for chapter_name, chapter_info in info.items():
        chapter_dict = {}
        if (
            "wav_path" not in chapter_info
            or "pause_ranges_str" not in chapter_info
            or "morp_lab" not in chapter_info
        ):
            print(f"[INFO] {chapter_name} is skipped.")
            continue
        chapter_dict["wav_path"] = chapter_info["wav_path"]
        chapter_dict["pause_ranges_str"] = chapter_info["pause_ranges_str"][:10]
        chapter_dict["morp_lab"] = chapter_info["morp_lab"]

        audiobook_dict[chapter_name] = chapter_dict
    new_morp_phons_yaml_data_small[audiobook_name] = audiobook_dict

with open(new_morp_phons_yaml_data_small_path, "w") as f:
    yaml.dump(new_morp_phons_yaml_data_small, f, allow_unicode=True)

In [None]:
# safe_loadできないので、pklで保存する。
import pickle

with open(str(new_morp_phons_yaml_data_path).replace("yaml", "pkl"), "wb") as f:
    pickle.dump(new_morp_phons_yaml_data, f, protocol=4)

# load
# with open(new_morp_phons_yaml_data_path, "rb") as f:
#     new_morp_phons_yaml_data = pickle.load(f)

In [None]:
new_morp_phons_yaml_data_small_path

In [None]:
# ----setting----
# 閾値の設定
# db_threshold = -50
db_threshold = -30
# time_threshold = 50 / 1000 # 50ms
time_threshold = 200 / 1000  # 200ms
sample_rate = 24000
# ---------------
# wav_p = "/data2/takeshun256/jmac_split_and_added_lab/audiobook_25/audiobook_25_100.wav"
wav_p = "/data2/takeshun256/jmac_split_and_added_lab/audiobook_0/audiobook_0_000.wav"
wav, sr = extract_waveform(wav_p, sr=sample_rate)
db = convert_db(wav)
# plot_wavform(wav, sr)
# plot_db_with_pause(db, sr, db_threshold, time_threshold)
pause_ranges = get_pause_ranges(db, db_threshold, time_threshold, sample_rate)
print(pause_ranges)

# ml = "/data2/takeshun256/jmac_split_and_added_lab/audiobook_25/audiobook_25_100.labm"
# ml = "/data2/takeshun256/jmac_split_and_added_lab/audiobook_0/audiobook_0_000.labm"
ml = wav_p.replace(".wav", ".labm")
with open(ml, "r") as f:
    ml_data = f.readlines()
ml_data = [line.strip() for line in ml_data]
pprint(ml_data)

df_temp = pd.DataFrame(
    [ml.split(" ") for ml in ml_data], columns=["start", "end", "phoneme"]
)
df_temp["start"] = df_temp["start"].astype(float)
df_temp["end"] = df_temp["end"].astype(float)

plot_all2(
    df_temp,
    wav_p,
    sample_rate=sample_rate,
    db_threshold=db_threshold,
    time_threshold=time_threshold,
)