In [None]:
import os
import torch
from PIL import Image
from torchvision.models.detection import fasterrcnn_resnet50_fpn
from torchvision.transforms import functional as F
from torch.utils.data import DataLoader, Dataset
import json

# 定义自定义数据集类
class DentalXrayDataset(Dataset):
    def __init__(self, root, annotation_dir, transforms=None):
        self.root = root
        self.annotation_dir = annotation_dir
        self.transforms = transforms
        self.image_files = list(sorted(os.listdir(root)))
        self.annotation_files = list(sorted(os.listdir(annotation_dir)))

    def __getitem__(self, idx):
        img_path = os.path.join(self.root, self.image_files[idx])
        img = Image.open(img_path).convert("RGB")

        # 加载标注文件
        annotation_path = os.path.join(self.annotation_dir, self.annotation_files[idx])
        with open(annotation_path) as f:
            annotation = json.load(f)

        # 解析边界框和标签
        boxes = []
        labels = []
        for shape in annotation["shapes"]:
            points = shape["points"]
            xmin = min(point[0] for point in points)
            ymin = min(point[1] for point in points)
            xmax = max(point[0] for point in points)
            ymax = max(point[1] for point in points)
            boxes.append([xmin, ymin, xmax, ymax])
            
            # 根据 label_name 映射类别ID
            label_name = shape["label"]
            if "caries" in label_name.lower():
                labels.append(1)  # 龋齿
            elif "deep caries" in label_name.lower():
                labels.append(2)  # 深度龋齿
            elif "periapical lesion" in label_name.lower():
                labels.append(3)  # 根尖病变
            elif "impacted" in label_name.lower():
                labels.append(4)  # 阻生牙
            else:
                labels.append(0)  # 其他

        # 转换为张量
        boxes = torch.as_tensor(boxes, dtype=torch.float32)
        labels = torch.as_tensor(labels, dtype=torch.int64)
        image_id = torch.tensor([idx])
        area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])
        iscrowd = torch.zeros((len(labels),), dtype=torch.int64)

        target = {
            "boxes": boxes,
            "labels": labels,
            "image_id": image_id,
            "area": area,
            "iscrowd": iscrowd
        }

        if self.transforms:
            img = self.transforms(img)

        return img, target

    def __len__(self):
        return len(self.image_files)

# 初始化模型、优化器和数据集
def get_model(num_classes):
    model = fasterrcnn_resnet50_fpn(pretrained=True)
    in_features = model.roi_heads.box_predictor.cls_score.in_features
    model.roi_heads.box_predictor = torch.nn.Linear(in_features, num_classes)
    return model

# 数据路径
train_img_dir = 'D:/DENTEX_dataset/training_data/disease/input'
train_annotation_dir = 'D:/DENTEX_dataset/training_data/disease/label'
test_img_dir = 'D:/DENTEX_dataset/test_data/disease/input'
test_annotation_dir = 'D:/DENTEX_dataset/test_data/disease/label'

# 创建数据集和数据加载器
train_dataset = DentalXrayDataset(train_img_dir, train_annotation_dir)
test_dataset = DentalXrayDataset(test_img_dir, test_annotation_dir)
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, collate_fn=lambda x: tuple(zip(*x)))
test_loader = DataLoader(test_dataset, batch_size=4, shuffle=False, collate_fn=lambda x: tuple(zip(*x)))

# 模型和训练参数
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model = get_model(num_classes=5)  # 4 类疾病 + 背景
model.to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=0.005, momentum=0.9, weight_decay=0.0005)
num_epochs = 10

# 训练循环
for epoch in range(num_epochs):
    model.train()
    for images, targets in train_loader:
        images = list(img.to(device) for img in images)
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
        
        loss_dict = model(images, targets)
        losses = sum(loss for loss in loss_dict.values())
        
        optimizer.zero_grad()
        losses.backward()
        optimizer.step()
        
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {losses.item()}")

# 测试/评估循环
model.eval()
for images, targets in test_loader:
    images = list(img.to(device) for img in images)
    with torch.no_grad():
        predictions = model(images)
    # 这里可以处理预测结果，计算mAP或可视化检测结果

print("训练和评估完成。")
