# 示例：使用框架训练图像分类模型

利用框架提供的模块，我们可以十分方便地训练深度学习模型。以 Kaggle 图像数据集 <https://www.kaggle.com/datasets/asaniczka/mammals-image-classification-dataset-45-animals/data> 为例。

In [None]:
import torch 
import sys
from pathlib import Path

sys.path.append(str(Path.cwd().parent))

## 1. 数据集

将数据下载并解压到 `data` 目录下：解压后的结构为

```
data/mammals/
        |-  african_elephant/
        |   |-  african_elephant-0001.jpg
        |   |-  *
        |
        |-  alpaca/
        |   |-  alpaca-0001.jpg
        |   |-  *
        |
        *
```

由于数据集中并未划分训练集和测试集，我们需要通过下面的单元格对数据集进行划分。或者你可以采用仓库作者生成的划分方式：`data/mammals_split/`

In [None]:
# 需要重新划分数据集，请将本单元格的最后一行取消注释并运行

from pathlib import Path 
import random 

# 获取所有的类别
def split_dataset(root: str = "data", split_prob: float = 0.8, seed: int = 0) -> None:
    r""" split the dataset
    
    root (pathlike): diretory where 'mammals/' lays.
    split_prob (float): the ratio of train-set against test-set.
    seed (int): random number generator seed.
    """
    root_dir = Path(root)
    labels = []
    train = []
    test = []
    label_idx = 0
    
    # 固定随机种子
    random.seed(seed)
    
    # 获取类名称并将图片分入训练或测试集
    for sub_dir in (root_dir / "mammals").glob("*"):
        labels.append(sub_dir.stem)
        for file_path in sub_dir.glob("*"):
            split = train if random.random() < split_prob else test
            split.append((file_path.name, label_idx))
        label_idx += 1

    split_dir = root_dir / "mammals_split"
    split_dir.mkdir(parents=True, exist_ok=True)

    # 写入 labels.txt
    with open(split_dir / "labels.txt", "w") as f:
        for i, label in enumerate(labels):
            f.write(label + "\n")

    # 写入 train.txt
    with open(split_dir / "train.txt", "w") as f:
        for file_name, label in train:
            f.write(file_name + f",{label}\n")

    # 写入 test.txt
    with open(split_dir / "test.txt", "w") as f:
        for file_name, label in test:
            f.write(file_name + f",{label}\n")

# 重新划分数据集时，请指定随机种子
# split_dataset(seed=0)

搭建数据集类：`src.modules.SizedDatast`，访问该数据集对象得到的数据类型为
```
{
    "image": torch.Tensor,
    "label": int
}
```

In [None]:
from PIL import Image
from src.modules import SizedDataset
from pathlib import Path
from typing import Callable, cast
from typing_extensions import TypedDict
from torchvision import transforms


class Data(TypedDict):
    image: torch.Tensor 
    label: int


class MammalsDataset(SizedDataset):
    r""" Dataset of mammals images
    
    root (pathlike): directory where 'mammals/' and 'mammals_split/' locate.
    train (bool): whether load the train-dataset or the test-dataset.
    """
    def __init__(
        self, 
        root: str, 
        train: bool = True, 
        transforms: Callable =  transforms.Compose([
            transforms.CenterCrop(256),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])
    ) -> None:
        super().__init__()
        self.root: Path = Path(root).absolute()
        self.train: bool = train
        self.labels: list[str] = self.__read_labels()
        self.data: list[dict] = self.__read_data()
        self.transforms = transforms
    
    def __len__(self) -> int:
        return len(self.data)
    
    def __getitem__(self, index: int) -> Data:
        path = self.__image_dir / self.data[index]["path"]
        label = self.data[index]["label"]
        image = self.transforms(Image.open(path))
        return {
            "image": cast(torch.Tensor, image),
            "label": label
        }
    
    def get_image_label(self, index: int) -> dict:
        path = self.__image_dir / self.data[index]["path"]
        label = self.data[index]["label"]
        image = Image.open(path)
        return {
            "image": image, "label": self.labels[label]
        }
        
        
    
    @property
    def __image_dir(self): return self.root / "mammals"
    
    @property
    def __split_dir(self): return self.root / "mammals_split"
    
    @property
    def __label_file(self): return self.__split_dir / "labels.txt"
    
    @property
    def __split_file(self): 
        name = "train.txt" if self.train else "test.txt"
        return self.__split_dir / name 
    
    def __read_labels(self) -> list[str]:
        labels = []
        with open(self.__label_file, "r") as f:
            for line in f.readlines():
                labels.append(line.strip())
        return labels
    
    def __read_data(self) -> list[dict]:
        data = []
        with open(self.__split_file, "r") as f:
            for line in f.readlines():
                [name, label_index] = line.strip().split(",")
                label_index = int(label_index)
                dic: dict = {
                    "path": self.labels[label_index] + "/" + name,
                    "label": label_index
                }
                data.append(dic)
        return data
        

展示数据

In [None]:
import random

image_dataset = MammalsDataset("../data/", transforms=lambda x: x )
index = random.randint(0, len(image_dataset)-1)

print(image_dataset.labels[image_dataset[index]["label"]])
image_dataset[index]["image"]

## 2. 神经网络模型
搭建 `src.modules.NeuralNetwork` 类。具体的，使用一个 `torchvision.models.resnet18` 实例并更换其分类头为 45 类，该分类头返回的是对应数据的 logits 值

In [None]:
import torch
from torch.nn import Linear
from torchvision.models import resnet18, resnet34
from src.modules import NeuralNetwork 
from typing import Literal

class ResNet(NeuralNetwork):
    
    def __init__(
        self, 
        resnet: Literal["resnet18", "resnet34"] = "resnet18", 
        use_pretrained: bool = False,
        freeze_backbone: bool = False
    ) -> None:
        super().__init__()
        self.use_pretrained = use_pretrained
        if resnet == "resnet18":
            self.resnet = resnet18(pretrained=use_pretrained)
        elif resnet == "resnet34":
            self.resnet = resnet34(pretrained=use_pretrained)
        
        if freeze_backbone:
            for params in self.backbone_parameters():
                params.requires_grad = False 
        self.resnet.fc = Linear(self.resnet.fc.in_features, 45, bias=True)
        self.freeze_backbone = freeze_backbone
    
    def forward(self, input_: torch.Tensor) ->torch.Tensor:
        return self.resnet(input_)
    
    def backbone_parameters(self):
        for name, param in self.named_parameters():
            if "fc" not in name:
                yield param
    
    def fc_parameters(self):
        return self.resnet.fc.parameters()
    
    def init_weights(self):
        if self.use_pretrained:
            self.resnet.fc.reset_parameters()
            return self
        else:
            super().init_weights()

## 3. 训练算法模型

编写 `src.modules.TrainModel` 类来管理损失函数。分类模型使用交叉熵函数来优化。

In [None]:
import torch 
from src.modules import NeuralNetwork, TrainModel


class TrainClassifyMammals(TrainModel):
    loss_fn = torch.nn.CrossEntropyLoss()
    _loss_weights = {"cross-entropy": 1.}
    
    def compute_loss(self, network: NeuralNetwork, batch: Data) -> dict:
        images = batch["image"]
        labels = batch["label"]
        
        logits = network(images)
        
        return {
            "cross-entropy": self.loss_fn(logits, labels)
        }

## 4. 训练器与插件

超参数：
- 训练轮数：20
- 批次大小：512
- 梯度累计：2
- 初始学习率：0.01
- 初始随机种子：0
- 网络结构：resnet18，预训练权重，骨干权重学习率为分类头权重的 1/10

拓展插件：
- `InitializeNetworkPlugin`：随机初始化网络权重
- `SaveCheckpointPlugin`：保存检查点
- `EvaluatePlugin`：epoch末尾评估模型，需要定义 `EvaluateModel`
- `AdjustLearningRatePlugin`：分别在第 6、16 个 epoch 处调整 lr 为 1/10
- `LossLoggerPlugin`，`MetricLoggerPlugin`，`LearningRateLoggerPlugin`
- `ProgressBarPlugin`

In [None]:
import torch
from torch.utils.data import DataLoader
from src.modules import EvaluateModel, NeuralNetwork
from src.utils import move_batch

# 评估模型类
class EvaluateClassifyMammalsTest(EvaluateModel):
    def __init__(self, root, batch_size, device="cpu") -> None:
        self._metrics = ["accuracy"]
        self.dataset = MammalsDataset(root, train=False)
        self.batch_size = batch_size
        self.device = torch.device(device)
    
    @torch.no_grad()
    def predict_batch_labels(self, network, batch_images):
        network.eval()
        logits = network(batch_images)
        pred = torch.argmax(logits, dim=1)
        return pred
    
    def evaluate(self, network: NeuralNetwork) -> dict[str, float]:
        network.to(self.device)
        correctness = 0
        for batch in DataLoader(self.dataset, batch_size=self.batch_size, num_workers=4):
            batch = move_batch(batch, self.device)
            images = batch["image"]
            labels = batch["label"]
            
            pred = self.predict_batch_labels(network, images)
            correctness += (pred == labels).sum().item()
        return {"accuracy": correctness / len(self.dataset)}

# 学习率调整函数
def lr_adjust_fn(lr, epoch, index):
    if epoch == 5 + 1 or epoch == 15 + 1:
        return lr / 10
    else:
        return lr

In [None]:
# exp-1
from pathlib import Path
from src.modules import Trainer 
from torch.optim import Adam
from src.plugins import (
    Plugin,
    LoadCheckpointPlugin,
    InitializeNetworkPlugin,
    SaveCheckpointPlugin,
    EvaluatePlugin,
    AdjustLearningRatePlugin,
    LossLoggerPlugin, MetricLoggerPlugin, LearningRateLoggerPlguin,
    ProgressBarPlugin
)

EXP_INDEX = 1

# Hyper Parameters
DATA_DIR = "../data/"
NUM_EPOCH = 20
BATCH_SIZE = 512
LEARNING_RATE = 0.01

SEED = 0
DEVICE = "cuda"

EVAL_PERIOD = 2

OUTPUT_DIR = Path("../OUTPUTs/mammals").absolute()
LOG_DIR = OUTPUT_DIR / "log" / f"exp-{EXP_INDEX}"
LOG_PERIOD = 5
CHECKPOINT_DIR = OUTPUT_DIR / "checkpoint" / f"exp-{EXP_INDEX}"
CHECKPOINT_PERIOD = EVAL_PERIOD

# modules

# 训练数据集
dataset = MammalsDataset(root=DATA_DIR)

# 神经网络
network = ResNet("resnet18", use_pretrained=True, freeze_backbone=False)

# 优化器，backbone 的学习率为 fc 学习率的 1/10
optimizer = Adam([
    {"params": network.backbone_parameters(), "lr": LEARNING_RATE / 10},
    {"params": network.fc_parameters(), "lr": LEARNING_RATE}
])

# 训练和评估模型
train_model = TrainClassifyMammals()
eval_model = EvaluateClassifyMammalsTest(root=DATA_DIR, batch_size=BATCH_SIZE, device=DEVICE)

# 训练器
trainer = (
    Trainer(
        train_model,
        num_epochs=NUM_EPOCH,
        batch_size=BATCH_SIZE,
        gradient_accumulation_step=2,
        init_seed=SEED,
        device="cuda"
    )
    # .add_plugin(LoadCheckpointPlugin("../OUTPUTs/mammals/checkpoint/exp-8/epoch-6/"))
    .add_plugin(InitializeNetworkPlugin())
    .add_plugin(AdjustLearningRatePlugin(lr_adjust_fn))
    .add_plugin(EvaluatePlugin(eval_model, eval_period=EVAL_PERIOD))
    .add_plugin(SaveCheckpointPlugin(saving_dir=CHECKPOINT_DIR, saving_period=CHECKPOINT_PERIOD))
    .add_plugin(LearningRateLoggerPlguin(log_dir=LOG_DIR, log_period=LOG_PERIOD))
    .add_plugin(LossLoggerPlugin(log_dir=LOG_DIR, log_period=LOG_PERIOD))
    .add_plugin(MetricLoggerPlugin(log_dir=LOG_DIR, log_period=EVAL_PERIOD))
    .add_plugin(ProgressBarPlugin())
)

## 5. 训练模型

In [None]:
trainer.loop(dataset, network, optimizer)

## 6. 查看训练结果（TensorBoard）

在 CLI 启动 tensorboard
```bash
cd ~/pytorch-training-framework
tensorboard --logdir OUTPUTs/mammals/log
```

## 7. 使用模型进行推理

随机读取一张图片并使用训练后的模型进行判断其物种。

In [None]:
# 加载模型

NETWORK_WEIGHTS_PATH = "../OUTPUTs/mammals/checkpoint/exp-1/epoch-18/network_state_dict.pth"

network.load_state_dict(torch.load(NETWORK_WEIGHTS_PATH))
eval_dataset = MammalsDataset(root="../data/", train=False)

In [None]:

# 随机选择 eval_dataset 中的图片做推断
index = random.randint(0, len(eval_dataset)-1)

data = eval_dataset.get_image_label(index)
image = data["image"]
img_tensor = eval_dataset[index]["image"].unsqueeze_(0)
label = data["label"]


print("true label:", label)
pred = eval_dataset.labels[eval_model.predict_batch_labels(network, img_tensor).tolist()[0]]
print("pred label:", pred)
image