<a href="https://colab.research.google.com/github/tj191073-droid/tj191073/blob/main/%E2%80%9Cfish_mobilenet_train_ipynb%E2%80%9D%E7%9A%84%E5%89%AF%E6%9C%AC.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:

from google.colab import files
uploaded = files.upload()  # 手动选择上传 kaggle.json 文件

Saving kaggle.json to kaggle.json


In [None]:

# 配置 Kaggle API 权限
!mkdir -p ~/.kaggle
!cp kaggle.json ~/.kaggle/
!chmod 600 ~/.kaggle/kaggle.json

# 下载数据集
!kaggle datasets download -d crowww/a-large-scale-fish-dataset -p /content

# 解压
!unzip -q /content/a-large-scale-fish-dataset.zip -d /content/fish_dataset

Dataset URL: https://www.kaggle.com/datasets/crowww/a-large-scale-fish-dataset
License(s): Attribution 4.0 International (CC BY 4.0)
Downloading a-large-scale-fish-dataset.zip to /content
100% 3.24G/3.24G [00:27<00:00, 253MB/s]
100% 3.24G/3.24G [00:27<00:00, 126MB/s]


In [None]:
import os, shutil
from sklearn.model_selection import train_test_split

original_data_path = "/content/fish_dataset/NA_Fish_Dataset"
base_dir = "/content/fish_dataset_split"
train_dir = os.path.join(base_dir, "train")
val_dir = os.path.join(base_dir, "val")

# 清空旧文件夹
if os.path.exists(base_dir):
    shutil.rmtree(base_dir)

# 创建 train 和 val 目录结构
for subset in [train_dir, val_dir]:
    os.makedirs(subset, exist_ok=True)

# 开始划分
min_required_images = 30
skipped = []

for class_name in os.listdir(original_data_path):
    class_path = os.path.join(original_data_path, class_name)
    images = [f for f in os.listdir(class_path) if f.endswith(".png")]

    if len(images) < min_required_images:
        skipped.append((class_name, len(images)))
        continue

    os.makedirs(os.path.join(train_dir, class_name), exist_ok=True)
    os.makedirs(os.path.join(val_dir, class_name), exist_ok=True)

    train_imgs, val_imgs = train_test_split(images, test_size=0.2, random_state=42)
    for img in train_imgs:
        shutil.copy(os.path.join(class_path, img), os.path.join(train_dir, class_name, img))
    for img in val_imgs:
        shutil.copy(os.path.join(class_path, img), os.path.join(val_dir, class_name, img))

print("✅ 图像划分完成。")
if skipped:
    print("⚠️ 以下类别被跳过（图片不足 30 张）:")
    for cname, num in skipped:
        print(f" - {cname}: {num} 张")

✅ 图像划分完成。
⚠️ 以下类别被跳过（图片不足 30 张）:
 - Gilt Head Bream: 1 张
 - Sea Bass: 0 张
 - Red Sea Bream: 0 张


In [None]:

from torchvision import datasets, transforms
from torch.utils.data import DataLoader

train_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
])

val_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

train_data = datasets.ImageFolder(train_dir, transform=train_transform)
val_data = datasets.ImageFolder(val_dir, transform=val_transform)
train_loader = DataLoader(train_data, batch_size=16, shuffle=True)
val_loader = DataLoader(val_data, batch_size=16)

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import models

# 检查类别数量
num_classes = len(train_data.classes)
print(f"📦 识别类别数：{num_classes}")

# 加载 MobileNetV2（预训练模型）
model = models.mobilenet_v2(pretrained=True)

# 替换最后的分类头
model.classifier[1] = nn.Linear(model.last_channel, num_classes)

# 使用 CPU
device = torch.device("cpu")
model = model.to(device)

# 定义损失函数与优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# 训练循环
def train_model(model, train_loader, val_loader, epochs=10):
    for epoch in range(epochs):
        model.train()
        running_loss, correct, total = 0.0, 0, 0
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()

        train_acc = correct / total

        # 验证
        model.eval()
        val_correct, val_total = 0, 0
        with torch.no_grad():
            for images, labels in val_loader:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                _, predicted = outputs.max(1)
                val_total += labels.size(0)
                val_correct += predicted.eq(labels).sum().item()

        val_acc = val_correct / val_total

        print(f"Epoch {epoch+1}: Train Acc = {train_acc:.4f}, Val Acc = {val_acc:.4f}")

# 开始训练
train_model(model, train_loader, val_loader, epochs=10)

📦 识别类别数：6


Downloading: "https://download.pytorch.org/models/mobilenet_v2-b0353104.pth" to /root/.cache/torch/hub/checkpoints/mobilenet_v2-b0353104.pth
100%|██████████| 13.6M/13.6M [00:00<00:00, 87.6MB/s]


Epoch 1: Train Acc = 0.7902, Val Acc = 0.8929
Epoch 2: Train Acc = 0.8393, Val Acc = 0.6071
Epoch 3: Train Acc = 0.8973, Val Acc = 0.8393
Epoch 4: Train Acc = 0.9688, Val Acc = 0.9643
Epoch 5: Train Acc = 0.9821, Val Acc = 0.9643
Epoch 6: Train Acc = 0.9732, Val Acc = 1.0000
Epoch 7: Train Acc = 0.9821, Val Acc = 1.0000
Epoch 8: Train Acc = 0.9955, Val Acc = 1.0000
Epoch 9: Train Acc = 0.9821, Val Acc = 0.9643
Epoch 10: Train Acc = 0.9643, Val Acc = 0.9821


In [None]:
from PIL import Image
import torchvision.transforms as transforms

def predict_image(model, image_path, transform, class_names):
    model.eval()
    img = Image.open(image_path).convert('RGB')
    img_t = transform(img).unsqueeze(0).to(device)

    with torch.no_grad():
        output = model(img_t)
        _, pred = torch.max(output, 1)

    print(f"✅ 预测结果: {class_names[pred.item()]}")

In [None]:
predict_image(model, "/content/fish_dataset/NA_Fish_Dataset/Black Sea Sprat/00001.png", train_transform, train_data.classes)

✅ 预测结果: Black Sea Sprat


In [None]:
torch.save(model.state_dict(), "/content/mobilenet_fish_model.pt")
print("✅ 模型保存成功")

✅ 模型保存成功


In [None]:
# 创建模型架构并加载权重
model = models.mobilenet_v2(pretrained=False)
model.classifier[1] = nn.Linear(model.last_channel, len(train_data.classes))
model.load_state_dict(torch.load("/content/mobilenet_fish_model.pt", map_location=torch.device('cpu')))
model.eval()



MobileNetV2(
  (features): Sequential(
    (0): Conv2dNormActivation(
      (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU6(inplace=True)
    )
    (1): InvertedResidual(
      (conv): Sequential(
        (0): Conv2dNormActivation(
          (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
          (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU6(inplace=True)
        )
        (1): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (2): InvertedResidual(
      (conv): Sequential(
        (0): Conv2dNormActivation(
          (0): Conv2d(16, 96, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (1): BatchNorm2d(96, eps=

In [None]:
from google.colab import files

uploaded = files.upload()  # 手动选择你要预测的图像

Saving 00002.png to 00002.png


In [None]:
from PIL import Image
import torchvision.transforms as transforms
import torch

def predict_uploaded_image(model, uploaded_file, transform, class_names):
    model.eval()
    img = Image.open(uploaded_file).convert('RGB')
    img_t = transform(img).unsqueeze(0).to(device)

    with torch.no_grad():
        output = model(img_t)
        _, pred = torch.max(output, 1)

    print(f"✅ 上传图像预测结果: {class_names[pred.item()]}")

In [None]:
predict_uploaded_image(model, "00002.png", train_transform, train_data.classes)

✅ 上传图像预测结果: Shrimp
