In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms, models
from PIL import Image
import sys

In [2]:
transform = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.ToTensor(),
    transforms.Normalize([0.5, 0.5, 0.5], [0.25, 0.25, 0.25])
])

In [3]:
device = torch.device("cpu")
if torch.cuda.is_available():
    device = torch.device("cuda")
elif torch.backends.mps.is_built() and torch.backends.mps.is_available():
    device = torch.device("mps")
print(device)

cuda


In [4]:
CHARS = "ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
NUM_CLASSES = 38 # 37 + 1 blank
CTC_LABELS = [
    "<BLANK>", "0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "-", 
    "A", "B", "C", "D", "E", "F", "G", "H", "I", "J", "K", "L", "M", "N", 
    "O", "P", "Q", "R", "S", "T", "U", "V", "W", "X", "Y", "Z"
]
COCO_TO_CTC = {
    1: 1, 2: 2, 3: 3, 4: 4, 5: 5, 6: 6, 7: 7, 8: 8, 9: 9, 10: 10,
    11: 11, 12: 12, 13: 13, 14: 14, 15: 15, 16: 16,
    18: 17, 19: 18, 20: 19, 21: 20, 22: 21, 23: 22, 24: 23, 25: 24,
    26: 25, 27: 26, 28: 27, 29: 28, 30: 29, 31: 30, 32: 31, 33: 32,
    34: 33, 35: 34, 36: 35, 37: 36, 38: 37
}
BATCH_SIZE = 4

In [5]:
from torchvision.models import resnet50, ResNet50_Weights

class plate_OCR(nn.Module):
    def __init__(self):
        super().__init__()
        resnet = models.resnet50(weights=ResNet50_Weights.DEFAULT)
        self.features = nn.Sequential( #bruh it was pooling too much
            resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool,
            resnet.layer1, resnet.layer2
        ) 
        #self.dimension_reduction = nn.Linear(1024, 512)
        self.rnn = nn.LSTM(input_size=512, hidden_size=128, num_layers=2, bidirectional=True, batch_first=True)
        self.classify = nn.Linear(128 * 2, NUM_CLASSES)
        with torch.no_grad():
            self.classify.bias.fill_(0.)
            self.classify.bias[0] = -2.0 
        self.height_conv = nn.Conv2d(in_channels=512, out_channels=512, kernel_size=(16, 1), stride=(1, 1), padding=0, bias=False)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        #print(x.shape)
        x = self.features(x) 
        x = self.relu(self.height_conv(x)) 
        x = x.squeeze(2) 
        x = x.permute(0, 2, 1) 
        #print(x.shape)
        #x = self.dimension_reduction(x)
        #print(x.shape)
        rnn, _ = self.rnn(x)
        result = self.classify(rnn)
        return F.log_softmax(result, dim=2)

In [6]:
def decode(result):
    pred = result.argmax(-1).squeeze(0).tolist()
    prev = -1
    output = []
    for p in pred:
        if p != prev and p != len(CHARS):
            output.append(CHARS[p])
        prev = p
    return ''.join(output)

In [7]:
from number_coco import license_coco
from number_coco import license_collate
from torch.utils.data import DataLoader
''' for some reason this line wouldn't work here so reput at the top
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize((128, 128)),
])
'''
X_train = "../data/license_numbers/train/images"
y_train = "../data/license_numbers/train/annotations.json"

train_dataset = license_coco(root=X_train, ann_file=y_train, transforms=transform)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4, prefetch_factor=2, persistent_workers=True, collate_fn=license_collate)

In [8]:
model = plate_OCR().to(device)
ctc_loss = nn.CTCLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

In [9]:
testing_train_loader = iter(train_loader)
predictions = []

In [15]:
images, targets, target_lengths = next(testing_train_loader)
print(targets)

tensor([ 7,  2,  2,  2,  3,  2,  9, 14,  2, 24, 10,  3,  8, 34, 34, 23,  5,  5,
        10, 19, 28, 35, 32, 31, 10, 25,  4])


In [10]:
model.train()

images, targets, target_lengths = next(testing_train_loader)  # adjust to your loader
images = images.to(device)
targets = targets.to(device)
target_lengths = target_lengths.to(device)
optimizer.zero_grad()

log_output = model(images) # [B, T, C]
log_output = log_output.permute(1, 0, 2) #[T, B, C]
# Input lengths: full length for each sequence
input_lengths = torch.full(size=(BATCH_SIZE,), fill_value=log_output.size(0), dtype=torch.long).to(device)

assert all(input_lengths >= target_lengths), "Target sequence too long for CTC"
# Loss computation

loss = ctc_loss(log_output, targets, input_lengths, target_lengths)
print(f"Loss: {loss.item():.4f}")

# Backward and optimize
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 5.0)
optimizer.step()

print(targets[:target_lengths[0]])
predicted = log_output.argmax(dim=2)
sample_pred = predicted[:, 0] 
print(sample_pred)
predictions.append(sample_pred.tolist())

Loss: 11.3916
tensor([27,  5, 19], device='cuda:0')
tensor([12, 13, 22, 34, 34, 34, 34, 34, 34, 34, 34, 34, 34, 34, 34, 34],
       device='cuda:0')


In [11]:
predictions

[[12, 13, 22, 34, 34, 34, 34, 34, 34, 34, 34, 34, 34, 34, 34, 34]]

In [11]:
model = plate_OCR().to(device)
ctc_loss = nn.CTCLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

epochs = 5
for epoch in range(epochs):
    model.train()
    running_loss = 0.0
    epoch_loss   = 0.0
    i = 0

    for images, targets, target_lengths in train_loader:  # adjust to your loader
        images = images.to(device)
        targets = targets.to(device)
        target_lengths = target_lengths.to(device)

        optimizer.zero_grad()
        log_output = model(images) # [B, T, C]
        log_output = log_output.permute(1, 0, 2) #[T, B, C]


        input_lengths = torch.full(size=(log_output.size(1),), fill_value=log_output.size(0), dtype=torch.long).to(device)
        if (target_lengths > input_lengths).any():
            # batch is impossible for CTC → skip
            print("Skipping batch: target length exceeds input length.")
            optimizer.zero_grad(set_to_none=True)  # cheap no-op
            continue    


        loss = ctc_loss(log_output, targets, input_lengths, target_lengths)
        #print(f"Loss: {loss.item():.4f}")

        # Backward and optimize
        loss.backward()
        optimizer.step()
        
        loss_val = loss.item()
        running_loss += loss_val
        epoch_loss   += loss_val
        
        if i % 50 == 0 or i == len(train_loader):
            print(f"[{epoch+1}/{epochs}, {i:5d}] "
                f"loss/50 = {running_loss/50:.4f} | "
                f"LR = {optimizer.param_groups[0]['lr']:.2e}")
            running_loss = 0.0
        i += 1
        
    print(f"Epoch {epoch+1} finished - avg loss: {epoch_loss/len(train_loader):.4f}\n")

torch.save(model.state_dict(), "number_model.pth")


[1/5,     0] loss/50 = 0.1448 | LR = 1.00e-04
[1/5,    50] loss/50 = 9.6155 | LR = 1.00e-04
[1/5,   100] loss/50 = 4.6613 | LR = 1.00e-04
Skipping batch: target length exceeds input length.
[1/5,   150] loss/50 = 3.9393 | LR = 1.00e-04
Skipping batch: target length exceeds input length.
[1/5,   200] loss/50 = 3.8894 | LR = 1.00e-04
Skipping batch: target length exceeds input length.
[1/5,   250] loss/50 = -119063179.5394 | LR = 1.00e-04
[1/5,   300] loss/50 = -309300035.8990 | LR = 1.00e-04
[1/5,   350] loss/50 = 3.9312 | LR = 1.00e-04
[1/5,   400] loss/50 = 3.8367 | LR = 1.00e-04
[1/5,   450] loss/50 = 3.9021 | LR = 1.00e-04
[1/5,   500] loss/50 = 3.9164 | LR = 1.00e-04
[1/5,   550] loss/50 = 3.8232 | LR = 1.00e-04
Skipping batch: target length exceeds input length.
[1/5,   600] loss/50 = 3.8685 | LR = 1.00e-04
[1/5,   650] loss/50 = 3.8925 | LR = 1.00e-04
[1/5,   700] loss/50 = 3.8971 | LR = 1.00e-04
[1/5,   750] loss/50 = 3.8529 | LR = 1.00e-04
Skipping batch: target length exceeds 

In [None]:
X_test = "../data/license_numbers/test/images"
y_test = "../data/license_numbers/test/annotations.json"

test_dataset = license_coco(root=X_test, ann_file=y_test, transforms=transform)
test_loader = DataLoader(test_dataset, batch_size=4, shuffle=True, num_workers=4, prefetch_factor=2, persistent_workers=True, collate_fn=license_collate)

In [None]:
images, targets, target_lengths = next(iter(train_loader))

images = images.to(device)
targets = targets.to(device)
target_lengths = target_lengths.to(device)
with torch.no_grad():
    output = model(images)
print(output)