In [18]:
import string
import argparse

import torch
import torch.backends.cudnn as cudnn
import torch.utils.data
import torch.nn.functional as F
import os
import pandas as pd

import time

from nltk.metrics.distance import edit_distance
from utils import CTCLabelConverter, AttnLabelConverter
from dataset import RawDataset, AlignCollate
from model import Model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [19]:
# 테스트할 데이터 종류

#handwriting1
data_type = '01'

#handwriting2
#data_type = '02'

#printing
#data_type = '03'

In [20]:
def validation(model, criterion, evaluation_loader, converter, opt):
    """ validation or evaluation """
    n_correct = 0
    norm_ED = 0
    length_of_data = 0
    infer_time = 0
    valid_loss_avg = Averager()

    for i, (image_tensors, labels) in enumerate(evaluation_loader):
        batch_size = image_tensors.size(0)
        length_of_data = length_of_data + batch_size
        image = image_tensors.to(device)
        # For max length prediction
        length_for_pred = torch.IntTensor([opt.batch_max_length] * batch_size).to(device)
        text_for_pred = torch.LongTensor(batch_size, opt.batch_max_length + 1).fill_(0).to(device)

        text_for_loss, length_for_loss = converter.encode(labels, batch_max_length=opt.batch_max_length)

        start_time = time.time()
        if 'CTC' in opt.Prediction:
            preds = model(image, text_for_pred)
            forward_time = time.time() - start_time

            # Calculate evaluation loss for CTC deocder.
            preds_size = torch.IntTensor([preds.size(1)] * batch_size)
            # permute 'preds' to use CTCloss format
            if opt.baiduCTC:
                cost = criterion(preds.permute(1, 0, 2), text_for_loss, preds_size, length_for_loss) / batch_size
            else:
                cost = criterion(preds.log_softmax(2).permute(1, 0, 2), text_for_loss, preds_size, length_for_loss)

            # Select max probabilty (greedy decoding) then decode index to character
            if opt.baiduCTC:
                _, preds_index = preds.max(2)
                preds_index = preds_index.view(-1)
            else:
                _, preds_index = preds.max(2)
            preds_str = converter.decode(preds_index.data, preds_size.data)
        
        else:
            preds = model(image, text_for_pred, is_train=False)
            forward_time = time.time() - start_time

            preds = preds[:, :text_for_loss.shape[1] - 1, :]
            target = text_for_loss[:, 1:]  # without [GO] Symbol
            cost = criterion(preds.contiguous().view(-1, preds.shape[-1]), target.contiguous().view(-1))

            # select max probabilty (greedy decoding) then decode index to character
            _, preds_index = preds.max(2)
            preds_str = converter.decode(preds_index, length_for_pred)
            labels = converter.decode(text_for_loss[:, 1:], length_for_loss)

        infer_time += forward_time
        valid_loss_avg.add(cost)

        # calculate accuracy & confidence score
        preds_prob = F.softmax(preds, dim=2)
        preds_max_prob, _ = preds_prob.max(dim=2)
        confidence_score_list = []
        for gt, pred, pred_max_prob in zip(labels, preds_str, preds_max_prob):
            if 'Attn' in opt.Prediction:
                gt = gt[:gt.find('[s]')]
                pred_EOS = pred.find('[s]')
                pred = pred[:pred_EOS]  # prune after "end of sentence" token ([s])
                pred_max_prob = pred_max_prob[:pred_EOS]

            # To evaluate 'case sensitive model' with alphanumeric and case insensitve setting.
            if opt.sensitive and opt.data_filtering_off:
                pred = pred.lower()
                gt = gt.lower()
                alphanumeric_case_insensitve = '0123456789abcdefghijklmnopqrstuvwxyz'
                out_of_alphanumeric_case_insensitve = f'[^{alphanumeric_case_insensitve}]'
                pred = re.sub(out_of_alphanumeric_case_insensitve, '', pred)
                gt = re.sub(out_of_alphanumeric_case_insensitve, '', gt)

            if pred == gt:
                n_correct += 1

            '''
            (old version) ICDAR2017 DOST Normalized Edit Distance https://rrc.cvc.uab.es/?ch=7&com=tasks
            "For each word we calculate the normalized edit distance to the length of the ground truth transcription."
            if len(gt) == 0:
                norm_ED += 1
            else:
                norm_ED += edit_distance(pred, gt) / len(gt)
            '''

            # ICDAR2019 Normalized Edit Distance
            if len(gt) == 0 or len(pred) == 0:
                norm_ED += 0
            elif len(gt) > len(pred):
                norm_ED += 1 - edit_distance(pred, gt) / len(gt)
            else:
                norm_ED += 1 - edit_distance(pred, gt) / len(pred)

            # calculate confidence score (= multiply of pred_max_prob)
            try:
                confidence_score = pred_max_prob.cumprod(dim=0)[-1]
            except:
                confidence_score = 0  # for empty pred case, when prune after "end of sentence" token ([s])
            confidence_score_list.append(confidence_score)
            # print(pred, gt, pred==gt, confidence_score)

    accuracy = n_correct / float(length_of_data) * 100
    norm_ED = norm_ED / float(length_of_data)  # ICDAR2019 Normalized Edit Distance

    return valid_loss_avg.val(), accuracy, norm_ED, preds_str, confidence_score_list, labels, infer_time, length_of_data


In [21]:
def demo(opt):
    """ model configuration """
    
    with open(opt.ground_truth, 'r', encoding='utf-8') as data:
        datalist = data.readlines()
    
    if 'CTC' in opt.Prediction:
        converter = CTCLabelConverter(opt.character)
    else:
        converter = AttnLabelConverter(opt.character)
    opt.num_class = len(converter.character)

    if opt.rgb:
        opt.input_channel = 3
    model = Model(opt)
    print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial, opt.input_channel, opt.output_channel,
          opt.hidden_size, opt.num_class, opt.batch_max_length, opt.Transformation, opt.FeatureExtraction,
          opt.SequenceModeling, opt.Prediction)
    model = torch.nn.DataParallel(model).to(device)

    # load model
    print('loading pretrained model from %s' % opt.saved_model)
    model.load_state_dict(torch.load(opt.saved_model, map_location=device))

    # prepare data. two demo images from https://github.com/bgshih/crnn#run-demo
    AlignCollate_demo = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD)
    demo_data = RawDataset(root=opt.image_folder, opt=opt)  # use RawDataset
    demo_loader = torch.utils.data.DataLoader(
        demo_data, batch_size=opt.batch_size,
        shuffle=False,
        num_workers=int(opt.workers),
        collate_fn=AlignCollate_demo, pin_memory=True)

    lableList =[]
    imagePathList = []
    
    with open(opt.ground_truth, 'r', encoding='utf-8') as data:
        datalist = data.readlines()
    
    datalist.sort()
    nSamples = len(datalist)
    #print(datalist)
    
    for i in range(nSamples):
        imagePath, label = datalist[i].strip('\n').split('\t')
        lableList.append(label)
        imagePathList.append(imagePath)
    
    # predict
    model.eval()
    total_norm_ED = 0
   

    #상대경로
    log = open(f'./result.txt', 'w')
    
    #절대경로
    #log = open(f'F:/python/Competition_No2/code/saved_models/addAug/log_demo_result.txt', 'w')
    
    head = f'{"image_path":25s}\t{"GT":25s}\t{"Prediction":25s}\taccuracy'
    log.write(f'{head}\n')
    
    start_time = time.time()
    
    with torch.no_grad():
        for image_tensors, image_path_list in demo_loader:
            batch_size = image_tensors.size(0)
            image = image_tensors.to(device)
            # For max length prediction
            length_for_pred = torch.IntTensor([opt.batch_max_length] * batch_size).to(device)
            text_for_pred = torch.LongTensor(batch_size, opt.batch_max_length + 1).fill_(0).to(device)

            if 'CTC' in opt.Prediction:
                preds = model(image, text_for_pred)

                # Select max probabilty (greedy decoding) then decode index to character
                preds_size = torch.IntTensor([preds.size(1)] * batch_size)
                _, preds_index = preds.max(2)
                # preds_index = preds_index.view(-1)
                preds_str = converter.decode(preds_index, preds_size)

            else:
                preds = model(image, text_for_pred, is_train=False)

                # select max probabilty (greedy decoding) then decode index to character
                _, preds_index = preds.max(2)
                preds_str = converter.decode(preds_index, length_for_pred)


                
                
            #log = open(f'F:/python/Competition_No2/code/saved_models/addAug/log_demo_result.txt', 'a')
            #log = open('result.csv','a'
            #dashed_line = '-' * 80
            #head = f'{"image_path":25s}\t{"GT":25s}\t{"Prediction":25s}\taccuracy'
            
            #print(f'{dashed_line}\n{head}\n{dashed_line}')
            #log.write(f'{dashed_line}\n{head}\n{dashed_line}\n')
            #log.write(f'{head}\n')

            preds_prob = F.softmax(preds, dim=2)
            preds_max_prob, _ = preds_prob.max(dim=2)
            for img_name, pred, pred_max_prob in zip(image_path_list, preds_str, preds_max_prob):
                
                img_name = img_name.split('\\')[-1]
                gt = lableList[imagePathList.index(img_name)]
                
                pred = pred.upper()
                
                norm_ED = 0
                
                if len(gt) == 0 or len(pred) == 0:
                    norm_ED += 0
                elif len(gt) > len(pred):
                    norm_ED += 1 - edit_distance(pred, gt) / len(gt)
                else:
                    norm_ED += 1 - edit_distance(pred, gt) / len(pred)
                
                
                
                if 'Attn' in opt.Prediction:
                    pred_EOS = pred.find('[s]')
                    pred = pred[:pred_EOS]  # prune after "end of sentence" token ([s])
                    pred_max_prob = pred_max_prob[:pred_EOS]

                # calculate confidence score (= multiply of pred_max_prob)
                confidence_score = pred_max_prob.cumprod(dim=0)[-1]

                print(f'{img_name:25s}\t{gt:25s}\t{pred:25s}\t{norm_ED*100:.2f}%')
                log.write(f'{img_name:25s}\t{gt:25s}\t{pred:25s}\t{norm_ED*100:.2f}%\n')
                
                total_norm_ED += norm_ED
                
            
            #log.close()
            
            
        
            
    end_time= time.time()
    #log = open(f'F:/python/Competition_No2/code/saved_models/addAug/log_demo_result.txt', 'a')
    log.write(f'Average Accuracy\t{total_norm_ED*100/len(lableList):.2f}%\n')
    log.write(f'Inference Speed (ms)\t{(end_time-start_time)*1000/len(lableList)}\n')
    log.close()
    file = pd.read_csv('./result.txt', delimiter='\t')
    file.to_csv('./result.csv', index=False)
    print('Average Accuracy = ', total_norm_ED*100/len(lableList))
    print('inference time = ', end_time - start_time)
    

In [22]:
# parser = argparse.ArgumentParser()
# parser.add_argument('--image_folder', required=True, help='path to image_folder which contains text images')
# parser.add_argument('--workers', type=int, help='number of data loading workers', default=4)
# parser.add_argument('--batch_size', type=int, default=192, help='input batch size')
# parser.add_argument('--saved_model', required=True, help="path to saved_model to evaluation")
# """ Data processing """
# parser.add_argument('--batch_max_length', type=int, default=25, help='maximum-label-length')
# parser.add_argument('--imgH', type=int, default=32, help='the height of the input image')
# parser.add_argument('--imgW', type=int, default=100, help='the width of the input image')
# parser.add_argument('--rgb', action='store_true', help='use rgb input')
# parser.add_argument('--character', type=str, default='0123456789abcdefghijklmnopqrstuvwxyz', help='character label')
# parser.add_argument('--sensitive', action='store_true', help='for sensitive character mode')
# parser.add_argument('--PAD', action='store_true', help='whether to keep ratio then pad for image resize')
# """ Model Architecture """
# parser.add_argument('--Transformation', type=str, required=True, help='Transformation stage. None|TPS')
# parser.add_argument('--FeatureExtraction', type=str, required=True, help='FeatureExtraction stage. VGG|RCNN|ResNet')
# parser.add_argument('--SequenceModeling', type=str, required=True, help='SequenceModeling stage. None|BiLSTM')
# parser.add_argument('--Prediction', type=str, required=True, help='Prediction stage. CTC|Attn')
# parser.add_argument('--num_fiducial', type=int, default=20, help='number of fiducial points of TPS-STN')
# parser.add_argument('--input_channel', type=int, default=1, help='the number of input channel of Feature extractor')
# parser.add_argument('--output_channel', type=int, default=512,
#                     help='the number of output channel of Feature extractor')
# parser.add_argument('--hidden_size', type=int, default=256, help='the size of the LSTM hidden state')

In [23]:
args = argparse.Namespace(
    
    image_folder = '../datasets/test1/' + data_type,
    ##ground truth file
    ground_truth =  '../datasets/test/' + data_type+'/gt_test_' + data_type +'.txt',
    
    workers = 0,
    batch_size = 192,    
    saved_model = "./saved_models/TPS-ResNet-BiLSTM-CTC/best_norm_ED.pth",
    #saved_model = "F:/python/Competition_No2/code/saved_models/addAug/TPS-ResNet-BiLSTM-CTC-Seed213/best_accuracy.pth",

    batch_max_length = 25,
    imgH = 32,
    imgW = 100,
    rgb = None,
    character = '0123456789abcdefghijklmnopqrstuvwxyz',
    sensitive = None,
    PAD = None,
    Transformation = "TPS",
    FeatureExtraction = "ResNet",
    SequenceModeling = "BiLSTM",
    Prediction = "CTC",
    num_fiducial = 20,
    input_channel = 1,
    output_channel = 512,
    hidden_size = 256

)

In [24]:
opt = args

""" vocab / character number configuration """
if opt.sensitive:
    opt.character = string.printable[:-6]  # same with ASTER setting (use 94 char).

cudnn.benchmark = True
cudnn.deterministic = True
opt.num_gpu = torch.cuda.device_count()




demo(opt)

model input parameters 32 100 20 1 512 256 37 25 TPS ResNet BiLSTM CTC
loading pretrained model from ./saved_models/TPS-ResNet-BiLSTM-CTC/best_norm_ED.pth
000000_3.png             	6601                     	6601                     	100.00%
000006_5.png             	3603                     	3603                     	100.00%
000006_6.png             	1003                     	1003                     	100.00%
000007_1.png             	9102                     	9102                     	100.00%
000009_0.png             	9901                     	9901                     	100.00%
000010_2.png             	2801                     	2801                     	100.00%
000015_0.png             	7702                     	7702                     	100.00%
000020_5.png             	1003                     	1003                     	100.00%
000021_0.png             	4601                     	4601                     	100.00%
000021_1.png             	2003                     	2003               