In [7]:
import os
import pandas as pd
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision.models as models
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
from sklearn.metrics import accuracy_score
from PIL import Image
import numpy as np

# 数据预处理
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# 加载训练集
train_dir = './dataset/train_set'  # 替换为你的训练集路径
train_dataset = ImageFolder(train_dir, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)

# 加载支持集和查询集
def load_test_data(task_dir):
    query_dir = os.path.join(task_dir, 'query')
    support_dir = os.path.join(task_dir, 'support')

    # 加载查询集
    query_images = []
    for img_name in os.listdir(query_dir):
        img_path = os.path.join(query_dir, img_name)
        img = Image.open(img_path).convert('RGB')
        img = transform(img)
        query_images.append(img)

    query_images = torch.stack(query_images)

    # 加载支持集
    support_images, support_labels = [], []
    for class_name in os.listdir(support_dir):
        class_path = os.path.join(support_dir, class_name)
        for img_name in os.listdir(class_path):
            img_path = os.path.join(class_path, img_name)
            img = Image.open(img_path).convert('RGB')
            img = transform(img)
            support_images.append(img)
            support_labels.append(class_name)

    support_images = torch.stack(support_images)
    return query_images, support_images, support_labels

# 加载Excel记录
excel_path = './dataset/test_set/query_labels.csv'  # 替换为你的Excel路径
query_labels_df = pd.read_csv(excel_path)
query_labels = query_labels_df['label'].tolist()

# 定义模型
model = models.resnet18(pretrained=True)
model.fc = nn.Linear(model.fc.in_features, len(train_dataset.classes))  # 修改分类层
model = model.to('cuda' if torch.cuda.is_available() else 'cpu')

# 训练模型
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# 训练阶段
for epoch in range(10):  # 设置合适的epoch
    model.train()
    for images, labels in train_loader:
        images, labels = images.to('cuda'), labels.to('cuda')

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels.view(-1))
        loss.backward()
        optimizer.step()

    print(f'Epoch [{epoch + 1}/10], Loss: {loss.item():.4f}')

# 测试阶段
model.eval()
task_results = []
for i in range(30):  # 处理30个任务
    task_dir = f'./dataset/test_set/task_{i}'  # 替换为你的测试集路径
    query_images, support_images, support_labels = load_test_data(task_dir)

    # 在支持集上提取特征
    with torch.no_grad():
        support_features = model(support_images.to('cuda'))
        query_features = model(query_images.to('cuda'))

    # 计算支持集和查询集之间的相似度
    support_features = support_features.cpu().numpy()
    query_features = query_features.cpu().numpy()

    # 简单的分类逻辑，使用余弦相似度
    for j, query_feature in enumerate(query_features):
        similarities = np.dot(support_features, query_feature)  # 计算相似度
        predicted_class = np.argmax(similarities)  # 找到最相似的类别
        predicted_label = support_labels[predicted_class]  # 记录结果

        task_results.append(predicted_label)  # 记录结果
        print(f'Task {i}, Query Image {j}, Predicted Class: {predicted_label}')  # 打印预测结果

# 计算准确率
accuracy = accuracy_score(query_labels, task_results)
print(f'Accuracy: {accuracy:.2f}')



Epoch [1/10], Loss: 1.1765
Epoch [2/10], Loss: 0.2666
Epoch [3/10], Loss: 0.2682
Epoch [4/10], Loss: 0.4306
Epoch [5/10], Loss: 0.0624
Epoch [6/10], Loss: 0.7052
Epoch [7/10], Loss: 0.0028
Epoch [8/10], Loss: 0.3847
Epoch [9/10], Loss: 0.0960
Epoch [10/10], Loss: 0.1634
Task 0, Query Image 0, Predicted Class: class_3
Task 0, Query Image 1, Predicted Class: class_7
Task 0, Query Image 2, Predicted Class: class_0
Task 0, Query Image 3, Predicted Class: class_9
Task 0, Query Image 4, Predicted Class: class_7
Task 0, Query Image 5, Predicted Class: class_3
Task 0, Query Image 6, Predicted Class: class_6
Task 0, Query Image 7, Predicted Class: class_9
Task 0, Query Image 8, Predicted Class: class_1
Task 0, Query Image 9, Predicted Class: class_7
Task 0, Query Image 10, Predicted Class: class_2
Task 0, Query Image 11, Predicted Class: class_3
Task 0, Query Image 12, Predicted Class: class_6
Task 0, Query Image 13, Predicted Class: class_9
Task 0, Query Image 14, Predicted Class: class_7
Task

主要改进包括以下几个方面：

数据增强


添加了更强的数据增强策略，包括随机裁剪、翻转、旋转和颜色抖动
区分了训练集和测试集的转换策略


模型架构


使用ResNet50替代ResNet18作为backbone，提供更强大的特征提取能力
添加了特征降维层，将特征压缩到更适合度量学习的维度
对特征进行L2归一化，使得距离计算更加稳定


度量学习策略


实现了原型网络（Prototypical Network）的架构
使用三元组损失(Triplet Loss)进行训练，更适合少样本学习任务
在预测时使用欧氏距离而不是简单的余弦相似度


训练优化


使用AdamW优化器并添加权重衰减，减少过拟合
实现学习率调度器（CosineAnnealingLR）
增加了梯度裁剪，提高训练稳定性
延长训练轮数到30轮


其他改进


更好的批处理和设备管理
更清晰的代码结构和模块化设计

要使用这个改进的版本：

确保你的环境中安装了所有必要的库
保持数据集结构不变
直接运行改进后的代码

你还可以尝试以下调优建议：

调整超参数：

增加/减少特征维度（512）
调整triplet loss的margin值（当前0.3）
修改学习率和权重衰减


数据处理：

如果数据集较小，可以增加更多数据增强
考虑添加mixup或cutmix等高级增强策略


模型选择：

可以尝试其他backbone，如EfficientNet或ViT
实验不同的度量学习方法，如关系网络（Relation Network）

In [8]:
import os
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
import torchvision.models as models
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
from sklearn.metrics import accuracy_score
from PIL import Image
import numpy as np

# 增强数据预处理
transform_train = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(15),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

transform_test = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# 特征提取器
class FeatureExtractor(nn.Module):
    def __init__(self, pretrained=True):
        super(FeatureExtractor, self).__init__()
        # 使用更强大的backbone
        model = models.resnet50(pretrained=pretrained)
        self.features = nn.Sequential(*list(model.children())[:-1])
        self.fc = nn.Linear(2048, 512)  # 降维到更合适的特征空间
        
    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        x = F.normalize(x, p=2, dim=1)  # L2归一化
        return x

# 原型网络实现
class ProtoNet:
    def __init__(self, model):
        self.model = model
        
    def get_prototypes(self, support_images, support_labels):
        features = self.model(support_images)
        unique_labels = list(set(support_labels))
        prototypes = []
        
        for label in unique_labels:
            mask = torch.tensor([l == label for l in support_labels])
            prototype = features[mask].mean(0)
            prototypes.append(prototype)
            
        return torch.stack(prototypes), unique_labels

    def predict(self, query_images, support_images, support_labels):
        self.model.eval()
        with torch.no_grad():
            prototypes, unique_labels = self.get_prototypes(support_images, support_labels)
            query_features = self.model(query_images)
            
            # 计算欧氏距离
            dists = torch.cdist(query_features, prototypes)
            pred_indices = torch.argmin(dists, dim=1)
            predictions = [unique_labels[idx] for idx in pred_indices]
            
        return predictions

# 加载数据
def load_test_data(task_dir):
    query_dir = os.path.join(task_dir, 'query')
    support_dir = os.path.join(task_dir, 'support')
    
    query_images = []
    for img_name in os.listdir(query_dir):
        img_path = os.path.join(query_dir, img_name)
        img = Image.open(img_path).convert('RGB')
        img = transform_test(img)
        query_images.append(img)
    query_images = torch.stack(query_images)
    
    support_images, support_labels = [], []
    for class_name in os.listdir(support_dir):
        class_path = os.path.join(support_dir, class_name)
        for img_name in os.listdir(class_path):
            img_path = os.path.join(class_path, img_name)
            img = Image.open(img_path).convert('RGB')
            img = transform_test(img)
            support_images.append(img)
            support_labels.append(class_name)
    support_images = torch.stack(support_images)
    
    return query_images, support_images, support_labels

# 主训练循环
def train(model, train_loader, epochs=30):
    criterion = nn.TripletMarginLoss(margin=0.3)
    optimizer = torch.optim.AdamW(model.parameters(), lr=0.0001, weight_decay=0.01)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
    
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)
    
    for epoch in range(epochs):
        model.train()
        total_loss = 0
        for images, labels in train_loader:
            images = images.to(device)
            
            # 生成三元组
            features = model(images)
            batch_size = len(features)
            
            # 为每个样本找到正样本和负样本
            loss = 0
            for i in range(batch_size):
                anchor = features[i].unsqueeze(0)
                pos_mask = labels == labels[i]
                neg_mask = labels != labels[i]
                
                if pos_mask.sum() > 1 and neg_mask.sum() > 0:
                    positive = features[pos_mask][1].unsqueeze(0)  # 选择同类别的另一个样本
                    negative = features[neg_mask][0].unsqueeze(0)  # 选择不同类别的样本
                    loss += criterion(anchor, positive, negative)
            
            if loss > 0:
                optimizer.zero_grad()
                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                optimizer.step()
                total_loss += loss.item()
        
        scheduler.step()
        print(f'Epoch [{epoch + 1}/{epochs}], Loss: {total_loss/len(train_loader):.4f}')

# 主执行流程
def main():
    # 初始化模型
    model = FeatureExtractor(pretrained=True)
    proto_net = ProtoNet(model)
    
    # 加载训练数据
    train_dir = './dataset/train_set'
    train_dataset = ImageFolder(train_dir, transform_train)
    train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
    
    # 训练模型
    train(model, train_loader)
    
    # 加载测试标签
    excel_path = './dataset/test_set/query_labels.csv'
    query_labels_df = pd.read_csv(excel_path)
    query_labels = query_labels_df['label'].tolist()
    
    # 测试阶段
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)
    model.eval()
    
    task_results = []
    for i in range(30):
        task_dir = f'./dataset/test_set/task_{i}'
        query_images, support_images, support_labels = load_test_data(task_dir)
        
        # 将数据移到GPU
        query_images = query_images.to(device)
        support_images = support_images.to(device)
        
        # 预测
        predictions = proto_net.predict(query_images, support_images, support_labels)
        task_results.extend(predictions)
        
        print(f'Task {i} completed')
    
    # 计算准确率
    accuracy = accuracy_score(query_labels, task_results)
    print(f'Final Accuracy: {accuracy:.4f}')

if __name__ == "__main__":
    main()

Downloading: "https://download.pytorch.org/models/resnet50-0676ba61.pth" to /root/.cache/torch/hub/checkpoints/resnet50-0676ba61.pth
100%|██████████| 97.8M/97.8M [00:02<00:00, 36.7MB/s]


Epoch [1/30], Loss: 1.0668
Epoch [2/30], Loss: 1.0315
Epoch [3/30], Loss: 1.0914
Epoch [4/30], Loss: 1.0675
Epoch [5/30], Loss: 0.9852
Epoch [6/30], Loss: 0.9763
Epoch [7/30], Loss: 1.0609
Epoch [8/30], Loss: 0.8420
Epoch [9/30], Loss: 0.8406
Epoch [10/30], Loss: 0.7740
Epoch [11/30], Loss: 0.7009
Epoch [12/30], Loss: 0.6881
Epoch [13/30], Loss: 0.6763
Epoch [14/30], Loss: 0.5586
Epoch [15/30], Loss: 0.6212
Epoch [16/30], Loss: 0.5623
Epoch [17/30], Loss: 0.5379
Epoch [18/30], Loss: 0.4454
Epoch [19/30], Loss: 0.4324
Epoch [20/30], Loss: 0.3867
Epoch [21/30], Loss: 0.3747
Epoch [22/30], Loss: 0.3708
Epoch [23/30], Loss: 0.3406
Epoch [24/30], Loss: 0.3341
Epoch [25/30], Loss: 0.2950
Epoch [26/30], Loss: 0.3002
Epoch [27/30], Loss: 0.2866
Epoch [28/30], Loss: 0.2622
Epoch [29/30], Loss: 0.3104
Epoch [30/30], Loss: 0.2626
Task 0 completed
Task 1 completed
Task 2 completed
Task 3 completed
Task 4 completed
Task 5 completed
Task 6 completed
Task 7 completed
Task 8 completed
Task 9 completed