In [7]:
import os
import pandas as pd
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision.models as models
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
from sklearn.metrics import accuracy_score
from PIL import Image
import numpy as np

# 数据预处理
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# 加载训练集
train_dir = './dataset/train_set'  # 替换为你的训练集路径
train_dataset = ImageFolder(train_dir, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)

# 加载支持集和查询集
def load_test_data(task_dir):
    query_dir = os.path.join(task_dir, 'query')
    support_dir = os.path.join(task_dir, 'support')

    # 加载查询集
    query_images = []
    for img_name in os.listdir(query_dir):
        img_path = os.path.join(query_dir, img_name)
        img = Image.open(img_path).convert('RGB')
        img = transform(img)
        query_images.append(img)

    query_images = torch.stack(query_images)

    # 加载支持集
    support_images, support_labels = [], []
    for class_name in os.listdir(support_dir):
        class_path = os.path.join(support_dir, class_name)
        for img_name in os.listdir(class_path):
            img_path = os.path.join(class_path, img_name)
            img = Image.open(img_path).convert('RGB')
            img = transform(img)
            support_images.append(img)
            support_labels.append(class_name)

    support_images = torch.stack(support_images)
    return query_images, support_images, support_labels

# 加载Excel记录
excel_path = './dataset/test_set/query_labels.csv'  # 替换为你的Excel路径
query_labels_df = pd.read_csv(excel_path)
query_labels = query_labels_df['label'].tolist()

# 定义模型
model = models.resnet18(pretrained=True)
model.fc = nn.Linear(model.fc.in_features, len(train_dataset.classes))  # 修改分类层
model = model.to('cuda' if torch.cuda.is_available() else 'cpu')

# 训练模型
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# 训练阶段
for epoch in range(10):  # 设置合适的epoch
    model.train()
    for images, labels in train_loader:
        images, labels = images.to('cuda'), labels.to('cuda')

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels.view(-1))
        loss.backward()
        optimizer.step()

    print(f'Epoch [{epoch + 1}/10], Loss: {loss.item():.4f}')

# 测试阶段
model.eval()
task_results = []
for i in range(30):  # 处理30个任务
    task_dir = f'./dataset/test_set/task_{i}'  # 替换为你的测试集路径
    query_images, support_images, support_labels = load_test_data(task_dir)

    # 在支持集上提取特征
    with torch.no_grad():
        support_features = model(support_images.to('cuda'))
        query_features = model(query_images.to('cuda'))

    # 计算支持集和查询集之间的相似度
    support_features = support_features.cpu().numpy()
    query_features = query_features.cpu().numpy()

    # 简单的分类逻辑，使用余弦相似度
    for j, query_feature in enumerate(query_features):
        similarities = np.dot(support_features, query_feature)  # 计算相似度
        predicted_class = np.argmax(similarities)  # 找到最相似的类别
        predicted_label = support_labels[predicted_class]  # 记录结果

        task_results.append(predicted_label)  # 记录结果
        print(f'Task {i}, Query Image {j}, Predicted Class: {predicted_label}')  # 打印预测结果

# 计算准确率
accuracy = accuracy_score(query_labels, task_results)
print(f'Accuracy: {accuracy:.2f}')



Epoch [1/10], Loss: 1.1765
Epoch [2/10], Loss: 0.2666
Epoch [3/10], Loss: 0.2682
Epoch [4/10], Loss: 0.4306
Epoch [5/10], Loss: 0.0624
Epoch [6/10], Loss: 0.7052
Epoch [7/10], Loss: 0.0028
Epoch [8/10], Loss: 0.3847
Epoch [9/10], Loss: 0.0960
Epoch [10/10], Loss: 0.1634
Task 0, Query Image 0, Predicted Class: class_3
Task 0, Query Image 1, Predicted Class: class_7
Task 0, Query Image 2, Predicted Class: class_0
Task 0, Query Image 3, Predicted Class: class_9
Task 0, Query Image 4, Predicted Class: class_7
Task 0, Query Image 5, Predicted Class: class_3
Task 0, Query Image 6, Predicted Class: class_6
Task 0, Query Image 7, Predicted Class: class_9
Task 0, Query Image 8, Predicted Class: class_1
Task 0, Query Image 9, Predicted Class: class_7
Task 0, Query Image 10, Predicted Class: class_2
Task 0, Query Image 11, Predicted Class: class_3
Task 0, Query Image 12, Predicted Class: class_6
Task 0, Query Image 13, Predicted Class: class_9
Task 0, Query Image 14, Predicted Class: class_7
Task