In [1]:
import os

In [2]:
%pwd

'e:\\GCET\\Machine Learning\\Deep Learning Projects\\text_recognition\\research'

In [3]:
os.chdir('../')

In [4]:
%pwd

'e:\\GCET\\Machine Learning\\Deep Learning Projects\\text_recognition'

# CRAFT Inference Pipeline

In [5]:
# entity
from dataclasses import dataclass
from pathlib import Path

@dataclass(frozen=True)
class TextDetectionConfig:
    craft_weights: Path
    refiner_weights: Path
    loaded_image_path: Path
    normalized_image_path: Path
    resized_data_path: Path
    text_threshold: float
    low_text: float
    link_threshold: float
    poly: bool
    refine: bool
    

In [6]:
#configuration
from mlOCR.constants import *
from mlOCR.utils.common import read_yaml,create_directories
from pathlib import Path

class ConfigurationManager:
    def __init__(
        self,
        config_filepath = CONFIG_FILE_PATH,
        params_filepath = PARAMS_FILE_PATH,
        schema_filepath = SCHEMA_FILE_PATH):

        self.config = read_yaml(config_filepath)
        self.params = read_yaml(params_filepath)
        self.schema = read_yaml(schema_filepath)

        create_directories([self.config.artifacts_root])

    def get_text_detection_config(self)->TextDetectionConfig:
        config=self.config.text_detection
        params=self.params.text_detection

        text_detection_config=TextDetectionConfig(
            craft_weights=Path(config.craft_weights),
            refiner_weights=Path(config.refiner_weights),
            loaded_image_path=Path(config.loaded_image_path),
            normalized_image_path=Path(config.normalized_image_path),
            resized_data_path=Path(config.resized_data_path),
            text_threshold=float(params.text_threshold),
            low_text=float(params.low_text),
            link_threshold=float(params.link_threshold),
            poly=bool(params.poly),
            refine=bool(params.refine)
        )

        return text_detection_config


In [7]:
#component
import torch
import torch.nn as nn
from torch.autograd import Variable
from collections import OrderedDict
from skimage import io
from mlOCR.components.Image_Processing import ImageProcessing
from mlOCR.models.craft import CRAFT
from mlOCR.models.refine_net import RefineNet
from mlOCR.utils.common import *
from mlOCR.utils.text_detection_utils import *
from mlOCR import logger


class TextDetection:
    def __init__(self,TextDetectionConfig):
        self.craft_weights=TextDetectionConfig.craft_weights
        self.refiner_weights=TextDetectionConfig.refiner_weights
        self.loaded_image_path=TextDetectionConfig.loaded_image_path
        self.normalized_image_path=TextDetectionConfig.normalized_image_path
        self.resized_data_path=TextDetectionConfig.resized_data_path
        self.text_threshold=TextDetectionConfig.text_threshold
        self.low_text=TextDetectionConfig.low_text
        self.link_threshold=TextDetectionConfig.link_threshold
        self.poly=TextDetectionConfig.poly
        self.refine=TextDetectionConfig.refine

    def copyStateDict(self,state_dict):
        if list(state_dict.keys())[0].startswith("module"):
            start_idx = 1
        else:
            start_idx = 0
        new_state_dict = OrderedDict()
        for k, v in state_dict.items():
            name = ".".join(k.split(".")[start_idx:])
            new_state_dict[name] = v
        return new_state_dict
    
    def modelInference(self,norm_image):
        net=CRAFT()
        net.load_state_dict(self.copyStateDict(torch.load(self.craft_weights, map_location='cpu')))
        logger.info(f'Craft model weights loaded successfully!')
        net.eval()

        x=norm_image
        x=x.astype('float32')
        x = torch.from_numpy(x).permute(2, 0, 1)    # [h, w, c] to [c, h, w]
        x = Variable(x.unsqueeze(0))                # [c, h, w] to [b, c, h, w]

        # forward pass
        with torch.no_grad():
            y, feature = net(x)

        # make score and link map
        score_text = y[0,:,:,0].cpu().data.numpy()
        score_link = y[0,:,:,1].cpu().data.numpy()

        if self.refine:
            refine_net=RefineNet()
            refine_net.load_state_dict(self.copyStateDict(torch.load(self.refiner_weights, map_location='cpu')))
            logger.info(f'RefineNet model weights loaded successfully!')
            refine_net.eval()
            self.poly=True

            with torch.no_grad():
                y_refiner = refine_net(y, feature)
            score_link = y_refiner[0,:,:,0].cpu().data.numpy()

        resized_data = load_bin(self.resized_data_path)
        ratio, size_heatmap = resized_data["ratio"], resized_data["size_heatmap"]
        ratio_h = ratio_w = 1 / ratio
        
        # Post-processing
        boxes, polys = getDetBoxes(score_text, score_link, self.text_threshold, self.link_threshold, self.low_text, self.poly)

        # coordinate adjustment
        boxes = adjustResultCoordinates(boxes, ratio_w, ratio_h)
        polys = adjustResultCoordinates(polys, ratio_w, ratio_h)
        for k in range(len(polys)):
            if polys[k] is None: polys[k] = boxes[k]

        render_img = score_text.copy()
        render_img = np.hstack((render_img, score_link))
        ret_score_text = cvt2HeatmapImg(render_img)

        return boxes, polys, ret_score_text
    
    def generate_result(self):
        image=io.imread(self.loaded_image_path)
        norm_image=io.imread(self.normalized_image_path)

        boxes, polys, ret_score_text=self.modelInference(norm_image)

        filename, file_ext = os.path.splitext(os.path.basename(self.loaded_image_path))
        mask_image='artifacts/image/result/res_'+filename+'_mask.jpg'
        io.imsave(mask_image,ret_score_text)

        saveResult(self.loaded_image_path, image[:,:,::-1], polys, dirname='artifacts/image/result/')



In [8]:
#pipeline

try:
    config=ConfigurationManager()
    text_detection_config=config.get_text_detection_config()
    text_detection=TextDetection(text_detection_config)
    text_detection.generate_result()
except Exception as e:
    logger.exception(e)
    raise e
    

[2025-03-04 12:41:08,303: INFO: common: yaml file: config\config.yaml loaded successfully]
[2025-03-04 12:41:08,307: INFO: common: yaml file: params.yaml loaded successfully]
[2025-03-04 12:41:08,309: INFO: common: yaml file: schema.yaml loaded successfully]
[2025-03-04 12:41:08,310: INFO: common: created directory at: artifacts]
[2025-03-04 12:41:10,320: INFO: 1655308015: Craft model weights loaded successfully!]
[2025-03-04 12:41:13,008: INFO: common: binary file loaded from: artifacts\image\result\resized_data.pkl]
