#Connect drive

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


#Vocab

In [None]:
import torch
from torch.nn import functional

class Vocab:
    def __init__(self, paths):
        self.chars = None
        self.image_paths = []
        self.labels = []
        self.load_dataset(paths)

    def load_dataset(self, paths):
        if isinstance(paths, str):
            paths = [paths]

        for path in paths:
            with open(path, 'r', encoding='utf-8') as f:
                for line in f:
                    parts = line.strip().split('\t')
                    if len(parts) != 2:
                        continue
                    img_path, label = parts
                    self.image_paths.append(img_path)
                    self.labels.append(label)
        print(f"Total {len(self.image_paths)} images loaded from {len(paths)} file(s)")

    def create_vocab(self):
        letters = "".join(self.labels)
        unique_chars = sorted(set(letters))

        self.chars = "".join(unique_chars)

        self.char_2_idx = {char: idx + 2 for idx, char in enumerate(self.chars)}
        self.blank_index = 1  # index cho CTC blank
        self.idx_2_char = {idx: char for char, idx in self.char_2_idx.items()}

        print(f"Vocab size (with blank): {len(self.chars) + 1}")
        print(f"char_2_idx: {self.char_2_idx}")

    def encode(self, input_sequence):
        max_label_len = max(len(label) for label in self.labels)
        encoded = torch.tensor(
            [self.char_2_idx.get(char, 0) for char in input_sequence], dtype=torch.long  #  -> 0 (pad)
        )
        label_len = len(encoded)
        lengths = torch.tensor(label_len, dtype=torch.long)
        padded = functional.pad(encoded, (0, max_label_len - label_len))
        return padded, lengths

    def decode(self, encode_sequences):
        decode_sequences = []

        for seq in encode_sequences:
            decode_label = []
            prev_token = None
            if isinstance(seq, torch.Tensor):
                seq = seq.tolist()

            for token in seq:
                if token != self.blank_index and token != 0 and token != prev_token:
                    char = self.idx_2_char.get(token, '')
                    decode_label.append(char)
                prev_token = token
            decode_sequences.append(''.join(decode_label))
        return decode_sequences

    def idx_2_labels(self, labels):
        idx_2_labels = []
        for label in labels:
            if isinstance(label, torch.Tensor):
                label = label.tolist()
            chars = [self.idx_2_char.get(idx, '') for idx in label if idx != 0]
            idx_2_labels.append(''.join(chars))
        return idx_2_labels

vocab = Vocab([
    "/content/drive/MyDrive/scene-text-ocr/train.txt",
    "/content/drive/MyDrive/scene-text-ocr/val.txt",
    "/content/drive/MyDrive/scene-text-ocr/test.txt"
])
vocab.create_vocab()


#Dataset

In [None]:
from PIL import Image
from torch.utils.data import Dataset

class OCRDataSet(Dataset):
    def __init__(self, mode = 'train', label_encoder = None, transform= None):
        self.transform = transform
        self.image_paths = []
        self.labels = []
        self.label_encoder = label_encoder
        if mode == 'train':
            path = "/content/drive/MyDrive/scene-text-ocr/train.txt"
        elif mode == 'val':
            path = "/content/drive/MyDrive/scene-text-ocr/val.txt"
        elif mode == 'test':
            path = "/content/drive/MyDrive/scene-text-ocr/test.txt"
        else:
            raise NotImplementedError

        with open(path, 'r', encoding='utf-8') as f:
            for line in f:
                label = line.strip().split('\t')[1]
                img_path = line.strip().split('\t')[0]
                self.image_paths.append(img_path)
                self.labels.append(label)


    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        label = self.labels[idx]
        image = Image.open(os.path.join(image_path)).convert('RGB')
        if self.transform:
            image = self.transform(image)

        if self.label_encoder:
            encoded_label, label_len = self.label_encoder(label)
            return image, encoded_label, label_len
        else:
            exit(-1)

if __name__ == '__main__':
    vocab = Vocab([
        "/content/drive/MyDrive/scene-text-ocr/train.txt",
        "/content/drive/MyDrive/scene-text-ocr/val.txt",
        "/content/drive/MyDrive/scene-text-ocr/test.txt"
    ])
    vocab.create_vocab()
    hw = OCRDataSet(mode = 'train', label_encoder = vocab.encode)


#Backbone

In [27]:
import torch
import torch.nn as nn
from torchvision.models import resnet34

class ResNet34Backbone(nn.Module):
    def __init__(self, pretrained=True):
        super().__init__()
        base_model = resnet34(pretrained=pretrained)

        old_conv1 = base_model.conv1

        new_conv1 = nn.Conv2d(
            in_channels=1,
            out_channels=old_conv1.out_channels,
            kernel_size=old_conv1.kernel_size,
            stride=old_conv1.stride,
            padding=old_conv1.padding,
            bias=old_conv1.bias is not None,
        )

        if pretrained:
            with torch.no_grad():
                new_conv1.weight = nn.Parameter(old_conv1.weight.mean(dim=1, keepdim=True))

        base_model.conv1 = new_conv1

        self.feature_extractor = nn.Sequential(
            base_model.conv1,
            base_model.bn1,
            base_model.relu,
            base_model.maxpool,
            base_model.layer1,
            base_model.layer2,
            base_model.layer3,
            base_model.layer4,
        )

        self.pool = nn.AdaptiveAvgPool2d((1, None))

    def forward(self, x):
        x = self.feature_extractor(x)  # (B, 512, H', W')
        x = self.pool(x)               # (B, 512, 1, W)
        return x


#Sequence head

In [28]:
import torch
import torch.nn as nn

class BiLSTM(nn.Module):
    def __init__(self, vocab_size, hidden_size, n_layers, dropout=0.3):
        super(BiLSTM, self).__init__()
        self.lstm = nn.LSTM(
            input_size=512,
            hidden_size=hidden_size,
            bidirectional=True,
            num_layers=n_layers,
            dropout=dropout if n_layers > 1 else 0,
        )
        self.layer_norm = nn.LayerNorm(hidden_size * 2)
        self.out = nn.Sequential(
            nn.Linear(hidden_size * 2, vocab_size),
            nn.LogSoftmax(dim=2),
        )

    def forward(self, x):
        x = x.permute(0, 3, 1, 2)  # (B, W, C, 1)
        x = x.view(x.size(0), x.size(1), -1)  # (B, W, 768)
        x, _ = self.lstm(x)  # (B, W, 2*hidden)
        x = self.layer_norm(x)
        x = self.out(x)
        x = x.permute(1, 0, 2)  # (W, B, Class)
        return x


#Model

In [29]:
import torch
import torch.nn as nn

class OCRModel(nn.Module):
    def __init__(self, vocab_size, hidden_size, n_layers, dropout=0.3, pretrained_backbone=True):
        super(OCRModel, self).__init__()
        self.backbone = ResNet34Backbone(pretrained=pretrained_backbone)
        self.lstm = BiLSTM(
            vocab_size=vocab_size,
            hidden_size=hidden_size,
            n_layers=n_layers,
            dropout=dropout,
        )

    def forward(self, x):
        x = self.backbone(x)
        x = self.lstm(x)
        return x

if __name__ == '__main__':
    model = OCRModel(vocab_size=38, hidden_size=512, n_layers=2)
    dummy_input = torch.randn(1, 1, 100, 420)  # (B, C, H, W)
    output = model(dummy_input)
    print("Final output shape:", output.shape)


Final output shape: torch.Size([14, 1, 38])


#Trainer

In [9]:
!pip install jiwer



In [30]:
import os
import torch
import torch.nn.functional as F
from jiwer import cer
from tqdm import tqdm

def calculate_cer(preds, target):
    total_cer = 0
    for pred, target in zip(preds, target):
        total_cer += cer(target, pred)
    return total_cer / len(preds)

class Trainer:
    def __init__(self, model, train_loader, val_loader, criterion, optimizer, scheduler, device, epochs, vocab_util, log_dir, save_path):
        self.model = model
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.criterion = criterion
        self.optimizer = optimizer
        self.scheduler = scheduler
        self.device = device
        self.epochs = epochs
        self.vocab_util = vocab_util
        self.save_path = save_path
        self.log_dir = log_dir
        os.makedirs(self.save_path, exist_ok=True)
        os.makedirs(self.log_dir, exist_ok=True)
        self.train_log_path = os.path.join(self.log_dir, "train_log.txt")
        self.val_log_path = os.path.join(self.log_dir, "val_log.txt")
        self.best_acc = 1e10
        self.start_epoch = 0

    def load_checkpoint(self, path):
        if os.path.exists(path):
            checkpoint = torch.load(path)
            self.model.load_state_dict(checkpoint["model_state_dict"])
            self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
            self.scheduler.load_state_dict(checkpoint["scheduler_state_dict"])
            self.start_epoch = checkpoint["epoch"]
            self.best_acc = checkpoint["best_acc"]
            print(f"Loaded checkpoint '{path}' (epoch {self.start_epoch})")

    def save_checkpoint(self, epoch, best_acc, is_best=False):
        checkpoint = {
            "epoch": epoch,
            "best_acc": best_acc,
            "model_state_dict": self.model.state_dict(),
            "optimizer_state_dict": self.optimizer.state_dict(),
            "scheduler_state_dict": self.scheduler.state_dict(),
        }
        torch.save(checkpoint, os.path.join(self.save_path, "last_model.pt"))
        if is_best:
            torch.save(checkpoint, os.path.join(self.save_path, "best_model.pt"))

    def log(self, text, train=True):
        path = self.train_log_path if train else self.val_log_path
        with open(path, "a", encoding="utf-8") as f:
            f.write(text + "\n")

    def fit(self):
        checkpoint_path = os.path.join(self.save_path, "last_model.pt")
        self.load_checkpoint(checkpoint_path)
        for epoch in range(self.start_epoch, self.epochs):
            self.model.train()
            batch_train_losses = []
            train_preds, train_targets = [], []
            pbar_train = tqdm(self.train_loader, desc=f"Train Epoch {epoch}", leave=False)
            for inputs, encoded_labels, labels_len in pbar_train:
                inputs = inputs.to(self.device)
                encoded_labels = encoded_labels.to(self.device)

                self.optimizer.zero_grad()
                outputs = self.model(inputs) # (T, B, C)

                logit_lens = torch.full(
                    size=(outputs.size(1),),
                    fill_value=outputs.size(0),
                    dtype=torch.long
                ).to(self.device)

                loss = self.criterion(outputs, encoded_labels, logit_lens, labels_len)
                loss.backward()
                torch.nn.utils.clip_grad_norm_(self.model.parameters(), 5)
                self.optimizer.step()
                batch_train_losses.append(loss.item())

                pred_labels = torch.argmax(outputs, dim=2).permute(1, 0)  # (B, T)
                pred_texts = self.vocab_util.decode(pred_labels)
                train_preds.extend(pred_texts)

                target_texts = self.vocab_util.idx_2_labels(encoded_labels)
                train_targets.extend(target_texts)

                batch_cer = calculate_cer(pred_texts, target_texts)
                pbar_train.set_postfix(loss=loss.item(), cer=batch_cer)

            train_loss_avg = sum(batch_train_losses) / len(batch_train_losses)
            train_cer = calculate_cer(train_preds, train_targets)

            train_log = f"Epoch {epoch} - Train Loss: {train_loss_avg:.4f}, Train CER: {train_cer:.4f}"
            # print(train_log)
            self.log(train_log, train=True)

            val_loss, val_cer = self.evaluate(epoch)

            val_log = f"Epoch {epoch} - Val Loss: {val_loss:.4f}, Val CER: {val_cer:.4f}"
            # print(val_log)
            self.log(val_log, train=False)

            self.save_checkpoint(epoch + 1, self.best_acc, is_best=False)

            if val_cer < self.best_acc:
                self.best_acc = val_cer
                self.save_checkpoint(epoch + 1, self.best_acc, is_best=True)

            self.scheduler.step()

    def evaluate(self, epoch=None):
        self.model.eval()
        losses = []
        val_preds, val_targets = [], []
        pbar_val = tqdm(self.val_loader, desc=f"Val Epoch {epoch}" if epoch is not None else "Validation", leave=False)
        with torch.no_grad():
            for inputs, encoded_labels, labels_len in pbar_val:
                inputs = inputs.to(self.device)
                encoded_labels = encoded_labels.to(self.device)

                outputs = self.model(inputs)
                logit_lens = torch.full(
                    size=(outputs.size(1),),
                    fill_value=outputs.size(0),
                    dtype=torch.long
                ).to(self.device)

                loss = self.criterion(outputs, encoded_labels, logit_lens, labels_len)
                losses.append(loss.item())

                pred_labels = torch.argmax(outputs, dim=2).permute(1, 0)
                pred_texts = self.vocab_util.decode(pred_labels)
                val_preds.extend(pred_texts)

                target_texts = self.vocab_util.idx_2_labels(encoded_labels)
                val_targets.extend(target_texts)

                batch_cer = calculate_cer(pred_texts, target_texts)
                pbar_val.set_postfix(loss=loss.item(), cer=batch_cer)

        loss_avg = sum(losses) / len(losses)
        cer_score = calculate_cer(val_preds, val_targets)
        return loss_avg, cer_score


#Preprocessing and transform

In [31]:
import cv2
import numpy as np
from PIL import Image
from skimage.morphology import skeletonize
import random
import albumentations as A
from torchvision import transforms

data_transforms = {
    "train": transforms.Compose(
        [
            transforms.Resize((100, 420)),
            transforms.ColorJitter(
                brightness=0.5,
                contrast=0.5,
                saturation=0.5,
            ),
            transforms.Grayscale(
                num_output_channels=1,
            ),
            transforms.GaussianBlur(3),
            transforms.RandomAffine(
                degrees=1,
                shear=1,
            ),
            transforms.RandomPerspective(
                distortion_scale=0.3,
                p=0.5,
                interpolation=3,
            ),
            transforms.RandomRotation(degrees=2),
            transforms.ToTensor(),
            transforms.Normalize((0.5,), (0.5,)),
        ]
    ),
    "val": transforms.Compose(
        [
            transforms.Resize((100, 420)),
            transforms.Grayscale(num_output_channels=1),
            transforms.ToTensor(),
            transforms.Normalize((0.5,), (0.5,)),
        ]
    ),
}

#Train

In [32]:
from torch.utils.data import DataLoader

train_dataset = OCRDataSet(
    mode = 'train',
    label_encoder = vocab.encode,
    transform = data_transforms['train']
)

val_dataset = OCRDataSet(
    mode = 'val',
    label_encoder = vocab.encode,
    transform = data_transforms['val']
)

train_batch_size = 32
val_batch_size = 32


train_loader = DataLoader(
    train_dataset,
    batch_size=train_batch_size,
    shuffle=True,
    num_workers=8,
    pin_memory=True
)

val_loader = DataLoader(
    val_dataset,
    batch_size=val_batch_size,
    shuffle=False,
    num_workers=8,
    pin_memory=True
)

In [36]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR

epochs = 100
learning_rate = 1e-3
weight_decay = 1e-5
step_size = epochs * 0.4
gamma = 0.1

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = OCRModel(
    vocab_size=38,
    hidden_size=512,
    n_layers=2,
)

model.to(device)

ctc_loss = nn.CTCLoss(blank=1, zero_infinity=True)

optimizer = optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)

scheduler = StepLR(optimizer, step_size=step_size, gamma=gamma)

log_dir = "/content/drive/MyDrive/logs"
save_path = "/content/drive/MyDrive/checkpoints"

trainer = Trainer(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    criterion=ctc_loss,
    optimizer=optimizer,
    scheduler=scheduler,
    device=device,
    epochs=epochs,
    vocab_util=vocab,
    log_dir=log_dir,
    save_path= save_path
)

trainer.fit()