In [1]:
import pandas as pd
from torchvision.ops import box_iou
import os
import json
import torch
import torch.nn as nn
import torchvision
from torch.utils.data import Dataset, DataLoader, random_split
from PIL import Image
import numpy as np
from tqdm import tqdm
import cv2
from mvssnet import get_mvss  # 假设mvssnet.py包含提供的MVSS代码


# 新增：掩码转边界框函数
def mask_to_boxes(mask, threshold=0.5, min_area=10):
    """
    将分割mask转换为边界框列表
    mask: (H, W)的numpy数组，值在0-1之间
    返回: list of [xmin, ymin, xmax, ymax]
    """
    binary_mask = (mask > threshold).astype(np.uint8) * 255
    contours, _ = cv2.findContours(binary_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)

    boxes = []
    for contour in contours:
        area = cv2.contourArea(contour)
        if area < min_area:
            continue
        x, y, w, h = cv2.boundingRect(contour)
        boxes.append([x, y, x + w, y + h])

    return boxes


class DocumentTamperDataset(Dataset):
    def __init__(self, image_dir, annotation_path, DEBUG_SUBSET_SIZE, target_size=(512, 512),transforms=None, debug_mode=False):
        self.image_dir = image_dir
        self.transforms = transforms
        with open(annotation_path) as f:
            self.annotations = json.load(f)

        if debug_mode:
            self.annotations = self.annotations[:DEBUG_SUBSET_SIZE]

        self.id_to_anns = {item['id']: item['region'] for item in self.annotations}
        self.ids = list(self.id_to_anns.keys())
        self.target_size = target_size  # 新增目标尺寸参数
    def __getitem__(self, idx):
        img_id = self.ids[idx]
        img_path = os.path.join(self.image_dir, img_id)

        # 加载原始图像并获取尺寸
        img = Image.open(img_path).convert("RGB")
        original_width, original_height = img.size

        # 转换到目标尺寸
        transform = get_transform(targer_size=self.target_size,train=self.transforms is not None)
        img = transform(img)  # 应用包含Resize的transform

        # 计算缩放比例
        scale_w = self.target_size[0] / original_width
        scale_h = self.target_size[1] / original_height

        # 生成调整后的掩码
        mask = np.zeros(self.target_size[::-1], dtype=np.float32)  # (H,W)
        regions = self.id_to_anns[img_id]

        for region in regions:
            # 调整坐标到目标尺寸
            xmin = int(float(region[0]) * scale_w)
            ymin = int(float(region[1]) * scale_h)
            xmax = int(float(region[2]) * scale_w)
            ymax = int(float(region[3]) * scale_h)

            # 确保坐标有效性
            xmin = max(0, xmin)
            ymin = max(0, ymin)
            xmax = min(self.target_size[0], xmax)
            ymax = min(self.target_size[1], ymax)

            if xmax > xmin and ymax > ymin:
                mask[ymin:ymax, xmin:xmax] = 1.0

        mask = torch.from_numpy(mask)

        # 生成调整后的边界框（可选）
        boxes = []

        # 应用数据增强
        return img, {
            "masks": mask,
            "boxes": torch.zeros((0, 4), dtype=torch.float32),  # 示例保留结构
            "image_id": torch.tensor([idx])
        }

    def __len__(self):
        return len(self.ids)


# 修改后的数据转换函数（包含归一化）
def get_transform(targer_size=(512, 512),train=True):
    transforms = []
    transforms.append(torchvision.transforms.Resize(targer_size))
    transforms.append(torchvision.transforms.ToTensor())
    transforms.append(torchvision.transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]))

    if train:
        transforms.append(torchvision.transforms.RandomHorizontalFlip(0.5))
    return torchvision.transforms.Compose(transforms)


# 修改模型创建函数
def create_model(num_classes=1):
    model = get_mvss(nclass=num_classes, sobel=True, constrain=True)
    return model


# 修改后的训练函数
def train_one_epoch(model, optimizer, data_loader, device, epoch):
    model.train()
    criterion = nn.BCEWithLogitsLoss()
    total_loss = 0

    progress_bar = tqdm(data_loader, desc=f"Epoch {epoch} Training", leave=True)

    for images, targets in progress_bar:
        try:
            images = torch.stack([img.to(device) for img in images])
            masks = torch.stack([t['masks'].to(device) for t in targets])
        except RuntimeError as e:
            print("\n尺寸不一致错误详情：")
            for i, t in enumerate(targets):
                print(f"样本 {i} 掩码尺寸：{t['masks'].shape}")
            raise
        # MVSS网络前向传播
        _, outputs = model(images)
        loss = criterion(outputs, masks.unsqueeze(1))  # 添加通道维度

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        progress_bar.set_postfix(loss=loss.item())

    return total_loss / len(data_loader)


def calculate_metrics(pred_boxes, true_boxes, iou_threshold=0.5):
    """
    计算单个样本的TP、FP、FN
    """
    if len(pred_boxes) == 0:
        return 0, 0, len(true_boxes)

    if len(true_boxes) == 0:
        return 0, len(pred_boxes), 0

    # 计算IoU矩阵
    iou_matrix = box_iou(pred_boxes, true_boxes)

    # 找到最佳匹配（每个真实框最多匹配一个预测框）
    matched_true = set()
    matched_pred = set()

    # 先为每个真实框找最佳预测
    for true_idx in range(len(true_boxes)):
        best_iou = iou_threshold
        best_pred = -1
        for pred_idx in range(len(pred_boxes)):
            iou = iou_matrix[pred_idx, true_idx]
            if iou > best_iou:
                best_iou = iou
                best_pred = pred_idx
        if best_pred != -1 and best_pred not in matched_pred:
            matched_true.add(true_idx)
            matched_pred.add(best_pred)

    tp = len(matched_true)
    fp = len(pred_boxes) - len(matched_pred)
    fn = len(true_boxes) - len(matched_true)

    return tp, fp, fn


# 修改后的验证函数
def evaluate(model, data_loader, device):
    model.eval()
    total_tp = total_fp = total_fn = 0

    progress_bar = tqdm(data_loader, desc="Validating", leave=True)

    with torch.no_grad():
        for images, targets in progress_bar:
            images = torch.stack([img.to(device) for img in images])
            _, outputs = model(images)
            preds = torch.sigmoid(outputs).cpu().numpy()

            for i in range(preds.shape[0]):
                pred_mask = preds[i, 0]  # 获取单通道预测结果
                h, w = targets[i]['masks'].shape[-2:]

                # 调整预测掩码到原始尺寸
                pred_mask = cv2.resize(pred_mask, (w, h))
                pred_boxes = mask_to_boxes(pred_mask)
                true_boxes = targets[i]['boxes'].cpu()

                # 转换并计算指标
                pred_boxes = torch.as_tensor(pred_boxes, dtype=torch.float32) if pred_boxes else torch.zeros((0, 4))
                tp, fp, fn = calculate_metrics(pred_boxes, true_boxes)

                total_tp += tp
                total_fp += fp
                total_fn += fn

    precision = total_tp / (total_tp + total_fp + 1e-10)
    recall = total_tp / (total_tp + total_fn + 1e-10)
    f1 = 2 * (precision * recall) / (precision + recall + 1e-10)

    return {
        "Micro_Prec": float(precision),
        "Micro_Recall": float(recall),
        "Micro_F1": float(f1)
    }


# 在main函数中需要调整优化器设置（因为模型参数结构变化）
def main():
    # 超参数配置
    BATCH_SIZE = 4
    NUM_EPOCHS = 15
    LR = 0.005
    VAL_SPLIT = 0.2
    DEBUG_MODE = True
    DEBUG_SUBSET_SIZE = 10

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # 数据集加载
    full_dataset = DocumentTamperDataset(
        image_dir="train/images",
        annotation_path="train/label_train.json",
        DEBUG_SUBSET_SIZE=DEBUG_SUBSET_SIZE,
        transforms=get_transform(train=True),
        debug_mode=DEBUG_MODE
    )

    # 数据划分
    dataset_size = len(full_dataset)
    val_size = int(VAL_SPLIT * dataset_size)
    train_size = dataset_size - val_size
    train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size],
                                              generator=torch.Generator().manual_seed(42))
    val_dataset.dataset.transforms = get_transform(train=False)

    # 数据加载器
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=lambda x: tuple(zip(*x)))
    val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, collate_fn=lambda x: tuple(zip(*x)))

    # 初始化模型
    model = create_model(num_classes=1)
    model.to(device)

    # 优化器配置
    params = [p for p in model.parameters() if p.requires_grad]
    optimizer = torch.optim.SGD(params, lr=LR, momentum=0.9, weight_decay=0.0005)
    lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=3, gamma=0.1)

    # 初始化记录字典
    metrics_history = []

    # 训练循环
    best_f1 = 0
    for epoch in range(NUM_EPOCHS):
        train_loss = train_one_epoch(model, optimizer, train_loader, device, epoch)
        val_metrics = evaluate(model, val_loader, device)
        lr_scheduler.step()

        print(f"\nEpoch {epoch + 1}/{NUM_EPOCHS}")
        print(f"Train Loss: {train_loss:.4f}")
        print(
            f"Validation Metrics: Precision={val_metrics['Micro_Prec']:.4f}, Recall={val_metrics['Micro_Recall']:.4f}, F1={val_metrics['Micro_F1']:.4f}")

        # 保存当前epoch的指标
        metrics_history.append({
            "Epoch": epoch + 1,
            "Train_Loss": train_loss,
            "Precision": val_metrics["Micro_Prec"],
            "Recall": val_metrics["Micro_Recall"],
            "F1": val_metrics["Micro_F1"]
        })

        # 保存最佳模型
        if val_metrics['Micro_F1'] > best_f1:
            best_f1 = val_metrics['Micro_F1']
            torch.save(model.state_dict(), "MVss_best_model.pth")
            print("↦ 保存新最佳模型")

    # 最终评估
    model.load_state_dict(torch.load("MVss_best_model.pth"))
    final_metrics = evaluate(model, val_loader, device)
    print("\n" + "=" * 50)
    print(
        f"最终评估结果: Precision={final_metrics['Micro_Prec']:.4f}, Recall={final_metrics['Micro_Recall']:.4f}, F1={final_metrics['Micro_F1']:.4f}")
    print("=" * 50)

    # 保存所有指标到CSV
    save_metrics_to_csv(metrics_history)


def save_metrics_to_csv(metrics_history, filename="training_metrics.csv"):
    """保存训练过程中的指标到CSV文件"""
    df = pd.DataFrame(metrics_history)
    df.to_csv(filename, index=False)
    print(f"Training metrics saved to {filename}")


if __name__ == "__main__":
    main()

ValueError: numpy.dtype size changed, may indicate binary incompatibility. Expected 96 from C header, got 88 from PyObject