In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
import pandas as pd
import numpy as np
from PIL import Image
import os
from tqdm import tqdm

class TestDataset:
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.tasks = [f for f in os.listdir(root_dir) if f.startswith('task_')]
        # Load query labels
        self.query_labels_df = pd.read_csv('./dataset/test_set/query_labels.csv')
        
    def load_task(self, task_idx):
        task_dir = os.path.join(self.root_dir, f'task_{task_idx}')
        
        # 加载支持集
        support_dir = os.path.join(task_dir, 'support')
        support_classes = os.listdir(support_dir)
        support_images = []
        support_labels = []
        
        for idx, class_name in enumerate(support_classes):
            class_dir = os.path.join(support_dir, class_name)
            for img_name in os.listdir(class_dir):
                img_path = os.path.join(class_dir, img_name)
                image = Image.open(img_path).convert('RGB')
                if self.transform:
                    image = self.transform(image)
                support_images.append(image)
                support_labels.append(idx)
        
        support_images = torch.stack(support_images)
        support_labels = torch.tensor(support_labels)
        
        # 加载查询集
        query_dir = os.path.join(task_dir, 'query')
        query_images = []
        query_paths = []
        query_labels = []  # 新增：存储查询集的真实标签
        
        for img_name in os.listdir(query_dir):
            img_path = os.path.join(query_dir, img_name)
            image = Image.open(img_path).convert('RGB')
            if self.transform:
                image = self.transform(image)
            query_images.append(image)
            query_paths.append(img_path)
            
            # 从CSV文件中获取真实标签
            task_name = f'task_{task_idx}'
            label = self.query_labels_df[
                (self.query_labels_df['img_name'] == img_name)
            ]['label'].values[0]
            
            # 将标签转换为数值索引
            label_idx = support_classes.index(label)
            query_labels.append(label_idx)
            
        query_images = torch.stack(query_images)
        query_labels = torch.tensor(query_labels)  # 新增：转换为tensor
        
        return support_images, support_labels, query_images, query_paths, support_classes, query_labels


# 2. 模型定义
class PrototypicalNet(nn.Module):
    def __init__(self, pretrained=True):
        super(PrototypicalNet, self).__init__()
        resnet = models.resnet50(pretrained=pretrained)
        self.feature_extractor = nn.Sequential(*list(resnet.children())[:-1])
        
    def forward(self, x):
        x = self.feature_extractor(x)
        x = x.view(x.size(0), -1)
        return x
    
    def get_prototypes(self, support_images, support_labels):
        features = self.forward(support_images)
        unique_labels = torch.unique(support_labels)
        prototypes = []
        
        for label in unique_labels:
            mask = support_labels == label
            prototype = features[mask].mean(0)
            prototypes.append(prototype)
            
        return torch.stack(prototypes)
    
    def predict(self, prototypes, query_features):
        distances = torch.cdist(query_features, prototypes)
        return torch.argmin(distances, dim=1)

def train_episode(model, support_images, support_labels, query_images, query_labels, optimizer):
    model.train()
    optimizer.zero_grad()
    
    # 获取原型
    prototypes = model.get_prototypes(support_images, support_labels)
    
    # 获取查询集特征
    query_features = model(query_images)
    
    # 计算距离和损失
    distances = torch.cdist(query_features, prototypes)
    log_probas = -distances
    loss = nn.CrossEntropyLoss()(log_probas, query_labels)
    
    # 计算准确率
    predictions = torch.argmin(distances, dim=1)
    accuracy = (predictions == query_labels).float().mean().item()
    
    loss.backward()
    optimizer.step()
    
    return loss.item(), accuracy

def evaluate_task(model, support_images, support_labels, query_images, query_labels=None):
    model.eval()
    with torch.no_grad():
        # 获取原型
        prototypes = model.get_prototypes(support_images, support_labels)
        
        # 获取查询集特征
        query_features = model(query_images)
        
        # 预测类别
        predictions = model.predict(prototypes, query_features)
        
        # 如果提供了标签，计算准确率
        accuracy = None
        if query_labels is not None:
            accuracy = (predictions == query_labels).float().mean().item()
        
    return predictions, accuracy

class TrainDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.classes = sorted(os.listdir(root_dir))  # 确保类别顺序一致
        self.class_to_idx = {cls: idx for idx, cls in enumerate(self.classes)}
        
        self.images = []
        self.labels = []
        
        # 收集所有图片路径和标签
        for class_name in self.classes:
            class_dir = os.path.join(root_dir, class_name)
            for img_name in os.listdir(class_dir):
                self.images.append(os.path.join(class_dir, img_name))
                self.labels.append(self.class_to_idx[class_name])
                
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, idx):
        img_path = self.images[idx]
        label = self.labels[idx]
        
        image = Image.open(img_path).convert('RGB')
        if self.transform:
            image = self.transform(image)
            
        return image, label

class EpisodeSampler:
    def __init__(self, dataset, n_way, n_support, n_query):
        self.dataset = dataset
        self.n_way = n_way  # 每个episode的类别数
        self.n_support = n_support  # 每个类别的支持集样本数
        self.n_query = n_query  # 每个类别的查询集样本数
        
        # 按类别组织数据
        self.label_to_indices = {}
        for idx, label in enumerate(dataset.labels):
            if label not in self.label_to_indices:
                self.label_to_indices[label] = []
            self.label_to_indices[label].append(idx)
    
    def sample_episode(self):
        # 随机选择n_way个类别
        selected_classes = np.random.choice(
            list(self.label_to_indices.keys()), 
            self.n_way, 
            replace=False
        )
        
        support_images = []
        support_labels = []
        query_images = []
        query_labels = []
        
        # 为每个选中的类别采样支持集和查询集
        for class_idx, class_label in enumerate(selected_classes):
            # 获取这个类别的所有样本索引
            class_indices = self.label_to_indices[class_label]
            
            # 随机选择支持集和查询集的样本
            selected_indices = np.random.choice(
                class_indices,
                self.n_support + self.n_query,
                replace=False
            )
            
            # 分割为支持集和查询集
            support_idx = selected_indices[:self.n_support]
            query_idx = selected_indices[self.n_support:]
            
            # 收集支持集样本
            for idx in support_idx:
                image, _ = self.dataset[idx]
                support_images.append(image)
                support_labels.append(class_idx)
            
            # 收集查询集样本
            for idx in query_idx:
                image, _ = self.dataset[idx]
                query_images.append(image)
                query_labels.append(class_idx)
        
        # 转换为tensor
        support_images = torch.stack(support_images)
        support_labels = torch.tensor(support_labels)
        query_images = torch.stack(query_images)
        query_labels = torch.tensor(query_labels)
        
        return support_images, support_labels, query_images, query_labels

def main():
    # 设置设备
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")
    
    # 数据转换
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                           std=[0.229, 0.224, 0.225])
    ])
    
    # 加载训练集
    print("Loading training dataset...")
    train_dataset = TrainDataset('dataset/train_set', transform=transform)
    test_dataset = TestDataset('dataset/test_set', transform=transform)
    
    # 创建episode采样器
    n_way = 10  # 每个episode的类别数
    n_support = 5  # 每个类别的支持集样本数
    n_query = 2  # 每个类别的查询集样本数
    episode_sampler = EpisodeSampler(train_dataset, n_way, n_support, n_query)
    
    # 初始化模型
    print("Initializing model...")
    model = PrototypicalNet(pretrained=True).to(device)
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    
    # 训练循环
    num_epochs = 2
    episodes_per_epoch = 50  # 每个epoch的episode数量
    best_train_acc = 0
    best_test_acc = 0
    
    print("Starting training...")
    for epoch in range(num_epochs):
        model.train()
        total_loss = 0
        total_acc = 0
        
        # 训练阶段
        pbar = tqdm(range(episodes_per_epoch), desc=f'Epoch {epoch+1}/{num_epochs}')
        for episode in pbar:
            # 采样一个episode
            support_images, support_labels, query_images, query_labels = episode_sampler.sample_episode()
            
            # 移动到设备
            support_images = support_images.to(device)
            support_labels = support_labels.to(device)
            query_images = query_images.to(device)
            query_labels = query_labels.to(device)
            
            # 训练一个episode
            loss, acc = train_episode(model, support_images, support_labels, 
                                    query_images, query_labels, optimizer)
            
            total_loss += loss
            total_acc += acc
            
            # 更新进度条
            pbar.set_postfix({
                'loss': f'{loss:.4f}',
                'acc': f'{acc:.4f}'
            })
        
        # 计算平均训练损失和准确率
        avg_train_loss = total_loss / episodes_per_epoch
        avg_train_acc = total_acc / episodes_per_epoch
        
        # 保存最佳训练准确率
        best_train_acc = max(best_train_acc, avg_train_acc)
        
        print(f'\nEpoch {epoch+1} Training Stats:')
        print(f'Average Loss: {avg_train_loss:.4f}')
        print(f'Average Accuracy: {avg_train_acc:.4f}')
        print(f'Best Training Accuracy: {best_train_acc:.4f}')
        
        # 验证阶段
        if (epoch + 1) % 1 == 0:  # 每5个epoch进行一次验证
            print("\nRunning validation...")
            model.eval()
            test_accuracies = []
            results = []
            
            for task_idx in tqdm(range(29), desc="Evaluating tasks"):
                support_images, support_labels, query_images, query_paths, support_classes, query_labels = \
                    test_dataset.load_task(task_idx)
                
                support_images = support_images.to(device)
                support_labels = support_labels.to(device)
                query_images = query_images.to(device)
                query_labels = query_labels.to(device)  # 移动到设备
                
                # 进行预测
                predictions, accuracy = evaluate_task(model, support_images, support_labels, 
                                                   query_images, query_labels)
                
                # 记录任务准确率
                test_accuracies.append(accuracy)
                
                # 保存预测结果
                for img_path, pred in zip(query_paths, predictions.cpu().numpy()):
                    results.append({
                        'task': f'task_{task_idx}',
                        'image': os.path.basename(img_path),
                        'predicted_class': support_classes[pred]
                    })
            
            # 计算并打印平均准确率
            avg_test_accuracy = np.mean(test_accuracies)
            print(f'\nAverage Test Accuracy: {avg_test_accuracy:.4f}')
            
            # 保存最佳测试准确率
            if avg_test_accuracy > best_test_acc:
                best_test_acc = avg_test_accuracy
                print(f'New Best Test Accuracy: {best_test_acc:.4f}')
                
                # 可选：保存最佳模型
                torch.save(model.state_dict(), 'best_model.pth')
            
            # 保存预测结果
            results_df = pd.DataFrame(results)
            results_df.to_csv(f'predictions_epoch_{epoch+1}.csv', index=False)
            
        print("\n" + "="*50)

if __name__ == '__main__':
    main()