In [6]:
import pandas as pd
import re
import matplotlib.pyplot as plt
import json
test_name = "4ch_LSaug"
path = f"/scratch/acf15294oh/workspace/apsipa/model/wav2vec2_FCNN/results/wav2vec2_ctc_{test_name}/1234/flame_test.txt"
target_flame_folder = f"/scratch/acf15294oh/database/behavior_TM01/old_text_aug"

hyp_row_data = []
with open(path, 'r') as file:
    for line in file:
        hyp_row_data.append(line)

In [7]:
def process_transcription_data(hyp_data):
    results = []
    for i, hyp_line in enumerate(hyp_data):
        num = i % 2
                
        if(num == 0):
            speech = []
            file = hyp_line.replace('\n', '')
            speech.append(file)
            name = "_".join(file.split("_")[3:]).replace(".wav" , ".txt")
            txt_file = target_flame_folder + "/" + file.split("_")[0] + "_" + file.split("_")[1] + "_" + name

            with open(txt_file, 'r') as file:
                txt_content = file.readlines()
            
            intervals = []  # 初期区間を設定
            last_end = 0  # 最初のスピーチの開始時間を設定
            
            for line in txt_content:
                parts = line.strip().split('\t')
                start, end, label = float(parts[0]), float(parts[1]), parts[2]

                if start > last_end:
                    intervals.append(('blank', (last_end, start)))  # 空白区間を追加
                intervals.append((label, (start, end)))  # ラベル付き区間を追加
                last_end = end

            # 最後の空白区間を追加
            intervals.append(('blank', (last_end, last_end + 0.5)))
            formatted_intervals = str(intervals)

            


        elif(num == 1):
            hyp_flame = hyp_line
            speech.append(hyp_flame)
            speech.append(formatted_intervals)
            results.append(speech)
            

            
        else:
            print("error")
            
    # 結果をデータフレームに変換
    df = pd.DataFrame(results, columns=['File', 'Hypothesis Flame', 'Target Flame'])
    return df

In [8]:
data = process_transcription_data(hyp_row_data)

In [9]:
data["Hypothesis Flame"][0]

"[('#', (2, 3)), ('#', (32, 34)), ('#', (66, 68)), ('#', (101, 104)), ('#', (133, 135)), ('#', (166, 168)), ('#', (199, 201)), ('#', (231, 233)), ('#', (264, 266)), ('#', (297, 299)), ('#', (337, 339)), ('#', (368, 370)), ('#', (400, 402)), ('#', (432, 435)), ('#', (466, 468))]\n"

In [10]:
import ast
for i, d in enumerate(data["Hypothesis Flame"]):
    data_tuples = ast.literal_eval(d)
    flame_list = []

    pre_end_time = 0
    start = 0.0
    end = 0.0

    for index, flame in enumerate(data_tuples):
        label = "ch" if flame[0] in ["#", "#$"] else "sw"
        start_time, end_time = (flame[1][0] * 0.02, flame[1][1] * 0.02)

        # 初期 blank 区間の追加
        if start_time - pre_end_time > 0:
            flame_list.append(('blank', pre_end_time, start_time))

        flame_list.append((label, start_time, end_time))
        pre_end_time = end_time


    

    target_format_data = [(item[0], (item[1], item[2])) for item in flame_list]

    data.loc[i, "Hypothesis Flame"] = json.dumps(target_format_data)

In [11]:
data["Hypothesis Flame"][0]

'[["blank", [0, 0.04]], ["ch", [0.04, 0.06]], ["blank", [0.06, 0.64]], ["ch", [0.64, 0.68]], ["blank", [0.68, 1.32]], ["ch", [1.32, 1.36]], ["blank", [1.36, 2.02]], ["ch", [2.02, 2.08]], ["blank", [2.08, 2.66]], ["ch", [2.66, 2.7]], ["blank", [2.7, 3.3200000000000003]], ["ch", [3.3200000000000003, 3.36]], ["blank", [3.36, 3.98]], ["ch", [3.98, 4.0200000000000005]], ["blank", [4.0200000000000005, 4.62]], ["ch", [4.62, 4.66]], ["blank", [4.66, 5.28]], ["ch", [5.28, 5.32]], ["blank", [5.32, 5.94]], ["ch", [5.94, 5.98]], ["blank", [5.98, 6.74]], ["ch", [6.74, 6.78]], ["blank", [6.78, 7.36]], ["ch", [7.36, 7.4]], ["blank", [7.4, 8.0]], ["ch", [8.0, 8.040000000000001]], ["blank", [8.040000000000001, 8.64]], ["ch", [8.64, 8.700000000000001]], ["blank", [8.700000000000001, 9.32]], ["ch", [9.32, 9.36]]]'

In [12]:
data["Target Flame"][0]

"[('ch', (0.0, 0.22191200000003164)), ('blank', (0.22191200000003164, 0.7089139999999929)), ('ch', (0.7089139999999929, 0.9261940000000095)), ('blank', (0.9261940000000095, 1.341068000000007)), ('ch', (1.341068000000007, 1.548275999999987)), ('blank', (1.548275999999987, 2.0409970000000044)), ('ch', (2.0409970000000044, 2.256340000000023)), ('blank', (2.256340000000023, 2.709904999999992)), ('ch', (2.709904999999992, 2.962363000000039)), ('blank', (2.962363000000039, 3.4660030000000006)), ('ch', (3.4660030000000006, 3.7032710000000293)), ('blank', (3.7032710000000293, 4.0439250000000015)), ('ch', (4.0439250000000015, 4.26751999999999)), ('blank', (4.26751999999999, 4.685064000000011)), ('ch', (4.685064000000011, 4.919549000000018)), ('blank', (4.919549000000018, 5.3360870000000205)), ('ch', (5.3360870000000205, 5.544702000000029)), ('blank', (5.544702000000029, 5.964004999999986)), ('ch', (5.964004999999986, 6.1923530000000255)), ('blank', (6.1923530000000255, 6.749217999999985)), ('ch

In [13]:
data

Unnamed: 0,File,Hypothesis Flame,Target Flame
0,eat_MDN01_TM01_CBG_0343_81.wav,"[[""blank"", [0, 0.04]], [""ch"", [0.04, 0.06]], [...","[('ch', (0.0, 0.22191200000003164)), ('blank',..."
1,eat_MDN01_TM01_CBG_0457_01.wav,"[[""blank"", [0, 0.34]], [""ch"", [0.34, 0.36]], [...","[('ch', (0.0, 0.245207999999991)), ('blank', (..."
2,eat_MDK01_TM01_RTZ_2_0178_58.wav,"[[""blank"", [0, 0.08]], [""ch"", [0.08, 0.12]], [...","[('ch', (0.0, 0.25387100000000373)), ('blank',..."
3,eat_MDN01_TM01_CBG_0208_33.wav,"[[""blank"", [0, 0.08]], [""ch"", [0.08, 0.1]], [""...","[('ch', (0.0, 0.23093099999999822)), ('blank',..."
4,eat_MAN01_TM01_RTZ_1_0022_00.wav,"[[""blank"", [0, 0.4]], [""sw"", [0.4, 0.42]], [""b...","[('sw', (0.0, 1.365836999999999)), ('blank', (..."
...,...,...,...
394,eat_MDK01_TM01_CBG_0429_58.wav,"[[""blank"", [0, 0.16]], [""sw"", [0.16, 0.18]]]","[('sw', (0.0, 1.0233579999999733)), ('blank', ..."
395,eat_MDK01_TM01_CBG_0454_80.wav,"[[""blank"", [0, 0.26]], [""ch"", [0.26, 0.3]], [""...","[('ch', (0.0, 0.4081090000000245)), ('blank', ..."
396,eat_MHF01_TM01_RTZ_1_0000_87.wav,"[[""blank"", [0, 0.1]], [""ch"", [0.1, 0.12]], [""b...","[('ch', (0.0, 0.215341)), ('blank', (0.215341,..."
397,eat_MDK01_TM01_GUM_0001_28.wav,"[[""blank"", [0, 0.38]], [""ch"", [0.38, 0.4]], [""...","[('ch', (0.0, 0.7643099999999998)), ('blank', ..."


In [14]:
labels = ["ch","sw"]
allowances = [0, 0.01, 0.05, 0.1]




for allowance in allowances:
    print(f"allowance:{allowance}")
    for label in labels:
        all_hyp_count = 0
        all_target_TP = 0
        all_target_count = 0
        all_hyp_TP = 0
        for index, row in data.iterrows():

            TP = 0
            FN = 0
            FP = 0
            #print(row["File"])
            target_str = row["Target Flame"]
                # 文字列をリストに変換
            try:
                target_list = ast.literal_eval(target_str)
            except ValueError as e:
                print(f"Error converting string to list: {e}")
                continue  # 変換エラーがある場合は次の行に進む

            hypo_str = row["Hypothesis Flame"]
                # 文字列をリストに変換
            try:
                hypo_list = ast.literal_eval(hypo_str)
            except ValueError as e:
                print(f"Error converting string to list: {e}")
                continue  # 変換エラーがある場合は次の行に進む

            target_count = 0
            
            for target_flame in target_list:

                target_label = target_flame[0]

                if target_label == label:
                    
                    target_count += 1
                    tp_flag = False
                    target_start_time = target_flame[1][0]
                    target_end_time = target_flame[1][1]
                    #print(f"target_Label: {target_label}, Start time: {target_start_time}, End time: {target_end_time}")


                    for hyp_flame in hypo_list:
                        hyp_label = hyp_flame[0]
                        hyp_start_time = hyp_flame[1][0]
                        hyp_end_time = hyp_flame[1][1]
                        
                        if hyp_label == target_label and not (hyp_end_time + allowance < target_start_time or hyp_start_time - allowance > target_end_time):

                            #print(f"hyp_Label: {hyp_label}, Start time: {hyp_start_time}, End time: {hyp_end_time}")
                            TP += 1
                            #print("TP")
                            tp_flag = True
                            break
                            
                    if not(tp_flag):
                        FN += 1
                        #print("FN")
            all_target_count += target_count
            all_target_TP += TP
            #print(target_count,TP,FN)
            #eat_data.loc[index, 'Target Count'] = target_count
            #eat_data.loc[index, 'TP Target'] = TP
            #eat_data.loc[index, 'FN'] = FN

            TP=0
            hyp_count = 0
            for hyp_flame in hypo_list:
                tp_flag = False
                hyp_label = hyp_flame[0]
                hyp_start_time = hyp_flame[1][0]
                hyp_end_time = hyp_flame[1][1]

                if hyp_label == label:
                    hyp_count += 1
                    for target_flame in target_list:
                        target_label = target_flame[0]
                        target_start_time = target_flame[1][0]
                        target_end_time = target_flame[1][1]

                        # 時間範囲とラベルが一致するか確認
                        if (hyp_label == target_label and not (hyp_end_time + allowance < target_start_time or hyp_start_time - allowance > target_end_time)):
                            tp_flag = True
                            TP += 1
                            break

                    if not tp_flag:
                        FP += 1
            all_hyp_count += hyp_count
            all_hyp_TP += TP
            #print(hyp_count, TP, FP)
            #eat_data.loc[index, 'Hyp Count'] = hyp_count
            #eat_data.loc[index, 'TP Hyp'] = TP
            #eat_data.loc[index, 'FP'] = FP

        #print(f"target_count={all_target_count}")
        #print(f"target_TP={all_target_TP}")
        #print(f"hyp_count={all_hyp_count}")
        #print(f"hyp_TP={all_hyp_TP}")


        precision = all_hyp_TP/all_hyp_count
        recall = all_target_TP/all_target_count

        F1_score = 2*precision*recall/(precision+recall)
        print(f"{label}:{F1_score}")

    

allowance:0
ch:0.7445677893147394
sw:0.9271758436944938
allowance:0.01
ch:0.7902468028436758
sw:0.9307282415630551
allowance:0.05
ch:0.9038272350799309
sw:0.9449378330373002
allowance:0.1
ch:0.9544439739955797
sw:0.9520426287744227
