In [1]:
# %% [markdown]
"""
# TransUNet肾脏分割可视化

本脚本提供以下功能：
1. 加载预训练的TransUNet模型
2. 执行单张图像推理
3. 应用后处理优化分割结果
4. 可视化原始图像、真实标签、预测结果对比
5. 计算定量评估指标
"""

# %%
import sys
import os
sys.path.append('/root/TransUNet_fusion')  # 添加项目根目录到Python路径
import numpy as np
import torch
import matplotlib.pyplot as plt
import cv2
import scipy.io as sio
from PIL import Image
from networks.vit_seg_modeling import VisionTransformer_mixRf as ViT_seg_mixRf
from networks.vit_seg_modeling import CONFIGS as CONFIGS_ViT_seg
from torchvision import transforms
# %% [markdown]
"""
## 1. 后处理函数定义
"""
# %%
def remove_small_regions(mask, min_area=500):
    """去除小面积不连通区域"""
    mask = (mask * 255).astype(np.uint8)
    num_labels, labels, stats, _ = cv2.connectedComponentsWithStats(mask, connectivity=8)
    cleaned_mask = np.zeros_like(mask)
    for i in range(1, num_labels):
        if stats[i, cv2.CC_STAT_AREA] >= min_area:
            cleaned_mask[labels == i] = 255
    return cleaned_mask.astype(float) / 255.0

def fill_holes(mask):
    """填充掩码中的孔洞"""
    mask = (mask * 255).astype(np.uint8)
    contours, hierarchy = cv2.findContours(mask, cv2.RETR_CCOMP, cv2.CHAIN_APPROX_SIMPLE)
    if contours:
        for i, contour in enumerate(contours):
            if hierarchy[0][i][3] != -1:  # 孔洞检测
                cv2.drawContours(mask, [contour], 0, 255, -1)
    return mask.astype(float) / 255.0
    
def make_convex(mask):
    """
    将掩码转换为凸形。
    :param mask: 输入掩码（二值图像，0 或 1）。
    :return: 凸形掩码。
    """
    # 转换为uint8
    mask = (mask * 255).astype(np.uint8)
    
    # 找到外部轮廓
    contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    
    # 创建一个空白掩码
    convex_mask = np.zeros_like(mask)
    
    # 计算凸包并填充
    for contour in contours:
        hull = cv2.convexHull(contour)
        cv2.drawContours(convex_mask, [hull], 0, 255, -1)
    
    return convex_mask.astype(float) / 255.0
    
def postprocess_mask(pred_mask, min_area=500):
    """完整后处理流程"""
    mask = (pred_mask > 0.5).astype(float)
    mask = remove_small_regions(mask, min_area)
    mask = fill_holes(mask)
    mask = make_convex(mask)
    return mask

# %% [markdown]
"""
## 2. 模型加载与初始化
"""
# %%
def load_transunet_model(model_path, img_size=224):
    """加载预训练的TransUNet模型"""
    config = CONFIGS_ViT_seg['R50-ViT-B_16']
    config.n_classes = 1
    config.n_skip = 3
    config.patches.grid = (int(img_size / 16), int(img_size / 16))
    
    model = ViT_seg_mixRf(config, img_size=img_size, num_classes=config.n_classes)
    model.load_state_dict(torch.load(model_path))
    model.cuda()
    model.eval()
    return model

# 配置路径
MODEL_PATH = "/root/TransUNet_fusion/new_model/ours (only image)1/best_model.pth"
DATA_DIR = "/root/autodl-tmp/data/imgs"
MASK_DIR = "/root/autodl-tmp/data/binary_masks"
RF_DIR = "/root/autodl-tmp/data/rfs"
OUTPUT_DIR = "/root/TransUNet_fusion/only_image_visualization"
RESULTS_TXT = os.path.join(OUTPUT_DIR, "results_summary.txt")  # 新增结果文件路径
os.makedirs(OUTPUT_DIR, exist_ok=True)

# %% [markdown]
"""
## 3. 数据预处理
"""
# %%
def preprocess_image(image_path, img_size=224):
    img_transform = transforms.Compose([
        transforms.ToPILImage(),
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5], std=[0.5])
    ])
    img = Image.open(image_path).convert('L')
    img = np.array(img)
    img = img_transform(img)
    img = img.unsqueeze(0).cuda()  # (1, 1, H, W)
    return img

def load_ground_truth(mask_path, img_size=224):
    mask = Image.open(mask_path).convert('L')
    mask = mask.resize((img_size, img_size))
    mask = np.array(mask) > 0
    return mask.astype(np.float32)

# %% [markdown]
"""
## 4. 推理与可视化
"""
# %%
def visualize_results(image, true_mask, pred_mask, processed_mask, save_path=None):

    plt.figure(figsize=(18, 6))
    
    plt.subplot(1, 4, 1)
    plt.imshow(image, cmap='gray')
    plt.title("Input Image")
    plt.axis('off')
    
    plt.subplot(1, 4, 2)
    plt.imshow(true_mask, cmap='gray')
    plt.title("Ground Truth")
    plt.axis('off')
    
    plt.subplot(1, 4, 3)
    plt.imshow(pred_mask, cmap='gray')
    plt.title("Raw Prediction")
    plt.axis('off')
    
    plt.subplot(1, 4, 4)
    plt.imshow(processed_mask, cmap='gray')
    plt.title("Post-processed")
    plt.axis('off')
    
    if save_path:
        plt.savefig(save_path, bbox_inches='tight', dpi=300)
    plt.show()

# %% [markdown]
"""
## 5. 评估指标计算
"""
# %%
def calculate_metrics(true_mask, pred_mask):
    
    y_true = true_mask.astype(bool)
    y_pred = pred_mask.astype(bool)
    
    tp = np.sum(y_true & y_pred)
    fp = np.sum(~y_true & y_pred)
    fn = np.sum(y_true & ~y_pred)
    
    precision = tp / (tp + fp + 1e-10)
    recall = tp / (tp + fn + 1e-10)
    dice = 2 * tp / (2 * tp + fp + fn + 1e-10)
    
    # print(f"\nEvaluation Metrics:")
    print(f"Dice Score: {dice:.4f}")
    # print(f"Precision: {precision:.4f}")
    # print(f"Recall: {recall:.4f}")
    return dice
    
def dice_score(y_true, y_pred):
    intersection = np.sum(y_true * y_pred)
    union = np.sum(y_true) + np.sum(y_pred)
    dice = (2. * intersection) / (union + 1e-7)
    print(f"Dice Score: {dice:.4f}")
    return dice
"""
## 6. 主执行流程
"""
# %%
# 加载模型
model = load_transunet_model(MODEL_PATH)
    
# 获取测试样本
test_images = sorted([f for f in os.listdir(DATA_DIR) if f.endswith('.jpg')])
test_masks = sorted([f for f in os.listdir(MASK_DIR) if f.endswith('.png')])
test_rfs = sorted([f for f in os.listdir(RF_DIR) if f.endswith('.mat')]
                 )
def load_rf_image(rfimg_path, img_size=224):
    """加载并处理RF图像（保持单通道），返回归一化后的RF数据和tensor"""
    rf_transform = transforms.Compose([
        transforms.ToPILImage(),
        transforms.Resize((224, 224)),
        transforms.ToTensor()
    ])
    mat_data = sio.loadmat(rfimg_path)
    frameRF = mat_data['frameRF']
    rf_data = frameRF[0]['data'][0]
    
    rf_data = np.abs(rf_data)
    rf_data = np.log(rf_data + 1e-6)
    
    rf_data = (rf_data - rf_data.min()) / (rf_data.max() - rf_data.min() + 1e-6)
    rf_data = torch.from_numpy(rf_data).float()
    rf_data = rf_data.unsqueeze(0)
    rf_tensor = rf_transform(rf_data)
    rf_tensor = rf_tensor.unsqueeze(0).cuda()
    return rf_tensor
    
def load_original_image_mask(img_path, mask_path, rf_path):
    """加载原始尺寸的图像和掩码"""
    img = Image.open(img_path).convert('L')
    mask = Image.open(mask_path).convert('L')
    original_size = img.size
    
    img_np = np.array(img)
    mask_np = (np.array(mask) > 0).astype(np.float32)
    
    return img_np, mask_np, original_size

def resize_mask(mask, target_size):
    mask_img = Image.fromarray((mask * 255).astype(np.uint8))
    
    resized_img = mask_img.resize(target_size, Image.LANCZOS)
    
    return (np.array(resized_img) > 127).astype(np.float32)

total_dice = 0
with open(RESULTS_TXT, 'w') as f:
    f.write("Image Name\tDice Score\n")
    f.write("------------------------\n")
for img_name, mask_name, rf_name in zip(test_images, test_masks, test_rfs):
    img_path = os.path.join(DATA_DIR, img_name)
    mask_path = os.path.join(MASK_DIR, mask_name)
    rf_path = os.path.join(RF_DIR, rf_name)
    print(img_path)
    original_img, original_mask, original_size = load_original_image_mask(img_path, mask_path, rf_path)
    
    img_tensor = preprocess_image(img_path)
    rf_tensor = load_rf_image(rf_path)

    with torch.no_grad():
        pred = model(img_tensor, None)
        # rf_aligned = model.rf_stn(rf_tensor).cpu().squeeze()
        pred_mask = torch.sigmoid(pred).squeeze().cpu().numpy()
    
    pred_mask_original_size = resize_mask(pred_mask > 0.5, original_size)
    processed_mask_original_size = postprocess_mask(pred_mask_original_size)
    
    # 可视化（原始尺寸）
    save_path = os.path.join(OUTPUT_DIR, f"result_{os.path.splitext(img_name)[0]}.png")
    # visualize_results(original_img, 
    #                 original_mask, 
    #                 pred_mask_original_size, 
    #                 processed_mask_original_size,
    #                 save_path)
    
    print(f"\nOriginal Size Results for {img_name}:")

    dice = calculate_metrics(original_mask, pred_mask_original_size)
    total_dice += dice
    with open(RESULTS_TXT, 'a') as f:
        f.write(f"{img_name}\t{dice:.4f}\n")
    # 保存原始尺寸预测结果
    cv2.imwrite(
        os.path.join(OUTPUT_DIR, f"pred_{img_name}"), 
        (pred_mask_original_size * 255).astype(np.uint8)
    )
mean_dice = total_dice / len(test_images)
print("Evaluation Finished: Dice Score is ", mean_dice)

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


/root/autodl-tmp/new_test_dataset/imgs/OCHADNGK.jpg

Original Size Results for OCHADNGK.jpg:
Dice Score: 0.9620
/root/autodl-tmp/new_test_dataset/imgs/OCHADNP4.jpg

Original Size Results for OCHADNP4.jpg:
Dice Score: 0.9465
/root/autodl-tmp/new_test_dataset/imgs/OCHALB1C.jpg

Original Size Results for OCHALB1C.jpg:
Dice Score: 0.9649
/root/autodl-tmp/new_test_dataset/imgs/OCHALB1G.jpg

Original Size Results for OCHALB1G.jpg:
Dice Score: 0.9538
/root/autodl-tmp/new_test_dataset/imgs/OCHALB1K.jpg

Original Size Results for OCHALB1K.jpg:
Dice Score: 0.9515
/root/autodl-tmp/new_test_dataset/imgs/OCHF8JA8.jpg

Original Size Results for OCHF8JA8.jpg:
Dice Score: 0.8036
/root/autodl-tmp/new_test_dataset/imgs/OCHF8JAC.jpg

Original Size Results for OCHF8JAC.jpg:
Dice Score: 0.8353
/root/autodl-tmp/new_test_dataset/imgs/OCHF8JAI.jpg

Original Size Results for OCHF8JAI.jpg:
Dice Score: 0.8964
/root/autodl-tmp/new_test_dataset/imgs/OCHF8JAM.jpg

Original Size Results for OCHF8JAM.jpg:
Dice Score:

KeyboardInterrupt: 

In [1]:
# import os
# import shutil
# from pathlib import Path

# def merge_datasets(source_dirs, target_dir):
#     """
#     将多个源目录中的内容合并到目标目录中
    
#     参数:
#         source_dirs: 源目录列表
#         target_dir: 目标目录
#     """
#     # 确保目标目录存在
#     os.makedirs(target_dir, exist_ok=True)
    
#     # 遍历所有源目录
#     for src_dir in source_dirs:
#         if not os.path.exists(src_dir):
#             print(f"警告: 源目录不存在: {src_dir}")
#             continue
            
#         # 遍历源目录中的所有文件和子目录
#         for root, dirs, files in os.walk(src_dir):
#             # 计算相对于源目录的相对路径
#             rel_path = os.path.relpath(root, src_dir)
#             target_path = os.path.join(target_dir, rel_path)
            
#             # 创建对应的目标子目录
#             os.makedirs(target_path, exist_ok=True)
            
#             # 复制所有文件
#             for file in files:
#                 src_file = os.path.join(root, file)
#                 dst_file = os.path.join(target_path, file)
                
#                 # 如果目标文件已存在，可以选择跳过或覆盖
#                 if os.path.exists(dst_file):
#                     print(f"警告: 文件已存在，将被覆盖: {dst_file}")
                
#                 shutil.copy2(src_file, dst_file)
#                 print(f"已复制: {src_file} -> {dst_file}")

# def main():
#     # 定义文件夹结构
#     folders = ['binary_masks', 'masks', 'imgs', 'rfs']
#     base_train = '/root/autodl-tmp/data'
#     base_test = '/root/autodl-tmp/test_dataset'
#     base_new_train = '/root/autodl-tmp/new_test_dataset'
    
#     # 为每个文件夹执行合并操作
#     for folder in folders:
#         print(f"\n正在处理文件夹: {folder}")
        
#         # 源目录列表
#         source_dirs = [
#             os.path.join(base_train, folder),
#             os.path.join(base_test, folder)
#         ]
        
#         # 目标目录
#         target_dir = os.path.join(base_new_train, folder)
        
#         # 执行合并
#         merge_datasets(source_dirs, target_dir)
    
#     print("\n所有文件夹合并完成!")

# main()


正在处理文件夹: binary_masks
已复制: /root/autodl-tmp/data/binary_masks/OCHADNGK.png -> /root/autodl-tmp/new_test_dataset/binary_masks/./OCHADNGK.png
已复制: /root/autodl-tmp/data/binary_masks/OCHADNP4.png -> /root/autodl-tmp/new_test_dataset/binary_masks/./OCHADNP4.png
已复制: /root/autodl-tmp/data/binary_masks/OCHALB1C.png -> /root/autodl-tmp/new_test_dataset/binary_masks/./OCHALB1C.png
已复制: /root/autodl-tmp/data/binary_masks/OCHALB1G.png -> /root/autodl-tmp/new_test_dataset/binary_masks/./OCHALB1G.png
已复制: /root/autodl-tmp/data/binary_masks/OCHALB1K.png -> /root/autodl-tmp/new_test_dataset/binary_masks/./OCHALB1K.png
已复制: /root/autodl-tmp/data/binary_masks/OCHF8JA8.png -> /root/autodl-tmp/new_test_dataset/binary_masks/./OCHF8JA8.png
已复制: /root/autodl-tmp/data/binary_masks/OCHF8JAC.png -> /root/autodl-tmp/new_test_dataset/binary_masks/./OCHF8JAC.png
已复制: /root/autodl-tmp/data/binary_masks/OCHF8JAI.png -> /root/autodl-tmp/new_test_dataset/binary_masks/./OCHF8JAI.png
已复制: /root/autodl-tmp/data/binary