In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchsummary import summary

class small_basic_block(nn.Module):
    def __init__(self, ch_in, ch_out):
        super(small_basic_block, self).__init__()
        q_ch_out = ch_out // 4
        self.block = nn.Sequential(
            nn.Conv2d(ch_in, q_ch_out, kernel_size=1),
            nn.ReLU(),
            nn.Conv2d(q_ch_out, q_ch_out, kernel_size=(3, 1), padding=(1, 0)),
            nn.ReLU(),
            nn.Conv2d(q_ch_out, q_ch_out, kernel_size=(1, 3), padding=(0, 1)),
            nn.ReLU(),
            nn.Conv2d(q_ch_out, ch_out, kernel_size=1),
        )
    def forward(self, x):
        return self.block(x)


class LPRNet(nn.Module):
    def __init__(self, lpr_max_len, class_num, dropout_rate=0.5):
        super(LPRNet, self).__init__()
        self.lpr_max_len = lpr_max_len
        self.class_num = class_num

        self.backbone = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, stride=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),  # 2
            nn.MaxPool3d(kernel_size=(1, 3, 3), stride=(1, 1, 1)),
            small_basic_block(ch_in=64, ch_out=128),  # [-1, 128, 20, 90]
            nn.BatchNorm2d(num_features=128),
            nn.ReLU(),  # 6
            nn.MaxPool3d(kernel_size=(1, 3, 3), stride=(2, 1, 2)),
            small_basic_block(ch_in=64, ch_out=256),
            nn.BatchNorm2d(num_features=256),
            nn.ReLU(),
            small_basic_block(ch_in=256, ch_out=256),
            nn.BatchNorm2d(num_features=256),
            nn.ReLU(),  # 13
            nn.MaxPool3d(kernel_size=(1, 3, 3), stride=(4, 1, 2)),
            nn.Dropout(dropout_rate),
            nn.Conv2d(in_channels=64, out_channels=256, kernel_size=(1, 4), stride=1),
            nn.BatchNorm2d(num_features=256),
            nn.ReLU(),
            nn.Dropout(dropout_rate),
            nn.Conv2d(in_channels=256, out_channels=class_num, kernel_size=(13, 1), stride=1),
            nn.BatchNorm2d(num_features=class_num),
            nn.ReLU(),  # 22
        )

        self.container = nn.Sequential(
            nn.Conv2d(in_channels=448 + self.class_num,
                      out_channels=self.class_num,
                      kernel_size=(1, 1), stride=(1, 1)),
        )

        self.localization = nn.Sequential(
            nn.Conv2d(3, 8, kernel_size=7),
            nn.MaxPool2d(2, stride=2),
            nn.ReLU(True),
            nn.Conv2d(8, 10, kernel_size=5),
            nn.MaxPool2d(2, stride=2),
            nn.ReLU(True)
        )

        # Regressor for the 3 * 2 affine matrix
        self.fc_loc = nn.Sequential(
            nn.Linear(10 * 2 * 20, 32),
            nn.ReLU(True),
            nn.Linear(32, 3 * 2)
        )

        # Initialize the weights/bias with identity transformation
        self.fc_loc[2].weight.data.zero_()
        self.fc_loc[2].bias.data.copy_(torch.tensor([1, 0, 0, 0, 1, 0], dtype=torch.float))

    # Spatial transformer network forward function
    def stn(self, x):
        # x = torch.Size([2, 1, 24, 94])
        xs = self.localization(x)
        xs = xs.view(xs.size(0), -1) # torch.Size([2, 10, 2, 20])
        theta = self.fc_loc(xs)  # torch.Size([2, 6])
        theta = theta.view(theta.size(0), 2, 3)

        grid = F.affine_grid(theta, x.size(), align_corners=True)
        x = F.grid_sample(x, grid, align_corners=True)

        return x

    def forward(self, x):
        # transform the input
        x = self.stn(x)

        keep_features = list()
        for i, layer in enumerate(self.backbone.children()):
            x = layer(x)
            if i in [2, 6, 13, 22]:
                keep_features.append(x)

        global_context = list()
        for i, f in enumerate(keep_features):
            if i in [0, 1]:
                f = nn.AvgPool2d(kernel_size=5, stride=5)(f)
            if i in [2]:
                f = nn.AvgPool2d(kernel_size=(4, 10), stride=(4, 2))(f)
            f_pow = torch.pow(f, 2)
            f_mean = torch.mean(f_pow)
            f = torch.div(f, f_mean)
            global_context.append(f)

        x = torch.cat(global_context, 1)
        x = self.container(x)
        logits = torch.mean(x, dim=2)
        return logits

    def show_num_layer(self):
        for i, layer in enumerate(self.backbone.children()):
            print(f"{i}: {layer}")


def build_lprnet(lpr_max_len=10, class_num=37, dropout_rate=0.5):
    return LPRNet(lpr_max_len, class_num, dropout_rate)

# Define DEVICE before building the model
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = build_lprnet().to(DEVICE) # Move model to DEVICE
summary(model, input_size=(3, 24, 94), device=DEVICE.type) # Specify device for summary

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1            [-1, 8, 18, 88]           1,184
         MaxPool2d-2             [-1, 8, 9, 44]               0
              ReLU-3             [-1, 8, 9, 44]               0
            Conv2d-4            [-1, 10, 5, 40]           2,010
         MaxPool2d-5            [-1, 10, 2, 20]               0
              ReLU-6            [-1, 10, 2, 20]               0
            Linear-7                   [-1, 32]          12,832
              ReLU-8                   [-1, 32]               0
            Linear-9                    [-1, 6]             198
           Conv2d-10           [-1, 64, 22, 92]           1,792
      BatchNorm2d-11           [-1, 64, 22, 92]             128
             ReLU-12           [-1, 64, 22, 92]               0
        MaxPool3d-13           [-1, 64, 20, 90]               0
           Conv2d-14           [-1, 32,

In [3]:
import os
from PIL import Image
import torch
from torch.utils.data import Dataset
from torchvision import transforms
import string

class ImageFolderCTCDataset(Dataset):
    def __init__(self, folder_path, image_shape=(3, 24, 94), augment=False):
        self.folder_path = folder_path
        self.image_files = os.listdir(folder_path)
        _, height, width = image_shape

        if augment:
            self.transform = transforms.Compose([
                transforms.Resize((height, width)),
                transforms.RandomAffine(
                    degrees=5,
                    translate=(0.05, 0.05),
                    scale=(0.9, 1.1),
                    shear=0
                ),
                transforms.ToTensor(),  # RGB tensor in [0,1], shape (3,H,W)
            ])
        else:
            self.transform = transforms.Compose([
                transforms.Resize((height, width)),
                transforms.ToTensor(),  # RGB tensor in [0,1], shape (3,H,W)
            ])

        # dictionary build
        self.chars = list(string.digits + string.ascii_uppercase)
        self.chars.append('-')
        self.char_to_idx = {ch: i for i, ch in enumerate(self.chars)}
        self.idx_to_char = {i: ch for i, ch in enumerate(self.chars)}

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        img_name = self.image_files[idx]
        img_path = os.path.join(self.folder_path, img_name)

        # Load as RGB
        image = Image.open(img_path).convert("RGB")
        image = self.transform(image)

        # Extract label
        label_str = img_name.split('.')[0].split('_')[0]
        label_encoded = [self.char_to_idx[ch] for ch in label_str]

        return image, torch.tensor(label_encoded, dtype=torch.long), len(label_encoded)


In [5]:
IMAGE_SHAPE = (3, 24, 94)
data_set_folder = r"/content/drive/MyDrive/Colab Notebooks/lprnet_split_data"
train_ds = ImageFolderCTCDataset(data_set_folder + "/train", image_shape=IMAGE_SHAPE, augment=True)
val_ds = ImageFolderCTCDataset(data_set_folder + "/val", image_shape=IMAGE_SHAPE)
test_ds = ImageFolderCTCDataset(data_set_folder + "/test", image_shape=IMAGE_SHAPE)

img, label_encoded, label_length = train_ds[0]

print(img.shape)           # torch.Size([3, 24, 94])
print(label_encoded)       # tensor([ 0,  0, 21, 17,  2,  8,  7,  7])
print(label_length)        # 8



torch.Size([3, 24, 94])
tensor([10, 12,  1,  7,  5,  8])
6


In [6]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader

# -------- Config ----------
CLASS_NUM = 37              # number of classes (0-9, A-Z, plus blank)
MAX_LABEL_LEN = 10          # max characters per sample (fake, for model design)
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
LR = 1e-3
EPOCHS = 100
BATCH_SIZE = 32
# --------------------------

# ===== Dataset & DataLoader =====
def ctc_collate_fn(batch):
    images, labels, lengths = zip(*batch)
    images = torch.stack(images, dim=0)
    labels = torch.cat(labels)
    lengths = torch.tensor(lengths, dtype=torch.long)
    return images, labels, lengths

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, collate_fn=ctc_collate_fn)
val_loader   = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, collate_fn=ctc_collate_fn)
test_loader  = DataLoader(test_ds,  batch_size=BATCH_SIZE, shuffle=False, collate_fn=ctc_collate_fn)

# check a batch
images, labels, lengths = next(iter(train_loader))
print(images.shape)   # [32, 3, 24, 94]
print(labels.shape)   # flat 1D tensor, e.g. torch.Size([180])
print(lengths.shape)  # [32], lengths of each label


torch.Size([32, 3, 24, 94])
torch.Size([192])
torch.Size([32])


In [8]:

# ===== Model, Loss, Optimizer =====
model = build_lprnet(MAX_LABEL_LEN, CLASS_NUM).to(DEVICE)
model.load_state_dict(torch.load("chinese_lprnet_best.pth", map_location=DEVICE))

criterion = nn.CTCLoss(blank=train_ds.char_to_idx['-'], reduction='mean')
optimizer = optim.Adam(model.parameters(), lr=LR)


# ===== Training & Validation =====
def train_one_epoch(epoch):
    model.train()
    running_loss = 0.0

    for batch_idx, (images, labels, target_lengths) in enumerate(train_loader):
        images = images.to(DEVICE)
        labels = labels.to(DEVICE)
        target_lengths = target_lengths.to(DEVICE)

        # Forward
        logits = model(images)                     # [N, C, T]
        logits = logits.permute(2, 0, 1)           # [T, N, C]
        log_probs = logits.log_softmax(2)

        # Input lengths = all T
        input_lengths = torch.full(size=(images.size(0),),
                                   fill_value=logits.size(0),
                                   dtype=torch.long).to(DEVICE)

        # Loss
        loss = criterion(log_probs, labels, input_lengths, target_lengths)

        # Backward
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

        if (batch_idx+1) % 10 == 0:
            print(f"Epoch [{epoch+1}], Step [{batch_idx+1}/{len(train_loader)}], "
                  f"Loss: {running_loss/10:.4f}")
            running_loss = 0.0


def validate(epoch):
    model.eval()
    val_loss = 0.0

    with torch.no_grad():
        for images, labels, target_lengths in val_loader:
            images = images.to(DEVICE)
            labels = labels.to(DEVICE)
            target_lengths = target_lengths.to(DEVICE)

            logits = model(images)
            logits = logits.permute(2, 0, 1)
            log_probs = logits.log_softmax(2)

            input_lengths = torch.full(size=(images.size(0),),
                                       fill_value=logits.size(0),
                                       dtype=torch.long).to(DEVICE)

            loss = criterion(log_probs, labels, input_lengths, target_lengths)
            val_loss += loss.item()

    val_loss /= len(val_loader)
    print(f"Epoch [{epoch+1}] Validation Loss: {val_loss:.4f}")
    return val_loss

In [9]:
# ===== Main Loop =====
best_val_loss = float("inf")

for epoch in range(EPOCHS):
    train_one_epoch(epoch)
    val_loss = validate(epoch)

    # Save best model
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(model.state_dict(), f"chinese_lprnet_best_{str(best_val_loss).replace('.', '_')}.pth")
        print(f"✅ Saved best model at epoch {epoch+1} with val_loss={val_loss:.4f}")


Epoch [1], Step [10/25], Loss: 4.5267
Epoch [1], Step [20/25], Loss: 3.4087
Epoch [1] Validation Loss: 23.6184
✅ Saved best model at epoch 1 with val_loss=23.6184
Epoch [2], Step [10/25], Loss: 3.0917
Epoch [2], Step [20/25], Loss: 3.0818
Epoch [2] Validation Loss: 7.2242
✅ Saved best model at epoch 2 with val_loss=7.2242
Epoch [3], Step [10/25], Loss: 2.9596
Epoch [3], Step [20/25], Loss: 2.9706
Epoch [3] Validation Loss: 3.1290
✅ Saved best model at epoch 3 with val_loss=3.1290
Epoch [4], Step [10/25], Loss: 2.8922
Epoch [4], Step [20/25], Loss: 2.8405
Epoch [4] Validation Loss: 2.9501
✅ Saved best model at epoch 4 with val_loss=2.9501
Epoch [5], Step [10/25], Loss: 2.7829
Epoch [5], Step [20/25], Loss: 2.7819
Epoch [5] Validation Loss: 2.9520
Epoch [6], Step [10/25], Loss: 2.6637
Epoch [6], Step [20/25], Loss: 2.5567
Epoch [6] Validation Loss: 2.7913
✅ Saved best model at epoch 6 with val_loss=2.7913
Epoch [7], Step [10/25], Loss: 2.3956
Epoch [7], Step [20/25], Loss: 2.2667
Epoch [

In [11]:
def validate_test():
    model.eval()
    val_loss = 0.0

    with torch.no_grad():
        for images, labels, target_lengths in test_loader:
            images = images.to(DEVICE)
            labels = labels.to(DEVICE)
            target_lengths = target_lengths.to(DEVICE)

            logits = model(images)
            logits = logits.permute(2, 0, 1)
            log_probs = logits.log_softmax(2)

            input_lengths = torch.full(size=(images.size(0),),
                                       fill_value=logits.size(0),
                                       dtype=torch.long).to(DEVICE)

            loss = criterion(log_probs, labels, input_lengths, target_lengths)
            val_loss += loss.item()

    val_loss /= len(val_loader)
    return val_loss

print(f"Test Loss: {validate_test():.4f}")

Test Loss: 0.0801


In [12]:
def greedy_decode(logits, idx_to_char, blank_idx):
    """
    logits: [T, N, C] tensor (log probs or raw logits)
    idx_to_char: dictionary mapping int -> char
    blank_idx: index of blank symbol
    """
    preds = logits.argmax(2).permute(1, 0)   # [N, T]

    results = []
    for pred in preds:
        string = ""
        prev = None
        for p in pred.cpu().numpy():
            if p != prev and p != blank_idx:   # collapse + remove blank
                string += idx_to_char[p]
            prev = p
        results.append(string)
    return results

def predict_image(model, image, dataset):
    """
    model   : trained model
    image   : tensor [3, 24, 94]
    dataset : dataset object (for idx_to_char, blank index)
    """
    model.eval()
    with torch.no_grad():
        image = image.unsqueeze(0).to(DEVICE)        # add batch dim [1, 3, 24, 94]
        logits = model(image)                        # [N, C, T]
        logits = logits.permute(2, 0, 1)             # [T, N, C]

        preds = greedy_decode(logits, dataset.idx_to_char, dataset.char_to_idx['-'])
        return preds[0]



In [13]:
# Take a sample from your test set
i = 0
img, _, _ = test_ds[i]
print("Ground truth:", test_ds.image_files[i])

pred = predict_image(model, img, test_ds)
print("Prediction  :", pred)


Ground truth: AY021B.jpg
Prediction  : AY021B


In [16]:
import os
import cv2
import torch
import numpy as np # Import numpy

def evaluate_and_save(model, dataset, output_dir="results"):
    model.eval()
    correct = 0
    total = len(dataset)

    # Prepare output folders
    right_dir = os.path.join(output_dir, "right")
    wrong_dir = os.path.join(output_dir, "wrong")
    os.makedirs(right_dir, exist_ok=True)
    os.makedirs(wrong_dir, exist_ok=True)

    with torch.no_grad():
        for i in range(total):
            img, label_encoded, _ = dataset[i]

            # Ground truth string
            label_str = "".join(dataset.idx_to_char[idx.item()] for idx in label_encoded)

            # Prediction
            pred_str = predict_image(model, img, dataset)

            # Convert tensor -> numpy for saving
            if isinstance(img, torch.Tensor):
                np_img = img.squeeze().cpu().numpy() * 255.0
                np_img = np.transpose(np_img, (1, 2, 0)) # Transpose dimensions to (height, width, channels)
                np_img = np_img.astype("uint8")
            else:
                np_img = img

            # Build filename: prediction_GT_index.png
            filename = f"{pred_str}_GT-{label_str}_{i}.png"

            if pred_str == label_str:
                correct += 1
                save_path = os.path.join(right_dir, filename)
            else:
                save_path = os.path.join(wrong_dir, filename)

            # Save with OpenCV
            cv2.imwrite(save_path, np_img)

    accuracy = correct / total
    return accuracy


# Run

test_acc = evaluate_and_save(model, test_ds, output_dir="predictions")
print(f"Test Accuracy: {test_acc:.2%}")

Test Accuracy: 89.00%


In [17]:
!zip -r /content/predictions.zip /content/predictions

  adding: content/predictions/ (stored 0%)
  adding: content/predictions/wrong/ (stored 0%)
  adding: content/predictions/wrong/AB8888_GT-AB88B8_59.png (stored 0%)
  adding: content/predictions/wrong/AE0178_GT-AED178_27.png (stored 0%)
  adding: content/predictions/wrong/A11111_GT-AJ0J22_46.png (deflated 1%)
  adding: content/predictions/wrong/AKR0356_GT-KR0356_94.png (stored 0%)
  adding: content/predictions/wrong/AY0083_GT-AY0D83_49.png (stored 0%)
  adding: content/predictions/wrong/AP0992_GT-AP099X_82.png (stored 0%)
  adding: content/predictions/wrong/APB2153_GT-PB2153_44.png (stored 0%)
  adding: content/predictions/wrong/AW1V2U_GT-AW112U_48.png (stored 0%)
  adding: content/predictions/wrong/KL2283_GT-KZ2283_86.png (stored 0%)
  adding: content/predictions/wrong/AHL28K_GT-AH128K_57.png (stored 0%)
  adding: content/predictions/wrong/A30CJ56_GT-A30J56_47.png (stored 0%)
  adding: content/predictions/right/ (stored 0%)
  adding: content/predictions/right/AER805_GT-AER805_81.png (s