<a href="https://colab.research.google.com/github/tj191073-droid/tj191073/blob/main/_untitled10_ipynb.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import os, torch
import numpy as np
from torch import nn
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader
from sklearn.utils.class_weight import compute_class_weight

# ✅ 环境与路径设置
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
data_dir = "/content/fish_dataset_split"
train_dir = os.path.join(data_dir, "train")
val_dir = os.path.join(data_dir, "val")
num_classes = len(os.listdir(train_dir))
save_path = "best_model.pth"

# ✅ 参数设置
image_size = 160
batch_size = 8
learning_rate = 3e-4
max_epochs = 50
early_stop_patience = 6

# ✅ 图像增强
train_transform = transforms.Compose([
    transforms.Resize((image_size, image_size)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.ColorJitter(0.2, 0.2, 0.2, 0.1),
    transforms.ToTensor()
])
val_transform = transforms.Compose([
    transforms.Resize((image_size, image_size)),
    transforms.ToTensor()
])

# ✅ 加载数据
train_data = datasets.ImageFolder(train_dir, transform=train_transform)
val_data = datasets.ImageFolder(val_dir, transform=val_transform)
train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_data, batch_size=batch_size)

# ✅ 类别权重（应对长尾分布）
weights = compute_class_weight(class_weight="balanced",
                                classes=np.unique(train_data.targets),
                                y=train_data.targets)
weights = torch.tensor(weights, dtype=torch.float).to(device)

# ✅ 加载模型并解冻所有层
model = models.mobilenet_v2(weights="IMAGENET1K_V1")
for param in model.parameters():
    param.requires_grad = True
model.classifier[1] = nn.Linear(model.last_channel, num_classes)
model.to(device)

# ✅ 损失函数与优化器
criterion = nn.CrossEntropyLoss(weight=weights)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

# ✅ 训练循环
best_val_acc, patience_counter = 0.0, 0
for epoch in range(1, max_epochs + 1):
    # === Train ===
    model.train()
    correct, total = 0, 0
    for x, y in train_loader:
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()
        pred = model(x)
        loss = criterion(pred, y)
        loss.backward()
        optimizer.step()
        correct += (pred.argmax(1) == y).sum().item()
        total += y.size(0)
    train_acc = correct / total

    # === Validation ===
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for x, y in val_loader:
            x, y = x.to(device), y.to(device)
            pred = model(x)
            correct += (pred.argmax(1) == y).sum().item()
            total += y.size(0)
    val_acc = correct / total

    print(f"Epoch {epoch}: Train Acc = {train_acc:.4f}, Val Acc = {val_acc:.4f}")

    # === Early Stopping ===
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save(model.state_dict(), save_path)
        patience_counter = 0
    else:
        patience_counter += 1
        if patience_counter >= early_stop_patience:
            print("⏹️ Early stopping triggered.")
            break

# ✅ 加载最优模型
model.load_state_dict(torch.load(save_path))
print(f"✅ Done. Best Validation Accuracy: {best_val_acc:.4f}")

Downloading: "https://download.pytorch.org/models/mobilenet_v2-b0353104.pth" to /root/.cache/torch/hub/checkpoints/mobilenet_v2-b0353104.pth
100%|██████████| 13.6M/13.6M [00:00<00:00, 102MB/s] 


Epoch 1: Train Acc = 0.1802, Val Acc = 0.3021
Epoch 2: Train Acc = 0.3531, Val Acc = 0.3937
Epoch 3: Train Acc = 0.4562, Val Acc = 0.5312
Epoch 4: Train Acc = 0.5344, Val Acc = 0.5604
Epoch 5: Train Acc = 0.6062, Val Acc = 0.6292
Epoch 6: Train Acc = 0.6432, Val Acc = 0.5979
Epoch 7: Train Acc = 0.6880, Val Acc = 0.6292
Epoch 8: Train Acc = 0.6995, Val Acc = 0.6208
Epoch 9: Train Acc = 0.7260, Val Acc = 0.6500
Epoch 10: Train Acc = 0.7604, Val Acc = 0.7000
Epoch 11: Train Acc = 0.7635, Val Acc = 0.6604
Epoch 12: Train Acc = 0.7849, Val Acc = 0.7146
Epoch 13: Train Acc = 0.8161, Val Acc = 0.7146
Epoch 14: Train Acc = 0.8172, Val Acc = 0.6729
Epoch 15: Train Acc = 0.8120, Val Acc = 0.6854
Epoch 16: Train Acc = 0.8526, Val Acc = 0.6896
Epoch 17: Train Acc = 0.8344, Val Acc = 0.7208
Epoch 18: Train Acc = 0.8552, Val Acc = 0.6979
Epoch 19: Train Acc = 0.8651, Val Acc = 0.6854
Epoch 20: Train Acc = 0.8734, Val Acc = 0.6937
Epoch 21: Train Acc = 0.8734, Val Acc = 0.7271
Epoch 22: Train Acc = 

In [None]:
# 克隆数据集
!git clone https://huggingface.co/datasets/imageomics/fish-vista
%cd fish-vista

# 安装 Git LFS 并拉取大文件（包括图像）
!git lfs install
!git lfs pull

Cloning into 'fish-vista'...
remote: Enumerating objects: 80157, done.[K
remote: Total 80157 (delta 0), reused 0 (delta 0), pack-reused 80157 (from 1)[K
Receiving objects: 100% (80157/80157), 212.04 MiB | 25.35 MiB/s, done.
Resolving deltas: 100% (104/104), done.
Updating files: 100% (75833/75833), done.
Filtering content: 100% (75699/75699), 11.04 GiB | 29.52 MiB/s, done.
/content/fish-vista
Updated git hooks.
Git LFS initialized.


In [None]:
import os
import pandas as pd

# 读取 CSV 元数据
df = pd.read_csv("/content/fish-vista/classification_train.csv", low_memory=False)

# 仅保留 family 为 Cyprinidae 的记录
cy_df = df[df['family'] == 'Cyprinidae']

# 找出每类图片数 ≥ 60 的种类
species_counts = cy_df['standardized_species'].value_counts()
valid_species = species_counts[species_counts >= 60]

# 随机选取 40 个种类
selected_species = valid_species.sample(n=40, random_state=42).index.tolist()
sub_df = cy_df[cy_df['standardized_species'].isin(selected_species)]

# 每个种类随机抽取 60 张图像
subset = sub_df.groupby('standardized_species', group_keys=False).apply(
    lambda x: x.sample(60, random_state=42)
).reset_index(drop=True)

# 拼接真实图像路径
def resolve_path(file_name):
    try:
        chunk_id = file_name.split('/')[1].split('_')[1]
        return os.path.join("/content/fish-vista/Images", f"chunk_{chunk_id}", os.path.basename(file_name))
    except:
        return None

subset['image_path'] = subset['file_name'].apply(resolve_path)

# 丢弃路径无效的行
subset['image_exists'] = subset['image_path'].apply(lambda x: os.path.exists(x))
subset = subset[subset['image_exists']].reset_index(drop=True)

# 查看结果
print(f"✅ 最终提取：{len(subset)} 张图像，{subset['standardized_species'].nunique()} 个种类")
print(subset[['standardized_species', 'image_path']].head())