In [21]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# Created by Tu Zhen on 2021/06/11
import logging
from argparse import ArgumentParser
from collections import Counter
from datetime import datetime
from pathlib import Path

__END = '[END]'


def __time_ms(time_in_second):
    """获取以ms为单位的时间.

    Args:
        time_in_second: 以s为单位的时间.

    Returns:
        以ms为单位的时间.
    """
    return int(time_in_second * 1000)


def __utt_id_from_time(name, begin, end):
    """从名称和时间获取一句话的标识id.

    Args:
        name: 文件名称标识.
        begin: 开始时间, 单位s.
        end: 结束时间, 单位s.

    Returns:
        utt_id, 例如输入(test, 1.234, 3.456), 输出test-0001234-0003456.
    """
    return f'{name}-{__time_ms(begin):0>7d}-{__time_ms(end):0>7d}'


def __read_lines(in_file: Path):
    """行数读取每一行的内容, 去除两端无用的空格.

    Args:
        in_file: 输入文件路径.

    Returns:
        每一行内容的列表.
    """
    lines = list()
    with in_file.open(encoding='utf-8') as file:
        for line in file:
            lines.append(line.strip())
    return lines


def __parse_srt_timestamp(timestamp_str):
    """解析字幕文件的时间戳.

    Args:
        timestamp_str: 时间戳字符串.

    Returns:
        时长, 单位s.
    """
    time_obj = datetime.strptime(timestamp_str, '%H:%M:%S,%f')
    second = time_obj.hour * 3600 + time_obj.minute * 60 + time_obj.second
    return second + time_obj.microsecond / 1e6


def parse_srt(in_file: Path):
    """解析字幕文件, 仅支持单行字幕的情况.

    Args:
        in_file: 输入文件路径.

    Returns:
        片段列表, 片段为(begin, end, text).
    """
    lines = __read_lines(in_file)
    if not lines:
        return []

    if lines[-1] != '':
        lines.append('')
    if len(lines) % 4 != 0:
        raise ValueError(f'行数必须为4的倍数: {in_file}, {len(lines)}行.')

    segments = list()
    segment_index = 1
    for index in range(0, len(lines), 4):
        if segment_index != int(lines[index]):
            raise ValueError(f'字幕文件格式错误: {segment_index} != '
                             f'{int(lines[index])}.')
        begin_str, end_str = lines[index + 1].split(' --> ')
        begin = __parse_srt_timestamp(begin_str)
        end = __parse_srt_timestamp(end_str)
        text = lines[index + 2].strip()
        segments.append((begin, end, text))
        segment_index += 1
    return segments


def __parse_srt_dir(srt_dir: Path):
    """解析字幕文件夹.

    Args:
        srt_dir: 字幕文件夹路径.

    Returns:
        id到符号列表的映射.
    """
    id_to_symbols = dict()
    for srt_file in srt_dir.iterdir():
        file_path = str(srt_file)
        if not file_path.endswith('srt'):
            continue
        if file_path.split('/')[-1].startswith(file_type):
            for begin, end, text in parse_srt(srt_file):
                utt_id = __utt_id_from_time(srt_file.stem, begin, end)
                symbols = list()
                for word in text.split(' '):
                    if word != __END:
                        symbols.append('')
                    else:
                        symbols[-1] = __END
                id_to_symbols[utt_id] = symbols
    return id_to_symbols


def eval_from_srt(true_srt_dir: Path, pred_srt_dir: Path):
    """从字幕文件夹进行评测.

    Args:
        true_srt_dir: 真实的字幕文件夹路径..
        pred_srt_dir: 预测的字幕文件夹路径.
    """
    id_to_sym_true = __parse_srt_dir(true_srt_dir)
    id_to_sym_pred = __parse_srt_dir(pred_srt_dir)
    c = 0
    if set(id_to_sym_true.keys()) != set(id_to_sym_pred.keys()):
        raise ValueError('真实和预测的字幕文件id不一致.')

    tp_counter = Counter()
    fp_counter = Counter()
    fn_counter = Counter()
    # print(len(id_to_sym_true))
    for utt_id, true_sym_list in id_to_sym_true.items():
        # [:-1]表示最后一个位置不需要测试, 已经是结尾.
        # print(len(id_to_sym_true[utt_id]), len(id_to_sym_pred[utt_id]))
        args = zip(true_sym_list[:-1], id_to_sym_pred[utt_id][:-1])
        for true_sym, pred_sym in args:
            if true_sym=='[END]':
                c += 1
            is_correct = true_sym == pred_sym
            tp_counter[true_sym] += int(is_correct)
            fp_counter[pred_sym] += int(not is_correct)
            fn_counter[true_sym] += int(not is_correct)

    print(f'tp = {tp_counter[__END]}, fp = {fp_counter[__END]}, '
                 f'fn = {fn_counter[__END]}')

    precision = tp_counter[__END] / (tp_counter[__END] + fp_counter[__END])
    recall = tp_counter[__END] / (tp_counter[__END] + fn_counter[__END])
#     logging.info(f'precision = {precision:.2%}, recall = {recall:.2%}, '
#                  f'f1 score = '
#                  f'{(2. * precision * recall / (precision + recall)):.2f}')
    print(f'precision = {precision:.3%}, recall = {recall:.3%}, '
                 f'f1 score = '
                 f'{(2. * precision * recall / (precision + recall)):.3f}')

In [24]:
for file_type in ['', 'bcut', 'BV']:
    if file_type:
        print("Evaluate Result for {} Type".format(file_type))
    else:
        print("Evaluate Result for Two Types")
    true_srt_dir = Path('./ali/true/')
    pred_srt_dir = Path('./ali/pred/')
    eval_from_srt(true_srt_dir, pred_srt_dir)
    print("=====================================================")

Evaluate Result for Two Types
tp = 2086, fp = 986, fn = 639
precision = 67.904%, recall = 76.550%, f1 score = 0.720
Evaluate Result for bcut Type
tp = 564, fp = 270, fn = 166
precision = 67.626%, recall = 77.260%, f1 score = 0.721
Evaluate Result for BV Type
tp = 1522, fp = 716, fn = 473
precision = 68.007%, recall = 76.291%, f1 score = 0.719
