In [1]:
import torch
from ultralytics import YOLO
import os
import csv
import yaml
from torchvision.ops import box_iou

# 训练配置
class Config:
    data_dir = "yolo_dataset"
    model_type = "yolov8n.yaml"
    epochs = 10
    img_size = 128
    batch_size = 16
    device = "cpu"
    results_file = "training_results.csv"

def calculate_metrics(pred_boxes, true_boxes, iou_threshold=0.5):
    """自定义指标计算函数"""
    if len(pred_boxes) == 0:
        return 0, 0, len(true_boxes)
    
    if len(true_boxes) == 0:
        return 0, len(pred_boxes), 0

    iou_matrix = box_iou(pred_boxes, true_boxes)
    matched_true = set()
    matched_pred = set()

    for true_idx in range(len(true_boxes)):
        best_iou = iou_threshold
        best_pred = -1
        for pred_idx in range(len(pred_boxes)):
            iou = iou_matrix[pred_idx, true_idx]
            if iou > best_iou:
                best_iou = iou
                best_pred = pred_idx
        if best_pred != -1 and best_pred not in matched_pred:
            matched_true.add(true_idx)
            matched_pred.add(best_pred)

    tp = len(matched_true)
    fp = len(pred_boxes) - len(matched_pred)
    fn = len(true_boxes) - len(matched_true)
    return tp, fp, fn

class YOLOTrainer:
    def __init__(self, cfg):
        self.cfg = cfg
        self.model = YOLO(cfg.model_type)
        self.results = []

    def prepare_dataset_config(self):
        """生成数据集配置文件"""
        config = {
            'path': os.path.abspath(self.cfg.data_dir),
            'train': 'train/images',
            'val': 'val/images',
            'names': {0: 'tamper'}
        }
        with open(os.path.join(self.cfg.data_dir, 'dataset.yaml'), 'w') as f:
            yaml.dump(config, f)

    def evaluate_custom(self):
        """自定义评估函数"""
        model = YOLO(os.path.join('runs', 'detect', 'doc_tamper_yolo', 'weights', 'best.pt'))
        val_dir = os.path.join(self.cfg.data_dir, 'val')
        
        total_tp = total_fp = total_fn = 0
        
        # 遍历验证集
        for img_file in os.listdir(os.path.join(val_dir, 'images')):
            # 加载图像并进行预测
            img_path = os.path.join(val_dir, 'images', img_file)
            results = model.predict(img_path, imgsz=self.cfg.img_size)

            # 加载真实标注
            label_path = os.path.join(val_dir, 'labels', img_file.replace('.jpg', '.txt'))
            true_boxes = self.load_true_boxes(label_path, results[0].orig_shape)

            # 转换预测结果
            pred_boxes = results[0].boxes.xyxy.cpu()

            # 计算指标
            tp, fp, fn = calculate_metrics(pred_boxes, true_boxes)
            total_tp += tp
            total_fp += fp
            total_fn += fn

        # 计算最终指标
        precision = total_tp / (total_tp + total_fp + 1e-10)
        recall = total_tp / (total_tp + total_fn + 1e-10)
        f1 = 2 * (precision * recall) / (precision + recall + 1e-10)
        
        return {
            "Micro_Prec": float(precision),
            "Micro_Recall": float(recall),
            "Micro_F1": float(f1)
        }

    def load_true_boxes(self, label_path, img_shape):
        """加载真实边界框"""
        boxes = []
        if os.path.exists(label_path):
            with open(label_path) as f:
                for line in f:
                    _, x_center, y_center, width, height = map(float, line.strip().split())
                    
                    # 转换为绝对坐标
                    img_w, img_h = img_shape[1], img_shape[0]
                    xmin = (x_center - width/2) * img_w
                    ymin = (y_center - height/2) * img_h
                    xmax = (x_center + width/2) * img_w
                    ymax = (y_center + height/2) * img_h
                    
                    boxes.append([xmin, ymin, xmax, ymax])
        
        return torch.tensor(boxes) if boxes else torch.zeros((0, 4))

    def save_results(self):
        """保存结果到CSV"""
        with open(self.cfg.results_file, 'w', newline='') as f:
            writer = csv.writer(f)
            writer.writerow(['Epoch', 'Precision', 'Recall', 'F1'])
            for row in self.results:
                writer.writerow(row)

    def train(self):
        """训练主函数"""
        self.prepare_dataset_config()

        # 训练模型
        self.model.train(
            data=os.path.join(self.cfg.data_dir, 'dataset.yaml'),
            epochs=self.cfg.epochs,
            imgsz=self.cfg.img_size,
            batch=self.cfg.batch_size,
            device=self.cfg.device,
            name='doc_tamper_yolo'
        )

        # 自定义评估
        metrics = self.evaluate_custom()
        print(f"Validation Metrics: Precision={metrics['Micro_Prec']:.4f}, "
              f"Recall={metrics['Micro_Recall']:.4f}, F1={metrics['Micro_F1']:.4f}")

        # 保存结果
        self.results.append((
            self.cfg.epochs,
            metrics['Micro_Prec'],
            metrics['Micro_Recall'],
            metrics['Micro_F1']
        ))
        self.save_results()

if __name__ == "__main__":
    print('1')
    trainer = YOLOTrainer(Config())
    trainer.train()

SyntaxError: invalid syntax (2200601458.py, line 2)