In [None]:
import torch
import torch.nn as nn
from torchvision import transforms
from PIL import Image
import os

In [None]:
CHARS = list("ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789")
IDX_TO_CHAR = {idx: char for idx, char in enumerate(CHARS)}
SEQ_LENGTH = 6
VOCAB_SIZE = len(CHARS)

In [None]:
class OCRModel(nn.Module):
    def __init__(self):
        super(OCRModel, self).__init__()
        self.cnn = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, padding=1), nn.ReLU(), nn.MaxPool2d(2),
            nn.Conv2d(32, 64, kernel_size=3, padding=1), nn.ReLU(), nn.MaxPool2d(2),
            nn.Conv2d(64, 128, kernel_size=3, padding=1), nn.ReLU(), nn.MaxPool2d(2),
            nn.Conv2d(128, 256, kernel_size=3, padding=1), nn.ReLU(), nn.MaxPool2d(2),
        )
        self.flatten = nn.Flatten()
        self.fc = nn.Linear(256 * 16 * 16, 512)
        self.relu = nn.ReLU()
        self.heads = nn.ModuleList([nn.Linear(512, VOCAB_SIZE) for _ in range(SEQ_LENGTH)])

    def forward(self, x):
        features = self.cnn(x)
        features = self.flatten(features)
        features = self.relu(self.fc(features))
        outputs = [head(features) for head in self.heads]
        return outputs

In [None]:
def predict(model, image_path, device):
    transform = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.ToTensor()
    ])

    image = Image.open(image_path).convert("RGB")
    image = transform(image).unsqueeze(0).to(device)

    model.eval()
    with torch.no_grad():
        outputs = model(image)
        predictions = [torch.argmax(o, dim=1).item() for o in outputs]
        predicted_label = ''.join([IDX_TO_CHAR[idx] for idx in predictions])
    return predicted_label


In [None]:
if __name__ == "__main__":
    model = OCRModel()
    model.load_state_dict(torch.load("ocr_model.pth", map_location="cpu"))
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)

    # 🔥 Test on one image
    test_image_path = "test_image.jpg"  # <-- put your test image here
    prediction = predict(model, test_image_path, device)
    print(f"Prediction: {prediction}")