<a href="https://colab.research.google.com/github/sestys/aicrowd_captcha/blob/main/captcha_recognition.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# CAPTCHA recognition
[AI Crowd](https://www.aicrowd.com/challenges/ai-blitz-4/problems/captcha)

[CRNN paper](https://arxiv.org/pdf/1507.05717.pdf)

[ocr.pytorch github](https://github.com/courao/ocr.pytorch)



In [None]:
!apt update

In [None]:
!pip install numpy
!pip install pandas
!pip install scikit-learn
!pip install textdistance
!pip install tqdm

In [None]:
### Download data ###
!rm -rf data
!rm -f train.tar.gz test.tar.gz
!mkdir data
!wget https://datasets.aicrowd.com/default/aicrowd-practice-challenges/public/cptcha/v0.1/train.tar.gz
!wget https://datasets.aicrowd.com/default/aicrowd-practice-challenges/public/cptcha/v0.1/test.tar.gz
!wget https://datasets.aicrowd.com/default/aicrowd-practice-challenges/public/cptcha/v0.1/train_info.csv
!wget https://datasets.aicrowd.com/default/aicrowd-practice-challenges/public/cptcha/v0.1/test_info.csv
!mkdir data/train 
!mkdir data/test
!tar -C data/ -xvzf train.tar.gz
!tar -C data/ -xvzf test.tar.gz
!mv train_info.csv data/train_info.csv
!mv test_info.csv data/test_info.csv

In [None]:
import os
import glob

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.models import resnet18

import string
from tqdm.notebook import tqdm
import cv2
from PIL import Image
from sklearn.model_selection import train_test_split
import multiprocessing as mp

In [None]:
cpu_count = mp.cpu_count()
print(cpu_count)

2


In [73]:
config = {'batch_size': 32,
          'epochs': 50,
          'rnn_hidden_size': 256,
          'lr': 0.001,
          'weight_decay': 1e-3,
          'clip_norm': 5,
          }

# 1. Load Data

In [None]:
train_info_path = "data/train_info.csv"
test_info_path = "data/test_info.csv"

train_images_path = "data/train/"
test_images_path = "data/test/"
train_info = pd.read_csv(train_info_path)
test_info = pd.read_csv(test_info_path)

In [None]:
print('Train images:', len(glob.glob(train_images_path + '/*.png')))
print('Test images:', len(glob.glob(test_images_path + '/*.png')))


Train images: 10000
Test images: 5000


In [None]:
def plot_image(img_path):
    img = cv2.imread(img_path)
    plt.imshow(img)

In [None]:
# fig=plt.figure(figsize=(20,20))
# columns = 3
# rows = 3
# for i in range(1, columns*rows +1):
#     img = train_images_path + train_info['filename'][i]
#     label = train_info['label'][i]
#     fig.add_subplot(rows, columns, i)
#     plot_image(img)
#     print(label)
# plt.show()

In [None]:
X_train, X_val= train_test_split(train_info, test_size=0.1, random_state=42)

In [None]:
labels = [x for x in train_info['label']]
labels = ''.join(labels)
letters = sorted(list(set(labels)))
len(letters) # 26 lower case + 26 upper case + 10 digits

62

In [38]:
# Character to idx mapping
vocabulary = letters + ['-']
idx2char = {k:v for k,v in enumerate(vocabulary, start=0)}
char2idx = {k:v for v,k in enumerate(vocabulary, start=0)}
num_characters = len(vocabulary)

# 2. Dataloaders

In [None]:
class CAPTCHADataset(Dataset):
    
    def __init__(self, data_dir, image_infos):
        self.data_dir = data_dir
        self.image_infos = image_infos
        
    def __len__(self):
        return self.image_infos.shape[0]
    
    def __getitem__(self, index):
        filename_label = self.image_infos.iloc[index]
        image_fp = self.data_dir + filename_label['filename']
        image = Image.open(image_fp).convert('RGB')
        image = self.transform(image)
        label = filename_label['label']
        return image, label
    
    def transform(self, image):
        
        transform_ops = transforms.Compose([
            transforms.ToTensor(),
            # transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
        ])
        return transform_ops(image)

In [None]:
trainset = CAPTCHADataset(train_images_path, X_train)
validationset = CAPTCHADataset(train_images_path, X_val)
testset = CAPTCHADataset(test_images_path, test_info)

train_loader = DataLoader(trainset, batch_size=config['batch_size'], num_workers=cpu_count, shuffle=True)
validation_loader = DataLoader(validationset, batch_size=config['batch_size'], num_workers=cpu_count, shuffle=True)
test_loader = DataLoader(testset, batch_size=config['batch_size'], num_workers=cpu_count, shuffle=False)
print(len(train_loader), len(validation_loader), len(test_loader))

282 32 157


In [None]:
image_batch, text_batch = iter(train_loader).next()
print(image_batch.size(), text_batch)

torch.Size([32, 3, 60, 120]) ('O3sx', 'EdJJjZr', '9ePO', 'hwNLh', 'l1wOzMjFv9', 'Dxel', 'Ssi1kJpyT', 'pVzEK4', '6ZSPTKxwEW', '9Baw', 'lXXiV', 'sYKahz2GWg', '4h9kJ7CyT', 'vCw8A', 'bbFGs0VzcA', 'nPCv', 'kC3i', 'BYIrzh', 'qblaAK207', 'hJ5qN5lBb', 'jNH12', '3XSR9', 'sw8co', 'iRKf1XIqs', 'h7KKtIub', 'j0SFPT3jja', 'ZAEX90zeb', 'Xr38S', 'GdiJeJ70YL', '1gLig', '5Yo8wM5zts', 'aBYY')


# 3. Model definition

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cuda


In [55]:
resnet = resnet18(pretrained=True)
resnet = list(resnet.children())[:-1]

In [None]:
resnet

In [34]:
class BidirectionalLSTM(nn.Module):

    def __init__(self, nIn, nHidden, nOut):
        super(BidirectionalLSTM, self).__init__()

        self.rnn = nn.LSTM(nIn, nHidden, bidirectional=True)
        self.embedding = nn.Linear(nHidden * 2, nOut)

    def forward(self, input):
        recurrent, _ = self.rnn(input)
        T, b, h = recurrent.size()
        t_rec = recurrent.view(T * b, h)

        output = self.embedding(t_rec)  # [T * b, nOut]
        output = output.view(T, b, -1)
        return output


In [63]:
class CRNN(nn.Module):

    def __init__(self, num_char, nh, leakyRelu=False):
        super(CRNN, self).__init__()

        self.resnet = nn.Sequential(*resnet)

        # 512x1x16

        self.rnn = nn.Sequential(
            BidirectionalLSTM(512, nh, nh),
            BidirectionalLSTM(nh, nh, num_char))


    def forward(self, input):
        # conv features
        conv = self.resnet(input)
        # print(conv.size())

        b, c, h, w = conv.size()
        # assert h == 1, "the height of conv must be 1"
        conv = conv.squeeze(2)
        conv = conv.permute(2, 0, 1)  # [w, b, c]

        # rnn features
        output = self.rnn(conv)

        return output


In [64]:
def weights_init(m):
    classname = m.__class__.__name__
    if type(m) in [nn.Linear, nn.Conv2d, nn.Conv1d]:
        torch.nn.init.xavier_uniform_(m.weight)
        if m.bias is not None:
            m.bias.data.fill_(0.01)
    elif classname.find('BatchNorm') != -1:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)

In [65]:
crnn = CRNN(num_characters, nh=config['rnn_hidden_size'])
crnn.apply(weights_init)
crnn = crnn.to(device)

In [None]:
crnn

# 4. Loss

In [69]:
criterion = nn.CTCLoss(blank=0)

In [70]:
def encode_text_batch(text_batch):
    
    text_batch_targets_lens = [len(text) for text in text_batch]
    text_batch_targets_lens = torch.IntTensor(text_batch_targets_lens)
    
    text_batch_concat = "".join(text_batch)
    text_batch_targets = [char2idx[c] for c in text_batch_concat]
    text_batch_targets = torch.IntTensor(text_batch_targets)
    
    return text_batch_targets, text_batch_targets_lens

In [71]:
def compute_loss(text_batch, text_batch_logits):
    """
    text_batch: list of strings of length equal to batch size
    text_batch_logits: Tensor of size([T, batch_size, num_classes])
    """
    text_batch_logps = F.log_softmax(text_batch_logits, 2) # [T, batch_size, num_classes]  
    text_batch_logps_lens = torch.full(size=(text_batch_logps.size(1),), 
                                       fill_value=text_batch_logps.size(0), 
                                       dtype=torch.int32).to(device) # [batch_size]  
    text_batch_targets, text_batch_targets_lens = encode_text_batch(text_batch)
    loss = criterion(text_batch_logps, text_batch_targets, text_batch_logps_lens, text_batch_targets_lens)

    return loss

In [72]:
compute_loss(text_batch, text_batch_logits)

tensor(inf, device='cuda:0', grad_fn=<MeanBackward0>)

# 5. Training

In [74]:
optimizer = optim.Adam(crnn.parameters(), lr=config['lr'], weight_decay=config['weight_decay'])
lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, verbose=True, patience=5)

In [75]:
crnn = CRNN(num_characters, nh=config['rnn_hidden_size'])
crnn.apply(weights_init)
crnn = crnn.to(device)

In [78]:
epoch_losses = []
iteration_losses = []
num_updates_epochs = []
for epoch in tqdm(range(1, config['epochs']+1)):
    epoch_loss_list = [] 
    num_updates_epoch = 0
    for image_batch, text_batch in tqdm(train_loader, leave=False):
        optimizer.zero_grad()
        text_batch_logits = crnn(image_batch.to(device))
        loss = compute_loss(text_batch, text_batch_logits)
        iteration_loss = loss.item()

        if np.isnan(iteration_loss) or np.isinf(iteration_loss):
            continue
          
        num_updates_epoch += 1
        iteration_losses.append(iteration_loss)
        epoch_loss_list.append(iteration_loss)
        loss.backward()
        nn.utils.clip_grad_norm_(crnn.parameters(), clip_norm)
        optimizer.step()

    epoch_loss = np.mean(epoch_loss_list)
    print("Epoch:{}    Loss:{}    NumUpdates:{}".format(epoch, epoch_loss, num_updates_epoch))
    epoch_losses.append(epoch_loss)
    num_updates_epochs.append(num_updates_epoch)
    lr_scheduler.step(epoch_loss)

HBox(children=(FloatProgress(value=0.0, max=50.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=282.0), HTML(value='')))

Epoch:1    Loss:nan    NumUpdates:0


  out=out, **kwargs)
  ret = ret.dtype.type(ret / rcount)


HBox(children=(FloatProgress(value=0.0, max=282.0), HTML(value='')))

Epoch:2    Loss:nan    NumUpdates:0


HBox(children=(FloatProgress(value=0.0, max=282.0), HTML(value='')))

Epoch:3    Loss:nan    NumUpdates:0


HBox(children=(FloatProgress(value=0.0, max=282.0), HTML(value='')))

KeyboardInterrupt: ignored