In [1]:
import torch
import torch.nn as nn
import numpy as np
from sklearn.datasets import fetch_olivetti_faces
from torch.utils.data import DataLoader, TensorDataset
from torch.utils.tensorboard import SummaryWriter
from sklearn.model_selection import train_test_split

# 修改点1：模型定义部分改为LSTM
class FaceLSTM(nn.Module):
    def __init__(self):
        super().__init__()
        self.lstm = nn.LSTM(
            input_size=4096,   # 保持输入维度不变
            hidden_size=256,   # 减小隐藏层维度（LSTM参数更多）
            num_layers=2,      # 保持2层结构
            batch_first=True,
            dropout=0.3
        )
        self.bn = nn.BatchNorm1d(256)
        self.fc = nn.Linear(256, 40)

    def forward(self, x):
        # LSTM返回(output, (h_n, c_n))
        outputs, _ = self.lstm(x)  # 只需要输出序列
        last_out = outputs[:, -1, :]
        normalized = self.bn(last_out)
        return self.fc(normalized)

# 数据预处理保持相同
def prepare_data(test_size=0.2):
    faces = fetch_olivetti_faces(shuffle=True, random_state=42)
    images = (faces.images - 0.5) / 0.5
    images = images.reshape(-1, 1, 4096)
    labels = faces.target

    X = torch.tensor(images, dtype=torch.float32)
    y = torch.tensor(labels, dtype=torch.long)
    
    X_train, X_test, y_train, y_test = train_test_split(
        X, y, test_size=test_size, stratify=y, random_state=42
    )
    return (X_train, y_train), (X_test, y_test)

def train_model():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    writer = SummaryWriter()

    # 加载数据
    (X_train, y_train), (X_test, y_test) = prepare_data()
    train_loader = DataLoader(TensorDataset(X_train, y_train), 
                            batch_size=32, shuffle=True)
    test_loader = DataLoader(TensorDataset(X_test, y_test),
                           batch_size=32)

    # 修改点2：实例化LSTM模型
    model = FaceLSTM().to(device)
    
    # 训练配置保持相同
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.AdamW(model.parameters(), lr=0.001, weight_decay=1e-4)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'max', patience=3)

    best_acc = 0.0
    for epoch in range(100):
        model.train()
        train_loss = 0.0
        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            train_loss += loss.item()

        model.eval()
        correct = 0
        with torch.no_grad():
            for inputs, labels in test_loader:
                inputs = inputs.to(device)
                outputs = model(inputs).cpu()
                preds = outputs.argmax(dim=1)
                correct += (preds == labels).sum().item()
        
        avg_loss = train_loss / len(train_loader)
        acc = 100 * correct / len(y_test)
        scheduler.step(acc)

        if acc > best_acc:
            best_acc = acc
            torch.save(model.state_dict(), "best_lstm.pth")

        writer.add_scalar('Loss/train', avg_loss, epoch)
        writer.add_scalar('Accuracy/test', acc, epoch)
        print(f"Epoch {epoch+1:03d} | Loss: {avg_loss:.4f} | Acc: {acc:.2f}%")

    writer.close()
    print(f"Best Accuracy: {best_acc:.2f}%")

if __name__ == "__main__":
    train_model()

Epoch 001 | Loss: 2.8285 | Acc: 63.75%
Epoch 002 | Loss: 1.3185 | Acc: 88.75%
Epoch 003 | Loss: 0.6178 | Acc: 95.00%
Epoch 004 | Loss: 0.2528 | Acc: 96.25%
Epoch 005 | Loss: 0.1127 | Acc: 97.50%
Epoch 006 | Loss: 0.0543 | Acc: 97.50%
Epoch 007 | Loss: 0.0307 | Acc: 97.50%
Epoch 008 | Loss: 0.0200 | Acc: 98.75%
Epoch 009 | Loss: 0.0133 | Acc: 96.25%
Epoch 010 | Loss: 0.0112 | Acc: 97.50%
Epoch 011 | Loss: 0.0080 | Acc: 97.50%
Epoch 012 | Loss: 0.0066 | Acc: 97.50%
Epoch 013 | Loss: 0.0061 | Acc: 97.50%
Epoch 014 | Loss: 0.0053 | Acc: 97.50%
Epoch 015 | Loss: 0.0054 | Acc: 97.50%
Epoch 016 | Loss: 0.0055 | Acc: 97.50%
Epoch 017 | Loss: 0.0050 | Acc: 97.50%
Epoch 018 | Loss: 0.0047 | Acc: 97.50%
Epoch 019 | Loss: 0.0048 | Acc: 97.50%
Epoch 020 | Loss: 0.0048 | Acc: 97.50%
Epoch 021 | Loss: 0.0052 | Acc: 97.50%
Epoch 022 | Loss: 0.0045 | Acc: 97.50%
Epoch 023 | Loss: 0.0048 | Acc: 97.50%
Epoch 024 | Loss: 0.0045 | Acc: 97.50%
Epoch 025 | Loss: 0.0049 | Acc: 97.50%
Epoch 026 | Loss: 0.0049 