In [3]:
import torch
import torch.nn as nn
import numpy as np
from sklearn.datasets import fetch_olivetti_faces
from torch.utils.data import DataLoader, TensorDataset
from torch.utils.tensorboard import SummaryWriter
from sklearn.model_selection import train_test_split

class FaceGRU(nn.Module):
    def __init__(self):
        super().__init__()
        self.gru = nn.GRU(
            input_size=4096,
            hidden_size=384,  # GRU参数效率介于RNN和LSTM之间
            num_layers=3,      # 增加层数增强表达能力
            batch_first=True,
            dropout=0.25
        )
        self.bn = nn.BatchNorm1d(384)
        self.fc = nn.Linear(384, 40)

    def forward(self, x):
        outputs, _ = self.gru(x)
        last_out = outputs[:, -1, :]
        normalized = self.bn(last_out)
        return self.fc(normalized)

def prepare_data(test_size=0.2):
    faces = fetch_olivetti_faces(shuffle=True, random_state=42)
    images = (faces.images - 0.5) / 0.5
    images = images.reshape(-1, 1, 4096)
    labels = faces.target

    X = torch.tensor(images, dtype=torch.float32)
    y = torch.tensor(labels, dtype=torch.long)
    
    X_train, X_test, y_train, y_test = train_test_split(
        X, y, test_size=test_size, stratify=y, random_state=42
    )
    return (X_train, y_train), (X_test, y_test)

def train_model():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    writer = SummaryWriter()

    (X_train, y_train), (X_test, y_test) = prepare_data()
    train_loader = DataLoader(TensorDataset(X_train, y_train), 
                            batch_size=40, shuffle=True)  # 增大batch_size
    test_loader = DataLoader(TensorDataset(X_test, y_test),
                           batch_size=40)

    model = FaceGRU().to(device)
    
    # 优化器调整
    optimizer = torch.optim.AdamW(model.parameters(), lr=0.002, weight_decay=1e-4)
    scheduler = torch.optim.lr_scheduler.CyclicLR(
        optimizer, 
        base_lr=0.001, 
        max_lr=0.005,
        step_size_up=15,
        cycle_momentum=False
    )
    criterion = nn.CrossEntropyLoss()

    best_acc = 0.0
    for epoch in range(80):  # 减少总epoch数
        model.train()
        train_loss = 0.0
        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            nn.utils.clip_grad_norm_(model.parameters(), 0.5)  # 更宽松的梯度裁剪
            optimizer.step()
            scheduler.step()  # 每个batch更新学习率
            train_loss += loss.item()

        model.eval()
        correct = 0
        with torch.no_grad():
            for inputs, labels in test_loader:
                inputs = inputs.to(device)
                outputs = model(inputs).cpu()
                preds = outputs.argmax(dim=1)
                correct += (preds == labels).sum().item()
        
        avg_loss = train_loss / len(train_loader)
        acc = 100 * correct / len(y_test)
        
        if acc > best_acc:
            best_acc = acc
            torch.save(model.state_dict(), "best_gru.pth")

        writer.add_scalar('Loss/train', avg_loss, epoch)
        writer.add_scalar('Accuracy/test', acc, epoch)
        print(f"Epoch {epoch+1:02d} | Loss: {avg_loss:.4f} | Acc: {acc:.2f}% | LR: {scheduler.get_last_lr()[0]:.5f}")

    writer.close()
    print(f"Best Accuracy: {best_acc:.2f}%")

if __name__ == "__main__":
    torch.manual_seed(42)
    np.random.seed(42)
    train_model()

Epoch 01 | Loss: 2.7679 | Acc: 61.25% | LR: 0.00313
Epoch 02 | Loss: 1.0619 | Acc: 82.50% | LR: 0.00473
Epoch 03 | Loss: 0.4644 | Acc: 83.75% | LR: 0.00260
Epoch 04 | Loss: 0.1903 | Acc: 96.25% | LR: 0.00153
Epoch 05 | Loss: 0.1068 | Acc: 88.75% | LR: 0.00367
Epoch 06 | Loss: 0.3216 | Acc: 88.75% | LR: 0.00420
Epoch 07 | Loss: 0.3125 | Acc: 95.00% | LR: 0.00207
Epoch 08 | Loss: 0.2144 | Acc: 92.50% | LR: 0.00207
Epoch 09 | Loss: 0.1125 | Acc: 91.25% | LR: 0.00420
Epoch 10 | Loss: 0.2892 | Acc: 90.00% | LR: 0.00367
Epoch 11 | Loss: 0.2627 | Acc: 91.25% | LR: 0.00153
Epoch 12 | Loss: 0.1218 | Acc: 95.00% | LR: 0.00260
Epoch 13 | Loss: 0.1390 | Acc: 88.75% | LR: 0.00473
Epoch 14 | Loss: 0.2559 | Acc: 87.50% | LR: 0.00313
Epoch 15 | Loss: 0.1867 | Acc: 92.50% | LR: 0.00100
Epoch 16 | Loss: 0.0962 | Acc: 87.50% | LR: 0.00313
Epoch 17 | Loss: 0.1433 | Acc: 90.00% | LR: 0.00473
Epoch 18 | Loss: 0.2578 | Acc: 88.75% | LR: 0.00260
Epoch 19 | Loss: 0.1984 | Acc: 91.25% | LR: 0.00153
Epoch 20 | L