In [53]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import os
import torchvision
from torch.utils.data import DataLoader, Dataset, random_split
from PIL import Image
from tqdm import tqdm

# 自定义数据集

In [105]:
transform = torchvision.transforms.Compose({
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Resize([224, 224])
})


class CustomDataset(Dataset):
    def __init__(self, data_dir, transform=None):
        self.data_dir = data_dir
        self.transform = transform
        self.classes = os.listdir(self.data_dir)
        self.img_paths, self.labels = [], []

        for i, cls in enumerate(self.classes):
            cls_dir = os.path.join(self.data_dir, cls)
            img_list = os.listdir(cls_dir)
            self.img_paths.extend([os.path.join(cls_dir, img) for img in img_list])
            self.labels.extend([i] * len(img_list))
    
    def __len__(self):
        return len(self.labels)
    
    def __getitem__(self, index):
        img_path = self.img_paths[index]
        img = Image.open(img_path).convert('RGB')
        label = self.labels[index]

        if self.transform:
            img = self.transform(img)

        # # 进行one-hot编码
        # one_hot_label = F.one_hot(torch.tensor([label]), num_classes=10)
        # one_hot_label = one_hot_label.to(torch.float32)
        return img, label

    def labelsname(self, i):
        """返回 0 ~ 9 号所对应的标签名"""
        return self.classes[i]


DATA_PATH = '../../DATASETS/animal10_classification/raw-img/'
dataset = CustomDataset(DATA_PATH, transform)
train_data, valid_data = random_split(dataset, [0.8, 0.2])  # generator类用于管理生成器随机种子
# train: 20944, valid: 5235

# 模型

In [93]:
class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, in_channels, out_channels, downsample=False):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=(1, 1), stride=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=(3, 3), stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.conv3 = nn.Conv2d(out_channels, out_channels * self.expansion, kernel_size=(1, 1), stride=1, bias=False)
        self.bn3 = nn.BatchNorm2d(out_channels * self.expansion)
        self.relu = nn.ReLU(inplace=True)
        if downsample:
            if out_channels == 64:
                stride = 1
                self.downsample = nn.Sequential(
                    nn.Conv2d(in_channels, out_channels * self.expansion, kernel_size=(1, 1), stride=stride, bias=False),
                    nn.BatchNorm2d(out_channels * self.expansion)
                    )
            else:
                self.downsample = nn.Sequential(
                    nn.Conv2d(in_channels, out_channels * self.expansion, kernel_size=(1, 1), stride=1, bias=False),
                    nn.BatchNorm2d(out_channels * self.expansion)
                    )
        else:
            self.downsample = None

    def forward(self, x):
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            identity = self.downsample(x)
        
        out += identity  # 残差连接
        out = self.relu(out)

        return out


class resnet50(nn.Module):
    def __init__(self, in_channels, backbone=False, num_classes=0):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, 64, kernel_size=(7, 7), stride=2, padding=(3, 3), bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)

        self.layer1 = nn.Sequential(
            Bottleneck(64, 64, True),
            Bottleneck(256, 64, False),
            Bottleneck(256, 64, False)
        )
        self.layer2 = nn.Sequential(
            Bottleneck(256, 128, True),
            Bottleneck(512, 128, False),
            Bottleneck(512, 128, False),
            Bottleneck(512, 128, False)
        )
        self.layer3 = nn.Sequential(
            Bottleneck(512, 256, True),
            Bottleneck(1024, 256, False),
            Bottleneck(1024, 256, False),
            Bottleneck(1024, 256, False),
            Bottleneck(1024, 256, False),
            Bottleneck(1024, 256, False)
        )
        self.layer4 = nn.Sequential(
            Bottleneck(1024, 512, True),
            Bottleneck(2048, 512, False),
            Bottleneck(2048, 512, False)
        )

        if not backbone:
            self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
            # [N, C, H, W] -> [N, C, 1, 1] 即一个批次中，对每份样本，在不同的通道层面上进行二维平均池化
            self.fc = nn.Linear(in_features=2048, out_features=num_classes, bias=True)
            self.softmax = nn.Softmax(dim=3)
        else:
            self.fc = None
        
    def forward(self, x):
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.maxpool(out)

        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)

        out = self.avgpool(out)
        if self.fc is not None:
            out = out.permute(0, 3, 2, 1)
            out = self.fc(out)
            out = self.softmax(out)
        
        return out

# 相关设置

In [106]:
epochs = 100
lr = 3e-3
weight_decay = 3e-3
bs = 2
loss_list = []  # 存储每个 epoch 的损失
model_path = '../../DATASETS/animal10_classification/weights_resnet50/'  # 模型存储路径

model = resnet50(3, num_classes=10, backbone=False)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

train_loader = DataLoader(train_data, batch_size=bs, shuffle=True)
valid_loader = DataLoader(valid_data, batch_size=bs)
optimizier = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
loss_func = nn.CrossEntropyLoss()

# 训练

In [107]:
for epoch in range(1, epochs + 1):
    pbar = tqdm(train_loader, desc=f'Epoch {epoch}/{epochs}')
    running_loss = 0.0
    info = {
        'loss': 0,
        'lr': 0,
        'bs': bs
    }  # 训练过程中打印的信息
    for img, label in pbar:
        pbar.set_postfix(info)
        img, label = img.to(device), label.to(device)
        pred = model(img)
        pred = torch.squeeze(pred)
        loss = loss_func(pred, label)  # pred: 2-D[N, C]  target: 1-D[N]

        info['loss'] = loss.item()
        info['lr'] = optimizier.param_groups[0]['lr']

        # 反向传播更新参数
        optimizier.zero_grad()
        loss.backward()
        optimizier.step()
        
        # 记录当前批次的损失
        running_loss += loss.item()

    epoch_loss = running_loss / len(train_data)
    loss_list.append(epoch_loss)
    # 存储模型
    if epoch_loss < min(loss_list):
        torch.save(model.state_dict(), model_path + 'best.pth')
    torch.save(model.state_dict(), model_path + 'last.pth')

Epoch 1/100:   0%|          | 0/10472 [00:00<?, ?it/s, loss=0, lr=0, bs=2]

tensor([[0.1164, 0.1120, 0.0592, 0.1870, 0.1712, 0.0782, 0.0832, 0.0395, 0.0854,
         0.0678],
        [0.1047, 0.1255, 0.0493, 0.1889, 0.2097, 0.0908, 0.0835, 0.0275, 0.0742,
         0.0459]], grad_fn=<SqueezeBackward0>)
tensor([8, 5])


Epoch 1/100:   0%|          | 1/10472 [00:04<13:49:28,  4.75s/it, loss=2.32, lr=0.003, bs=2]

tensor([[7.6341e-04, 9.5941e-04, 6.4535e-04, 1.0621e-03, 1.0151e-03, 4.7797e-01,
         7.3332e-04, 5.2015e-04, 5.1565e-01, 6.8080e-04],
        [6.1238e-04, 7.2805e-04, 5.3433e-04, 7.3080e-04, 8.6744e-04, 4.0951e-01,
         6.1365e-04, 4.2443e-04, 5.8547e-01, 5.0893e-04]],
       grad_fn=<SqueezeBackward0>)
tensor([8, 7])


Epoch 1/100:   0%|          | 1/10472 [00:09<27:07:03,  9.32s/it, loss=2.32, lr=0.003, bs=2]


KeyboardInterrupt: 