#### 数据预处理

In [8]:
# uv 环境下可运行如下命令下载模型：
# uv run modelscope download --model tiansz/dinov2-base --local_dir models/dinov2-base

import evaluate
import numpy as np
from datasets import load_dataset
from transformers import AutoImageProcessor
from transformers import DefaultDataCollator
from torchvision.transforms import RandomResizedCrop, Compose, Normalize, ToTensor
from transformers import AutoModelForImageClassification, TrainingArguments, Trainer

In [9]:
EPOCHS = 5  # 模型学习训练集的轮次
BATCH_SIZE = 4  # 模型并行学习的样本数量
PRETRAINED_MODEL_NAME_OR_PATH = "../models/dinov2-base"  # 预训练模型路径
IMADE_PATH = "../datasets/图像分类数据集"  # 图像训练集路径
OUPUT_CHECKPOINT_PATH = "../models/image_classification_checkpoint"  # 训练过程中产生的模型文件
OUTPUT_MODEL_PATH = "../models/image_classification_model"  # 微调后的文本分类模型路径
VAL_SIZE = 0.2  # 验证集的占比
ACCURACY_PATH = "../common/accuracy.py"  # 评估脚本本地路径

In [10]:
img_dataset = load_dataset("imagefolder", data_dir=IMADE_PATH)
img_dataset = img_dataset["train"].train_test_split(test_size=VAL_SIZE)
print(img_dataset)

Resolving data files:   0%|          | 0/20 [00:00<?, ?it/s]

DatasetDict({
    train: Dataset({
        features: ['image', 'label'],
        num_rows: 16
    })
    test: Dataset({
        features: ['image', 'label'],
        num_rows: 4
    })
})


In [11]:
labels = img_dataset["train"].features["label"].names
label2id, id2label = dict(), dict()
for i, label in enumerate(labels):
    label2id[label] = str(i)
    id2label[str(i)] = label

print("标签映射关系：", label2id)

标签映射关系： {'cat': '0', 'dog': '1'}


In [12]:
image_processor = AutoImageProcessor.from_pretrained(PRETRAINED_MODEL_NAME_OR_PATH)
normalize = Normalize(mean=image_processor.image_mean, std=image_processor.image_std)
size = (
    image_processor.size["shortest_edge"]
    if "shortest_edge" in image_processor.size
    else (image_processor.size["height"], image_processor.size["width"])
)
_transforms = Compose([RandomResizedCrop(size), ToTensor(), normalize])


def transforms(examples):
    examples["pixel_values"] = [
        _transforms(img.convert("RGB")) for img in examples["image"]
    ]
    del examples["image"]
    return examples


img_dataset = img_dataset.with_transform(transforms)

#### 模型微调

In [13]:
accuracy = evaluate.load(ACCURACY_PATH)


def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=1)
    return accuracy.compute(predictions=predictions, references=labels)

In [14]:
model = AutoModelForImageClassification.from_pretrained(
    PRETRAINED_MODEL_NAME_OR_PATH,
    num_labels=len(labels),
    id2label=id2label,
    label2id=label2id,
)

training_args = TrainingArguments(
    output_dir=OUPUT_CHECKPOINT_PATH,
    remove_unused_columns=False,
    eval_strategy="epoch",
    save_strategy="epoch",
    learning_rate=5e-5,
    per_device_train_batch_size=BATCH_SIZE,
    gradient_accumulation_steps=4,
    per_device_eval_batch_size=BATCH_SIZE,
    num_train_epochs=EPOCHS,
    warmup_ratio=0.1,
    load_best_model_at_end=True,
    save_total_limit=1,
    fp16=True,
)

data_collator = DefaultDataCollator()

trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=data_collator,
    train_dataset=img_dataset["train"],
    eval_dataset=img_dataset["test"],
    processing_class=image_processor,
    compute_metrics=compute_metrics,
)

trainer.train()
model.half()
model.save_pretrained(OUTPUT_MODEL_PATH)
image_processor.save_pretrained(OUTPUT_MODEL_PATH)

Some weights of Dinov2ForImageClassification were not initialized from the model checkpoint at ../models/dinov2-base and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Epoch,Training Loss,Validation Loss,Accuracy
1,No log,0.413891,0.75
2,No log,0.007806,1.0
3,No log,0.073578,1.0
4,No log,0.70621,0.75
5,No log,0.119363,1.0


['../models/image_classification_model\\preprocessor_config.json']

#### 模型推理

In [15]:
import torch
from transformers import pipeline

clf = pipeline(
    "image-classification",
    model=OUTPUT_MODEL_PATH,
    torch_dtype=torch.float16,
)

print(clf("../tests/猫(测试用).jpg"))

Device set to use cpu


[{'label': 'cat', 'score': 0.9027382731437683}, {'label': 'dog', 'score': 0.0972617045044899}]
