In [1]:
import sys
import os
import cv2
import torch
import json
import numpy as np
from PIL import Image
from typing import Dict, List
from ultralytics import YOLO

# --- Cấu hình đường dẫn ---
# Thêm thư mục gốc vào path để import được utils
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..')))

# 1. Import Sanitizer (Theo đúng code của bạn)
try:
    from utils.sanitizer import SmartSanitizer
except ImportError:
    # Fallback nếu path chưa đúng, nhưng trong notebook của bạn đã chạy được thì dòng trên sẽ OK
    print("CẢNH BÁO: Không tìm thấy file utils/sanitizer.py")
    class SmartSanitizer:
        @staticmethod
        def sanitize(x): return x

# 2. Import VietOCR & Transformer
from vietocr.tool.predictor import Predictor
from vietocr.tool.config import Cfg
from transformers import LayoutLMv3ForTokenClassification, LayoutLMv3Processor

# ==============================================================================
# PHẦN 1: CLASS KIE (GIỮ NGUYÊN CODE CỦA BẠN)
# ==============================================================================

class SROIEInference:
    """Inference cho LayoutLMv3 trên SROIE receipts (kèm Auto-Cleaning)"""

    def __init__(self, model_path: str, device=None):
        self.device = device or torch.device('cuda' if torch.cuda.is_available() else 'cpu')

        # Load model và processor
        print(f"Loading KIE model from {model_path} to {self.device}...")
        self.model = LayoutLMv3ForTokenClassification.from_pretrained(model_path)
        self.model.to(self.device)
        self.model.eval()

        self.processor = LayoutLMv3Processor.from_pretrained(
            model_path,
            apply_ocr=False
        )

        # Load label mappings
        label_file = os.path.join(model_path, 'label2id.json')
        if os.path.exists(label_file):
            with open(label_file, 'r') as f:
                self.label2id = json.load(f)
            self.id2label = {int(v): k for k, v in self.label2id.items()}
        else:
            self.label2id = {
                "O": 0, "B-COMPANY": 1, "I-COMPANY": 2, "B-DATE": 3, "I-DATE": 4,
                "B-ADDRESS": 5, "I-ADDRESS": 6, "B-TOTAL": 7, "I-TOTAL": 8
            }
            self.id2label = {v: k for k, v in self.label2id.items()}

    def predict_single(self, image_path: str, words: List[str], boxes: List[List[int]]) -> Dict[str, str]:
        # Load image
        image = Image.open(image_path).convert('RGB')
        width, height = image.size

        # Normalize boxes (0-1000)
        normalized_boxes = [
            [
                int(1000 * box[0] / width),
                int(1000 * box[1] / height),
                int(1000 * box[2] / width),
                int(1000 * box[3] / height)
            ]
            for box in boxes
        ]

        # Encode
        encoding = self.processor(
            image,
            words,
            boxes=normalized_boxes,
            truncation=True,
            padding='max_length',
            max_length=512,
            return_tensors='pt'
        )

        word_ids = encoding.word_ids(batch_index=0)
        input_data = {k: v.to(self.device) for k, v in encoding.items()}

        with torch.no_grad():
            outputs = self.model(**input_data)
            predictions = torch.argmax(outputs.logits, dim=-1)

        predictions = predictions.cpu().numpy()[0]

        word_predictions = []
        previous_word_idx = None

        for idx, word_idx in enumerate(word_ids):
            if word_idx is not None and word_idx != previous_word_idx:
                word_predictions.append({
                    'word': words[word_idx],
                    'label': self.id2label[predictions[idx]]
                })
                previous_word_idx = word_idx

        entities = self.extract_entities(word_predictions)
        return entities

    def extract_entities(self, word_predictions: List[Dict]) -> Dict[str, str]:
        entities = {'company': '', 'date': '', 'address': '', 'total': ''}
        current_entity = None
        current_text = []

        for item in word_predictions:
            word = item['word']
            label = item['label']

            if label.startswith('B-'):
                if current_entity and current_text:
                    entities[current_entity.lower()] = ' '.join(current_text)
                current_entity = label[2:]
                current_text = [word]

            elif label.startswith('I-'):
                entity_name = label[2:]
                if entity_name == current_entity:
                    current_text.append(word)
                else:
                    if current_entity and current_text:
                         entities[current_entity.lower()] = ' '.join(current_text)
                    current_entity = entity_name
                    current_text = [word]
            else:
                if current_entity and current_text:
                    entities[current_entity.lower()] = ' '.join(current_text)
                current_entity = None
                current_text = []

        if current_entity and current_text:
            entities[current_entity.lower()] = ' '.join(current_text)

        clean_entities = SmartSanitizer.sanitize(entities)
        return clean_entities

# ==============================================================================
# PHẦN 2: PIPELINE TÍCH HỢP (YOLO + VietOCR -> gọi SROIEInference)
# ==============================================================================

class EndToEndPipeline:
    def __init__(self, yolo_path, kie_model_path, device='cpu'):
        self.device = device

        # 1. Khởi tạo YOLO
        print(f"--- Loading YOLO from {yolo_path} ---")
        self.yolo_model = YOLO(yolo_path)

        # 2. Khởi tạo VietOCR (Dùng vgg_transformer như bạn yêu cầu)
        print("--- Loading VietOCR (vgg_transformer) ---")
        config = Cfg.load_config_from_name('vgg_transformer')
        config['device'] = device
        config['predictor']['beamsearch'] = False
        config['cnn']['pretrained'] = False
        self.ocr_model = Predictor(config)

        # 3. Khởi tạo KIE (Dùng class của bạn)
        # Lưu ý: SROIEInference tự load model bên trong __init__
        self.kie_engine = SROIEInference(model_path=kie_model_path, device=device)

    def run(self, image_path):
        print(f"\n>>> Processing: {os.path.basename(image_path)}")

        # A. Detect Text (YOLO)
        # Đọc ảnh bằng OpenCV để crop cho chính xác
        img_cv = cv2.imread(image_path)
        if img_cv is None:
            print("Error: Image not found.")
            return {}

        results = self.yolo_model(image_path, verbose=False)

        detected_data = [] # Lưu trữ {box, text}

        for result in results:
            for box in result.boxes:
                # Lấy tọa độ x1, y1, x2, y2
                x1, y1, x2, y2 = map(int, box.xyxy[0])

                # Check biên ảnh
                h, w, _ = img_cv.shape
                x1, y1 = max(0, x1), max(0, y1)
                x2, y2 = min(w, x2), min(h, y2)

                # Crop ảnh
                crop = img_cv[y1:y2, x1:x2]
                if crop.size == 0: continue

                # B. Recognize Text (VietOCR)
                # Chuyển BGR (OpenCV) -> RGB -> PIL
                crop_pil = Image.fromarray(cv2.cvtColor(crop, cv2.COLOR_BGR2RGB))

                try:
                    text = self.ocr_model.predict(crop_pil)
                except:
                    text = ""

                if text.strip():
                    detected_data.append({
                        "box": [x1, y1, x2, y2],
                        "text": text
                    })

        if not detected_data:
            print("No text detected.")
            return {}

        # C. Chuẩn bị dữ liệu cho KIE
        # QUAN TRỌNG: Sort box từ trên xuống dưới, trái sang phải
        detected_data.sort(key=lambda k: (k['box'][1], k['box'][0]))

        words = [item['text'] for item in detected_data]
        boxes = [item['box'] for item in detected_data]

        print(f"-> Detected {len(words)} words. Running KIE...")

        # D. Gọi model KIE của bạn
        # Hàm predict_single của bạn nhận vào image_path (string)
        final_result = self.kie_engine.predict_single(image_path, words, boxes)

        return final_result

# ==============================================================================
# PHẦN 3: CHẠY THỬ
# ==============================================================================

# Cấu hình Path
YOLO_WEIGHTS = r'D:\ADMIN\Documents\Classwork\advance_cv_project\train\runs\detect\sroie_yolov8m_finetune\weights\best.pt'
KIE_MODEL_DIR = r'D:\ADMIN\Documents\Classwork\advance_cv_project\train\layoutlmv3_sroie_output\best_model'
TEST_IMAGE = r'D:\ADMIN\Documents\Classwork\advance_cv_project\data\Receipt_OCR_1\raw\A0011.png'

# Chạy
if __name__ == "__main__":
    if os.path.exists(TEST_IMAGE):
        # Khởi tạo Pipeline
        pipeline = EndToEndPipeline(YOLO_WEIGHTS, KIE_MODEL_DIR, device='cpu')

        # Dự đoán
        result = pipeline.run(TEST_IMAGE)

        print("\n=== FINAL RESULT ===")
        print(json.dumps(result, indent=4, ensure_ascii=False))
    else:
        print(f"File ảnh không tồn tại: {TEST_IMAGE}")

  from .autonotebook import tqdm as notebook_tqdm


--- Loading YOLO from D:\ADMIN\Documents\Classwork\advance_cv_project\train\runs\detect\sroie_yolov8m_finetune\weights\best.pt ---
--- Loading VietOCR (vgg_transformer) ---




Model weight C:\Users\OS\AppData\Local\Temp\vgg_transformer.pth exsits. Ignore download!
Loading KIE model from D:\ADMIN\Documents\Classwork\advance_cv_project\train\layoutlmv3_sroie_output\best_model to cpu...

>>> Processing: A0011.png
-> Detected 35 words. Running KIE...





=== FINAL RESULT ===
{
    "company": "VM% QNH Dư án KDC lấn biển coc 6",
    "date": "14/08/2020",
    "address": "TP. Câm Phà, T, Quảng Ninh",
    "total": "41.000"
}
