# Faster R-CNN for CAPTCHA Recognition

End-to-end character detection and recognition.

Training: Images + Bounding boxes -> Learn detection + recognition  
Testing: Images only -> Character sequence

Total Loss = RPN Loss (objectness + bbox regression) + Classifier Loss


In [8]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.ops import RoIAlign, nms, box_iou
from PIL import Image
import matplotlib.pyplot as plt
from tqdm import tqdm
import matplotlib.patches as patches

torch.manual_seed(42)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Device: {device}')


Device: cpu


## 1. Dataset


In [9]:
CHARS = '0123456789abcdefghijklmnopqrstuvwxyz'
CHAR_TO_IDX = {c: i for i, c in enumerate(CHARS)}
IDX_TO_CHAR = {i: c for i, c in enumerate(CHARS)}
NUM_CLASSES = len(CHARS)

class CaptchaDataset(Dataset):
    def __init__(self, img_dir, label_dir=None, is_test=False):
        self.img_dir = img_dir
        self.label_dir = label_dir
        self.is_test = is_test
        self.files = [f for f in os.listdir(img_dir) if f.endswith('.png')]
        self.to_tensor = transforms.ToTensor()
        self.normalize = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    
    def __len__(self):
        return len(self.files)
    
    def __getitem__(self, idx):
        img_name = self.files[idx]
        img = Image.open(os.path.join(self.img_dir, img_name)).convert('RGB')
        w, h = img.size
        text = img_name.split('-')[0]
        
        boxes, labels = [], []
        if not self.is_test and self.label_dir:
            label_file = os.path.join(self.label_dir, img_name.replace('.png', '.txt'))
            with open(label_file, 'r') as f:
                for i, line in enumerate(f):
                    if line.strip() and i < len(text):
                        _, cx, cy, bw, bh = map(float, line.strip().split())
                        boxes.append([(cx-bw/2)*w, (cy-bh/2)*h, (cx+bw/2)*w, (cy+bh/2)*h])
                        labels.append(CHAR_TO_IDX[text[i].lower()])
        
        img = self.normalize(self.to_tensor(img))
        return {
            'image': img,
            'boxes': torch.tensor(boxes, dtype=torch.float32),
            'labels': torch.tensor(labels, dtype=torch.long),
            'text': text
        }

def collate_fn(batch):
    max_h = max(item['image'].shape[1] for item in batch)
    max_w = max(item['image'].shape[2] for item in batch)
    
    images = []
    for item in batch:
        img = item['image']
        c, h, w = img.shape
        padded = torch.zeros(c, max_h, max_w)
        padded[:, :h, :w] = img
        images.append(padded)
    
    return {
        'images': torch.stack(images),
        'boxes_list': [item['boxes'] for item in batch],
        'labels_list': [item['labels'] for item in batch],
        'texts': [item['text'] for item in batch],
        'img_sizes': [(item['image'].shape[1], item['image'].shape[2]) for item in batch]
    }

train_dataset = CaptchaDataset('../dataset/train', '../Segmented_dataset/train_labels', is_test=False)
test_dataset = CaptchaDataset('../dataset/test', label_dir=None, is_test=True)

print(f'Train: {len(train_dataset)}, Test: {len(test_dataset)}')


Train: 7764, Test: 1943


## 2. Helper Functions


In [None]:
def match_proposals_to_gt(proposals, gt_boxes, gt_labels, pos_th=0.5, neg_th=0.3):
    if len(gt_boxes) == 0:
        return torch.full((len(proposals),), NUM_CLASSES, dtype=torch.long, device=proposals.device), proposals.clone()
    
    ious = box_iou(proposals, gt_boxes)
    max_iou, max_idx = ious.max(dim=1)
    
    matched_labels = torch.full((len(proposals),), -1, dtype=torch.long, device=proposals.device)
    matched_labels[max_iou >= pos_th] = gt_labels[max_idx[max_iou >= pos_th]]
    matched_labels[max_iou < neg_th] = NUM_CLASSES
    
    return matched_labels, gt_boxes[max_idx]

def sample_proposals(proposals, labels, n=256, pos_frac=0.5):
    pos_idx = torch.where((labels >= 0) & (labels < NUM_CLASSES))[0]
    neg_idx = torch.where(labels == NUM_CLASSES)[0]
    
    n_pos = min(int(n * pos_frac), len(pos_idx))
    n_neg = min(n - n_pos, len(neg_idx))
    
    sampled = []
    if len(pos_idx) > 0:
        sampled.append(pos_idx[torch.randperm(len(pos_idx), device=pos_idx.device)[:n_pos]])
    if len(neg_idx) > 0:
        sampled.append(neg_idx[torch.randperm(len(neg_idx), device=neg_idx.device)[:n_neg]])
    
    return torch.cat(sampled) if sampled else torch.tensor([], dtype=torch.long, device=labels.device)


## 3. RPN


In [11]:
class RPN(nn.Module):
    def __init__(self, in_ch=512, n_anchors=2):
        super().__init__()
        self.anchor_sizes = [(30, 50), (40, 60)]
        self.conv = nn.Conv2d(in_ch, 512, 3, padding=1)
        self.cls = nn.Conv2d(512, n_anchors * 2, 1)
        self.reg = nn.Conv2d(512, n_anchors * 4, 1)
    
    def generate_anchors(self, fm_size, img_size, device):
        fh, fw = fm_size
        img_h, img_w = img_size
        stride_h, stride_w = img_h / fh, img_w / fw
        
        anchors = []
        for i in range(fh):
            for j in range(fw):
                cx, cy = (j + 0.5) * stride_w, (i + 0.5) * stride_h
                for aw, ah in self.anchor_sizes:
                    anchors.append([cx-aw/2, cy-ah/2, cx+aw/2, cy+ah/2])
        return torch.tensor(anchors, dtype=torch.float32, device=device)
    
    def forward(self, x, img_sizes):
        B, _, H, W = x.shape
        x = F.relu(self.conv(x))
        
        obj = self.cls(x).permute(0, 2, 3, 1).reshape(B, -1, 2)
        deltas = self.reg(x).permute(0, 2, 3, 1).reshape(B, -1, 4)
        
        anchors_list = [self.generate_anchors((H, W), sz, x.device) for sz in img_sizes]
        return obj, deltas, anchors_list
    
    def apply_deltas(self, anchors, deltas):
        w = anchors[:, 2] - anchors[:, 0]
        h = anchors[:, 3] - anchors[:, 1]
        cx = anchors[:, 0] + 0.5 * w
        cy = anchors[:, 1] + 0.5 * h
        
        dx, dy, dw, dh = deltas.unbind(dim=1)
        pred_cx = dx * w + cx
        pred_cy = dy * h + cy
        pred_w = torch.exp(dw.clamp(max=4)) * w
        pred_h = torch.exp(dh.clamp(max=4)) * h
        
        return torch.stack([pred_cx-0.5*pred_w, pred_cy-0.5*pred_h, 
                           pred_cx+0.5*pred_w, pred_cy+0.5*pred_h], dim=1)
    
    def compute_targets(self, anchors, gt_boxes):
        aw = anchors[:, 2] - anchors[:, 0]
        ah = anchors[:, 3] - anchors[:, 1]
        acx = anchors[:, 0] + 0.5 * aw
        acy = anchors[:, 1] + 0.5 * ah
        
        gw = gt_boxes[:, 2] - gt_boxes[:, 0]
        gh = gt_boxes[:, 3] - gt_boxes[:, 1]
        gcx = gt_boxes[:, 0] + 0.5 * gw
        gcy = gt_boxes[:, 1] + 0.5 * gh
        
        return torch.stack([(gcx-acx)/aw, (gcy-acy)/ah, 
                           torch.log(gw/aw), torch.log(gh/ah)], dim=1)


## 4. Model


In [12]:
class FasterRCNN(nn.Module):
    def __init__(self, num_classes=36):
        super().__init__()
        self.num_classes = num_classes
        
        self.backbone = nn.Sequential(
            nn.Conv2d(3, 64, 3, padding=1), nn.BatchNorm2d(64), nn.ReLU(), nn.MaxPool2d(2),
            nn.Conv2d(64, 128, 3, padding=1), nn.BatchNorm2d(128), nn.ReLU(), nn.MaxPool2d(2),
            nn.Conv2d(128, 256, 3, padding=1), nn.BatchNorm2d(256), nn.ReLU(),
            nn.Conv2d(256, 256, 3, padding=1), nn.BatchNorm2d(256), nn.ReLU(), nn.MaxPool2d(2),
            nn.Conv2d(256, 512, 3, padding=1), nn.BatchNorm2d(512), nn.ReLU(),
            nn.Conv2d(512, 512, 3, padding=1), nn.BatchNorm2d(512), nn.ReLU()
        )
        
        self.rpn = RPN(512, 2)
        self.roi_align = RoIAlign((7, 7), spatial_scale=1/8, sampling_ratio=2)
        
        self.classifier = nn.Sequential(
            nn.Linear(512*7*7, 1024), nn.ReLU(), nn.Dropout(0.5),
            nn.Linear(1024, 512), nn.ReLU(), nn.Dropout(0.5),
            nn.Linear(512, num_classes + 1)
        )
    
    def forward(self, images, img_sizes, gt_boxes_list=None, gt_labels_list=None, mode='train'):
        features = self.backbone(images)
        obj, deltas, anchors_list = self.rpn(features, img_sizes)
        
        if mode == 'train':
            return self._train(features, obj, deltas, anchors_list, 
                             gt_boxes_list, gt_labels_list, img_sizes)
        else:
            return self._test(features, obj, deltas, anchors_list, img_sizes)
    
    def _train(self, feat, obj, deltas, anchors_list, gt_boxes_list, gt_labels_list, img_sizes):
        B = len(anchors_list)
        rpn_cls_loss = rpn_reg_loss = 0
        all_props, all_labels, all_bidx = [], [], []
        
        for i in range(B):
            anchors = anchors_list[i]
            gt_boxes = gt_boxes_list[i]
            gt_labels = gt_labels_list[i]
            
            props = self.rpn.apply_deltas(anchors, deltas[i])
            h, w = img_sizes[i]
            props[:, 0::2].clamp_(0, w)
            props[:, 1::2].clamp_(0, h)
            
            if len(gt_boxes) > 0:
                # RPN loss
                anc_labels, matched_gt = match_proposals_to_gt(anchors, gt_boxes, gt_labels, 0.7, 0.3)
                valid = anc_labels >= 0
                if valid.sum() > 0:
                    rpn_cls_loss += F.cross_entropy(obj[i][valid], 
                                                    (anc_labels[valid] < self.num_classes).long())
                
                pos = (anc_labels >= 0) & (anc_labels < self.num_classes)
                if pos.sum() > 0:
                    targets = self.rpn.compute_targets(anchors[pos], matched_gt[pos])
                    rpn_reg_loss += F.smooth_l1_loss(deltas[i][pos], targets)
                
                # Classifier samples
                prop_labels, _ = match_proposals_to_gt(props, gt_boxes, gt_labels, 0.5, 0.3)
                sampled = sample_proposals(props, prop_labels, 128)
                
                if len(sampled) > 0:
                    all_props.append(props[sampled])
                    all_labels.append(prop_labels[sampled])
                    all_bidx.extend([i] * len(sampled))
        
        cls_loss = 0
        if all_props:
            all_props = torch.cat(all_props)
            all_labels = torch.cat(all_labels)
            all_bidx = torch.tensor(all_bidx, dtype=torch.float32, device=feat.device)
            
            roi_boxes = torch.cat([all_bidx.unsqueeze(1), all_props], dim=1)
            roi_feat = self.roi_align(feat, roi_boxes).view(len(all_props), -1)
            preds = self.classifier(roi_feat)
            cls_loss = F.cross_entropy(preds, all_labels)
        
        return {
            'rpn_cls': rpn_cls_loss / B,
            'rpn_reg': rpn_reg_loss / B,
            'cls': cls_loss
        }
    
    def _test(self, feat, obj, deltas, anchors_list, img_sizes):
        batch_props = []
        
        for i in range(len(anchors_list)):
            scores = F.softmax(obj[i], dim=1)[:, 1]
            props = self.rpn.apply_deltas(anchors_list[i], deltas[i])
            
            h, w = img_sizes[i]
            props[:, 0::2].clamp_(0, w)
            props[:, 1::2].clamp_(0, h)
            
            keep = scores > 0.5
            props, scores = props[keep], scores[keep]
            
            if len(props) > 0:
                keep_nms = nms(props, scores, 0.3)
                props = props[keep_nms[:10]]
            
            batch_props.append(props)
        
        roi_boxes = []
        for bi, props in enumerate(batch_props):
            if len(props) > 0:
                bidx = torch.full((len(props), 1), bi, dtype=torch.float32, device=props.device)
                roi_boxes.append(torch.cat([bidx, props], dim=1))
        
        if not roi_boxes:
            return [], []
        
        roi_boxes = torch.cat(roi_boxes)
        roi_feat = self.roi_align(feat, roi_boxes).view(len(roi_boxes), -1)
        preds = self.classifier(roi_feat)
        
        boxes_list, preds_list = [], []
        start = 0
        for props in batch_props:
            n = len(props)
            if n > 0:
                preds_list.append(preds[start:start+n])
                boxes_list.append(props)
                start += n
            else:
                preds_list.append(torch.empty(0, self.num_classes+1, device=feat.device))
                boxes_list.append(torch.empty(0, 4, device=feat.device))
        
        return boxes_list, preds_list

model = FasterRCNN(NUM_CLASSES).to(device)
print(f'Parameters: {sum(p.numel() for p in model.parameters()):,}')


Parameters: 33,105,201


## 5. Training


In [13]:
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, collate_fn=collate_fn)
test_loader = DataLoader(test_dataset, batch_size=4, shuffle=False, collate_fn=collate_fn)

optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)

def train_epoch(model, loader, optimizer):
    model.train()
    total = rpn_c = rpn_r = cls = 0
    
    for batch in tqdm(loader, desc='Training'):
        images = batch['images'].to(device)
        boxes = [b.to(device) for b in batch['boxes_list']]
        labels = [l.to(device) for l in batch['labels_list']]
        
        optimizer.zero_grad()
        losses = model(images, batch['img_sizes'], boxes, labels, 'train')
        
        loss = losses['rpn_cls'] + losses['rpn_reg'] + losses['cls']
        if loss > 0:
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 5.0)
            optimizer.step()
            
            total += loss.item() if torch.is_tensor(loss) else loss
            rpn_c += losses['rpn_cls'].item() if torch.is_tensor(losses['rpn_cls']) else losses['rpn_cls']
            rpn_r += losses['rpn_reg'].item() if torch.is_tensor(losses['rpn_reg']) else losses['rpn_reg']
            cls += losses['cls'].item() if torch.is_tensor(losses['cls']) else losses['cls']
    
    n = len(loader)
    return {'total': total/n, 'rpn_cls': rpn_c/n, 'rpn_reg': rpn_r/n, 'cls': cls/n}

def evaluate(model, loader):
    model.eval()
    correct = total = 0
    
    with torch.no_grad():
        for batch in tqdm(loader, desc='Eval'):
            images = batch['images'].to(device)
            boxes_list, preds_list = model(images, batch['img_sizes'], mode='test')
            
            for boxes, preds, text in zip(boxes_list, preds_list, batch['texts']):
                if len(preds) > 0:
                    pred_cls = preds.argmax(1)
                    valid = pred_cls < NUM_CLASSES
                    
                    if valid.sum() > 0:
                        x = boxes[valid][:, 0].cpu().numpy()
                        sorted_idx = np.argsort(x)
                        pred_text = ''.join([IDX_TO_CHAR[pred_cls[valid][i].item()] 
                                           for i in sorted_idx])
                        if pred_text.lower() == text.lower():
                            correct += 1
                total += 1
    
    return 100 * correct / total if total > 0 else 0

print('Ready to train')


Ready to train


## 6. Run Training


In [14]:
NUM_EPOCHS = 20
best_acc = 0

for epoch in range(NUM_EPOCHS):
    print(f'\nEpoch {epoch+1}/{NUM_EPOCHS}')
    losses = train_epoch(model, train_loader, optimizer)
    print(f'Loss: {losses["total"]:.4f} (RPN_cls:{losses["rpn_cls"]:.4f} '
          f'RPN_reg:{losses["rpn_reg"]:.4f} Cls:{losses["cls"]:.4f})')
    
    if (epoch + 1) % 5 == 0:
        acc = evaluate(model, test_loader)
        print(f'Test Acc: {acc:.2f}%')
        if acc > best_acc:
            best_acc = acc
            torch.save(model.state_dict(), 'faster_rcnn.pth')
            print('Saved')
    
    scheduler.step()

print(f'\nBest: {best_acc:.2f}%')



Epoch 1/20


Training:  31%|███       | 601/1941 [27:17<1:00:50,  2.72s/it]


KeyboardInterrupt: 

## 7. Evaluation


In [None]:
model.load_state_dict(torch.load('faster_rcnn.pth'))
acc = evaluate(model, test_loader)
print(f'\nTest Accuracy: {acc:.2f}%')


## 8. Visualization


In [None]:
def visualize(model, dataset, idx):
    model.eval()
    sample = dataset[idx]
    
    with torch.no_grad():
        img = sample['image'].unsqueeze(0).to(device)
        h, w = sample['image'].shape[1], sample['image'].shape[2]
        boxes_list, preds_list = model(img, [(h, w)], mode='test')
    
    boxes = boxes_list[0].cpu() if len(boxes_list[0]) > 0 else torch.empty(0, 4)
    preds = preds_list[0]
    
    # Denormalize
    img = sample['image']
    mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
    std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
    img = torch.clamp(img * std + mean, 0, 1).permute(1, 2, 0).numpy()
    
    # Get text
    pred_text = ''
    valid_boxes = boxes
    if len(preds) > 0:
        pred_cls = preds.argmax(1)
        valid = pred_cls < NUM_CLASSES
        if valid.sum() > 0:
            valid_boxes = boxes[valid]
            x = valid_boxes[:, 0].numpy()
            sorted_idx = np.argsort(x)
            pred_text = ''.join([IDX_TO_CHAR[pred_cls[valid][i].item()] for i in sorted_idx])
        else:
            pred_text = '(bg)'
    else:
        pred_text = '(none)'
    
    # Plot
    fig, ax = plt.subplots(1, 1, figsize=(12, 3))
    ax.imshow(img)
    for box in valid_boxes.numpy():
        x1, y1, x2, y2 = box
        rect = patches.Rectangle((x1, y1), x2-x1, y2-y1, 
                                fill=False, edgecolor='red', linewidth=2)
        ax.add_patch(rect)
    
    correct = pred_text.lower() == sample['text'].lower()
    color = 'green' if correct else 'red'
    ax.set_title(f'GT: {sample["text"]} | Pred: {pred_text} | {"OK" if correct else "WRONG"}',
                fontsize=11, color=color, fontweight='bold')
    ax.axis('off')
    plt.tight_layout()
    plt.show()
    plt.close()

for i in range(min(10, len(test_dataset))):
    visualize(model, test_dataset, i)
