# Code Block Index

- [roi_extract](#roi_extract)
- [scaler](#scaler)
- [split_dataset](#split_dataset)
- [train](#train)
- [test](#test)


<a id="roi_extract"></a>

In [None]:
# roi_extract
import os
import json
from PIL import Image
import numpy as np

# 路径设置
image_dir = "./toy-dataset/breast_bm_b-mode/images"  # 原始图像路径
label_dir = "./toy-dataset/breast_bm_b-mode/labels"  # JSON标注路径
output_dir = "./toy-dataset/images"  # 裁剪后保存路径
os.makedirs(output_dir, exist_ok=True)  # 创建输出目录

# 遍历所有JSON文件
for json_file in os.listdir(label_dir):
    if json_file.endswith(".json"):
        json_path = os.path.join(label_dir, json_file)
        
        # 读取JSON标注
        with open(json_path, "r") as f:
            data = json.load(f)
        
        # 获取矩形坐标 (x1, y1), (x2, y2)
        points = data["shapes"][0]["points"]
        x1, y1 = points[0]  # 左上角
        x2, y2 = points[1]  # 右下角
        
        # 打开原始图像并裁剪
        image_name = data["imagePath"].split("/")[-1]  # 如 "B-4-CAIWENZHI-1.jpg"
        image_path = os.path.join(image_dir, image_name)
        img = Image.open(image_path)
        
        # 确保坐标在图像范围内
        x1, y1, x2, y2 = map(int, [x1, y1, x2, y2])
        x1 = max(0, x1)
        y1 = max(0, y1)
        x2 = min(img.width, x2)
        y2 = min(img.height, y2)
        
        # 裁剪并保存
        cropped_img = img.crop((x1, y1, x2, y2))
        output_path = os.path.join(output_dir, image_name)
        cropped_img.save(output_path)

print(f"已完成所有图像裁剪，保存在 {output_dir}")

<a id="scaler"></a>

In [None]:
# scaler
import pandas as pd
from sklearn.preprocessing import MinMaxScaler

# 路径设置
input_csv_path = "./toy-dataset/features_raw.csv"
output_csv_path = "./toy-dataset/features.csv"

# 1. 读取原始CSV文件
try:
    df = pd.read_csv(input_csv_path)
except FileNotFoundError:
    print(f"错误：找不到文件 {input_csv_path}")
    exit()  # 停止程序

# 2. 检查必须存在的列
required_columns = ["anonymous_id", "label"]
if not all(col in df.columns for col in required_columns):
    print(f"错误：CSV文件缺少必要的列。必须包含 {required_columns}")
    exit()  # 停止程序


# 3. 分离ID和标签列
id_col = df["anonymous_id"]
label_col = df["label"]
feature_cols = [col for col in df.columns if col not in ["anonymous_id", "label"]]

# 4. 提取需要归一化的特征
features_df = df[feature_cols]


# 5. 使用MinMaxScaler进行归一化
scaler = MinMaxScaler()
normalized_features = scaler.fit_transform(features_df)

# 6. 创建归一化后的DataFrame
normalized_df = pd.DataFrame(normalized_features, columns=feature_cols)

# 7. 合并ID、归一化后的特征和标签
final_df = pd.concat([id_col, normalized_df, label_col], axis=1)

# 8. 保存到新的CSV文件
final_df.to_csv(output_csv_path, index=False)

print(f"特征已归一化并保存到: {output_csv_path}")

<a id="split_dataset"></a>

In [None]:
# split_dataset
import pandas as pd
from sklearn.model_selection import StratifiedShuffleSplit
import os

# 路径设置
input_csv_path = "./toy-dataset/features.csv"
output_csv_path = "./toy-dataset/features_split.csv"

# 划分比例
TRAIN_SIZE = 0.7
VAL_SIZE = 0.15
TEST_SIZE = 0.15 # TRAIN_SIZE + VAL_SIZE + TEST_SIZE 应该接近 1.0

# 确保输出目录存在
output_dir = os.path.dirname(output_csv_path)
if not os.path.exists(output_dir):
    os.makedirs(output_dir)
    print(f"创建输出目录: {output_dir}")

print(f"正在读取数据集: {input_csv_path}")
try:
    df = pd.read_csv(input_csv_path)
except FileNotFoundError:
    print(f"错误：找不到文件 {input_csv_path}")
    exit()

# 检查是否存在 label 列用于分层抽样
if 'label' not in df.columns:
    print("错误：CSV文件缺少 'label' 列，无法进行分层抽样。")
    exit()

print(f"原始数据集大小: {len(df)}")

# 进行分层抽样划分
# 先划分出测试集
split1 = StratifiedShuffleSplit(n_splits=1, test_size=TEST_SIZE, random_state=42)
train_val_indices, test_indices = next(split1.split(df, df['label']))

# 然后在剩余的训练+验证集中划分出验证集
# 注意：这里 test_size 是相对于 train_val_indices 的比例
# 验证集相对于总体的比例是 VAL_SIZE，那么它占 train_val_indices 的比例是 VAL_SIZE / (TRAIN_SIZE + VAL_SIZE)
val_relative_size = VAL_SIZE / (TRAIN_SIZE + VAL_SIZE)
split2 = StratifiedShuffleSplit(n_splits=1, test_size=val_relative_size, random_state=42)
train_indices, val_indices = next(split2.split(df.iloc[train_val_indices], df.iloc[train_val_indices]['label']))

# 将划分结果映射回原始 DataFrame 的索引
train_original_indices = df.iloc[train_val_indices].iloc[train_indices].index
val_original_indices = df.iloc[train_val_indices].iloc[val_indices].index
test_original_indices = df.iloc[test_indices].index

# 在 DataFrame 中添加 'set' 列并赋值
df['set'] = 'unknown' # 初始化
df.loc[train_original_indices, 'set'] = 'train'
df.loc[val_original_indices, 'set'] = 'val'
df.loc[test_original_indices, 'set'] = 'test'

print(f"划分结果:")
print(df['set'].value_counts())

# 检查是否有未被划分的数据
if 'unknown' in df['set'].unique():
     print("警告：存在未被划分的数据。")

# 保存划分后的 CSV 文件
df.to_csv(output_csv_path, index=False)

print(f"数据集划分完成，结果已保存到: {output_csv_path}")

<a id="train"></a>

In [None]:
# train
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
import pandas as pd
from PIL import Image
import os
import math
import numpy as np

# 超参数 (保持不变)
BATCH_SIZE = 32
NUM_EPOCHS = 50
LR = 1e-4
NUM_CLASSES = 2
IMG_SIZE = 224
# Transformer 参数 (保持与之前一致)
PATCH_SIZE = 16
NUM_HEADS = 8
HIDDEN_DIM = 128
NUM_LAYERS = 4

# 路径设置 (保持不变)
image_dir = "./toy-dataset/images"
csv_split_path = "./toy-dataset/features_split.csv"
output_dir = "./"

# 确保输出目录存在 (保持不变)
if not os.path.exists(output_dir):
    os.makedirs(output_dir)
    print(f"创建输出目录: {output_dir}")

# --- Transformer 工具类 (保持不变) ---
class PositionalEncoding(nn.Module):  # ... (保持不变)
    def __init__(self, d_model, max_len=1000):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x):
        # x shape: (sequence_length, batch_size, d_model)
        # self.pe shape: (max_len, d_model)
        # We need to add position encoding for the sequence length
        pe = self.pe[:x.size(0), :].unsqueeze(1).expand(-1, x.size(1), -1)
        return x + pe


class TransformerEncoderLayer(nn.Module):
    def __init__(self, d_model, nhead, dropout=0.1):
        super().__init__()
        self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=False) # batch_first=False aligns with PositionalEncoding expected input
        self.linear1 = nn.Linear(d_model, d_model * 4)
        self.linear2 = nn.Linear(d_model * 4, d_model)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
        self.activation = nn.GELU()

    def forward(self, src):
        # src shape: (sequence_length, batch_size, d_model)
        src2 = self.norm1(src)
        # MultiheadAttention expects (sequence_length, batch_size, d_model) if batch_first=False
        src2, _ = self.self_attn(src2, src2, src2)
        src = src + self.dropout(src2)
        src2 = self.norm2(src)
        src2 = self.linear2(self.dropout(self.activation(self.linear1(src2))))
        src = src + self.dropout(src2)
        return src

# --- MultiModalModel 类 (修改以使用 ResNet18 作为图像特征提取器) ---
class MultiModalModel(nn.Module):
    def __init__(self, num_features, num_classes):
        super().__init__()

        # 1. 使用预训练的 ResNet18 作为图像特征提取器
        self.resnet = models.resnet18(pretrained=True)  # 下载并加载预训练权重

        # 冻结 ResNet18 的参数 (可选，如果你希望只训练后面的层)
        for param in self.resnet.parameters():
            param.requires_grad = False

        # 修改 ResNet18 的最后几层，使其输出适合 Transformer 的输入
        # 移除 ResNet18 的 AvgPool 和 FC 层
        self.resnet = nn.Sequential(*list(self.resnet.children())[:-2]) # Remove avgpool and fc layer

        # 添加一个卷积层，将 ResNet18 的输出通道数调整为 HIDDEN_DIM
        self.conv = nn.Conv2d(512, HIDDEN_DIM, kernel_size=1) # ResNet18 outputs 512 channels

        # 2. 图像Patch处理
        PATCH_KERNEL = 4
        PATCH_STRIDE = 4
        self.patch_embed = nn.Conv2d(HIDDEN_DIM, HIDDEN_DIM,
                                     kernel_size=PATCH_KERNEL, stride=PATCH_STRIDE)
        self.num_patches = (56 // PATCH_KERNEL) * (56 // PATCH_STRIDE)

        # 3. 数值特征处理 (保持不变)
        self.feature_proj = nn.Sequential(
            nn.Linear(num_features, HIDDEN_DIM),
            nn.LayerNorm(HIDDEN_DIM),
            nn.Dropout(0.1)
        )

        # 4. Transformer部分 (保持不变)
        self.seq_length = 1 + 1 + self.num_patches
        self.cls_token = nn.Parameter(torch.randn(1, 1, HIDDEN_DIM))
        self.pos_encoder = PositionalEncoding(HIDDEN_DIM, max_len=self.seq_length)
        self.transformer = nn.ModuleList([
            TransformerEncoderLayer(HIDDEN_DIM, NUM_HEADS)
            for _ in range(NUM_LAYERS)
        ])

        # 5. 分类头 (保持不变)
        self.classifier = nn.Sequential(
            nn.LayerNorm(HIDDEN_DIM),
            nn.Linear(HIDDEN_DIM, num_classes)
        )

    def forward(self, image, features):
        # 1. 使用 ResNet18 提取图像特征
        img_feat = self.resnet(image)  # [B, 512, 7, 7] (ResNet18 before avgpool)
        img_feat = self.conv(img_feat) # [B, HIDDEN_DIM, 7, 7]
        img_feat = self.patch_embed(img_feat) # [B, HIDDEN_DIM, 14, 14]
        img_feat = img_feat.flatten(2).transpose(1, 2)  # [B, num_patches, HIDDEN_DIM]

        # 2. 处理数值特征 (保持不变)
        num_feat = self.feature_proj(features).unsqueeze(1)  # [B, 1, HIDDEN_DIM]

        # 3. 添加CLS token和融合 (保持不变)
        cls_tokens = self.cls_token.expand(image.size(0), -1, -1)
        x = torch.cat([cls_tokens, num_feat, img_feat], dim=1)  # [B, 1+1+num_patches, HIDDEN_DIM]

        # 4. Transformer编码 (保持不变)
        x = x.transpose(0, 1)
        x = self.pos_encoder(x)
        for layer in self.transformer:
            x = layer(x)

        # 5. 分类 (保持不变)
        cls_output = x[0]  # [B, HIDDEN_DIM]
        output = self.classifier(cls_output)
        return output

# --- 数据集类 (保持不变) ---
class MultiModalDataset(Dataset):  # ... (保持不变)
    def __init__(self, csv_path, image_dir, set_type, transform=None):
        df_full = pd.read_csv(csv_path)
        # 根据 set_type 过滤数据
        if set_type not in ['train', 'val', 'test']:
            raise ValueError("set_type must be 'train', 'val', or 'test'")
        self.df = df_full[df_full['set'] == set_type].reset_index(drop=True) # Reset index after filtering
        
        self.image_dir = image_dir
        self.transform = transform
        self.set_type = set_type
        
        # Cache feature column names
        self.feature_cols = [col for col in self.df.columns if col not in ["anonymous_id", "label", "set"]]
        
        print(f"Loaded {len(self.df)} samples for set_type: {self.set_type}")

    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        # 获取当前行的数据
        row = self.df.iloc[idx]
        
        # 加载结构化特征
        features = row[self.feature_cols].values.astype(float)
        
        # 加载标签
        label = row["label"].astype(int)
        
        # 加载图像 (匿名化后名称格式为 img_XXX.jpg)
        # 注意：这里的 idx 是过滤后 DataFrame 的行索引，可能与原始 CSV 的索引不同
        # 因此需要存储原始索引或 anonymous_id 来找到对应的图片文件。
        # 假设匿名化后的名称格式为 img_[原始索引].jpg
        # 如果你的图片命名是 img_000.jpg, img_001.jpg 等直接对应处理后的CSV行索引，
        # 那么需要确保features_split.csv保存时保留了原始顺序或者有映射关系。
        # 考虑到 split_dataset.py 只是添加了 set 列并保存，它保留了原始索引顺序。
        # 因此，filtered_df.iloc[idx] 对应的图片索引应该是其原始索引。
        # 查找原始索引: df_full[df_full['set'] == set_type].index[idx]
        # 或者更简单：如果在 MultiModalDataset.__init__ 中 reset_index(drop=True)
        # 那么新的 DataFrame 的 index 0, 1, 2... 对应的是过滤后的行的顺序。
        # 如果 img_XXX.jpg 是根据原始 CSV 的行号匿名化的，我们需要找到当前行在原始 features.csv 中的位置。
        # split_dataset.py 没有删除或重排序原始行，只是加了一列。
        # 所以 features_split.csv 的第 i 行仍然对应 img_i.jpg。
        # 然而，过滤后的 self.df 的第 i 行可能不是原始 CSV 的第 i 行。
        # 最安全的方法是假设 anonymous_id 或其他唯一标识符可以映射到图片文件名。
        # 但是你的匿名化代码 img_XXX.jpg 似乎是按顺序来的。
        # 如果 img_XXX.jpg 对应的是 *原始* features_raw.csv 或 features.csv 的行号 X,
        # 那么你需要找到当前 row 在原始 features.csv 中的位置。
        # split_dataset.py 保留了原始索引。所以 self.df 的 index 属性就是原始索引。
        original_index = self.df.index[idx]
        image_name = f"img_{original_index:03d}.jpg" # 使用原始索引来构建文件名
        image_path = os.path.join(self.image_dir, image_name)
        
        try:
            image = Image.open(image_path).convert("RGB")
        except FileNotFoundError:
             print(f"警告: 找不到图片文件 {image_path}，跳过此样本。")
             # 返回一个占位符或者处理错误，这里简单返回None，需要在DataLoader中处理或过滤
             # 更健壮的方法是在__init__时检查图片是否存在，并从df中移除对应的行
             # 但为了代码简洁，这里先这样处理
             return None
        
        if self.transform:
            image = self.transform(image)
            
        return {
            "image": image,
            "features": torch.FloatTensor(features),
            "label": torch.LongTensor([label]).squeeze()
        }

# --- 数据加载 (创建训练和验证DataLoader) ---
# 计算特征数量 (保持不变)
try:
    df_temp = pd.read_csv(csv_split_path)
    num_features = len([col for col in df_temp.columns if col not in ["anonymous_id", "label", "set"]])
    print(f"检测到特征数量: {num_features}")
    del df_temp
except FileNotFoundError:
    print(f"错误：找不到划分后的数据集文件 {csv_split_path}。请先运行 split_dataset.py")
    exit()

# 2. 图像预处理 (使用与 ResNet18 预训练时相同的 Normalization 参数)
transform = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),  # ImageNet mean and std
])

train_dataset = MultiModalDataset(csv_split_path, image_dir, set_type='train', transform=transform)
val_dataset = MultiModalDataset(csv_split_path, image_dir, set_type='val', transform=transform)

# --- 定义 collate_fn 处理可能存在的 None (由于图片缺失) ---
def collate_fn(batch): # ... (保持不变)
    # 过滤掉 None 样本
    batch = list(filter(lambda x: x is not None, batch))
    if not batch:
        return None # 如果整个批次都是 None，返回 None

    # 将字典列表转换为字典，其中每个值是张量批次
    return {key: torch.stack([d[key] for d in batch]) for key in batch[0]}

train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn)
val_dataloader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_fn)

# --- 训练和验证函数 (保持不变) ---
def train_epoch(model, dataloader, criterion, optimizer, device): # ... (保持不变)
    model.train()
    total_loss = 0
    
    for batch in dataloader:
        # 处理空批次
        if batch is None:
            continue

        images = batch["image"].to(device)
        features = batch["features"].to(device)
        labels = batch["label"].to(device)
        
        # 前向传播
        outputs = model(images, features)
        loss = criterion(outputs, labels)
        
        # 反向传播
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()

    return total_loss / len(dataloader)

def validate_epoch(model, dataloader, criterion, device): # ... (保持不变)
    model.eval()
    total_loss = 0
    correct_predictions = 0
    total_samples = 0

    with torch.no_grad():
        for batch in dataloader:
            # 处理空批次
            if batch is None:
                continue

            images = batch["image"].to(device)
            features = batch["features"].to(device)
            labels = batch["label"].to(device)

            outputs = model(images, features)
            loss = criterion(outputs, labels)

            total_loss += loss.item()

            # 计算准确率
            _, predicted = torch.max(outputs, 1)
            total_samples += labels.size(0)
            correct_predictions += (predicted == labels).sum().item()

    avg_loss = total_loss / len(dataloader)
    accuracy = correct_predictions / total_samples if total_samples > 0 else 0
    return avg_loss, accuracy

# --- 主训练流程 (保持不变) ---
if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"使用设备: {device}")

    model = MultiModalModel(num_features=num_features, num_classes=NUM_CLASSES).to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.AdamW(model.parameters(), lr=LR, weight_decay=0.01)

    best_val_loss = float('inf')
    best_val_accuracy = 0.0
    best_epoch = -1
    model_save_path = os.path.join(output_dir, 'best_model.pth')

    print("开始训练...")
    for epoch in range(NUM_EPOCHS):
        train_loss = train_epoch(model, train_dataloader, criterion, optimizer, device)
        val_loss, val_accuracy = validate_epoch(model, val_dataloader, criterion, device)

        print(f"Epoch [{epoch+1}/{NUM_EPOCHS}]")
        print(f"  Train Loss: {train_loss:.4f}")
        print(f"  Val Loss: {val_loss:.4f}, Val Accuracy: {val_accuracy:.4f}")

        # 保存最佳模型 (基于验证集损失)
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            best_val_accuracy = val_accuracy
            best_epoch = epoch + 1
            torch.save(model.state_dict(), model_save_path)
            print(f"  >> 验证集损失改进 ({best_val_loss:.4f})，保存模型到 {model_save_path}")
        # 也可以选择基于验证集准确率保存最佳模型
        # if val_accuracy > best_val_accuracy:
        #     best_val_accuracy = val_accuracy
        #     best_val_loss = val_loss # 记录对应的损失
        #     best_epoch = epoch + 1
        #     torch.save(model.state_dict(), model_save_path)
        #     print(f"  >> 验证集准确率改进 ({best_val_accuracy:.4f})，保存模型到 {model_save_path}")

    print("\n训练完成！")
    print(f"最佳模型保存在 Epoch {best_epoch}，验证集损失为: {best_val_loss:.4f}，验证集准确率为: {best_val_accuracy:.4f}")


<a id="test"></a>

In [None]:
# test
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
import pandas as pd
from PIL import Image
import os
import math
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix

# 超参数 (需要与训练时保持一致，尤其是模型相关的参数)
BATCH_SIZE = 32
NUM_CLASSES = 2
IMG_SIZE = 224
# Transformer 参数
PATCH_SIZE = 16
NUM_HEADS = 8
HIDDEN_DIM = 128
NUM_LAYERS = 4

# 路径设置
image_dir = "./toy-dataset/images"              # 裁剪后图像路径
csv_split_path = "./toy-dataset/features_split.csv" # 划分后的CSV路径
model_path = "./best_model.pth"    # 训练保存的最佳模型权重路径

# --- Transformer 工具类 (与train.py保持一致) ---
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=1000):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x):
        pe = self.pe[:x.size(0), :].unsqueeze(1).expand(-1, x.size(1), -1)
        return x + pe

class TransformerEncoderLayer(nn.Module):
    def __init__(self, d_model, nhead, dropout=0.1):
        super().__init__()
        self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=False)
        self.linear1 = nn.Linear(d_model, d_model * 4)
        self.linear2 = nn.Linear(d_model * 4, d_model)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
        self.activation = nn.GELU()

    def forward(self, src):
        src2 = self.norm1(src)
        src2, _ = self.self_attn(src2, src2, src2)
        src = src + self.dropout(src2)
        src2 = self.norm2(src)
        src2 = self.linear2(self.dropout(self.activation(self.linear1(src2))))
        src = src + self.dropout(src2)
        return src

# --- MultiModalModel 类 (与train.py保持一致) ---
class MultiModalModel(nn.Module):
    def __init__(self, num_features, num_classes):
        super().__init__()

        self.cnn = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1), # Output: [B, 64, 56, 56]
            nn.Conv2d(64, HIDDEN_DIM, kernel_size=3, stride=1, padding=1, bias=False), # Output: [B, HIDDEN_DIM, 56, 56]
            nn.BatchNorm2d(HIDDEN_DIM),
            nn.ReLU(),
        )

        PATCH_KERNEL = 4
        PATCH_STRIDE = 4
        self.patch_embed = nn.Conv2d(HIDDEN_DIM, HIDDEN_DIM,
                                     kernel_size=PATCH_KERNEL, stride=PATCH_STRIDE)
        self.num_patches = (56 // PATCH_KERNEL) * (56 // PATCH_STRIDE)

        self.feature_proj = nn.Sequential(
            nn.Linear(num_features, HIDDEN_DIM),
            nn.LayerNorm(HIDDEN_DIM),
            nn.Dropout(0.1)
        )

        self.seq_length = 1 + 1 + self.num_patches # CLS + Numerical + Patches
        self.cls_token = nn.Parameter(torch.randn(1, 1, HIDDEN_DIM))
        self.pos_encoder = PositionalEncoding(HIDDEN_DIM, max_len=self.seq_length)
        self.transformer = nn.ModuleList([
            TransformerEncoderLayer(HIDDEN_DIM, NUM_HEADS)
            for _ in range(NUM_LAYERS)
        ])

        self.classifier = nn.Sequential(
            nn.LayerNorm(HIDDEN_DIM),
            nn.Linear(HIDDEN_DIM, num_classes)
        )

    def forward(self, image, features):
        img_feat = self.cnn(image)
        img_feat = self.patch_embed(img_feat)
        img_feat = img_feat.flatten(2).transpose(1, 2)

        num_feat = self.feature_proj(features).unsqueeze(1)

        cls_tokens = self.cls_token.expand(image.size(0), -1, -1)
        x = torch.cat([cls_tokens, num_feat, img_feat], dim=1)

        x = x.transpose(0, 1)
        x = self.pos_encoder(x)

        for layer in self.transformer:
            x = layer(x)

        cls_output = x[0]
        output = self.classifier(cls_output)
        return output

# --- 自定义数据集类 (与train.py保持一致, 过滤 set_type='test') ---
class MultiModalDataset(Dataset):
    def __init__(self, csv_path, image_dir, set_type, transform=None):
        df_full = pd.read_csv(csv_path)
        if set_type not in ['train', 'val', 'test']:
            raise ValueError("set_type must be 'train', 'val', or 'test'")
        self.df = df_full[df_full['set'] == set_type].reset_index(drop=True)
        
        self.image_dir = image_dir
        self.transform = transform
        self.set_type = set_type
        
        self.feature_cols = [col for col in self.df.columns if col not in ["anonymous_id", "label", "set"]]
        
        print(f"Loaded {len(self.df)} samples for set_type: {self.set_type}")

    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        features = row[self.feature_cols].values.astype(float)
        label = row["label"].astype(int)
        
        # Use original index from the filtered dataframe's index attribute
        original_index = self.df.index[idx]
        image_name = f"img_{original_index:03d}.jpg"
        image_path = os.path.join(self.image_dir, image_name)
        
        try:
            image = Image.open(image_path).convert("RGB")
        except FileNotFoundError:
             print(f"警告: 找不到图片文件 {image_path}，跳过此样本。")
             return None
        
        if self.transform:
            image = self.transform(image)
            
        return {
            "image": image,
            "features": torch.FloatTensor(features),
            "label": torch.LongTensor([label]).squeeze()
        }

# 2. 图像预处理 (与train.py保持一致)
transform = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# --- 数据加载 (创建测试DataLoader) ---
# 计算特征数量
try:
    df_temp = pd.read_csv(csv_split_path)
    num_features = len([col for col in df_temp.columns if col not in ["anonymous_id", "label", "set"]])
    print(f"检测到特征数量: {num_features}")
    del df_temp
except FileNotFoundError:
    print(f"错误：找不到划分后的数据集文件 {csv_split_path}。请先运行 split_dataset.py")
    exit()

test_dataset = MultiModalDataset(csv_split_path, image_dir, set_type='test', transform=transform)

# --- collate_fn (与train.py保持一致) ---
def collate_fn(batch):
    batch = list(filter(lambda x: x is not None, batch))
    if not batch:
        return None
    return {key: torch.stack([d[key] for d in batch]) for key in batch[0]}

test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_fn) # Test set does not need shuffle


# --- 测试函数 ---
def test_model(model, dataloader, device):
    model.eval() # Set model to evaluation mode
    all_labels = []
    all_predictions = []

    print("开始测试...")
    with torch.no_grad(): # Disable gradient calculation for testing
        for batch in dataloader:
            if batch is None:
                continue

            images = batch["image"].to(device)
            features = batch["features"].to(device)
            labels = batch["label"].to(device)

            outputs = model(images, features)
            _, predicted = torch.max(outputs, 1) # Get the index of the max log-probability

            all_labels.extend(labels.cpu().numpy())
            all_predictions.extend(predicted.cpu().numpy())

    # 计算并打印评估指标
    accuracy = accuracy_score(all_labels, all_predictions)
    print(f"\n测试集准确率: {accuracy:.4f}")

    print("\n分类报告:")
    print(classification_report(all_labels, all_predictions, target_names=[f'Class {i}' for i in range(NUM_CLASSES)]))

    print("\n混淆矩阵:")
    print(confusion_matrix(all_labels, all_predictions))


# --- 主测试流程 ---
if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"使用设备: {device}")

    # 实例化模型 (结构必须与训练时一致)
    model = MultiModalModel(num_features=num_features, num_classes=NUM_CLASSES).to(device)

    # 加载训练好的模型权重
    if os.path.exists(model_path):
        print(f"加载模型权重: {model_path}")
        model.load_state_dict(torch.load(model_path, map_location=device))
    else:
        print(f"错误：找不到模型权重文件 {model_path}。请先运行 train.py 完成训练。")
        exit()

    # 运行测试
    test_model(model, test_dataloader, device)