## Import Packages

In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from cnn.resnet18 import CaptchaCNN as CNN

## Load data

In [None]:
X = np.load("dataset/tixcraft_image_set.npy", mmap_mode='r')  # (num_samples, 1, 100, 120)
Y = np.load("dataset/tixcraft_label_set.npy", mmap_mode='r')  # (num_samples, 4, 26)

X = torch.tensor(X, dtype=torch.float32)
Y = torch.tensor(Y, dtype=torch.float32)

## Dataset

In [None]:
class CaptchaDataset(Dataset):
    def __init__(self, X, Y):
        self.X = X
        self.Y = Y

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        return self.X[idx], self.Y[idx]

## Training

### Hyperparameter

In [None]:
BATCH_SIZE = 128
EPOCHS = 50
LEARNING_RATE = 1e-3
WEIGHT_DECAY = 1e-4

### Data Loader

Split the labeled data into a training set and a validation set

In [None]:
VALIDATION_RATIO = 0.2

validation_size = int(VALIDATION_RATIO * len(X))
training_size = len(X) - validation_size
training_set, validation_set = torch.utils.data.random_split(CaptchaDataset(X, Y), [training_size, validation_size])

training_loader = DataLoader(training_set, batch_size=BATCH_SIZE, shuffle=True)
validation_loader = DataLoader(validation_set, batch_size=BATCH_SIZE, shuffle=False)

### Initialization

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

model = CNN().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)

### Training Function

In [None]:
def train():
    model.train()
    total_loss = 0
    for images, labels in training_loader:
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(images)

        loss = sum(criterion(outputs[:, i, :], labels[:, i, :].argmax(dim=1)) for i in range(4))
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    return total_loss / len(training_loader)

### Validation Function

In [None]:
def validate():
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in validation_loader:
            images, labels = images.to(device), labels.to(device)

            outputs = model(images)
            preds = torch.argmax(outputs, dim=2)
            true_labels = labels.argmax(dim=2)

            correct += (preds == true_labels).all(dim=1).sum().item()
            total += labels.size(0)

    return correct / total

### Run

In [None]:
for epoch in range(EPOCHS):
    training_loss = train()
    validation_acc = validate()

    print(f"Epoch [{epoch+1}/{EPOCHS}] -> Training Loss: {training_loss:.4f} / Validation Acc: {validation_acc:.4f}")

torch.save(model.state_dict(), "ocr_128_model.pth")

print("Model saved!")