# Pre-training
Ensure that:
- You Already have captchas separated into individual images named with \<char>\*.png or \<char>\*.jpg
  - e.g. 0-0123.png, a-abcd.jpg ... z-z69420.png
- You Have your pipeline to segment ready
  - Know where your pipeline sends your individual images to. (e.g. ./data/segmented)

## Sample Segmentor
Below is a segmentor that uses a basic kmeans and sharp drop-off heuristic to create individual images. Feel free to replace

In [None]:
# imports
import os
import cv2
import numpy as np
import torch
from sklearn.cluster import KMeans
from PIL import Image
from torch.utils.data import Dataset
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

In [None]:
def save_color_clusters(roi, k, out_dir="clusters", pad=3, min_pixels=30, captcha_name="unknown", idx=0):
    os.makedirs(out_dir, exist_ok=True)
    h, w, _ = roi.shape

    # preprocess black interference lines by inpainting
    orig = roi.copy()
    gray = cv2.cvtColor(roi, cv2.COLOR_BGR2GRAY)
    mask = (gray < 10).astype(np.uint8) * 255
    if cv2.countNonZero(mask) > 0:
        roi = cv2.inpaint(roi, mask, 3, cv2.INPAINT_TELEA)

    pixels = roi.reshape(-1, 3)
    kmeans = KMeans(n_clusters=k+1, n_init=10, random_state=42)
    labels = kmeans.fit_predict(pixels)
    centers = kmeans.cluster_centers_
    labels_img = labels.reshape(h, w)

    for i in range(k+1):
        # identify the whitest / lowest-variance cluster as background
        brightness = centers.mean(axis=1)
        variances = pixels.var(axis=0).mean()
        bg_idx = np.argmax(brightness)  # simple but usually correct

        # this skips background
        if i == bg_idx:
            continue

        # this skips the thin black lines and noise
        mask = (labels_img == i).astype(np.uint8) * 255
        if cv2.countNonZero(mask) < min_pixels:
            continue

        ys, xs = np.where(mask > 0)
        if len(xs) == 0 or len(ys) == 0:
            continue
        x1, x2 = xs.min(), xs.max()
        y1, y2 = ys.min(), ys.max()
        x1 = max(0, x1 - pad)
        y1 = max(0, y1 - pad)
        x2 = min(w - 1, x2 + pad)
        y2 = min(h - 1, y2 + pad)

        cropped = roi[y1:y2+1, x1:x2+1].copy()
        cropped_mask = mask[y1:y2+1, x1:x2+1]
        cropped[~(cropped_mask.astype(bool))] = 255

        counter = 0 # for unique filenames
        filename = os.path.join(out_dir, f"{captcha_name[idx]}_{captcha_name}.png")
        while True:
            if not os.path.exists(filename):
                break
            else:
                counter += 1
                filename = os.path.join(out_dir, f"{captcha_name[idx]}_{captcha_name}_{counter}.png")

        cv2.imwrite(filename, cropped)
        idx += 1
        print(f"Saved cluster {i}: {filename}")

    print("Done.")

In [16]:
def segmentor(data_path, out_path):
  if not os.path.exists(data_path):
    raise ValueError(f"Data path {data_path} does not exist.")
  if not os.path.exists(out_path):
    os.makedirs(out_path)
  else:
    for f in os.listdir(out_path):
      os.remove(os.path.join(out_path, f))

  for root, dir, files in os.walk(data_path):
    for file in files:
      if file.endswith('.png') or file.endswith('.jpg'):
        bounding_boxes = []
        captcha_name = os.path.splitext(file)[0]
        print(f"Processing file: {captcha_name}")
        captcha_chars = [c for c in captcha_name]
        index = 0
        rois_to_analyse = []
        # threshold letters and find contours
        img_path = os.path.join(root, file)
        img = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)
        mask = (img < 5).astype(np.uint8) * 255
        img = cv2.inpaint(img, mask, 3, cv2.INPAINT_TELEA)
        ret, thresh = cv2.threshold(img, 250, 255, cv2.THRESH_BINARY_INV)
        contours, hierarchy = cv2.findContours(thresh, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)

        for i, c in enumerate(contours):
          parent = hierarchy[0][i][3]
          if parent != -1:  # Skip if it has a parent contour    
              continue
          x, y, w, h = cv2.boundingRect(c)
          bounding_boxes.append((x, y, w, h))
        bounding_boxes = sorted(bounding_boxes, key=lambda box: box[0])

        for b in bounding_boxes:
          x, y, w, h = b
          # for this local scope, COLOR is important
          img = cv2.imread(img_path)
          roi = img[y:y + h, x:x + w]
          # try and see how many unique colors are in this roi
          roi_rgb = roi.reshape(-1, 3)
          colors, counts = np.unique(roi_rgb, axis=0, return_counts=True)
          print(f"Unique colors: {len(colors)}")
          # sort by leftmost pixel
          top_color_counts = sorted(zip(counts, colors), key=lambda x: x[0], reverse=True)
          usable_colors = []
          usable_counts = []
          for count, color in top_color_counts[:8]:
              if np.linalg.norm(color - 255) > 5 and np.linalg.norm(color) > 5 and count > w * h * 0.001:
                  usable_colors.append(color)
                  usable_counts.append(count)

          # safety check: If there's only the background color or less, skip
          if len(usable_counts) <= 1:
              continue
          
          ratios = usable_counts[:-1] / (np.array(usable_counts[1:]) + 1e-5)  # avoid div by 0
          est_k = np.argmax(ratios) + 1
          rois_to_analyse.append((roi, est_k, index))
          index += est_k

        if index != len(captcha_chars) - 2: # because the files are current appended with -0
          print(f"Warning: Expected {len(captcha_chars) - 2} chars but got {index} segments.")
        else:
          for roi, k, idx in rois_to_analyse:
              save_color_clusters(roi, k, out_dir=out_path, captcha_name=captcha_name, idx=idx)

segmentor('../data', '../data/segmented')

Processing file: 002k-0
Unique colors: 70
Unique colors: 72
Unique colors: 77
Unique colors: 46
Processing file: 006aguv-0
Unique colors: 43
Unique colors: 43
Unique colors: 66
Unique colors: 46
Unique colors: 59
Unique colors: 51
Unique colors: 42
Saved cluster 1: ../data/segmented\0_006aguv-0.png
Done.
Saved cluster 1: ../data/segmented\0_006aguv-0_1.png
Done.
Saved cluster 0: ../data/segmented\6_006aguv-0.png
Done.
Saved cluster 1: ../data/segmented\a_006aguv-0.png
Done.
Saved cluster 1: ../data/segmented\g_006aguv-0.png
Done.
Saved cluster 1: ../data/segmented\u_006aguv-0.png
Done.
Saved cluster 1: ../data/segmented\v_006aguv-0.png
Done.
Processing file: 00fh-0
Unique colors: 24
Unique colors: 21
Unique colors: 33
Unique colors: 25
Unique colors: 31
Processing file: 00hai-0
Unique colors: 29
Unique colors: 29
Unique colors: 4
Unique colors: 41
Unique colors: 3
Processing file: 00hgi3n7-0
Unique colors: 66
Unique colors: 69
Unique colors: 32
Unique colors: 74
Unique colors: 14
Uniqu

KeyboardInterrupt: 

## Balanced Class Selector 
Get 100 of each letter for fair test

### Method
This recursively searches a data folder for N paths of each letter and writes them into a text file. <br><br>
Class is *determined by the FIRST char of the filename* excluding the parent folders. <br><br>
After getting the filepaths, feed them downstream to a Dataset 

In [None]:

folder_name = "../data/segmented"
train = "../data/train_balanced.txt"
test = "../data/test_balanced.txt"

def get_class_from_filename(filename):
    base = os.path.basename(filename)
    return base[0].lower()  # Assuming the first character indicates the class

if not os.path.exists(folder_name):
    raise ValueError(f"Folder {folder_name} does not exist.")

# recursively go through folder and n of each alphanumeric character
def create_balanced_file_list(file, n=100):
    with open(file, 'w') as f_out:
        for root, dirs, files in os.walk(folder_name):
            class_counts = {}
            for file in files:
                if not (file.endswith('.png') or file.endswith('.jpg')):
                    continue

                class_char = get_class_from_filename(file)
                if class_char not in class_counts:
                    class_counts[class_char] = 0
                if class_counts[class_char] < n:
                    full_path = os.path.join(root, file)
                    f_out.write(full_path + '\n')
                    class_counts[class_char] += 1

create_balanced_file_list(train)
create_balanced_file_list(test, n=20)

In [None]:
# run if for some reason torch isn't installed (cpu)
%pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu

## Create Pytorch Dataset
To Create the Dataset we read the train and test filepaths. Similarly, Class Label is determined by first char of base filename.

In [None]:
import logging
 
chars = ["a","b","c","d","e","f","g","h","i","j","k","l","m","n","o","p","q","r","s","t","u","v","w","x","y","z",
         "0","1","2","3","4","5","6","7","8","9"]
char_to_class = {c: i for i, c in enumerate(chars)}

class ImageDatasetFromTextFile(Dataset):
  def __init__(self, text_file, transform=None):
    self.image_files = []
    self.labels = []
    self.transform = transform
    
    with open(text_file, 'r') as f:
      for line in f:
        img_path = line.strip()
        if not img_path:
          continue
        alphanumeric_char = os.path.basename(img_path)[0].lower()  # The first char of the filename but not including the folders
        label = char_to_class[alphanumeric_char]
        self.image_files.append(img_path)
        self.labels.append(label)

    if len(self.image_files) == 0:
        logging.warning(f"Dataset is empty after reading file: {text_file}")
    else:
        print(f"Loaded {len(self.image_files)} images from {text_file}")
  
  def __len__(self):
    return len(self.image_files)
  
  def __getitem__(self, idx):
    img_path = self.image_files[idx]
    image = Image.open(img_path).convert('L')  # Convert to grayscale
    label = self.labels[idx]
    
    if self.transform:
      image = self.transform(image)
    
    return image, label

In [None]:
my_transform = transforms.Compose([
    transforms.Resize((32, 32)),
    transforms.RandomAffine(degrees=30, translate=(0.1, 0.1)),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

train_dataset = ImageDatasetFromTextFile(train, transform=my_transform)
test_dataset = ImageDatasetFromTextFile(test, transform=my_transform)

In [None]:
# most basic CNN model
class CNN(nn.Module):
  def __init__(self):
    super(CNN, self).__init__()
    self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)
    self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
    self.fc1 = nn.Linear(64 * 7 * 7, 128)
    self.fc2 = nn.Linear(128, 36)  # Assuming 36 classes (0-9, A-Z)
    self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
    self.relu = nn.ReLU()
    self.dropout = nn.Dropout(0.25)

  def forward(self, x):
    x = self.pool(self.relu(self.conv1(x)))
    x = self.pool(self.relu(self.conv2(x)))
    x = x.view(-1, 64 * 7 * 7)
    x = self.relu(self.fc1(x))
    x = self.dropout(x)
    x = self.fc2(x)
    return x

In [None]:
from torchvision.models import resnet18

resnet_model = resnet18(weights=None)
resnet_model.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
resnet_model.fc = nn.Linear(resnet_model.fc.in_features, 36)  # Assuming 36 classes (0-9, A-Z)

In [None]:
# try out the models
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = resnet_model.to(device)
# model = CNN().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

if len(train_dataset) == 0:
    logging.error("Train dataset is empty! Check your train_txt paths.")
if len(test_dataset) == 0:
    logging.error("Test dataset is empty! Check your test_txt paths.")

num_epochs = 5
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)
        
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
    
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(train_loader):.4f}")

# Evaluate the model
model.eval()
