<a href="https://colab.research.google.com/github/sakshamsaxena22/Hand_Written_equation_solver/blob/main/HandwrittenEquationTransformerOCRTranining.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [16]:
# Transformer OCR with Pretrained CNN Encoder (Equation Recognition)
# Uses pretrained_cnn_encoder_32.pth from symbol pretraining


# =========================
# 1. SETUP
# =========================
!pip install torch torchvision albumentations opencv-python tqdm


import os
import cv2
import math
import torch
import numpy as np
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from tqdm import tqdm


DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'



In [17]:
import kagglehub

DATASET_ROOT = kagglehub.dataset_download("cutedeadu/hme100k")
print("HME100K root:", DATASET_ROOT)

# IMPORTANT:
# In this dataset, DATASET_ROOT already points to the folder containing images
# There is NO nested images/ directory beyond this level

# Detect label file dynamically
LABEL_FILE = None
for fname in os.listdir(DATASET_ROOT):
    if fname.lower() in {"labels.txt", "annotations.txt", "train.txt"}:
        LABEL_FILE = os.path.join(DATASET_ROOT, fname)
        break

assert LABEL_FILE is not None, "No label file found (labels.txt / annotations.txt)"

IMG_DIR = DATASET_ROOT

print("Images directory:", IMG_DIR)
print("Label file:", LABEL_FILE)

Using Colab cache for faster access to the 'hme100k' dataset.
HME100K root: /kaggle/input/hme100k
Images directory: /kaggle/input/hme100k
Label file: /kaggle/input/hme100k/train.txt


In [18]:
# =========================
# 3. CHARSET
# =========================
CHARSET = "0123456789+-=*/()xyz"
char2idx = {c: i + 1 for i, c in enumerate(CHARSET)}  # 0 = CTC blank
idx2char = {i: c for c, i in char2idx.items()}


# =========================
# 4. DATASET CLASS
# =========================
# class EquationDataset(Dataset):
#     def __init__(self, img_dir, label_file):
#         self.img_dir = img_dir
#         self.samples = []

#         with open(label_file, 'r', encoding='utf-8') as f:
#             for line in f:
#                 parts = line.strip().split('\t')
#                 if len(parts) != 2:
#                     continue
#                 fname, label = parts
#                 img_path = os.path.join(self.img_dir, fname)
#                 if os.path.isfile(img_path):
#                     self.samples.append((fname, label))

#         assert len(self.samples) > 0, "No valid image-label pairs found"

#         self.transform = transforms.Compose([
#             transforms.ToTensor(),
#             transforms.Normalize((0.5,), (0.5,))
#         ])

#         print(f"Loaded {len(self.samples)} samples")

#     def encode(self, text):
#         return torch.tensor([char2idx[c] for c in text if c in char2idx], dtype=torch.long)

#     def __len__(self):
#         return len(self.samples)

#     def __getitem__(self, idx):
#         fname, label = self.samples[idx]
#         img = cv2.imread(os.path.join(self.img_dir, fname), cv2.IMREAD_GRAYSCALE)
#         img = cv2.resize(img, (256, 64))
#         img = self.transform(img)
#         return img, self.encode(label), len(label)


class EquationDataset(Dataset):
    def __init__(self, img_dir, label_file):
        self.img_dir = img_dir
        self.samples = []

        with open(label_file, 'r', encoding='utf-8') as f:
            for line in f:
                parts = line.strip().split('\t')
                if len(parts) != 2:
                    continue

                fname, label = parts
                img_path = os.path.join(self.img_dir, fname)
                if not os.path.isfile(img_path):
                    continue

                # Encode ONCE, here
                encoded = [char2idx[c] for c in label if c in char2idx]

                # Drop samples that collapse after filtering
                if len(encoded) == 0:
                    continue

                self.samples.append((fname, torch.tensor(encoded, dtype=torch.long)))

        assert len(self.samples) > 0, "No valid samples after filtering"
        print(f"Loaded {len(self.samples)} samples")

        self.transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5,), (0.5,))
        ])

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        fname, encoded = self.samples[idx]

        img = cv2.imread(os.path.join(self.img_dir, fname), cv2.IMREAD_GRAYSCALE)
        img = cv2.resize(img, (192, 64))
        img = self.transform(img)

        return img, encoded, encoded.numel()


In [19]:
 #=========================
# 5. PRETRAINED CNN ENCODER
# =========================
class CNNEncoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(1, 32, 3, padding=1), nn.BatchNorm2d(32), nn.ReLU(), nn.MaxPool2d(2),
            nn.Conv2d(32, 64, 3, padding=1), nn.BatchNorm2d(64), nn.ReLU(), nn.MaxPool2d(2),
            nn.Conv2d(64, 128, 3, padding=1), nn.BatchNorm2d(128), nn.ReLU(), nn.MaxPool2d(2)
        )

    def forward(self, x):
        return self.features(x)

In [24]:
# =========================
# 6. TRANSFORMER OCR MODEL
# =========================
class TransformerOCR(nn.Module):
    def __init__(self, num_classes, d_model=256, nhead=8, num_layers=4):
        super().__init__()

        self.cnn = CNNEncoder()
        self.cnn.features.load_state_dict(
            torch.load("pretrained_cnn_encoder_32.pth", map_location="cpu")
        )

        # Freeze CNN initially
        for p in self.cnn.parameters():
            p.requires_grad = False

        self.proj = nn.Linear(128 * 8, d_model)
        self.pos_enc = PositionalEncoding(d_model)

        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model, nhead=nhead, batch_first=True
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)

        self.classifier = nn.Linear(d_model, num_classes)

    def forward(self, x):
        x = self.cnn(x)               # [B, 128, 8, W]
        b, c, h, w = x.size()
        x = x.permute(0, 3, 1, 2).contiguous().view(b, w, c * h)
        x = self.proj(x)
        x = self.pos_enc(x)
        x = self.transformer(x)
        return self.classifier(x)

In [25]:
# =========================
# 7. POSITIONAL ENCODING
# =========================
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=1000):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        pos = torch.arange(0, max_len).unsqueeze(1)
        div = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(pos * div)
        pe[:, 1::2] = torch.cos(pos * div)
        self.register_buffer('pe', pe.unsqueeze(0))

    def forward(self, x):
        return x + self.pe[:, :x.size(1)]

In [26]:
# =========================
# 8. COLLATE FUNCTION
# =========================
def collate(batch):
    imgs, labels, lens = zip(*batch)
    imgs = torch.stack(imgs)
    labels = torch.cat(labels)
    target_lengths = torch.tensor(lens)
    return imgs, labels, target_lengths


In [27]:
# =========================
# 9. TRAINING LOOP (FIXED)
# =========================
def train():
    dataset = EquationDataset(IMG_DIR, LABEL_FILE)
    loader = DataLoader(
        dataset,
        batch_size=8,
        shuffle=True,
        collate_fn=collate
    )

    model = TransformerOCR(len(CHARSET) + 1).to(DEVICE)

    # ‚úÖ CORRECT: zero_infinity belongs HERE
    criterion = nn.CTCLoss(blank=0, zero_infinity=True)

    optimizer = torch.optim.AdamW(
        filter(lambda p: p.requires_grad, model.parameters()),
        lr=3e-4
    )

    for epoch in range(10):
        model.train()
        total_loss = 0.0

        for imgs, targets, target_lengths in tqdm(loader):
            imgs = imgs.to(DEVICE)
            targets = targets.to(DEVICE)
            target_lengths = target_lengths.to(DEVICE)

            # Forward
            logits = model(imgs)
            log_probs = logits.log_softmax(2)

            # CTC input lengths (time dimension)
            input_lengths = torch.full(
                size=(log_probs.size(0),),
                fill_value=log_probs.size(1),
                dtype=torch.long,
                device=DEVICE
            )

            # ‚úÖ CORRECT: forward call has ONLY 4 arguments
            loss = criterion(
                log_probs.permute(1, 0, 2),  # (T, B, C)
                targets,                     # (sum(target_lengths))
                input_lengths,               # (B,)
                target_lengths               # (B,)
            )

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        print(f"Epoch {epoch + 1} | Loss: {total_loss / len(loader):.4f}")

    torch.save(
        model.state_dict(),
        "transformer_ocr_with_pretrained_cnn.pth"
    )
    print("Model saved")


train()


Loaded 98086 samples


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 12261/12261 [09:22<00:00, 21.80it/s]


Epoch 1 | Loss: 2.5354


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 12261/12261 [05:52<00:00, 34.76it/s]


Epoch 2 | Loss: 1.9229


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 12261/12261 [05:53<00:00, 34.67it/s]


Epoch 3 | Loss: 1.6821


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 12261/12261 [05:51<00:00, 34.93it/s]


Epoch 4 | Loss: 1.5362


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 12261/12261 [05:55<00:00, 34.46it/s]


Epoch 5 | Loss: 1.4394


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 12261/12261 [05:51<00:00, 34.92it/s]


Epoch 6 | Loss: 1.3629


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 12261/12261 [05:49<00:00, 35.06it/s]


Epoch 7 | Loss: 1.3055


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 12261/12261 [05:49<00:00, 35.11it/s]


Epoch 8 | Loss: 1.2549


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 12261/12261 [05:49<00:00, 35.03it/s]


Epoch 9 | Loss: 1.2182


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 12261/12261 [05:49<00:00, 35.11it/s]


Epoch 10 | Loss: 1.1861
Model saved


In [28]:
def ctc_greedy_decode(log_probs, idx2char):
    # log_probs: (T, C)
    pred = log_probs.argmax(dim=1).cpu().numpy()

    decoded = []
    prev = 0  # CTC blank = 0
    for p in pred:
        if p != prev and p != 0:
            decoded.append(idx2char[p])
        prev = p

    return "".join(decoded)


In [36]:
model= TransformerOCR(len(CHARSET) + 1).to(DEVICE)
model.load_state_dict(torch.load("transformer_ocr_with_pretrained_cnn.pth"))

<All keys matched successfully>

In [35]:
# Recreate dataset for inference
dataset = EquationDataset(IMG_DIR, LABEL_FILE)

print("Dataset size:", len(dataset))


Loaded 98086 samples
Dataset size: 98086


In [37]:
model.eval()

# Take one sample
img, label_encoded, _ = dataset[0]
img = img.unsqueeze(0).to(DEVICE)

with torch.no_grad():
    logits = model(img)
    log_probs = logits.log_softmax(2)[0]   # (T, C)

prediction = ctc_greedy_decode(log_probs, idx2char)

# Convert true label
true_label = "".join([idx2char[i.item()] for i in label_encoded])

print("Prediction :", prediction)
print("Ground Truth:", true_label)


Prediction : 22+=2+()2
Ground Truth: (2)2+4=24+()2


Unfreezing last layer of transformer+cnn ocr


In [41]:
# =========================
# 9. TRAINING LOOP (FINE-TUNING)
# =========================
def train_finetune():

    # Dataset + Loader
    dataset = EquationDataset(IMG_DIR, LABEL_FILE)
    loader = DataLoader(
        dataset,
        batch_size=8,
        shuffle=True,
        collate_fn=collate
    )

    # 1Ô∏è‚É£ Recreate model architecture
    model = TransformerOCR(len(CHARSET) + 1).to(DEVICE)

    # 2Ô∏è‚É£ Load trained weights
    model.load_state_dict(
        torch.load(
            "/content/transformer_ocr_with_pretrained_cnn.pth",
            map_location=DEVICE
        )
    )

    # 3Ô∏è‚É£ Freeze EVERYTHING first
    for p in model.parameters():
        p.requires_grad = False

    # 4Ô∏è‚É£ Unfreeze ONLY last Transformer block
    for name, p in model.named_parameters():
        if "transformer.layers.3" in name:   # last encoder layer
            p.requires_grad = True

    # Always keep classifier trainable
    for p in model.classifier.parameters():
        p.requires_grad = True

    # 5Ô∏è‚É£ Loss (correct usage)
    criterion = nn.CTCLoss(blank=0, zero_infinity=True)

    # 6Ô∏è‚É£ Optimizer with LOWER LR
    optimizer = torch.optim.AdamW(
        filter(lambda p: p.requires_grad, model.parameters()),
        lr=1e-4
    )

    # üî• Fine-tune ONLY 1‚Äì2 epochs
    EPOCHS = 2

    for epoch in range(EPOCHS):
        model.train()
        total_loss = 0.0

        for imgs, targets, target_lengths in tqdm(loader):
            imgs = imgs.to(DEVICE)
            targets = targets.to(DEVICE)
            target_lengths = target_lengths.to(DEVICE)

            logits = model(imgs)
            log_probs = logits.log_softmax(2)

            input_lengths = torch.full(
                (log_probs.size(0),),
                log_probs.size(1),
                dtype=torch.long,
                device=DEVICE
            )

            loss = criterion(
                log_probs.permute(1, 0, 2),
                targets,
                input_lengths,
                target_lengths
            )

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        print(f"Fine-tune Epoch {epoch + 1} | Loss: {total_loss / len(loader):.4f}")

    # 7Ô∏è‚É£ Save updated model
    torch.save(
        model.state_dict(),
        "/content/transformer_ocr_with_pretrained_cnn_finetuned.pth"
    )

    print("Fine-tuned model saved")


train_finetune()


Loaded 98086 samples


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 12261/12261 [04:41<00:00, 43.54it/s]


Fine-tune Epoch 1 | Loss: 1.0805


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 12261/12261 [04:39<00:00, 43.87it/s]


Fine-tune Epoch 2 | Loss: 1.0575
Fine-tuned model saved


In [48]:
model= TransformerOCR(len(CHARSET) + 1).to(DEVICE)
model.load_state_dict(torch.load("transformer_ocr_with_pretrained_cnn_finetuned.pth"))

<All keys matched successfully>

In [46]:
model.eval()

# Take one sample
img, label_encoded, _ = dataset[0]
img = img.unsqueeze(0).to(DEVICE)

with torch.no_grad():
    logits = model(img)
    log_probs = logits.log_softmax(2)[0]   # (T, C)

prediction = ctc_greedy_decode(log_probs, idx2char)

# Convert true label
true_label = "".join([idx2char[i.item()] for i in label_encoded])

print("Prediction :", prediction)
print("Ground Truth:", true_label)


Prediction : 22+=2+()2
Ground Truth: (2)2+4=24+()2
