In [None]:
# =====================================
# Step 0: 挂载 Google Drive
# =====================================
from google.colab import drive
drive.mount('/content/drive')

import os

# =====================================
# Step 1: 导入所需库
# =====================================
!pip install torch torchvision --quiet

import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
import numpy as np
import copy

from torch.utils.data import DataLoader, Subset
from torchvision import transforms, datasets
from torchvision.models import vit_b_16, ViT_B_16_Weights

from sklearn.model_selection import train_test_split
from sklearn.metrics import (classification_report, confusion_matrix,
                             precision_recall_fscore_support, accuracy_score)

# =====================================
# Step 2: 设置数据集路径 & 训练参数
# =====================================
# 注意：这是你在 Google Drive 中的数据路径，请根据实际情况修改
data_dir = "/content/drive/MyDrive/9517 gp/Aerial_Landscapes"

# 是否使用GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# 超参数
batch_size = 32
num_epochs = 100
learning_rate = 1e-3
test_split = 0.2  # 80% 训练 / 20% 测试
random_seed = 42  # 随机种子

# =====================================
# Step 3: 加载数据集 & 图像预处理
# =====================================
# 训练数据增广 & 测试数据变换
train_transform = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],
                         [0.229, 0.224, 0.225])
])

test_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],
                         [0.229, 0.224, 0.225])
])

# 利用 ImageFolder 加载图像数据（文件夹名即类别名）
full_dataset = datasets.ImageFolder(root=data_dir)
print("Total samples in dataset:", len(full_dataset))
print("Classes:", full_dataset.classes)

# 按照 80-20 的比例分割：训练集 / 测试集（分层抽样保证各类比例）
indices = np.arange(len(full_dataset))
targets = np.array(full_dataset.targets)
train_idx, test_idx = train_test_split(
    indices, test_size=test_split, stratify=targets, random_state=random_seed
)
print("Train samples:", len(train_idx), "Test samples:", len(test_idx))

# 根据不同的 transform，制作训练集和测试集
train_dataset = copy.deepcopy(full_dataset)
train_dataset.transform = train_transform

test_dataset = copy.deepcopy(full_dataset)
test_dataset.transform = test_transform

train_subset = Subset(train_dataset, train_idx)
test_subset = Subset(test_dataset, test_idx)

# DataLoaders
train_loader = DataLoader(train_subset, batch_size=batch_size, shuffle=True, num_workers=2)
test_loader = DataLoader(test_subset, batch_size=batch_size, shuffle=False, num_workers=2)

# =====================================
# Step 4: 构建 ViT 模型
# =====================================
# 使用 torchvision 提供的 vit_b_16 预训练权重
weights = ViT_B_16_Weights.IMAGENET1K_V1
model = vit_b_16(weights=weights)

# 获取原分类层(Sequential)中第2个模块(Linear)的 in_features
num_features = model.heads[0].in_features
# 替换为新的线性层，让输出类别数与数据集一致
model.heads[0] = nn.Linear(num_features, len(full_dataset.classes))

model = model.to(device)

# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# =====================================
# Step 5: 训练与验证循环
# =====================================
train_losses, train_accuracies = [], []
test_losses, test_accuracies = [], []

for epoch in range(num_epochs):
    # ---------- 训练 ----------
    model.train()
    running_loss, correct, total = 0.0, 0, 0

    for inputs, labels in train_loader:
        inputs = inputs.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * inputs.size(0)
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()

    epoch_loss = running_loss / total
    epoch_acc = correct / total
    train_losses.append(epoch_loss)
    train_accuracies.append(epoch_acc)

    # ---------- 测试 ----------
    model.eval()
    test_loss, test_correct, test_total = 0.0, 0, 0

    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs = inputs.to(device)
            labels = labels.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            test_loss += loss.item() * inputs.size(0)
            _, predicted = outputs.max(1)
            test_total += labels.size(0)
            test_correct += predicted.eq(labels).sum().item()

    epoch_test_loss = test_loss / test_total
    epoch_test_acc = test_correct / test_total
    test_losses.append(epoch_test_loss)
    test_accuracies.append(epoch_test_acc)

    print(f"Epoch [{epoch+1}/{num_epochs}] "
          f"Train Loss: {epoch_loss:.4f}, Train Acc: {epoch_acc:.4f} | "
          f"Test Loss: {epoch_test_loss:.4f}, Test Acc: {epoch_test_acc:.4f}")

# =====================================
# Step 6: 最终评估 & 指标输出
# =====================================
model.eval()
all_preds, all_labels = [], []

with torch.no_grad():
    for inputs, labels in test_loader:
        inputs = inputs.to(device)
        outputs = model(inputs)
        _, predicted = torch.max(outputs, 1)
        all_preds.extend(predicted.cpu().numpy())
        all_labels.extend(labels.numpy())

# Overall Accuracy
accuracy_val = accuracy_score(all_labels, all_preds)
# Macro Precision, Recall, F1
precision, recall, f1, _ = precision_recall_fscore_support(all_labels, all_preds, average='macro')

print("\n=== Test Evaluation Metrics ===")
print(f"Accuracy: {accuracy_val:.4f}")
print(f"Precision (Macro): {precision:.4f}")
print(f"Recall (Macro): {recall:.4f}")
print(f"F1-Score (Macro): {f1:.4f}")

print("\n=== Classification Report ===")
print(classification_report(all_labels, all_preds, target_names=full_dataset.classes))

# =====================================
# Step 7: 混淆矩阵可视化
# =====================================
cm = confusion_matrix(all_labels, all_preds)

plt.figure(figsize=(10,8))
plt.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
plt.title("Confusion Matrix")
plt.colorbar()

tick_marks = np.arange(len(full_dataset.classes))
plt.xticks(tick_marks, full_dataset.classes, rotation=45)
plt.yticks(tick_marks, full_dataset.classes)

plt.xlabel("Predicted Label")
plt.ylabel("True Label")

# 在混淆矩阵上打印数字
thresh = cm.max() / 2.
for i in range(cm.shape[0]):
    for j in range(cm.shape[1]):
        plt.text(j, i, format(cm[i, j], 'd'),
                 horizontalalignment="center",
                 color="white" if cm[i, j] > thresh else "black")
plt.tight_layout()
plt.show()

# =====================================
# Step 8: 绘制 Training History (Loss & Accuracy)
# =====================================
epochs_range = range(1, num_epochs+1)

plt.figure(figsize=(12,5))

# Loss 曲线
plt.subplot(1, 2, 1)
plt.plot(epochs_range, train_losses, label="Train Loss")
plt.plot(epochs_range, test_losses, label="Test Loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.title("Training and Test Loss")
plt.legend()

# Accuracy 曲线
plt.subplot(1, 2, 2)
plt.plot(epochs_range, train_accuracies, label="Train Accuracy")
plt.plot(epochs_range, test_accuracies, label="Test Accuracy")
plt.xlabel("Epoch")
plt.ylabel("Accuracy")
plt.title("Training and Test Accuracy")
plt.legend()

plt.tight_layout()
plt.show()


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
Using device: cuda
Total samples in dataset: 12010
Classes: ['Agriculture', 'Airport', 'Beach', 'City', 'Desert', 'Forest', 'Grassland', 'Highway', 'Lake', 'Mountain', 'Parking', 'Port', 'Railway', 'Residential', 'River']
Train samples: 9608 Test samples: 2402
Epoch [1/100] Train Loss: 2.1251, Train Acc: 0.2833 | Test Loss: 1.6709, Test Acc: 0.4596
Epoch [2/100] Train Loss: 1.5476, Train Acc: 0.4830 | Test Loss: 1.4551, Test Acc: 0.5004
Epoch [3/100] Train Loss: 1.3267, Train Acc: 0.5576 | Test Loss: 1.1790, Test Acc: 0.5941
Epoch [4/100] Train Loss: 1.1967, Train Acc: 0.5969 | Test Loss: 1.1622, Test Acc: 0.6091
Epoch [5/100] Train Loss: 1.0980, Train Acc: 0.6336 | Test Loss: 1.0063, Test Acc: 0.6661
Epoch [6/100] Train Loss: 0.9998, Train Acc: 0.6730 | Test Loss: 1.0417, Test Acc: 0.6532
Epoch [7/100] Train Loss: 0.9310, Train Acc: 0.6877 | Test Loss: 1.0619, Test Acc: 0.6399
Epoch [8/100] Train Loss: 0.8832, Train Acc: 0.7013 | Test Loss: 0.9449, Test Acc: 0.6840
Epoch [9/100] Train Loss: 0.8568, Train Acc: 0.7147 | Test Loss: 0.8876, Test Acc: 0.7090
Epoch [10/100] Train Loss: 0.7923, Train Acc: 0.7366 | Test Loss: 0.8837, Test Acc: 0.7057
Epoch [11/100] Train Loss: 0.7874, Train Acc: 0.7378 | Test Loss: 0.8181, Test Acc: 0.7182
Epoch [12/100] Train Loss: 0.7448, Train Acc: 0.7532 | Test Loss: 0.7715, Test Acc: 0.7365
Epoch [13/100] Train Loss: 0.7342, Train Acc: 0.7542 | Test Loss: 0.7414, Test Acc: 0.7490
Epoch [14/100] Train Loss: 0.7153, Train Acc: 0.7634 | Test Loss: 0.7463, Test Acc: 0.7485
Epoch [15/100] Train Loss: 0.6750, Train Acc: 0.7800 | Test Loss: 0.7991, Test Acc: 0.7206
Epoch [16/100] Train Loss: 0.6603, Train Acc: 0.7789 | Test Loss: 0.7382, Test Acc: 0.7531
Epoch [17/100] Train Loss: 0.6403, Train Acc: 0.7847 | Test Loss: 0.7438, Test Acc: 0.7540
Epoch [18/100] Train Loss: 0.6171, Train Acc: 0.7929 | Test Loss: 0.6978, Test Acc: 0.7648
Epoch [19/100] Train Loss: 0.6144, Train Acc: 0.7927 | Test Loss: 0.7331, Test Acc: 0.7565
Epoch [20/100] Train Loss: 0.5969, Train Acc: 0.7967 | Test Loss: 0.6922, Test Acc: 0.7719
Epoch [21/100] Train Loss: 0.5794, Train Acc: 0.8081 | Test Loss: 0.6544, Test Acc: 0.7843
Epoch [22/100] Train Loss: 0.5457, Train Acc: 0.8179 | Test Loss: 0.6520, Test Acc: 0.7806
Epoch [23/100] Train Loss: 0.5392, Train Acc: 0.8205 | Test Loss: 0.6747, Test Acc: 0.7731
Epoch [24/100] Train Loss: 0.5325, Train Acc: 0.8237 | Test Loss: 0.6597, Test Acc: 0.7860
Epoch [25/100] Train Loss: 0.5322, Train Acc: 0.8208 | Test Loss: 0.6446, Test Acc: 0.7885
Epoch [26/100] Train Loss: 0.5070, Train Acc: 0.8317 | Test Loss: 0.6729, Test Acc: 0.7710
Epoch [27/100] Train Loss: 0.4867, Train Acc: 0.8386 | Test Loss: 0.6350, Test Acc: 0.7839
Epoch [28/100] Train Loss: 0.4903, Train Acc: 0.8366 | Test Loss: 0.6470, Test Acc: 0.7856
Epoch [29/100] Train Loss: 0.4816, Train Acc: 0.8382 | Test Loss: 0.6966, Test Acc: 0.7773
Epoch [30/100] Train Loss: 0.4745, Train Acc: 0.8434 | Test Loss: 0.6565, Test Acc: 0.7939
Epoch [31/100] Train Loss: 0.4560, Train Acc: 0.8479 | Test Loss: 0.6419, Test Acc: 0.7956
Epoch [32/100] Train Loss: 0.4414, Train Acc: 0.8528 | Test Loss: 0.5516, Test Acc: 0.8231
Epoch [33/100] Train Loss: 0.4303, Train Acc: 0.8568 | Test Loss: 0.6037, Test Acc: 0.8093
Epoch [34/100] Train Loss: 0.4197, Train Acc: 0.8593 | Test Loss: 0.6357, Test Acc: 0.7931
Epoch [35/100] Train Loss: 0.4173, Train Acc: 0.8600 | Test Loss: 0.6019, Test Acc: 0.8172
Epoch [36/100] Train Loss: 0.4028, Train Acc: 0.8680 | Test Loss: 0.6298, Test Acc: 0.8081
Epoch [37/100] Train Loss: 0.4123, Train Acc: 0.8621 | Test Loss: 0.6374, Test Acc: 0.8118
Epoch [38/100] Train Loss: 0.3908, Train Acc: 0.8660 | Test Loss: 0.6522, Test Acc: 0.8014
Epoch [39/100] Train Loss: 0.3852, Train Acc: 0.8728 | Test Loss: 0.6460, Test Acc: 0.8072
Epoch [40/100] Train Loss: 0.3727, Train Acc: 0.8770 | Test Loss: 0.5824, Test Acc: 0.8272
Epoch [41/100] Train Loss: 0.3853, Train Acc: 0.8743 | Test Loss: 0.6150, Test Acc: 0.8152
Epoch [42/100] Train Loss: 0.3497, Train Acc: 0.8858 | Test Loss: 0.5436, Test Acc: 0.8306
Epoch [43/100] Train Loss: 0.3536, Train Acc: 0.8829 | Test Loss: 0.6525, Test Acc: 0.7948
Epoch [44/100] Train Loss: 0.3503, Train Acc: 0.8845 | Test Loss: 0.5417, Test Acc: 0.8193
Epoch [45/100] Train Loss: 0.3478, Train Acc: 0.8818 | Test Loss: 0.6313, Test Acc: 0.8072
Epoch [46/100] Train Loss: 0.3343, Train Acc: 0.8910 | Test Loss: 0.5757, Test Acc: 0.8168
Epoch [47/100] Train Loss: 0.3256, Train Acc: 0.8919 | Test Loss: 0.5935, Test Acc: 0.8181
Epoch [48/100] Train Loss: 0.3368, Train Acc: 0.8887 | Test Loss: 0.5629, Test Acc: 0.8206
Epoch [49/100] Train Loss: 0.3201, Train Acc: 0.8900 | Test Loss: 0.5812, Test Acc: 0.8297
Epoch [50/100] Train Loss: 0.3189, Train Acc: 0.8965 | Test Loss: 0.5643, Test Acc: 0.8251
Epoch [51/100] Train Loss: 0.3138, Train Acc: 0.8888 | Test Loss: 0.5848, Test Acc: 0.8318
Epoch [52/100] Train Loss: 0.3054, Train Acc: 0.9014 | Test Loss: 0.6020, Test Acc: 0.8127
Epoch [53/100] Train Loss: 0.3119, Train Acc: 0.8970 | Test Loss: 0.5345, Test Acc: 0.8318
Epoch [54/100] Train Loss: 0.2921, Train Acc: 0.9024 | Test Loss: 0.6199, Test Acc: 0.8181
Epoch [55/100] Train Loss: 0.2989, Train Acc: 0.9017 | Test Loss: 0.5729, Test Acc: 0.8268
Epoch [56/100] Train Loss: 0.2802, Train Acc: 0.9084 | Test Loss: 0.5672, Test Acc: 0.8372
Epoch [57/100] Train Loss: 0.2884, Train Acc: 0.9072 | Test Loss: 0.5526, Test Acc: 0.8360
Epoch [58/100] Train Loss: 0.2697, Train Acc: 0.9128 | Test Loss: 0.5388, Test Acc: 0.8356
Epoch [59/100] Train Loss: 0.2771, Train Acc: 0.9099 | Test Loss: 0.6238, Test Acc: 0.8152
Epoch [60/100] Train Loss: 0.2733, Train Acc: 0.9109 | Test Loss: 0.5352, Test Acc: 0.8293
Epoch [61/100] Train Loss: 0.2637, Train Acc: 0.9112 | Test Loss: 0.5500, Test Acc: 0.8439
Epoch [62/100] Train Loss: 0.2694, Train Acc: 0.9128 | Test Loss: 0.5686, Test Acc: 0.8281
Epoch [63/100] Train Loss: 0.2699, Train Acc: 0.9108 | Test Loss: 0.6027, Test Acc: 0.8268
Epoch [64/100] Train Loss: 0.2560, Train Acc: 0.9172 | Test Loss: 0.5321, Test Acc: 0.8405
Epoch [65/100] Train Loss: 0.2511, Train Acc: 0.9183 | Test Loss: 0.5750, Test Acc: 0.8264
Epoch [66/100] Train Loss: 0.2709, Train Acc: 0.9112 | Test Loss: 0.5224, Test Acc: 0.8426
Epoch [67/100] Train Loss: 0.2370, Train Acc: 0.9224 | Test Loss: 0.5736, Test Acc: 0.8393
Epoch [68/100] Train Loss: 0.2593, Train Acc: 0.9136 | Test Loss: 0.5582, Test Acc: 0.8376
Epoch [69/100] Train Loss: 0.2434, Train Acc: 0.9234 | Test Loss: 0.5518, Test Acc: 0.8410
Epoch [70/100] Train Loss: 0.2429, Train Acc: 0.9223 | Test Loss: 0.5715, Test Acc: 0.8285
Epoch [71/100] Train Loss: 0.2555, Train Acc: 0.9178 | Test Loss: 0.5613, Test Acc: 0.8318
Epoch [72/100] Train Loss: 0.2285, Train Acc: 0.9229 | Test Loss: 0.5907, Test Acc: 0.8397
Epoch [73/100] Train Loss: 0.2207, Train Acc: 0.9306 | Test Loss: 0.5193, Test Acc: 0.8485
Epoch [74/100] Train Loss: 0.2263, Train Acc: 0.9271 | Test Loss: 0.5814, Test Acc: 0.8260
Epoch [75/100] Train Loss: 0.2413, Train Acc: 0.9213 | Test Loss: 0.6278, Test Acc: 0.8347
Epoch [76/100] Train Loss: 0.2304, Train Acc: 0.9236 | Test Loss: 0.5436, Test Acc: 0.8464
Epoch [77/100] Train Loss: 0.2294, Train Acc: 0.9272 | Test Loss: 0.5299, Test Acc: 0.8397
Epoch [78/100] Train Loss: 0.2202, Train Acc: 0.9253 | Test Loss: 0.5752, Test Acc: 0.8405
Epoch [79/100] Train Loss: 0.2267, Train Acc: 0.9259 | Test Loss: 0.5515, Test Acc: 0.8397
Epoch [80/100] Train Loss: 0.2222, Train Acc: 0.9287 | Test Loss: 0.5996, Test Acc: 0.8306
Epoch [81/100] Train Loss: 0.2092, Train Acc: 0.9317 | Test Loss: 0.5389, Test Acc: 0.8347
Epoch [82/100] Train Loss: 0.2114, Train Acc: 0.9315 | Test Loss: 0.5749, Test Acc: 0.8293
Epoch [83/100] Train Loss: 0.2174, Train Acc: 0.9282 | Test Loss: 0.5194, Test Acc: 0.8472
Epoch [84/100] Train Loss: 0.2179, Train Acc: 0.9317 | Test Loss: 0.4970, Test Acc: 0.8489
Epoch [85/100] Train Loss: 0.1953, Train Acc: 0.9365 | Test Loss: 0.5500, Test Acc: 0.8381
Epoch [86/100] Train Loss: 0.2058, Train Acc: 0.9329 | Test Loss: 0.5583, Test Acc: 0.8318
Epoch [87/100] Train Loss: 0.2131, Train Acc: 0.9288 | Test Loss: 0.5887, Test Acc: 0.8381
Epoch [88/100] Train Loss: 0.1938, Train Acc: 0.9374 | Test Loss: 0.5476, Test Acc: 0.8493
Epoch [89/100] Train Loss: 0.1932, Train Acc: 0.9363 | Test Loss: 0.6004, Test Acc: 0.8418
Epoch [90/100] Train Loss: 0.1862, Train Acc: 0.9402 | Test Loss: 0.5386, Test Acc: 0.8485
Epoch [91/100] Train Loss: 0.1904, Train Acc: 0.9381 | Test Loss: 0.5388, Test Acc: 0.8505
Epoch [92/100] Train Loss: 0.2035, Train Acc: 0.9336 | Test Loss: 0.6293, Test Acc: 0.8347
Epoch [93/100] Train Loss: 0.2004, Train Acc: 0.9363 | Test Loss: 0.5346, Test Acc: 0.8505
Epoch [94/100] Train Loss: 0.1856, Train Acc: 0.9424 | Test Loss: 0.5353, Test Acc: 0.8480
Epoch [95/100] Train Loss: 0.1828, Train Acc: 0.9393 | Test Loss: 0.5281, Test Acc: 0.8485
Epoch [96/100] Train Loss: 0.1845, Train Acc: 0.9419 | Test Loss: 0.5372, Test Acc: 0.8410
Epoch [97/100] Train Loss: 0.1932, Train Acc: 0.9393 | Test Loss: 0.5661, Test Acc: 0.8472
Epoch [98/100] Train Loss: 0.1690, Train Acc: 0.9467 | Test Loss: 0.5471, Test Acc: 0.8501
Epoch [99/100] Train Loss: 0.1824, Train Acc: 0.9394 | Test Loss: 0.4839, Test Acc: 0.8530
Epoch [100/100] Train Loss: 0.1804, Train Acc: 0.9433 | Test Loss: 0.5763, Test Acc: 0.8501

=== Test Evaluation Metrics ===
Accuracy: 0.8501
Precision (Macro): 0.8509
Recall (Macro): 0.8502
F1-Score (Macro): 0.8491

=== Classification Report ===
              precision    recall  f1-score   support

 Agriculture       0.83      0.80      0.82       162
     Airport       0.71      0.57      0.64       160
       Beach       0.92      0.89      0.91       160
        City       0.89      0.88      0.88       160
      Desert       0.90      0.95      0.93       160
      Forest       0.97      0.96      0.96       160
   Grassland       0.98      0.86      0.92       160
     Highway       0.74      0.78      0.76       160
        Lake       0.84      0.79      0.81       160
    Mountain       0.88      0.83      0.85       160
     Parking       0.88      0.97      0.93       160
        Port       0.89      0.94      0.92       160
     Railway       0.73      0.79      0.76       160
 Residential       0.91      0.97      0.94       160
       River       0.69      0.76      0.73       160

    accuracy                           0.85      2402
   macro avg       0.85      0.85      0.85      2402
weighted avg       0.85      0.85      0.85      2402


