In [9]:
import json
import os
from PIL import Image
import torch
from torch.utils.data import DataLoader
from datasets import Dataset
from transformers import AutoFeatureExtractor, AutoModelForObjectDetection, TrainingArguments, Trainer


In [10]:
# 加载COCO数据集
def load_coco_data(data_path, images_path):
    with open(data_path) as f:
        coco_data = json.load(f)

    image_data = [
        {
            "image_id": image["id"],
            "image": os.path.join(images_path, image["file_name"]),
            "height": image["height"],
            "width": image["width"],
        }
        for image in coco_data["images"]
    ]

    annotations_data = [
        {
            "image_id": ann["image_id"],
            "bbox": ann["bbox"],
            "category_id": ann["category_id"],
        }
        for ann in coco_data["annotations"]
    ]

    dataset = Dataset.from_dict({"image_id": [d["image_id"] for d in image_data]})
    dataset = dataset.add_column("image", [d["image"] for d in image_data])
    dataset = dataset.add_column("height", [d["height"] for d in image_data])
    dataset = dataset.add_column("width", [d["width"] for d in image_data])
    dataset = dataset.add_column("annotations", [[] for _ in image_data])

    for ann in annotations_data:
        idx = dataset["image_id"].index(ann["image_id"])
        dataset[idx]["annotations"].append(ann)

    return dataset

In [11]:
data_dir = r"./data/test-strip-head.v3i.coco"  # 更改为你的数据集路径

train_data_path = f"{data_dir}/train/_annotations.coco.json"
train_images_path = f"{data_dir}/train"
val_data_path = f"{data_dir}/valid/_annotations.coco.json"
val_images_path = f"{data_dir}/valid"
test_data_path = f"{data_dir}/test/_annotations.coco.json"
test_images_path = f"{data_dir}/test"

train_dataset = load_coco_data(train_data_path, train_images_path)
val_dataset = load_coco_data(val_data_path, val_images_path)
test_dataset = load_coco_data(test_data_path, test_images_path)

# 加载模型和特征提取器
extractor = AutoFeatureExtractor.from_pretrained("facebook/detr-resnet-101")
model = AutoModelForObjectDetection.from_pretrained("facebook/detr-resnet-101")



In [12]:
# 自定义数据整理器
class CustomDataCollator:
    def __call__(self, batch):
        pixel_values = [item["pixel_values"] for item in batch]
        pixel_values = torch.stack(pixel_values, dim=0)

        labels = [item["target"]["labels"] for item in batch]
        boxes = [item["target"]["boxes"] for item in batch]

        targets = []
        for i in range(len(labels)):
            target = {}
            target["labels"] = labels[i]
            target["boxes"] = boxes[i]
            targets.append(target)

        return {"pixel_values": pixel_values, "target": targets}

data_collator = CustomDataCollator()


In [22]:
# 处理样本
def process_sample(example, extractor):
    image = Image.open(example["image"]).convert("RGB")
    inputs = extractor(images=image, return_tensors="pt")
    example["pixel_values"] = inputs["pixel_values"][0]  # 更改这一行，去掉多余的维度
    
    target = {}
    target["labels"] = torch.tensor([ann["category_id"] for ann in example["annotations"]], dtype=torch.int64)
    target["boxes"] = torch.tensor([ann["bbox"] for ann in example["annotations"]], dtype=torch.float32)
    example["target"] = target
    
    return example  # 返回单个样本，而不是批处理样本


In [23]:
# 使用修复后的process_sample函数处理数据集
train_dataset = train_dataset.map(lambda x: process_sample(x, extractor))
val_dataset = val_dataset.map(lambda x: process_sample(x, extractor))

# 移除 'image' 和 'annotations' 列
train_dataset = train_dataset.remove_columns(["image", "annotations"])
val_dataset = val_dataset.remove_columns(["image", "annotations"])

# 创建训练参数
training_args = TrainingArguments(
    output_dir="./output",
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    num_train_epochs=5,
    learning_rate=5e-5,
    weight_decay=0.01,
    logging_dir="./logs",
    logging_steps=10,
    eval_steps=100,
    remove_unused_columns=False,
    warmup_steps=warmup_steps,
)

# 创建Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    data_collator=data_collator,
)

# 开始微调模型
trainer.train()

Map:   0%|          | 0/134 [00:00<?, ? examples/s]

KeyError: 'image'