In [24]:
import torch
import torch.nn as nn
from torchvision import models, transforms
from PIL import Image, ImageOps
import pandas as pd
import os
import pandas as pd
import torchvision.transforms as transforms

In [25]:
# 1. PadToSquare class

class PadToSquare(object):
    def __call__(self, img):
        w, h = img.size
        max_side = max(w, h)
        delta_w = max_side - w
        delta_h = max_side - h
        padding = (delta_w//2, delta_h//2, delta_w - delta_w//2, delta_h - delta_h//2)
        return ImageOps.expand(img, padding, fill=0)


# 2. Transform

transform = transforms.Compose([
    PadToSquare(),
    transforms.Resize((128,128)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5,0.5,0.5], std=[0.5,0.5,0.5])
])

In [26]:
# 3. Load CSV class mapping

df = pd.read_csv('labels.csv')  # columns: filename,label
classes = sorted(df['label'].unique())

In [27]:
# ---------- Load Model ----------

# *********** แก้ไขบรรทัดนี้ ***********
# device = torch.device('mps') 
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# **************************************

num_classes = len(classes)

# เนื่องจากคุณกำลังโหลดไฟล์ resnet18_thai_char_25e.pth
# โครงสร้างโมเดลต้องเป็น ResNet-18 ไม่ใช่ DenseNet121

# *********** แก้ไขบรรทัดนี้ ***********
# model = models.densenet121(pretrained=False)
model = models.resnet18(weights=None) # ต้องสร้างโครงสร้างให้ตรงกับโมเดลที่บันทึก
# **************************************

in_features = model.fc.in_features # เปลี่ยนจาก model.classifier.in_features
model.fc = nn.Linear(in_features, num_classes) # เปลี่ยนจาก model.classifier

# โหลด Weights และย้ายไปที่ Device ที่ถูกต้อง (cuda:0)
# เพิ่ม weights_only=True เพื่อความปลอดภัย (ตามคำแนะนำเดิม)
model.load_state_dict(torch.load('./resnet18_thai_char_25e.pth', map_location=device, weights_only=True))

model.to(device)
model.eval()

print(f"✅ โหลดโมเดล ResNet-18 สำเร็จและพร้อมใช้งานบน {device} แล้ว")


✅ โหลดโมเดล ResNet-18 สำเร็จและพร้อมใช้งานบน cuda:0 แล้ว


In [28]:
# ---------- Define Transform ----------
transform = transforms.Compose([
    PadToSquare(),
    transforms.Resize((128,128)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5,0.5,0.5], std=[0.5,0.5,0.5])
])

In [29]:
# ---------- Inference Function ----------
def predict_image(img_path):
    img = Image.open(img_path).convert('RGB')
    img = transform(img).unsqueeze(0).to(device)
    with torch.no_grad():
        outputs = model(img)
        probs = torch.softmax(outputs, dim=1)
        _, pred = torch.max(probs, 1)
    return pred.item(), probs.squeeze().cpu().numpy()

In [30]:
# ---------- 5. Loop Folder Test ----------
test_root = './test'  # folder test structure: test/class_name/*.jpg
results = []

correct = 0
total = 0

for class_name in os.listdir(test_root):
    class_path = os.path.join(test_root, class_name)
    if not os.path.isdir(class_path):
        print(f"Skipping non-directory item: {class_path}")
        continue
    for img_file in os.listdir(class_path):
        if not img_file.lower().endswith((".jpg",".jpeg",".png")):
            continue
        img_path = os.path.join(class_path, img_file)
        pred_idx, probs = predict_image(img_path)
        pred_class = classes[pred_idx]
        
        # print('pred_class:', type(pred_class))
        # print('class name: ', type(class_name))

        results.append({
            'filename': img_file,
            'true_label': class_name,
            'pred_label': pred_class,
            'confidence': float(probs[pred_idx])
        })

        # # Accuracy
        # total += 1
        # if pred_class == class_name:
        #     correct += 1
        total += 1
        if str(pred_class).strip() == str(class_name).strip():
            correct += 1
            
print(correct)
print(total)

Skipping non-directory item: ./test\.DS_Store
158
226


In [31]:
# ---------- Save Results to CSV ----------
df_results = pd.DataFrame(results)
df_results.to_csv('test_predictions.csv', index=False)
print(f"Saved predictions to test_predictions.csv")

Saved predictions to test_predictions.csv


In [32]:
# ---------- Print Accuracy ----------
accuracy = correct / total * 100
print(f"Accuracy on test folder: {accuracy:.2f}%")

Accuracy on test folder: 69.91%
