In [1]:
import os

In [2]:
class CharJudge:
    def __init__(self):
        self._sys_chtotalErrors = 0
        self._gt_chtotalErrors = 0
        self._sys_chdetectErrors = 0
        self._sys_chcorrectErrors = 0
        
    def itemInput(self, item):
        self._gt_chtotalErrors += len(item['gt_set'])
        self._sys_chtotalErrors += len(item['sys_set'])
        self._sys_chdetectErrors += len(item['match'])
        
        for idx in item['match']:
            if item['gt'][idx] == item['sys'][idx]:
                self._sys_chcorrectErrors += 1
                
    def result(self):        
        self.detect_precision = (self._sys_chdetectErrors / self._sys_chtotalErrors)
        self.detect_recall = (self._sys_chdetectErrors / self._gt_chtotalErrors)

        self.correct_precision = (self._sys_chcorrectErrors / self._sys_chtotalErrors)
        self.correct_recall = (self._sys_chcorrectErrors / self._gt_chtotalErrors)

        print('Char')
        print('=== Detection Stage ===')
        print('Precision(%) = {:.4f}'.format(self.detect_precision))
        print('Recall(%) = {:.4f}'.format(self.detect_recall))
        print('=== Correction Stage ===')
        print('Precision(%) = {:.4f}'.format(self.correct_precision))
        print('Recall(%) = {:.4f}'.format(self.correct_recall))
        

In [3]:
class SeqJudge:
    def __init__(self):
        self._gtCorrectLine = 0        
        
        self._sysTrueCorrect = 0
        self._fp = 0 
        self._gtErrorLine = 0
        self._sysTrueError = 0 
        self._corsysTrueError = 0 
        self._sysErrorLine = 0
    def detail(self):
        print(self._sysTrueCorrect)
        print(self._fp) 
        print(self._gtErrorLine)
        print(self._sysTrueError) 
        print(self._corsysTrueError) 
        print(self._sysErrorLine)
    
    def itemInput(self, item):
        gt_set = item['gt_set']
        sys_set = item['sys_set']
        
        if len(gt_set) == 0: 
            self._gtCorrectLine += 1
            if len(sys_set) == 0:
                self._sysTrueCorrect += 1                        
            else:
                self._fp += 1
        else:
            self._gtErrorLine += 1
            if gt_set == sys_set:
                self._sysTrueError += 1                
                for idx in item['match']:
                    if item['gt'][idx] != item['sys'][idx]:
                        break
                else:
                    self._corsysTrueError += 1        
        if len(sys_set) != 0:
            self._sysErrorLine += 1            
    
    def result(self):
        self.d_acc = (self._sysTrueError+self._sysTrueCorrect) / (self._gtCorrectLine + self._gtErrorLine)
        self.d_recall = self._sysTrueError / (self._gtErrorLine)
        self.d_precision = self._sysTrueError / (self._sysErrorLine)
        self.d_fprate = self._fp / self._gtCorrectLine
        self.d_f1 = (2*self.d_recall*self.d_precision)/(self.d_recall+self.d_precision)
        
        self.c_acc = (self._corsysTrueError+self._sysTrueCorrect) / (self._gtCorrectLine + self._gtErrorLine)
        self.c_recall = self._corsysTrueError / (self._gtErrorLine)
        self.c_precision = self._corsysTrueError / (self._sysErrorLine)
        self.c_f1 = (2*self.c_recall*self.c_precision)/(self.c_recall+self.c_precision)        
        
#         print('Sentence')
#         print('=== Detection Level ===')
#         print('FPR = {:.4f}'.format(self.d_fprate))
#         print('Accuracy = {:.4f}'.format(self.d_acc))
#         print('Precision = {:.4f}'.format(self.d_precision))
#         print('Recall = {:.4f}'.format(self.d_recall))
#         print('F1 Score = {:.4f}'.format(self.d_f1))
        
#         print('=== Correction Level ===')
#         print('Accuracy = {:.4f}'.format(self.c_acc))
#         print('Precision = {:.4f}'.format(self.c_precision))
#         print('Recall = {:.4f}'.format(self.c_recall))
#         print('F1 Score = {:.4f}'.format(self.c_f1))
        
        return {'detection':[self.d_fprate, self.d_acc, self.d_precision, self.d_recall, self.d_f1], 
                'correction':[self.c_acc, self.c_precision, self.c_recall, self.c_f1]}

In [4]:
def fileJudge(judge_file):
    chjudge = CharJudge()
    seqjudge = SeqJudge()
    with open(judge_file, 'r', encoding='utf8') as fp:
        fp.readline()

        for line_idx, line in enumerate(fp):
            item = extract(line.strip().split('|||'))
            chjudge.itemInput(item)
            seqjudge.itemInput(item)

#     chjudge.result()
#     print()
#     seqjudge.detail()
    tt = seqjudge.result()
    
    return tt

In [5]:
def extract(lst):
    gt = lst[2].split(', ')
    ground_truth = {int(idx):ch for idx, ch in zip(gt[:-1:2], gt[1::2])}
    st = lst[3].split(', ')
    sys_truth = {int(idx):ch for idx, ch in zip(st[:-1:2], st[1::2])} if len(st)>0 else dict()

    ground_set = set(ground_truth)
    sys_set = set(sys_truth)

    detect_correct_idx = ground_set.intersection(sys_set)
    
    return {'gt':ground_truth, 'sys':sys_truth, 'gt_set':ground_set, 'sys_set':sys_set, 'match':detect_correct_idx}

In [9]:
dataroot = './UDN_benchmark/0711_beam/'

In [10]:
filelist = {file:dataroot+file for _, _, files in os.walk(dataroot) for file in files if os.path.splitext}

In [11]:
detail_file = os.path.join(dataroot, 'detail.csv')
with open(detail_file, 'w', encoding='utf8') as wp:
    for file,path in filelist.items():
        cur_filename, file_extension = os.path.splitext(file)
        if file_extension != '.txt':
            continue
        
        print('Current file = {}'.format(file))
        try:
            result = fileJudge(path)
            fileinfo = ','.join(cur_filename.split('_')[1:])
            wp.write('{},'.format(fileinfo))
            wp.write('{de[0]},{de[1]},{de[2]},{de[3]},{de[4]},{co[0]},{co[1]},{co[2]},{co[3]}\n'.
                     format(de=result['detection'], co=result['correction']))

        except:
            print('Failed')
            pass

Current file = re_beam10_2_10.txt
Current file = re_beam10_2_9.txt
Current file = re_beam10_3_10.txt
Current file = re_beam10_3_9.txt
