In [1]:
!pip install torch torchvision torchaudio --quiet
!pip install albumentations --quiet  # for augmentation (optional)

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as transforms

import albumentations as A
from albumentations.pytorch import ToTensorV2

import numpy as np
from PIL import Image
import os
import string


[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m363.4/363.4 MB[0m [31m4.1 MB/s[0m eta [36m0:00:00[0m0:00:01[0m00:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m664.8/664.8 MB[0m [31m2.4 MB/s[0m eta [36m0:00:00[0m0:00:01[0m00:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m211.5/211.5 MB[0m [31m7.2 MB/s[0m eta [36m0:00:00[0m0:00:01[0m00:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m56.3/56.3 MB[0m [31m30.7 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m127.9/127.9 MB[0m [31m13.5 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m207.5/207.5 MB[0m [31m2.6 MB/s[0m eta [36m0:00:00[0m0:00:01[0m00:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m21.1/21.1 MB[0m [31m78.9 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[?25h

  check_for_updates()


In [2]:
# Define Vietnamese characters set (you can expand it)
# Including a blank character '' for CTC at index 0
lowercase = "aăâbcdđeêghijklmnoôơpqrstuưvwxyz" \
            "áàảãạằắẳẵặấầẩẫậéèẻẽẹếềểễệíìỉĩịóòỏõọốồổỗộớờởỡợúùủũụứừửữựýỳỷỹỵ0123456789"

uppercase = lowercase.upper()

special_chars = "/!@#$%^&*()_+:,.-;?{}[]|~` "

full_alphabet = lowercase + uppercase + special_chars
print(full_alphabet)
# Map char to index and vice versa
char_to_idx = {char: idx + 1 for idx, char in enumerate(full_alphabet)}  # start at 1; 0 is blank for CTC
idx_to_char = {idx: char for char, idx in char_to_idx.items()}

# Add blank character at index 0
idx_to_char[0] = ''

aăâbcdđeêghijklmnoôơpqrstuưvwxyzáàảãạằắẳẵặấầẩẫậéèẻẽẹếềểễệíìỉĩịóòỏõọốồổỗộớờởỡợúùủũụứừửữựýỳỷỹỵ0123456789AĂÂBCDĐEÊGHIJKLMNOÔƠPQRSTUƯVWXYZÁÀẢÃẠẰẮẲẴẶẤẦẨẪẬÉÈẺẼẸẾỀỂỄỆÍÌỈĨỊÓÒỎÕỌỐỒỔỖỘỚỜỞỠỢÚÙỦŨỤỨỪỬỮỰÝỲỶỸỴ0123456789/!@#$%^&*()_+:,.-;?{}[]|~` 


In [3]:
import pandas as pd

class VietnameseOCRDataset(Dataset):
    def __init__(self, img_dir, labels_csv, transform=None):
        self.img_dir = img_dir
        self.transform = transform
        df = pd.read_csv(labels_csv, encoding='utf-8')
        self.samples = list(zip(df['image_name'], df['text']))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        img_name, label = self.samples[idx]
        img_path = os.path.join(self.img_dir, img_name)
        image = Image.open(img_path).convert('L')  # grayscale
    
        if self.transform:
            augmented = self.transform(image=np.array(image))  # pass as named argument
            image = augmented['image']                        # get transformed image tensor
    
        # Encode label string to list of indices
        label_idx = [char_to_idx[char] for char in label if char in char_to_idx]
    
        return image, torch.tensor(label_idx, dtype=torch.long)


In [4]:
transform = A.Compose([
    A.Resize(32, 2048),  # height fixed to 32, width 128 (adjust as needed)
    A.Normalize(mean=(0.5,), std=(0.5,)),
    ToTensorV2(),
])

In [5]:
train_dataset = VietnameseOCRDataset('/kaggle/input/vaipe-crops/vaipe_crops/train', '/kaggle/input/vaipe-crops/vaipe_crops/train.csv', transform=transform)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, collate_fn=lambda x: x)


In [6]:
class CRNN(nn.Module):
    def __init__(self, imgH, nc, nclass, nh):
        super(CRNN, self).__init__()
        assert imgH % 16 == 0, "imgH has to be a multiple of 16"

        self.cnn = nn.Sequential(
            nn.Conv2d(nc, 64, 3, 1, 1),  # conv1
            nn.ReLU(True),
            nn.MaxPool2d(2, 2),          # 32x128 -> 16x64

            nn.Conv2d(64, 128, 3, 1, 1), # conv2
            nn.ReLU(True),
            nn.MaxPool2d(2, 2),          # 16x64 -> 8x32

            nn.Conv2d(128, 256, 3, 1, 1), # conv3
            nn.ReLU(True),
            nn.Conv2d(256, 256, 3, 1, 1), # conv4
            nn.ReLU(True),
            nn.MaxPool2d((2,2), (2,1), (0,1)), # 8x32 -> 4x33

            nn.Conv2d(256, 512, 3, 1, 1), # conv5
            nn.BatchNorm2d(512),
            nn.ReLU(True),
            nn.Conv2d(512, 512, 3, 1, 1), # conv6
            nn.BatchNorm2d(512),
            nn.ReLU(True),
            nn.MaxPool2d((2,2), (2,1), (0,1)), # 4x33 -> 2x34

            nn.Conv2d(512, 512, 2, 1, 0),  # conv7 kernel=2 no padding
            nn.ReLU(True)
        )

        self.rnn = nn.LSTM(
            input_size=512,
            hidden_size=nh,
            num_layers=2,
            bidirectional=True,
            batch_first=True
        )

        self.embedding = nn.Linear(nh * 2, nclass)

    def forward(self, x):
        # x: (batch, channel=1, height, width)
        conv = self.cnn(x)  # [batch, 512, 1, width']
        b, c, h, w = conv.size()
        assert h == 1, "height after conv must be 1"
        conv = conv.squeeze(2)  # [batch, 512, width]
        conv = conv.permute(0, 2, 1)  # [batch, width, 512]

        rnn_out, _ = self.rnn(conv)  # [batch, width, nh*2]
        output = self.embedding(rnn_out)  # [batch, width, nclass]

        # output: logit sequence for CTC loss
        return output.log_softmax(2)  # for CTC loss: log prob on dim=2


In [7]:
device = torch.device('cuda')

In [8]:
model = CRNN(imgH=32, nc=1, nclass=len(full_alphabet) + 1, nh=256).to(device)
ctc_loss = nn.CTCLoss(blank=0, zero_infinity=True)
optimizer = optim.Adam(model.parameters(), lr=0.0005)

In [9]:
model.load_state_dict(torch.load('/kaggle/input/cnn-ocr-2/pytorch/default/4/best_crnn.pth'))

<All keys matched successfully>

In [10]:
def beam_search_decode(probs, beam_width=5, blank=0):
    import math
    from collections import defaultdict

    seq_len, batch_size, nclass = probs.size()
    decoded_batch = []

    for batch_idx in range(batch_size):
        beam = [(tuple(), 0.0)]

        for t in range(seq_len):
            new_beam = defaultdict(lambda: -math.inf)
            time_step_log_prob = probs[t, batch_idx].cpu().numpy()

            for seq, score in beam:
                for c in range(nclass):
                    p = time_step_log_prob[c]
                    if len(seq) > 0 and c == seq[-1]:
                        new_seq = seq
                    else:
                        new_seq = seq + (c,) if c != blank else seq
                    new_score = score + p
                    if new_score > new_beam[new_seq]:
                        new_beam[new_seq] = new_score

            beam = sorted(new_beam.items(), key=lambda x: x[1], reverse=True)[:beam_width]

        best_seq, best_score = beam[0]

        # Filter blanks and repeated characters here
        decoded = []
        prev = None
        for idx in best_seq:
            if idx != blank and idx != prev:
                # Defensive check in case idx_to_char missing key
                char = idx_to_char.get(idx, '')
                if char != '':
                    decoded.append(char)
            prev = idx

        decoded_str = "".join(decoded)
        decoded_batch.append(decoded_str)

    return decoded_batch


In [11]:
def clean_decoded_text(text, blank_char=''):
    """
    Remove duplicates and blanks if any remain.
    Assumes blank_char is '' (empty string) for blank token.
    """
    cleaned = []
    prev_char = None
    for ch in text:
        if ch != blank_char and ch != prev_char:
            cleaned.append(ch)
        prev_char = ch
    return ''.join(cleaned)

In [13]:
with torch.no_grad():
    sample_img, _ = train_dataset[3]
    sample_img = sample_img.unsqueeze(0).to(device)
    output = model(sample_img)
    decoded_texts = beam_search_decode(output, beam_width=10, blank=0)
    raw_text = decoded_texts
    cleaned_text = clean_decoded_text(raw_text)
    print("Raw decoded text:", raw_text)
    print("Cleaned text:", cleaned_text)

Raw decoded text: ['', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', 'M', '', '', '', 'M', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', 'H', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', 'g', 'g', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', ' ', ' ', ' ', '', '', '', '', '', '', '