In [None]:
import os
import torch
import torch.nn as nn
from torchvision import transforms, models
from torchvision.models import ResNet50_Weights
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
from tkinter import Tk, filedialog

def build_model(num_classes, use_pretrained=False, freeze_backbone=False):
    weights = ResNet50_Weights.DEFAULT if use_pretrained else None
    model = models.resnet50(weights=weights)

    if freeze_backbone:
        for param in model.parameters():
            param.requires_grad = False

    in_features = model.fc.in_features
    model.fc = nn.Linear(in_features, num_classes)
    return model

CLASS_NAMES = {
    0: 'Speed limit (20km/h)', 1: 'Speed limit (30km/h)', 2: 'Speed limit (50km/h)',
    3: 'Speed limit (60km/h)', 4: 'Speed limit (70km/h)', 5: 'Speed limit (80km/h)',
    6: 'End of speed limit (80km/h)', 7: 'Speed limit (100km/h)', 8: 'Speed limit (120km/h)',
    9: 'No passing', 10: 'No passing veh > 3.5 tons', 11: 'Right-of-way at intersection',
    12: 'Priority road', 13: 'Yield', 14: 'Stop', 15: 'No vehicles',
    16: 'Veh > 3.5 tons prohibited', 17: 'No entry', 18: 'General caution',
    19: 'Dangerous curve left', 20: 'Dangerous curve right', 21: 'Double curve',
    22: 'Bumpy road', 23: 'Slippery road', 24: 'Road narrows on the right',
    25: 'Road work', 26: 'Traffic signals', 27: 'Pedestrians', 28: 'Children crossing',
    29: 'Bicycles crossing', 30: 'Beware of ice/snow', 31: 'Wild animals crossing',
    32: 'End speed + passing limits', 33: 'Turn right ahead', 34: 'Turn left ahead',
    35: 'Ahead only', 36: 'Go straight or right', 37: 'Go straight or left',
    38: 'Keep right', 39: 'Keep left', 40: 'Roundabout mandatory',
    41: 'End of no passing', 42: 'End no passing veh > 3.5 tons'
}

MODEL_PATH = "resnet50_gtsrb.pth"
IMG_SIZE = 224
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
num_classes = 43

print("üîÑ Loading model...")
model = build_model(num_classes, use_pretrained=False, freeze_backbone=False)

try:
    checkpoint = torch.load(MODEL_PATH, map_location=device, weights_only=False)
    model.load_state_dict(checkpoint["model_state"])
    model = model.to(device)
    model.eval()
    print(f"‚úÖ Model loaded! Accuracy: {checkpoint['val_acc']:.2%} | Device: {device}\n")
except Exception as e:
    print(f"‚ùå Error loading model: {e}")

transform = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

def predict_image(image_path):
    img = Image.open(image_path).convert('RGB')
    tensor = transform(img).unsqueeze(0).to(device)

    with torch.no_grad():
        outputs = model(tensor)
        probs = torch.softmax(outputs, dim=1)[0]
        conf, pred = torch.max(probs, dim=0)

    return pred.item(), conf.item(), probs.cpu().numpy(), img

def run_predict_notebook(image_path):
    pred_class, confidence, all_probs, img = predict_image(image_path)
    class_name = CLASS_NAMES[pred_class]

    print("\n" + "="*70)
    print(f"üìÅ File: {os.path.basename(image_path)}")
    print(f"üéØ Predicted: {class_name} (Class ID: {pred_class})")
    print(f"üìä Confidence: {confidence:.2%}")
    print("="*70)
    top5_idx = np.argsort(all_probs)[-5:][::-1]
    top5_probs = all_probs[top5_idx]
    top5_names = [CLASS_NAMES[i] for i in top5_idx]

    fig, axes = plt.subplots(1, 2, figsize=(14, 6))
    axes[0].imshow(img)
    axes[0].axis('off')
    axes[0].set_title("Input Image", fontsize=14, fontweight='bold')

    axes[1].barh(range(5), top5_probs[::-1])
    axes[1].set_yticks(range(5))
    axes[1].set_yticklabels(top5_names[::-1])
    axes[1].set_xlabel('Confidence')
    axes[1].set_title('Top 5 Predictions')
    plt.suptitle(f'üéØ {class_name} ({confidence:.1%})', fontsize=16, fontweight='bold')
    plt.tight_layout()
    plt.show()

def choose_and_predict():
    root = Tk()
    root.withdraw()
    file_path = filedialog.askopenfilename(
        title="Ch·ªçn ·∫£nh bi·ªÉn b√°o giao th√¥ng",
        filetypes=[("Image files", "*.jpg *.jpeg *.png *.bmp *.gif *.ppm")]
    )
    if not file_path:
        print("‚ùå Kh√¥ng ch·ªçn ·∫£nh n√†o. Tho√°t.")
        return
    run_predict_notebook(file_path)

# G·ªçi h√†m
choose_and_predict()


üîÑ Loading model...
‚úÖ Model loaded! Accuracy: 99.83% | Device: cpu

