In [1]:
import torch, torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from datasets import load_dataset
import numpy as np, cv2, joblib
from pathlib import Path
from tqdm import tqdm

device = torch.device("cpu")

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
ds = load_dataset("lansinuote/ocr_id_card")
maxSamples = 1000
numCardsForDict = 1000 #len(ds['train'])

In [3]:
all_words = []
for k in range(numCardsForDict):
    words = [item['word'] for item in ds['train'][k]['ocr']]
    text = "".join(words)
    all_words.append(text)

# Collect unique characters
vocabulary = sorted(set("".join(all_words)))

# Build dictionaries
char_to_idx = {}
idx_to_char = {}

i = 1
for c in vocabulary:
    char_to_idx[c] = i
    idx_to_char[i] = c
    i += 1

idx_to_char[0] = "<BLANK>"
num_classes = len(idx_to_char)
print("Vocabulary size:", num_classes)

Vocabulary size: 1473


In [4]:
class OCRDataset(Dataset):
    def __init__(self, hf_dataset, char_to_idx, max_samples=None):
        self.ds = hf_dataset
        self.char_to_idx = char_to_idx
        self.max_samples = max_samples or len(hf_dataset)

    def __len__(self):
        return self.max_samples

    def text_to_indices(self, text):
        indices = []
        for c in text:
            if c in self.char_to_idx:
                indices.append(self.char_to_idx[c])
        return indices

    def __getitem__(self, idx):
        item = self.ds[idx]
        image = np.array(item["image"])
        text = "".join([w["word"] for w in item["ocr"]])

        # Convert to grayscale and resize (128x32 works well)
        gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
        img_resized = cv2.resize(gray, (128, 32))
        img_tensor = torch.tensor(img_resized, dtype=torch.float32).unsqueeze(0) / 255.0

        target = torch.tensor(self.text_to_indices(text), dtype=torch.long)
        return img_tensor, target

In [5]:
class DeskewedGrayscaleCNN(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(1, 32, 3, 1, 1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),  # (B, 32, 16, 64)
            nn.Conv2d(32, 64, 3, 1, 1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(2, 2)   # (B, 64, 8, 32)
        )
        self.lstm = nn.LSTM(64*8, 128, num_layers=2, bidirectional=True, batch_first=True)
        self.fc = nn.Linear(256, num_classes)

    def forward(self, x):
        feats = self.features(x)
        b, c, h, w = feats.size()
        feats = feats.permute(0, 3, 1, 2).reshape(b, w, -1)  # (B, W, C*H)
        out, _ = self.lstm(feats)
        out = self.fc(out)
        return out

In [6]:
train_dataset = OCRDataset(ds["train"], char_to_idx, max_samples=maxSamples)  # small subset for demo
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True, collate_fn=lambda x: x)

In [7]:
model = DeskewedGrayscaleCNN(num_classes=num_classes).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
ctc_loss = nn.CTCLoss(blank=0, zero_infinity=True)

In [8]:
for epoch in range(3):  # Increase this later
    model.train()
    total_loss = 0.0
    for batch in tqdm(train_loader, desc=f"Epoch {epoch+1}"):
        images = [b[0] for b in batch]
        targets = [b[1] for b in batch]

        x = torch.stack(images).to(device)
        targets_concat = torch.cat(targets).to(device)
        target_lengths = torch.tensor([len(t) for t in targets], dtype=torch.long).to(device)

        logits = model(x)
        log_probs = F.log_softmax(logits, dim=2)

        input_lengths = torch.full(
            size=(log_probs.size(0),),
            fill_value=log_probs.size(1),
            dtype=torch.long
        ).to(device)

        loss = ctc_loss(log_probs.permute(1, 0, 2), targets_concat, input_lengths, target_lengths)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    print(f"Epoch {epoch+1} - Average loss: {total_loss / len(train_loader):.4f}")

Epoch 1: 100%|██████████| 125/125 [00:18<00:00,  6.92it/s]


Epoch 1 - Average loss: 0.0000


Epoch 2: 100%|██████████| 125/125 [00:24<00:00,  5.03it/s]


Epoch 2 - Average loss: 0.0000


Epoch 3: 100%|██████████| 125/125 [00:22<00:00,  5.59it/s]

Epoch 3 - Average loss: 0.0000





In [9]:
save_dir = Path("saved_model")
torch.save(model.state_dict(), save_dir / "model.joblib")
joblib.dump({
    "char_to_idx": char_to_idx,
    "idx_to_char": idx_to_char,
    "num_classes": num_classes
}, save_dir / "vocab.joblib")


['saved_model\\vocab.joblib']

In [10]:
def ctc_decode(indices, idx_to_char):
    decoded = []
    prev = None
    for i in indices:
        char = idx_to_char.get(int(i), "")
        if char != "<BLANK>" and char != prev:
            decoded.append(char)
        prev = char
    return "".join(decoded)

data = joblib.load("saved_model/vocab.joblib")
char_to_idx = data["char_to_idx"]
idx_to_char = data["idx_to_char"]
num_classes = data["num_classes"]

model_loaded = DeskewedGrayscaleCNN(num_classes=num_classes)
model_loaded.load_state_dict(torch.load("saved_model/model.joblib", map_location=device))
model_loaded.to(device)


DeskewedGrayscaleCNN(
  (features): Sequential(
    (0): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (4): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (5): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (6): ReLU()
    (7): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (lstm): LSTM(512, 128, num_layers=2, batch_first=True, bidirectional=True)
  (fc): Linear(in_features=256, out_features=1473, bias=True)
)

In [11]:
model_loaded.eval()
with torch.no_grad():
    img, target = train_dataset[0]
    x = img.unsqueeze(0).to(device)
    logits = model_loaded(x)
    log_probs = F.log_softmax(logits, dim=2)
    pred_indices = torch.argmax(log_probs, dim=2)[0].cpu().numpy()
    pred_text = ctc_decode(pred_indices, idx_to_char)

    actual_text = "".join([idx_to_char[i.item()] for i in target])
    print("Predicted:", pred_text)
    print("Actual:", actual_text)

Predicted: 邓兵何辛兵圣芷特藤枫
Actual: 窦加强女汉19881221湖南省衡阳市珠晖区641853198812215365
