In [None]:
!pip install torch torchvision torchaudio --quiet
!pip install albumentations --quiet  # for augmentation (optional)

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as transforms

import albumentations as A
from albumentations.pytorch import ToTensorV2

import numpy as np
from PIL import Image
import os
import string


In [None]:
batch_size = 32
lr = 1e-4

# Vietnamese Alphabet

In [None]:
# Define Vietnamese characters set (you can expand it)
# Including a blank character '' for CTC at index 0
lowercase = "aăâbcdđeêghijklmnoôơpqrstuưvwxyz" \
            "áàảãạằắẳẵặấầẩẫậéèẻẽẹếềểễệíìỉĩịóòỏõọốồổỗộớờởỡợúùủũụứừửữựýỳỷỹỵ0123456789"

uppercase = lowercase.upper()

special_chars = "/!@#$%^&*()_+:,.-;?{}[]|~` "

full_alphabet = lowercase + uppercase + special_chars
print(full_alphabet)
# Map char to index and vice versa
char_to_idx = {char: idx + 1 for idx, char in enumerate(full_alphabet)}  # start at 1; 0 is blank for CTC
idx_to_char = {idx: char for char, idx in char_to_idx.items()}

# Add blank character at index 0
idx_to_char[0] = ''

# Dataset Preparation

In [None]:
import pandas as pd

class VietnameseOCRDataset(Dataset):
    def __init__(self, img_dir, labels_csv, transform=None):
        self.img_dir = img_dir
        self.transform = transform
        df = pd.read_csv(labels_csv, encoding='utf-8')
        self.samples = list(zip(df['image_name'], df['text']))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        img_name, label = self.samples[idx]
        img_path = os.path.join(self.img_dir, img_name)
        image = Image.open(img_path).convert('L')  # grayscale
    
        if self.transform:
            augmented = self.transform(image=np.array(image))  # pass as named argument
            image = augmented['image']                        # get transformed image tensor
    
        # Encode label string to list of indices
        label_idx = [char_to_idx[char] for char in label if char in char_to_idx]
    
        return image, torch.tensor(label_idx, dtype=torch.long)


# Define Transformations (Resize + Normalize)

In [None]:
transform = A.Compose([
    A.Resize(32, 512),  # height fixed to 32, width 128 (adjust as needed)
    A.Normalize(mean=(0.5,), std=(0.5,)),
    ToTensorV2(),
])

# Instantiate Dataset and DataLoader

In [None]:
train_dataset = VietnameseOCRDataset('/kaggle/input/genrated-text/generated_text_recognition/train', '/kaggle/input/genrated-text/generated_text_recognition/train.csv', transform=transform)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=lambda x: x)
val_dataset = VietnameseOCRDataset('/kaggle/input/genrated-text/generated_text_recognition/val', '/kaggle/input/genrated-text/generated_text_recognition/val.csv', transform=transform)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True, collate_fn=lambda x: x)

# CRNN Model

In [None]:
class CRNN(nn.Module):
    def __init__(self, imgH, nc, nclass, nh):
        super(CRNN, self).__init__()
        assert imgH % 16 == 0, "imgH has to be a multiple of 16"

        self.cnn = nn.Sequential(
            nn.Conv2d(nc, 64, 3, 1, 1),  # conv1
            nn.ReLU(True),
            nn.MaxPool2d(2, 2),          # 32x128 -> 16x64

            nn.Conv2d(64, 128, 3, 1, 1), # conv2
            nn.ReLU(True),
            nn.MaxPool2d(2, 2),          # 16x64 -> 8x32

            nn.Conv2d(128, 256, 3, 1, 1), # conv3
            nn.ReLU(True),
            nn.Conv2d(256, 256, 3, 1, 1), # conv4
            nn.ReLU(True),
            nn.MaxPool2d((2,2), (2,1), (0,1)), # 8x32 -> 4x33

            nn.Conv2d(256, 512, 3, 1, 1), # conv5
            nn.BatchNorm2d(512),
            nn.ReLU(True),
            nn.Conv2d(512, 512, 3, 1, 1), # conv6
            nn.BatchNorm2d(512),
            nn.ReLU(True),
            nn.MaxPool2d((2,2), (2,1), (0,1)), # 4x33 -> 2x34

            nn.Conv2d(512, 512, 2, 1, 0),  # conv7 kernel=2 no padding
            nn.ReLU(True)
        )

        self.rnn = nn.LSTM(
            input_size=512,
            hidden_size=nh,
            num_layers=2,
            bidirectional=True,
            batch_first=True
        )

        self.embedding = nn.Linear(nh * 2, nclass)

    def forward(self, x):
        # x: (batch, channel=1, height, width)
        conv = self.cnn(x)  # [batch, 512, 1, width']
        b, c, h, w = conv.size()
        assert h == 1, "height after conv must be 1"
        conv = conv.squeeze(2)  # [batch, 512, width]
        conv = conv.permute(0, 2, 1)  # [batch, width, 512]

        rnn_out, _ = self.rnn(conv)  # [batch, width, nh*2]
        output = self.embedding(rnn_out)  # [batch, width, nclass]

        # output: logit sequence for CTC loss
        return output.log_softmax(2)  # for CTC loss: log prob on dim=2


# CTC Loss and Optimizer

In [None]:
device = torch.device('cuda')

In [None]:
model = CRNN(imgH=32, nc=1, nclass=len(full_alphabet) + 1, nh=256).to(device)
ctc_loss = nn.CTCLoss(blank=0, zero_infinity=True)
optimizer = optim.Adam(model.parameters(), lr=lr)

# Early Stopping

In [None]:
import numpy as np
import torch

class EarlyStopping:
    def __init__(self, patience=5, verbose=False, delta=0, path='checkpoint.pth'):
        """
        Args:
            patience (int): How many epochs to wait after last improvement.
            verbose (bool): If True, prints messages when validation loss improves.
            delta (float): Minimum change to qualify as improvement.
            path (str): Path to save the best model.
        """
        self.patience = patience
        self.verbose = verbose
        self.delta = delta
        self.path = path

        self.counter = 0
        self.best_loss = np.Inf
        self.early_stop = False
        self.best_model_state = None

    def __call__(self, val_loss, model):
        if val_loss < self.best_loss - self.delta:
            self.best_loss = val_loss
            self.best_model_state = model.state_dict()
            self.counter = 0
            if self.verbose:
                torch.save(self.best_model_state, self.path)
        else:
            self.counter += 1
            if self.verbose:
                if self.counter >= self.patience:
                    self.early_stop = True


# Train

In [None]:
from tqdm import tqdm

def validate_epoch(model, dataloader, ctc_loss, device):
    model.eval()
    total_loss = 0

    with torch.no_grad():
        for batch in dataloader:
            images, labels = zip(*batch)
            images = torch.stack(images).to(device)

            label_lengths = torch.tensor([len(l) for l in labels], dtype=torch.long)
            targets = torch.cat(labels).to(device)

            outputs = model(images)
            outputs = outputs.permute(1, 0, 2)  # (T, N, C)

            input_lengths = torch.full(size=(outputs.size(1),), fill_value=outputs.size(0), dtype=torch.long)

            loss = ctc_loss(outputs, targets, input_lengths, label_lengths)
            total_loss += loss.item()

    avg_loss = total_loss / len(dataloader)
    print(f"Validation loss: {avg_loss:.4f}")
    model.train()
    return avg_loss

def train_epoch(model, dataloader, optimizer, ctc_loss, device):
    model.train()
    total_loss = 0
    with tqdm(dataloader, unit="batch") as tepoch:
        for batch in tepoch:
            images, labels = zip(*batch)
            images = torch.stack(images).to(device)

            label_lengths = torch.tensor([len(l) for l in labels], dtype=torch.long)
            targets = torch.cat(labels).to(device)

            outputs = model(images)
            outputs = outputs.permute(1, 0, 2)

            input_lengths = torch.full(size=(outputs.size(1),), fill_value=outputs.size(0), dtype=torch.long)

            optimizer.zero_grad()
            loss = ctc_loss(outputs, targets, input_lengths, label_lengths)
            loss.backward()
            optimizer.step()

            loss_value = loss.item()
            total_loss += loss_value

            tepoch.set_postfix(loss=loss_value)
            

    avg_loss = total_loss / len(dataloader)
    print(f"Epoch loss: {avg_loss:.4f}")
    return avg_loss


In [None]:
num_epochs = 40
early_stopping = EarlyStopping(patience=5, verbose=True, path='best_crnn.pth')
for epoch in range(num_epochs):
    print(f"Epoch {epoch+1}/{num_epochs}")
    train_epoch(model, train_loader, optimizer, ctc_loss, device)
    val_loss = validate_epoch(model, val_loader, ctc_loss, device)
    early_stopping(val_loss, model)
    if early_stopping.early_stop:
        print("Early stopping triggered.")
        break

# Decoding

In [None]:
def beam_search_decode(probs, beam_width=5, blank=0):
    import math
    from collections import defaultdict

    seq_len, batch_size, nclass = probs.size()
    decoded_batch = []

    for batch_idx in range(batch_size):
        beam = [(tuple(), 0.0)]

        for t in range(seq_len):
            new_beam = defaultdict(lambda: -math.inf)
            time_step_log_prob = probs[t, batch_idx].cpu().numpy()

            for seq, score in beam:
                for c in range(nclass):
                    p = time_step_log_prob[c]
                    if len(seq) > 0 and c == seq[-1]:
                        new_seq = seq
                    else:
                        new_seq = seq + (c,) if c != blank else seq
                    new_score = score + p
                    if new_score > new_beam[new_seq]:
                        new_beam[new_seq] = new_score

            beam = sorted(new_beam.items(), key=lambda x: x[1], reverse=True)[:beam_width]

        best_seq, best_score = beam[0]

        # Filter blanks and repeated characters here
        decoded = []
        prev = None
        for idx in best_seq:
            if idx != blank and idx != prev:
                # Defensive check in case idx_to_char missing key
                char = idx_to_char.get(idx, '')
                if char != '':
                    decoded.append(char)
            prev = idx

        decoded_str = "".join(decoded)
        decoded_batch.append(decoded_str)

    return decoded_batch


# Inference

In [None]:
def clean_decoded_text(text, blank_char=''):
    """
    Remove duplicates and blanks if any remain.
    Assumes blank_char is '' (empty string) for blank token.
    """
    cleaned = []
    prev_char = None
    for ch in text:
        if ch != blank_char and ch != prev_char:
            cleaned.append(ch)
        prev_char = ch
    return ''.join(cleaned)

In [None]:
import matplotlib.pyplot as plt

def process_image(image_path):
    image = Image.open(image_path).convert('L')  # grayscale
    transform_image = transform(image=np.array(image))
    image = transform_image['image']
    if hasattr(image, 'numpy'):
        old_image = image.numpy()

    # If image is (C, H, W), squeeze or select channel for grayscale
    if old_image.ndim == 3:
        old_image = old_image.squeeze()  # for single-channel, or use image[0] if you want

    plt.imshow(old_image, cmap='gray')
    plt.axis('off')
    plt.show()
    return image

In [None]:
with torch.no_grad():
    image = process_image('/kaggle/input/test-image/Times-New-Roman-Font.png')
    image = image.unsqueeze(0).to(device)
    output = model(image)
    decoded_texts = beam_search_decode(output, beam_width=10, blank=0)
    raw_text = decoded_texts
    cleaned_text = clean_decoded_text(raw_text)
    print("Raw decoded text:", raw_text)
    print("Cleaned text:", cleaned_text)