In [1]:
import torch
import torch.nn as nn
from torchvision import models, transforms
from PIL import Image
import os

In [2]:
transform = transforms.Compose(
    [
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(
            mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
        ),
    ]
)

In [None]:
model = models.resnet18()
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 1)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_path = "best_model.pth"
model.load_state_dict(torch.load(model_path, map_location=device))
model = model.to(device)
model.eval()

In [None]:
def predict_image(
    image_path,
    model,
    transform,
    device,
    class_names=["Normal", "Tuberculosis"],
):
    image = Image.open(image_path).convert("RGB")
    image = transform(image)
    image = image.unsqueeze(0)
    image = image.to(device)

    with torch.no_grad():
        output = model(image)
        prob = torch.sigmoid(output)
        pred_prob = prob.item()
        pred_class = 1 if pred_prob > 0.5 else 0

    return class_names[pred_class], pred_prob

In [None]:
image_path = "path_to_new_image.jpg"
predicted_class, probability = predict_image(
    image_path, model, transform, device
)
print(f"Dự đoán: {predicted_class}")
print(f"Xác suất Tuberculosis: {probability:.4f}")

In [None]:
def predict_multiple_images(
    image_dir, model, transform, device, class_names=["Normal", "Tuberculosis"]
):
    results = []
    for image_name in os.listdir(image_dir):
        image_path = os.path.join(image_dir, image_name)
        if image_path.endswith(
            (".jpg", ".jpeg", ".png")
        ):
            pred_class, pred_prob = predict_image(
                image_path, model, transform, device, class_names
            )
            results.append((image_name, pred_class, pred_prob))
            print(
                f"Ảnh: {image_name}, Dự đoán: {pred_class}, Xác suất TB: {pred_prob:.4f}"
            )
    return results

In [None]:
image_dir = "path_to_image_folder"
results = predict_multiple_images(image_dir, model, transform, device)