In [None]:
# =========================
# ðŸ”¹ TEST SINGLE IMAGE
# =========================

import torch
from torchvision import transforms
from PIL import Image
import os

# ---------------------------- Settings
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
IMAGE_SIZE = 64  # Must match training
img_path = r"D:\DeepTech\data\train\Edge-Ring\image_ER_26011_png_jpg.rf.9f54dd1469519a583f9ec785c2358773.jpg"

# ---------------------------- Load checkpoint
checkpoint = torch.load("cnn_model.pth", map_location=device)
class_names = checkpoint['class_names']

# Define the same model architecture
import torch.nn as nn
import torch.nn.functional as F

class MyCustomCNN(nn.Module):
    def __init__(self, num_classes=len(class_names)):
        super(MyCustomCNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, padding=1)  # grayscale input
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
        self.fc1 = nn.Linear(64 * 16 * 16, 128)
        self.fc2 = nn.Linear(128, num_classes)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# Initialize model and load weights
model = MyCustomCNN().to(device)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

# ---------------------------- Prediction
def predict_single_image(img_path, model, class_names):
    img = Image.open(img_path).convert("L")  # Convert to grayscale
    transform = transforms.Compose([
        transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,))
    ])
    img_tensor = transform(img).unsqueeze(0).to(device)
    with torch.no_grad():
        logits = model(img_tensor)
        probs = torch.softmax(logits, dim=1)[0]
        max_prob, idx = torch.max(probs, 0)
    print(f"Image: {os.path.basename(img_path)}")
    print(f"Predicted Class: {class_names[idx]}")
    print(f"Confidence: {max_prob.item() * 100:.2f}%")

# Run prediction
predict_single_image(img_path, model, class_names)
