In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms, models
from PIL import Image
import sys

In [2]:
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize([0.5, 0.5, 0.5], [0.25, 0.25, 0.25])
])

In [3]:
device = torch.device("cpu")
if torch.cuda.is_available():
    device = torch.device("cuda")
elif torch.backends.mps.is_built() and torch.backends.mps.is_available():
    device = torch.device("mps")
print(device)

cuda


In [4]:
BATCH_SIZE = 8
NUM_CLASSES = 39 # 38 + 1

In [5]:
from torchvision.models import resnet50, ResNet50_Weights

class plate_OCR(nn.Module):
    def __init__(self):
        super().__init__()
        resnet = models.resnet50(weights=ResNet50_Weights.DEFAULT)
        self.features = nn.Sequential( #bruh it was pooling too much
            resnet.conv1, resnet.bn1, resnet.relu,
            resnet.layer1
        ) 
        #self.dimension_reduction = nn.Linear(1024, 512)
        self.linear1 = nn.Linear(in_features=32768, out_features=2048)
        self.rnn = nn.LSTM(input_size=2048, hidden_size=128, num_layers=2, bidirectional=True, batch_first=True)
        self.classify = nn.Linear(128 * 2, NUM_CLASSES)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.features(x) 
        x = self.relu(x) 
        x = x.permute(0, 3, 1, 2)  # (N, W, C, H)
        x = x.reshape(x.size(0), x.size(1), -1)  # (N, W, C*H) 
        x = self.linear1(x)
        rnn, _ = self.rnn(x)
        result = self.classify(rnn)
        result = result.permute(1, 0, 2)
        return F.log_softmax(result, dim=2)

In [6]:
from number_coco import license_coco
from number_coco import license_collate
from torch.utils.data import DataLoader
''' for some reason this line wouldn't work here so reput at the top
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize((128, 128)),
])
'''
X_train = "../data/license_numbers/train/images"
y_train = "../data/license_numbers/train/annotations.json"

train_dataset = license_coco(root=X_train, ann_file=y_train, transforms=transform)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4, prefetch_factor=2, persistent_workers=True, collate_fn=license_collate, drop_last=True)

In [7]:
def greedy_decode(log_probs, blank=0):
    """Greedy-CTC decode a single sample (T, C) → List[int]."""
    best = log_probs.argmax(dim=1)          # (T,)
    decoded, prev = [], blank
    for idx in best.cpu().tolist():
        if idx != prev and idx != blank:
            decoded.append(idx)
        prev = idx
    return decoded

In [8]:
from contextlib import suppress
from itertools import batched


model = plate_OCR().to(device)
ctc_loss = nn.CTCLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
model.train()
global_step = 0

epochs = 5
for epoch in range(epochs):
    global_step += 1

    for images, targets, target_lengths in train_loader:  # adjust to your loader
        try:
            images          = images.to(device)
            targets         = targets.to(device)
            target_lengths  = target_lengths.to(device)

            log_probs = model(images)
            size = log_probs.size()[0]

            input_lengths = torch.full(
                size=(BATCH_SIZE,),
                fill_value=size,
                dtype=torch.long,
                device=log_probs.device
            )

            optimizer.zero_grad(set_to_none=True)
            loss = ctc_loss(
                log_probs,        # (T, N, C)
                targets,          # 1-D concat labels
                input_lengths,    # (N,)
                target_lengths    # (N,)
            )

            if not torch.isfinite(loss):
                raise FloatingPointError(f"loss is {loss}")

            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0)
            optimizer.step()
            if global_step % 100 == 0:
                # look at first sample in the batch
                pred_decoded = greedy_decode(log_probs[:, 0, :])
                true_decoded = targets[: target_lengths[0]].cpu().tolist()
                print(
                    f"step {global_step:>6} | "
                    f"loss {loss.item():.4f} | "
                    f"pred {pred_decoded} | "
                    f"true {true_decoded}"
                )

        except (RuntimeError, ValueError, FloatingPointError) as err:
            print(f"[warn] skipping batch")
            with suppress(Exception):
                torch.cuda.empty_cache()
            continue

torch.save(model.state_dict(), "number_model.pth")


In [9]:
X_test = "../data/license_numbers/test/images"
y_test = "../data/license_numbers/test/annotations.json"

test_dataset = license_coco(root=X_test, ann_file=y_test, transforms=transform)
test_loader = DataLoader(test_dataset, batch_size=4, shuffle=True, num_workers=4, prefetch_factor=2, persistent_workers=True, collate_fn=license_collate)

In [12]:
images, targets, target_lengths = next(iter(train_loader))

images = images.to(device)
targets = targets.to(device)
target_lengths = target_lengths.to(device)
with torch.no_grad():
    output = model(images)
print(targets)
print(greedy_decode(output))

tensor([10, 10,  5, 32, 32, 23, 15, 24, 22,  7,  7, 12, 31,  5, 11,  6,  2, 10,
         7,  3, 35, 24, 12,  6, 10, 10, 22, 22, 33, 21, 23, 33, 32,  7, 24, 31,
         8, 32,  8,  8,  8,  9,  2,  2,  3, 15,  6,  8, 24, 24,  3, 30],
       device='cuda:0')
[[1, 2, 2, 2, 4, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 4, 2, 2, 2, 2, 2, 2, 2, 4, 2, 2, 2, 2, 2, 2, 2, 2, 4, 2, 2, 2], [1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2], [0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2], [1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2], [0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2], [1, 2, 2, 2, 2, 3, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2], [0, 2, 2, 2, 2, 2, 2, 2, 2, 