In [None]:
import torch
import torch.nn as nn
from torchvision import transforms
from PIL import Image
from model import resnet50
import matplotlib.pyplot as plt
import os
import random

# 设置设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# 定义预处理操作（和训练时保持一致）
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])

# 加载模型和训练好的权重
model = resnet50().to(device)

checkpoint = torch.load("results/best_model.pth", map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])  # 从checkpoint中提取模型权重

model.eval()  # 设置为评估模式

# 定义测试目录和类别文件夹
test_dir = 'data/test'
class_folders = ['0', '1', '2']  # 对应normal, LUAD, LUSC
class_names = ['normal', 'LUAD', 'LUSC']

# 收集所有图片路径及其真实标签
image_paths = []
true_labels = []

for class_idx, folder in enumerate(class_folders):
    folder_path = os.path.join(test_dir, folder)
    if not os.path.exists(folder_path):
        print(f"警告：目录 {folder_path} 不存在！")
        continue
    
    for img_name in os.listdir(folder_path):
        if img_name.lower().endswith(('.png', '.jpg', '.jpeg')):
            image_paths.append(os.path.join(folder_path, img_name))
            true_labels.append(class_idx)

if not image_paths:
    print("错误：在所有类别目录中未找到任何图片！")
    exit()

# 随机选择最多6张图片（或更少如果图片不足）
num_images = min(6, len(image_paths))
selected_indices = random.sample(range(len(image_paths)), num_images)

# 创建可视化结果
plt.figure(figsize=(15, 10))

for i, idx in enumerate(selected_indices, 1):
    img_path = image_paths[idx]
    true_label = true_labels[idx]
    
    # 读取并处理图片
    image = Image.open(img_path).convert('RGB')
    input_image = transform(image).unsqueeze(0).to(device)
    
    # 预测
    with torch.no_grad():
        outputs = model(input_image)
        _, predicted = torch.max(outputs, 1)
        predicted_label = predicted.item()
    
    # 打印结果
    print(f"图片 {i}: {os.path.basename(img_path)}")
    print(f"真实类别: {class_names[true_label]} ({true_label})")
    print(f"预测类别: {class_names[predicted_label]} ({predicted_label})\n")
    
    # 可视化
    plt.subplot(2, 3, i)
    plt.imshow(image)
    plt.title(f"{os.path.basename(img_path)}\nTrue: {class_names[true_label]}\nPred: {class_names[predicted_label]}")
    plt.axis('off')

plt.tight_layout()
plt.show()



