In [1]:
import string
import argparse

import matplotlib.pyplot as plt

import torch
import torch.backends.cudnn as cudnn
import torch.utils.data
import torch.nn.functional as F

import pandas as pd
import os
import sys
module_path = os.path.abspath(os.path.join('src'))
if module_path not in sys.path:
    sys.path.append(module_path)
sys.path.insert(0, '../../deep-text-recognition-benchmark/')
sys.path.insert(0, '../')

from utils import CTCLabelConverter, AttnLabelConverter
from dataset import RawDataset, AlignCollate
from model import Model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [92]:
import sys
import os
import time
import argparse

import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
from torch.autograd import Variable

from PIL import Image

import cv2
from skimage import io
import numpy as np
import craft_utils
import test
import imgproc
import file_utils
import json
import zipfile
import pandas as pd

from craft import CRAFT

from collections import OrderedDict

print(os.getcwd())

args = argparse.Namespace(
    trained_model='../weights/craft_mlt_25k.pth',
    text_threshold=0.3,
    low_text=0.3,
    link_threshold=0.3,
    cuda=True,
    canvas_size=1280,
    mag_ratio=1.5,
    poly=False,
    show_time=False,
    test_folder='../Results/screenshots',
    refine=False,
    refiner_model='../weights/craft_refiner_CTW1500.pth'
)

/home/tung/pj/sel/CRAFT-pytorch/notebooks


In [99]:
image_list, _, _ = file_utils.get_files(args.test_folder)

print(image_list)

image_names = []
image_paths = []

#CUSTOMISE START
start = args.test_folder

for num in range(len(image_list)):
    image_names.append(os.path.relpath(image_list[num], start))


result_folder = '../Results'
if not os.path.isdir(result_folder):
    os.mkdir(result_folder)

data=pd.DataFrame(columns=['image_name', 'word_bboxes', 'pred_words', 'align_text'])
data['image_name'] = image_names
data["image_data"] = [None for x in image_names]

# load net
net = CRAFT()     # initialize

print('Loading weights from checkpoint (' + args.trained_model + ')')
if args.cuda:
    net.load_state_dict(test.copyStateDict(torch.load(args.trained_model)))
else:
    net.load_state_dict(test.copyStateDict(torch.load(args.trained_model, map_location='cpu')))

if args.cuda:
    net = net.cuda()
    net = torch.nn.DataParallel(net)
    cudnn.benchmark = False

net.eval()

# LinkRefiner
refine_net = None
if args.refine:
    from refinenet import RefineNet
    refine_net = RefineNet()
    print('Loading weights of refiner from checkpoint (' + args.refiner_model + ')')
    if args.cuda:
        refine_net.load_state_dict(torch.load(args.refiner_model))
        refine_net = refine_net.cuda()
        refine_net = torch.nn.DataParallel(refine_net)
    else:
        refine_net.load_state_dict(copyStateDict(torch.load(args.refiner_model, map_location='cpu')))

    refine_net.eval()
    args.poly = True

t = time.time()

# load data
for k, image_path in enumerate(image_list):
    print("Test image {:d}/{:d}: {:s}".format(k+1, len(image_list), image_path), end='\r')
    image = imgproc.loadImage(image_path)
    data["image_data"][k] = image
#     image = image[:54,:]

    bboxes, polys, score_text, det_scores = test.test_net(net, image, args.text_threshold, args.link_threshold, args.low_text, args.cuda, args.poly, args, refine_net)

    bbox_score={}

    for box_num in range(len(bboxes)):
        key = str(det_scores[box_num])
        item = bboxes[box_num]
        bbox_score[key]=item

    data['word_bboxes'][k]=bbox_score
    # save score text
    filename, file_ext = os.path.splitext(os.path.basename(image_path))
    mask_file = result_folder + "/res_" + filename + '_mask.jpg'
    cv2.imwrite(mask_file, score_text)

    file_utils.saveResult(image_path, image[:,:,::-1], polys, dirname=result_folder)

# data.to_csv(result_folder + 'data.csv', sep = ',', na_rep='Unknown')
# print("elapsed time : {}s".format(time.time() - t))
data

['../Results/screenshots/screenshot_1.png', '../Results/screenshots/screenshot_3.png', '../Results/screenshots/screenshot_4.png', '../Results/screenshots/screenshot_0.png', '../Results/screenshots/screenshot_2.png', '../Results/screenshots/.ipynb_checkpoints/screenshot_4-checkpoint.png', '../Results/screenshots/.ipynb_checkpoints/screenshot_0-checkpoint.png', '../Results/screenshots/.ipynb_checkpoints/screenshot_2-checkpoint.png', '../Results/screenshots/.ipynb_checkpoints/screenshot_3-checkpoint.png', '../Results/screenshots/.ipynb_checkpoints/screenshot_1-checkpoint.png']
Loading weights from checkpoint (../weights/craft_mlt_25k.pth)
Test image 10/10: ../Results/screenshots/.ipynb_checkpoints/screenshot_1-checkpoint.png

Unnamed: 0,image_name,word_bboxes,pred_words,align_text,image_data
0,screenshot_1.png,"{'0.89525574': [[791.9999, 14.999997], [1112.9...",,,"[[[1, 1, 1], [19, 19, 19], [1, 1, 1], [19, 19,..."
1,screenshot_3.png,"{'0.6253818': [[1641.0, 15.0], [1683.0, 15.0],...",,,"[[[1, 1, 1], [19, 19, 19], [1, 1, 1], [18, 18,..."
2,screenshot_4.png,"{'0.6253818': [[1641.0, 15.0], [1683.0, 15.0],...",,,"[[[1, 1, 1], [19, 19, 19], [1, 1, 1], [18, 18,..."
3,screenshot_0.png,"{'0.8324406': [[1272.0, 12.0], [1389.0, 12.0],...",,,"[[[23, 23, 23], [27, 27, 27], [27, 27, 27], [2..."
4,screenshot_2.png,"{'0.6253818': [[1641.0, 15.0], [1683.0, 15.0],...",,,"[[[1, 1, 1], [19, 19, 19], [1, 1, 1], [18, 18,..."
5,.ipynb_checkpoints/screenshot_4-checkpoint.png,"{'0.73630595': [[219.0, 15.0], [264.0, 15.0], ...",,,"[[[23, 23, 23], [27, 27, 27], [27, 27, 27], [2..."
6,.ipynb_checkpoints/screenshot_0-checkpoint.png,"{'0.8324406': [[1272.0, 12.0], [1389.0, 12.0],...",,,"[[[23, 23, 23], [27, 27, 27], [27, 27, 27], [2..."
7,.ipynb_checkpoints/screenshot_2-checkpoint.png,"{'0.73630595': [[219.0, 15.0], [264.0, 15.0], ...",,,"[[[23, 23, 23], [27, 27, 27], [27, 27, 27], [2..."
8,.ipynb_checkpoints/screenshot_3-checkpoint.png,"{'0.6253818': [[1641.0, 15.0], [1683.0, 15.0],...",,,"[[[1, 1, 1], [19, 19, 19], [1, 1, 1], [18, 18,..."
9,.ipynb_checkpoints/screenshot_1-checkpoint.png,"{'0.73630595': [[219.0, 15.0], [264.0, 15.0], ...",,,"[[[23, 23, 23], [27, 27, 27], [27, 27, 27], [2..."


In [94]:
opt = argparse.Namespace(
    image_folder="/home/tung/pj/sel/CRAFT-pytorch/Results/cropped",
    workers=4,
    batch_size=192,
    saved_model="/home/tung/pj/sel/deep-text-recognition-benchmark/weights/TPS-ResNet-BiLSTM-CTC.pth",
    batch_max_length=25,
    imgH=32,
    imgW=100,
    rgb=False,
    character='0123456789abcdefghijklmnopqrstuvwxyz',
    sensitive=False,
    PAD=True,
    Transformation="TPS",
    FeatureExtraction="ResNet",
    SequenceModeling="BiLSTM",
    Prediction="CTC",
    num_fiducial=20,
    input_channel=1,
    output_channel=512,
    hidden_size=256
)

""" vocab / character number configuration """
if opt.sensitive:
    opt.character = string.printable[:-6]  # same with ASTER setting (use 94 char).

cudnn.benchmark = True
cudnn.deterministic = True
opt.num_gpu = torch.cuda.device_count()
# print (opt.image_folder)

In [95]:
"""Open csv file wherein you are going to write the Predicted Words"""
# data = pd.read_csv(opt.image_folder + 'data.csv')

def extract_text(image, bboxes):

    # Save cropped images
    for i, bbox in enumerate(bboxes):
        filename = f"../Results/cropped/test_{i}.png"
        top_left, bottom_right = bbox
        cv2.imwrite(filename, image[top_left[1]:bottom_right[1], top_left[0]:bottom_right[0], :])

    extracted_texts = []

    """ model configuration """
    if 'CTC' in opt.Prediction:
        converter = CTCLabelConverter(opt.character)
    else:
        converter = AttnLabelConverter(opt.character)
#     print(converter.character[:37])
    opt.num_class = len(converter.character)

    if opt.rgb:
        opt.input_channel = 3
    model = Model(opt)
#     print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial, opt.input_channel, opt.output_channel,
#           opt.hidden_size, opt.num_class, opt.batch_max_length, opt.Transformation, opt.FeatureExtraction,
#           opt.SequenceModeling, opt.Prediction)
    model = torch.nn.DataParallel(model).to(device)

    # load model
#     print('loading pretrained model from %s' % opt.saved_model)
    model.load_state_dict(torch.load(opt.saved_model, map_location=device))

    # prepare data. two demo images from https://github.com/bgshih/crnn#run-demo
    AlignCollate_demo = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD)
    demo_data = RawDataset(root=opt.image_folder, opt=opt)  # use RawDataset
    demo_loader = torch.utils.data.DataLoader(
        demo_data, batch_size=opt.batch_size,
        shuffle=False,
        num_workers=int(opt.workers),
        collate_fn=AlignCollate_demo, pin_memory=True)

    # predict
    model.eval()
    with torch.no_grad():
        for image_tensors, image_path_list in demo_loader:

            batch_size = image_tensors.size(0)
            image = image_tensors.to(device)
            # For max length prediction
            length_for_pred = torch.IntTensor([opt.batch_max_length] * batch_size).to(device)
            text_for_pred = torch.LongTensor(batch_size, opt.batch_max_length + 1).fill_(0).to(device)

            if 'CTC' in opt.Prediction:
                preds = model(image, text_for_pred)

                # Select max probabilty (greedy decoding) then decode index to character
                preds_size = torch.IntTensor([preds.size(1)] * batch_size)
                _, preds_index = preds.max(2)
                # preds_index = preds_index.view(-1)
                preds_str = converter.decode(preds_index.data, preds_size.data)

            else:
                preds = model(image, text_for_pred, is_train=False)

                # select max probabilty (greedy decoding) then decode index to character
                _, preds_index = preds.max(2)
                preds_str = converter.decode(preds_index, length_for_pred)

            dashed_line = '-' * 80
            head = f'{"image_path":25s}\t {"predicted_labels":25s}\t confidence score'

#             print(f'{dashed_line}\n{head}\n{dashed_line}')
            # log.write(f'{dashed_line}\n{head}\n{dashed_line}\n')

            preds_prob = F.softmax(preds, dim=2)
            preds_max_prob, _ = preds_prob.max(dim=2)
            for img_name, pred, pred_max_prob in zip(image_path_list, preds_str, preds_max_prob):


                start = opt.image_folder
                path = os.path.relpath(img_name, start)

                folder = os.path.dirname(path)

                image_name=os.path.basename(path)

                file_name='_'.join(image_name.split('_')[:-8])

                txt_file=os.path.join(start, folder, file_name)                

                if 'Attn' in opt.Prediction:
                    pred_EOS = pred.find('[s]')
                    pred = pred[:pred_EOS]  # prune after "end of sentence" token ([s])
                    pred_max_prob = pred_max_prob[:pred_EOS]

                # calculate confidence score (= multiply of pred_max_prob)
                confidence_score = pred_max_prob.cumprod(dim=0)[-1]
                extracted_texts.append([pred, confidence_score])
#                 print(f'{image_name:25s}\t {pred:25s}\t {confidence_score:0.4f}')

    # Save cropped images
    for i in range(len(bboxes)):
        filename = f"../Results/cropped/test_{i}.png"
        os.remove(filename)
    
    return extracted_texts

In [100]:
%matplotlib widget

I_copy = imgproc.loadImage(image_list[1])
bboxes = []
for i, key in enumerate(data["word_bboxes"][1]):
    top_left = tuple([int(x) for x in data["word_bboxes"][1][key][0]])
    bottom_right = tuple([int(x) for x in data["word_bboxes"][1][key][2]])
    bboxes.append([top_left, bottom_right])
#     print(tuple(top_left.tolist()), tybottom_right.tolist())
    cv2.rectangle(I_copy, top_left, bottom_right, 255, 2)
#     cv2.imwrite(f"../Results/cropped/test_{i}.png", 
#                 data["image_data"][0][top_left[1]:bottom_right[1], top_left[0]:bottom_right[0]])
#     plt.imshow(I_copy[top_left[1]:bottom_right[1], top_left[0]:bottom_right[0]])
plt.imshow(I_copy)
plt.show()
# bboxes.append([[211,0], [1620,23]])
print(extract_text(I_copy, bboxes))

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

[['accent', tensor(0.4214, device='cuda:0')], ['loginn', tensor(0.1980, device='cuda:0')], ['withh', tensor(0.0799, device='cuda:0')], ['g', tensor(0.0070, device='cuda:0')], ['t', tensor(0.2390, device='cuda:0')], ['orcreate', tensor(0.5701, device='cuda:0')], ['neww', tensor(0.7467, device='cuda:0')], ['account', tensor(0.4195, device='cuda:0')], ['emailaddress', tensor(0.4479, device='cuda:0')], ['o', tensor(0.2487, device='cuda:0')], ['password', tensor(0.8193, device='cuda:0')], ['oi', tensor(0.2006, device='cuda:0')], ['contimpassword', tensor(0.3247, device='cuda:0')], ['create', tensor(0.8140, device='cuda:0')], ['account', tensor(0.5874, device='cuda:0')], ['already', tensor(0.8242, device='cuda:0')], ['lnae', tensor(0.0092, device='cuda:0')], ['encome', tensor(0.0001, device='cuda:0')], ['oginnere', tensor(0.5913, device='cuda:0')], ['es', tensor(4.5099e-09, device='cuda:0')]]
