In [2]:
import cv2
import os
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import torchvision.transforms as transforms

class OCRDataset(Dataset):
    def __init__(self, img_paths: List[str], labels: List[str], transform=None):
        assert len(img_paths) == len(labels)
        self.img_paths = img_paths
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.img_paths)

    def __getitem__(self, idx):
        img = Image.open(self.img_paths[idx]).convert("L")
        if self.transform:
            img = self.transform(img)
        label = self.labels[idx]
        return img, label

In [None]:
TRAIN_ROOT = "/kaggle/input/sintetic-texts/ocr_dataset/train"  
TEST_ROOT  = "/kaggle/input/sintetic-texts-multibackground/ocr_dataset/test"   
OUT_DIR    = "/kaggele/output/checkpoints"

os.makedirs(OUT_DIR, exist_ok=True)

all_labels = train_labels + val_labels + test_labels
char_to_idx, idx_to_char = build_charset(all_labels)
num_classes = max(char_to_idx.values()) + 1

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def read_labels_file(label_file: str, img_dir: str) -> Tuple[List[str], List[str]]:
    img_paths, labels = [], []
    with open(label_file, "r", encoding="utf-8") as f:
        for line in f:
            line = line.strip()
            if not line:
                continue
            parts = line.split(maxsplit=1)
            if len(parts) == 1:
                fname, text = parts[0], ""
            else:
                fname, text = parts
            p = os.path.join(img_dir, fname)
            if os.path.exists(p):
                img_paths.append(p)
                labels.append(text)
    return img_paths, labels


def load_datasets(root: str = TRAIN_ROOT, test_root: str = TEST_ROOT,
                  val_split: float = 0.1, seed: int = 42):
    random.seed(seed)
    train_paths, train_labels = [], []
    val_paths, val_labels = [], []
    train_levels = []  # parallel list of 'A'/'B'/'C' for curriculum

    for subset in ["A", "B", "C"]:
        dir_subset = os.path.join(root, subset)
        img_dir = os.path.join(dir_subset, "images")
        label_file = os.path.join(dir_subset, "labels.txt")
        imgs, labs = read_labels_file(label_file, img_dir)

        # shuffle inside subset to avoid ordering by background complexity
        combined = list(zip(imgs, labs))
        random.shuffle(combined)
        imgs, labs = zip(*combined) if combined else ([], [])

        n_val = int(len(imgs) * val_split)
        if n_val > 0:
            val_paths.extend(imgs[:n_val])
            val_labels.extend(labs[:n_val])

        train_paths.extend(imgs[n_val:])
        train_labels.extend(labs[n_val:])
        train_levels.extend([subset] * (len(imgs) - n_val))

    # test
    test_img_dir = os.path.join(test_root, "images")
    test_label_file = os.path.join(test_root, "labels.txt")
    test_paths, test_labels = read_labels_file(test_label_file, test_img_dir)

    return (train_paths, train_labels, train_levels), (val_paths, val_labels), (test_paths, test_labels)


(train_paths, train_labels, train_levels), (val_paths, val_labels), (test_paths, test_labels) = load_datasets()

In [None]:
def build_charset(all_labels: List[str], extra_chars: str = ""):
    chars = set()
    for s in all_labels:
        chars.update(list(s))
    chars.update(list(extra_chars))
    chars = sorted(chars)
    # blank = 0
    char_to_idx = {c: i + 1 for i, c in enumerate(chars)}  # start from 1
    idx_to_char = {i + 1: c for i, c in enumerate(chars)}
    # blank represented by 0
    return char_to_idx, idx_to_char

In [None]:
def ctc_collate(batch, char_to_idx):
    imgs, labels = zip(*batch)
    imgs = torch.stack(imgs, dim=0)

    targets = []
    target_lengths = []
    for lab in labels:
        idxs = []
        for ch in lab:
            # unknown char -> skip (could also map to some special token)
            if ch in char_to_idx:
                idxs.append(char_to_idx[ch])
        targets.extend(idxs)
        target_lengths.append(len(idxs))

    if len(targets) == 0:
        targets_tensor = torch.zeros(0, dtype=torch.long)
    else:
        targets_tensor = torch.tensor(targets, dtype=torch.long)

    target_lengths = torch.tensor(target_lengths, dtype=torch.long)
    return imgs, targets_tensor, target_lengths, list(labels)



In [None]:
def greedy_ctc_decode(logits, idx_to_char, blank_index=0):
    with torch.no_grad():
        # choose most probable class at each time step
        preds = torch.argmax(logits, dim=2)  # [T, B]
        preds = preds.transpose(0, 1).cpu().numpy()  # [B, T]
    decoded = []
    for seq in preds:
        last = None
        out_chars = []
        for p in seq:
            if p == last:
                continue
            last = p
            if p != blank_index:
                ch = idx_to_char.get(int(p), "")
                out_chars.append(ch)
        decoded.append("".join(out_chars))
    return decoded


def edit_distance(a: str, b: str) -> int:
    la, lb = len(a), len(b)
    if la == 0: return lb
    if lb == 0: return la
    dp = [[0] * (lb + 1) for _ in range(la + 1)]
    for i in range(la + 1):
        dp[i][0] = i
    for j in range(lb + 1):
        dp[0][j] = j
    for i in range(1, la + 1):
        for j in range(1, lb + 1):
            cost = 0 if a[i - 1] == b[j - 1] else 1
            dp[i][j] = min(dp[i - 1][j] + 1,      # del
                           dp[i][j - 1] + 1,      # ins
                           dp[i - 1][j - 1] + cost)  # sub
    return dp[la][lb]


def cer_batch(preds: List[str], targets: List[str]) -> float:
    total_ed = 0
    total_chars = 0
    for p, t in zip(preds, targets):
        total_ed += edit_distance(p, t)
        total_chars += max(1, len(t))
    return total_ed / total_chars


In [12]:
import torch
import torch.nn as nn

class CRNN(nn.Module):
    def __init__(self, img_h=32, num_channels=1, num_classes=37, hidden_size=128):
        super(CRNN, self).__init__()
        
        # CNN 
        self.cnn = nn.Sequential(
            nn.Conv2d(num_channels, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),    

            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2), 

            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.MaxPool2d((2, 1), (2, 1)), 

            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
        )
        
        # BiLSTM
        self.rnn = nn.LSTM(
            input_size=256*4,
            hidden_size=hidden_size,
            num_layers=2,
            bidirectional=True,
            batch_first=True
        )
        
        # Классификатор
        self.fc = nn.Linear(hidden_size*2, num_classes)

    def forward(self, x):
        # x: [B, 1, H, W]
        conv = self.cnn(x)        
        b, c, h, w = conv.size()
        
        conv = conv.permute(0, 3, 1, 2) 
        conv = conv.reshape(b, w, c*h) 
        
        rnn_out, _ = self.rnn(conv)  
        logits = self.fc(rnn_out)
        
        logits = logits.permute(1, 0, 2) 
        return logits

In [None]:
def make_curriculum_weights(level_labels: List[str], epoch: int, max_epoch: int):
    alpha = min(epoch / max_epoch, 1.0)
    weights_map = {
        "A": max(0.05, 1.0 - 0.8 * alpha), 
        "B": 0.2 + 0.4 * alpha,
        "C": 0.1 + 0.6 * alpha
    }
    return [weights_map.get(lvl, 0.1) for lvl in level_labels]


In [None]:
def train_loop(
    model: nn.Module,
    device: torch.device,
    train_ds_tuple,
    val_ds_tuple,
    char_to_idx,
    idx_to_char,
    batch_size: int = 32,
    num_epochs: int = 30,
    lr: float = 1e-3,
    max_epoch_for_curriculum: int = 20,
    save_best_path: str = os.path.join(OUT_DIR, "best.pth"),
):
    (train_paths, train_labels, train_levels) = train_ds_tuple
    (val_paths, val_labels) = val_ds_tuple

    # transforms (same as before)
    transform = T.Compose([
        T.Resize((32, 128)),
        T.ToTensor(),
        T.Normalize((0.5,), (0.5,))
    ])

    train_ds = OCRDataset(train_paths, train_labels, transform)
    val_ds = OCRDataset(val_paths, val_labels, transform)

    # optimizer, loss
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    ctc_loss = nn.CTCLoss(blank=0, zero_infinity=True).to(device)

    best_val_loss = float("inf")

    for epoch in range(1, num_epochs + 1):
        # --- build curriculum sampler for this epoch ---
        weights = make_curriculum_weights(train_levels, epoch, max_epoch_for_curriculum)
        sampler = WeightedRandomSampler(weights, num_samples=len(weights), replacement=True)

        train_loader = DataLoader(train_ds, batch_size=batch_size, sampler=sampler,
                                  collate_fn=lambda b: ctc_collate(b, char_to_idx),
                                  num_workers=4, pin_memory=True)

        val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False,
                                collate_fn=lambda b: ctc_collate(b, char_to_idx),
                                num_workers=2, pin_memory=True)

        # --- train epoch ---
        model.train()
        total_train_loss = 0.0
        n_train_batches = 0
        for imgs, targets, target_lengths, raw_labels in train_loader:
            imgs = imgs.to(device)                       
            targets = targets.to(device)
            target_lengths = target_lengths.to(device)

            optimizer.zero_grad()
            logits = model(imgs)  
            T_time, B_batch, C_classes = logits.size()
            input_lengths = torch.full((B_batch,), T_time, dtype=torch.long).to(device)

            log_probs = logits.log_softmax(2)  # along class dim

            loss = ctc_loss(log_probs, targets, input_lengths, target_lengths)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0)
            optimizer.step()

            total_train_loss += loss.item()
            n_train_batches += 1

        avg_train_loss = total_train_loss / max(1, n_train_batches)

        # validation 
        model.eval()
        total_val_loss = 0.0
        n_val_batches = 0
        total_cer = 0.0
        with torch.no_grad():
            for imgs, targets, target_lengths, raw_labels in val_loader:
                imgs = imgs.to(device)
                targets = targets.to(device)
                target_lengths = target_lengths.to(device)

                logits = model(imgs)
                T_time, B_batch, C_classes = logits.size()
                input_lengths = torch.full((B_batch,), T_time, dtype=torch.long).to(device)
                log_probs = logits.log_softmax(2)

                loss = ctc_loss(log_probs, targets, input_lengths, target_lengths)
                total_val_loss += loss.item()
                n_val_batches += 1

                # decode predictions and compute CER
                # use greedy decoding
                decoded = greedy_ctc_decode(logits, idx_to_char, blank_index=0)
                # raw_labels are original strings
                total_cer += cer_batch(decoded, raw_labels)

        avg_val_loss = total_val_loss / max(1, n_val_batches)
        avg_cer = total_cer / max(1, n_val_batches)

        print(f"Epoch {epoch}/{num_epochs} | train_loss: {avg_train_loss:.4f} | val_loss: {avg_val_loss:.4f} | val_CER: {avg_cer:.4f}")

        # save best
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            torch.save({
                "epoch": epoch,
                "model_state": model.state_dict(),
                "optimizer_state": optimizer.state_dict(),
                "char_to_idx": char_to_idx,
                "idx_to_char": idx_to_char,
                "val_loss": avg_val_loss
            }, save_best_path)
            print(f"Saved best model at epoch {epoch} with val_loss {avg_val_loss:.4f}")

    print("Training finished.")

In [None]:
model = CRNN(img_h=32, num_channels=1, num_classes=num_classes, hidden_size=128)
model = model.to(device)

train_loop(
    model=model,
    device=device,
    train_ds_tuple=(train_paths, train_labels, train_levels),
    val_ds_tuple=(val_paths, val_labels),
    char_to_idx=char_to_idx,
    idx_to_char=idx_to_char,
    batch_size=32,
    num_epochs=30,
    lr=1e-3,
    max_epoch_for_curriculum=20,
    save_best_path=os.path.join(OUT_DIR, "best_crnn.pth")
)