In [4]:
# ============================================================
# Automated OCR Model Training + Evaluation
# Works with lansinuote/ocr_id_card dataset
# ============================================================

import os
import time
import cv2
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
from datasets import load_dataset
from sklearn.preprocessing import LabelEncoder
from torch.utils.data import DataLoader, Dataset

# ------------------------------------------------------------
# Configuration
# ------------------------------------------------------------
SAVE_DIR = Path("saved_model")
SAVE_DIR.mkdir(exist_ok=True)
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
IMG_W, IMG_H = 128, 32
EPOCHS = 3
BATCH_SIZE = 8

# ------------------------------------------------------------
# Shared CNN-LSTM Model Definition
# ------------------------------------------------------------
class OCRNet(nn.Module):
    def __init__(self, nclass=80):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(1, 32, 3, 1, 1), nn.ReLU(), nn.MaxPool2d(2, 2),
            nn.Conv2d(32, 64, 3, 1, 1), nn.ReLU(), nn.MaxPool2d(2, 2)
        )
        self.lstm = nn.LSTM(64 * 8, 128, num_layers=2,
                             bidirectional=True, batch_first=True)
        self.fc = nn.Linear(256, nclass)

    def forward(self, x):
        x = self.conv(x)
        b, c, h, w = x.size()
        x = x.permute(0, 3, 1, 2).reshape(b, w, -1)
        x, _ = self.lstm(x)
        x = self.fc(x)
        return x.mean(dim=1)

# ------------------------------------------------------------
# Preprocessing Functions
# ------------------------------------------------------------
def deskew_red(img):
    gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
    edges = cv2.Canny(gray, 50, 150)
    lines = cv2.HoughLines(edges, 1, np.pi / 180, 200)
    angle = 0
    if lines is not None:
        angles = [theta for rho, theta in lines[:, 0]]
        angle = (np.mean(angles) - np.pi / 2) * 180 / np.pi
    (h, w) = img.shape[:2]
    M = cv2.getRotationMatrix2D((w // 2, h // 2), angle, 1)
    img = cv2.warpAffine(img, M, (w, h), flags=cv2.INTER_CUBIC)
    b, g, r = cv2.split(img)
    red_only = cv2.merge([np.zeros_like(r), np.zeros_like(r), r])
    gray = cv2.cvtColor(red_only, cv2.COLOR_BGR2GRAY)
    return cv2.resize(gray, (IMG_W, IMG_H))

def deskew_grey(img):
    gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
    edges = cv2.Canny(gray, 50, 150)
    lines = cv2.HoughLines(edges, 1, np.pi / 180, 200)
    angle = 0
    if lines is not None:
        angles = [theta for rho, theta in lines[:, 0]]
        angle = (np.mean(angles) - np.pi / 2) * 180 / np.pi
    (h, w) = img.shape[:2]
    M = cv2.getRotationMatrix2D((w // 2, h // 2), angle, 1)
    img = cv2.warpAffine(img, M, (w, h), flags=cv2.INTER_CUBIC)
    gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
    return cv2.resize(gray, (IMG_W, IMG_H))

def deskew_red_denoise(img):
    gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
    edges = cv2.Canny(gray, 50, 150)
    lines = cv2.HoughLines(edges, 1, np.pi / 180, 200)
    angle = 0
    if lines is not None:
        angles = [theta for rho, theta in lines[:, 0]]
        angle = (np.mean(angles) - np.pi / 2) * 180 / np.pi
    (h, w) = img.shape[:2]
    M = cv2.getRotationMatrix2D((w // 2, h // 2), angle, 1)
    img = cv2.warpAffine(img, M, (w, h), flags=cv2.INTER_CUBIC)
    b, g, r = cv2.split(img)
    red_only = cv2.merge([np.zeros_like(r), np.zeros_like(r), r])
    gray = cv2.cvtColor(red_only, cv2.COLOR_BGR2GRAY)
    denoiseImg = cv2.fastNlMeansDenoising(gray, h=10)
    return cv2.resize(denoiseImg, (IMG_W, IMG_H))

def predenoise_red(img):
    img = cv2.fastNlMeansDenoisingColored(img, None, 10, 10, 7, 21)
    gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
    edges = cv2.Canny(gray, 50, 150)
    lines = cv2.HoughLines(edges, 1, np.pi / 180, 200)
    angle = 0
    if lines is not None:
        angles = [theta for rho, theta in lines[:, 0]]
        angle = (np.mean(angles) - np.pi / 2) * 180 / np.pi
    (h, w) = img.shape[:2]
    M = cv2.getRotationMatrix2D((w // 2, h // 2), angle, 1)
    img = cv2.warpAffine(img, M, (w, h), flags=cv2.INTER_CUBIC)
    b, g, r = cv2.split(img)
    red_only = cv2.merge([np.zeros_like(r), np.zeros_like(r), r])
    gray = cv2.cvtColor(red_only, cv2.COLOR_BGR2GRAY)
    return cv2.resize(gray, (IMG_W, IMG_H))

# ------------------------------------------------------------
# Dataset Wrapper
# ------------------------------------------------------------
class OCRDataset(Dataset):
    def __init__(self, hf_split, label_encoder, preprocess_fn):
        self.ds = hf_split
        self.label_encoder = label_encoder
        self.preprocess_fn = preprocess_fn

    def __len__(self):
        return len(self.ds)

    def __getitem__(self, idx):
        img = np.array(self.ds[idx]["image"])
        ocr_list = self.ds[idx]["ocr"]
        label_text = "".join([w["word"] for w in ocr_list if "word" in w and w["word"]])
        img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
        proc = self.preprocess_fn(img)
        x = torch.tensor(proc / 255.0).unsqueeze(0).float()
        y = torch.tensor(self.label_encoder.transform([label_text])[0])
        return x, y

# ------------------------------------------------------------
# Utility Functions
# ------------------------------------------------------------
def train_model(model, train_loader, n_epochs, save_path):
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=1e-3)
    for epoch in range(n_epochs):
        model.train()
        total_loss = 0
        for x, y in train_loader:
            x, y = x.to(DEVICE), y.to(DEVICE)
            optimizer.zero_grad()
            out = model(x)
            loss = criterion(out, y)
            loss.backward()
            optimizer.step()
            total_loss += loss.item() * x.size(0)
        print(f"Epoch {epoch + 1}/{n_epochs}, Loss: {total_loss / len(train_loader.dataset):.4f}")
    torch.save(model.state_dict(), save_path)

def evaluate_model(model, test_loader):
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for x, y in test_loader:
            x, y = x.to(DEVICE), y.to(DEVICE)
            out = model(x)
            preds = out.argmax(dim=1)
            correct += (preds == y).sum().item()
            total += y.size(0)
    return 100 * correct / total

def evaluate_latency(model, dataset, preprocess_fn, n=10):
    model.eval()
    latencies = []
    with torch.no_grad():
        for i in range(n):
            img = np.array(dataset[i]["image"])
            proc = preprocess_fn(cv2.cvtColor(img, cv2.COLOR_RGB2BGR))
            x = torch.tensor(proc / 255.0).unsqueeze(0).unsqueeze(0).float().to(DEVICE)
            t0 = time.perf_counter()
            _ = model(x)
            t1 = time.perf_counter()
            latencies.append((t1 - t0) * 1000)
    latency = np.mean(latencies)
    throughput = 1000 / latency
    return latency, throughput

# ------------------------------------------------------------
# Load dataset and prepare label encoder
# ------------------------------------------------------------
train_ds = load_dataset("lansinuote/ocr_id_card", split="train[:80%]")
test_ds = load_dataset("lansinuote/ocr_id_card", split="train[80%:]")

all_texts = []
for sample in train_ds:
    text = "".join([w["word"] for w in sample["ocr"] if "word" in w and w["word"]])
    if text:
        all_texts.append(text)

label_encoder = LabelEncoder().fit(all_texts)
nclass = len(label_encoder.classes_)

# ------------------------------------------------------------
# Train and Evaluate All Models
# ------------------------------------------------------------
models_info = [
    ("Model 1: Deskewed Red", deskew_red, "model1.pt"),
    ("Model 2: Deskewed Grey", deskew_grey, "model2.pt"),
    ("Model 3: Red+Denoise", deskew_red_denoise, "model3.pt"),
    ("Model 4: PreDenoise Red", predenoise_red, "model4.pt")
]

results = {}

for name, preprocess_fn, fname in models_info:
    print(f"\nTraining {name}")
    train_data = OCRDataset(train_ds, label_encoder, preprocess_fn)
    test_data = OCRDataset(test_ds, label_encoder, preprocess_fn)
    train_loader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True)
    test_loader = DataLoader(test_data, batch_size=BATCH_SIZE)

    model = OCRNet(nclass=nclass).to(DEVICE)
    train_model(model, train_loader, EPOCHS, SAVE_DIR / fname)
    acc = evaluate_model(model, test_loader)
    lat, thr = evaluate_latency(model, test_ds, preprocess_fn)
    results[name] = (acc, lat, thr)
    print(f"{name}: Accuracy={acc:.2f}%, Latency={lat:.2f}ms, Throughput={thr:.2f} img/s")

# ------------------------------------------------------------
# Plot Comparisons
# ------------------------------------------------------------
models = list(results.keys())
accuracy = [v[0] for v in results.values()]
latency = [v[1] for v in results.values()]
throughput = [v[2] for v in results.values()]

fig, ax1 = plt.subplots()
ax1.set_xlabel("Model Version")
ax1.set_ylabel("Accuracy (%)", color="tab:blue")
ax1.plot(models, accuracy, "o-", color="tab:blue", label="Accuracy")
ax1.tick_params(axis="y", labelcolor="tab:blue")

ax2 = ax1.twinx()
ax2.set_ylabel("Latency (ms)", color="tab:red")
ax2.plot(models, latency, "s--", color="tab:red", label="Latency")
ax2.tick_params(axis="y", labelcolor="tab:red")

plt.title("Model Accuracy vs Latency")
fig.autofmt_xdate(rotation=20)
plt.tight_layout()
plt.show()

fig, ax1 = plt.subplots()
ax1.set_xlabel("Model Version")
ax1.set_ylabel("Accuracy (%)", color="tab:blue")
ax1.plot(models, accuracy, "o-", color="tab:blue", label="Accuracy")
ax1.tick_params(axis="y", labelcolor="tab:blue")

ax2 = ax1.twinx()
ax2.set_ylabel("Throughput (images/sec)", color="tab:green")
ax2.plot(models, throughput, "d--", color="tab:green", label="Throughput")
ax2.tick_params(axis="y", labelcolor="tab:green")

plt.title("Model Accuracy vs Throughput")
fig.autofmt_xdate(rotation=20)
plt.tight_layout()
plt.show()



Training Model 1: Deskewed Red
Epoch 1/3, Loss: 9.8193
Epoch 2/3, Loss: 9.8085
Epoch 3/3, Loss: 9.7808


ValueError: y contains previously unseen labels: np.str_('萧逸齐男汉1984927湖北省宜昌市伍家岗区965652198409276669')