In [2]:
# %% [markdown]
"""
# TransUNet肾脏分割可视化

本脚本提供以下功能：
1. 加载预训练的TransUNet模型
2. 执行单张图像推理
3. 应用后处理优化分割结果
4. 可视化原始图像、真实标签、预测结果对比
5. 计算定量评估指标
"""

# %%
import sys
import os
sys.path.append('/root/TransUNet_fusion')  # 添加项目根目录到Python路径
import numpy as np
import torch
import matplotlib.pyplot as plt
import cv2
import scipy.io as sio
from PIL import Image
from networks.vit_seg_modeling import VisionTransformer_mixRf as ViT_seg_mixRf
from networks.vit_seg_modeling import CONFIGS as CONFIGS_ViT_seg
from torchvision import transforms
# %% [markdown]
"""
## 1. 后处理函数定义
"""
# %%
def remove_small_regions(mask, min_area=500):
    """去除小面积不连通区域"""
    mask = (mask * 255).astype(np.uint8)
    num_labels, labels, stats, _ = cv2.connectedComponentsWithStats(mask, connectivity=8)
    cleaned_mask = np.zeros_like(mask)
    for i in range(1, num_labels):
        if stats[i, cv2.CC_STAT_AREA] >= min_area:
            cleaned_mask[labels == i] = 255
    return cleaned_mask.astype(float) / 255.0

def fill_holes(mask):
    """填充掩码中的孔洞"""
    mask = (mask * 255).astype(np.uint8)
    contours, hierarchy = cv2.findContours(mask, cv2.RETR_CCOMP, cv2.CHAIN_APPROX_SIMPLE)
    if contours:
        for i, contour in enumerate(contours):
            if hierarchy[0][i][3] != -1:  # 孔洞检测
                cv2.drawContours(mask, [contour], 0, 255, -1)
    return mask.astype(float) / 255.0
    
def make_convex(mask):
    """
    将掩码转换为凸形。
    :param mask: 输入掩码（二值图像，0 或 1）。
    :return: 凸形掩码。
    """
    # 转换为uint8
    mask = (mask * 255).astype(np.uint8)
    
    # 找到外部轮廓
    contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    
    # 创建一个空白掩码
    convex_mask = np.zeros_like(mask)
    
    # 计算凸包并填充
    for contour in contours:
        hull = cv2.convexHull(contour)
        cv2.drawContours(convex_mask, [hull], 0, 255, -1)
    
    return convex_mask.astype(float) / 255.0
    
def postprocess_mask(pred_mask, min_area=500):
    """完整后处理流程"""
    mask = (pred_mask > 0.5).astype(float)
    mask = remove_small_regions(mask, min_area)
    mask = fill_holes(mask)
    mask = make_convex(mask)
    return mask

# %% [markdown]
"""
## 2. 模型加载与初始化
"""
# %%
def load_transunet_model(model_path, img_size=224):
    """加载预训练的TransUNet模型"""
    config = CONFIGS_ViT_seg['R50-ViT-B_16']
    config.n_classes = 1
    config.n_skip = 3
    config.patches.grid = (int(img_size / 16), int(img_size / 16))
    
    model = ViT_seg_mixRf(config, img_size=img_size, num_classes=config.n_classes)
    model.load_state_dict(torch.load(model_path))
    model.cuda()
    model.eval()
    return model

# 配置路径
MODEL_PATH = "/root/TransUNet_fusion/new_model/ours (rf+image)1/best_model.pth"
DATA_DIR = "/root/autodl-tmp/data/imgs"
MASK_DIR = "/root/autodl-tmp/data/binary_masks"
RF_DIR = "/root/autodl-tmp/data/rfs"
OUTPUT_DIR = "/root/TransUNet_fusion/ours_visualization_results"
RESULTS_TXT = os.path.join(OUTPUT_DIR, "results_summary.txt")  # 新增结果文件路径
os.makedirs(OUTPUT_DIR, exist_ok=True)

# %% [markdown]
"""
## 3. 数据预处理
"""
# %%
def preprocess_image(image_path, img_size=224):
    img_transform = transforms.Compose([
        transforms.ToPILImage(),
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5], std=[0.5])
    ])
    img = Image.open(image_path).convert('L')
    img = np.array(img)
    img = img_transform(img)
    img = img.unsqueeze(0).cuda()  # (1, 1, H, W)
    return img

def load_ground_truth(mask_path, img_size=224):
    mask = Image.open(mask_path).convert('L')
    mask = mask.resize((img_size, img_size))
    mask = np.array(mask) > 0
    return mask.astype(np.float32)

# %% [markdown]
"""
## 4. 推理与可视化
"""
# %%
def visualize_results(image, true_mask, pred_mask, processed_mask, save_path=None):

    plt.figure(figsize=(18, 6))
    
    plt.subplot(1, 4, 1)
    plt.imshow(image, cmap='gray')
    plt.title("Input Image")
    plt.axis('off')
    
    plt.subplot(1, 4, 2)
    plt.imshow(true_mask, cmap='gray')
    plt.title("Ground Truth")
    plt.axis('off')
    
    plt.subplot(1, 4, 3)
    plt.imshow(pred_mask, cmap='gray')
    plt.title("Raw Prediction")
    plt.axis('off')
    
    plt.subplot(1, 4, 4)
    plt.imshow(processed_mask, cmap='gray')
    plt.title("Post-processed")
    plt.axis('off')
    
    if save_path:
        plt.savefig(save_path, bbox_inches='tight', dpi=300)
    plt.show()

# %% [markdown]
"""
## 5. 评估指标计算
"""
# %%
def calculate_metrics(true_mask, pred_mask):
    
    y_true = true_mask.astype(bool)
    y_pred = pred_mask.astype(bool)
    
    tp = np.sum(y_true & y_pred)
    fp = np.sum(~y_true & y_pred)
    fn = np.sum(y_true & ~y_pred)
    
    precision = tp / (tp + fp + 1e-10)
    recall = tp / (tp + fn + 1e-10)
    dice = 2 * tp / (2 * tp + fp + fn + 1e-10)
    
    # print(f"\nEvaluation Metrics:")
    print(f"Dice Score: {dice:.4f}")
    # print(f"Precision: {precision:.4f}")
    # print(f"Recall: {recall:.4f}")
    return dice
    
def dice_score(y_true, y_pred):
    intersection = np.sum(y_true * y_pred)
    union = np.sum(y_true) + np.sum(y_pred)
    dice = (2. * intersection) / (union + 1e-7)
    print(f"Dice Score: {dice:.4f}")
    return dice
"""
## 6. 主执行流程
"""
# %%
# 加载模型
model = load_transunet_model(MODEL_PATH)
    
# 获取测试样本
test_images = sorted([f for f in os.listdir(DATA_DIR) if f.endswith('.jpg')])
test_masks = sorted([f for f in os.listdir(MASK_DIR) if f.endswith('.png')])
test_rfs = sorted([f for f in os.listdir(RF_DIR) if f.endswith('.mat')]
                 )
def load_rf_image(rfimg_path, img_size=224):
    """加载并处理RF图像（保持单通道），返回归一化后的RF数据和tensor"""
    rf_transform = transforms.Compose([
        transforms.ToPILImage(),
        transforms.Resize((224, 224)),
        transforms.ToTensor()
    ])
    mat_data = sio.loadmat(rfimg_path)
    frameRF = mat_data['frameRF']
    rf_data = frameRF[0]['data'][0]
    
    rf_data = np.abs(rf_data)
    rf_data = np.log(rf_data + 1e-6)
    
    rf_data = (rf_data - rf_data.min()) / (rf_data.max() - rf_data.min() + 1e-6)
    rf_data = torch.from_numpy(rf_data).float()
    rf_data = rf_data.unsqueeze(0)
    rf_tensor = rf_transform(rf_data)
    rf_tensor = rf_tensor.unsqueeze(0).cuda()
    return rf_tensor
    
def load_original_image_mask(img_path, mask_path, rf_path):
    """加载原始尺寸的图像和掩码"""
    img = Image.open(img_path).convert('L')
    mask = Image.open(mask_path).convert('L')
    original_size = img.size
    
    img_np = np.array(img)
    mask_np = (np.array(mask) > 0).astype(np.float32)
    
    return img_np, mask_np, original_size

def resize_mask(mask, target_size):
    mask_img = Image.fromarray((mask * 255).astype(np.uint8))
    
    resized_img = mask_img.resize(target_size, Image.LANCZOS)
    
    return (np.array(resized_img) > 127).astype(np.float32)

total_dice = 0
with open(RESULTS_TXT, 'w') as f:
    f.write("Image Name\tDice Score\n")
    f.write("------------------------\n")
for img_name, mask_name, rf_name in zip(test_images, test_masks, test_rfs):
    img_path = os.path.join(DATA_DIR, img_name)
    mask_path = os.path.join(MASK_DIR, mask_name)
    rf_path = os.path.join(RF_DIR, rf_name)
    print(img_path)
    original_img, original_mask, original_size = load_original_image_mask(img_path, mask_path, rf_path)
    
    img_tensor = preprocess_image(img_path)
    rf_tensor = load_rf_image(rf_path)

    with torch.no_grad():
        pred = model(img_tensor, rf_tensor)
        rf_aligned = model.rf_stn(rf_tensor).cpu().squeeze()
        pred_mask = torch.sigmoid(pred).squeeze().cpu().numpy()
    
    pred_mask_original_size = resize_mask(pred_mask > 0.5, original_size)
    processed_mask_original_size = postprocess_mask(pred_mask_original_size)
    
    # 可视化（原始尺寸）
    save_path = os.path.join(OUTPUT_DIR, f"result_{os.path.splitext(img_name)[0]}.png")
    # visualize_results(original_img, 
    #                 original_mask, 
    #                 pred_mask_original_size, 
    #                 processed_mask_original_size,
    #                 save_path)
    
    print(f"\nOriginal Size Results for {img_name}:")

    dice = calculate_metrics(original_mask, pred_mask_original_size)
    total_dice += dice
    with open(RESULTS_TXT, 'a') as f:
        f.write(f"{img_name}\t{dice:.4f}\n")
    # 保存原始尺寸预测结果
    cv2.imwrite(
        os.path.join(OUTPUT_DIR, f"ours_{img_name}"), 
        (pred_mask_original_size * 255).astype(np.uint8)
    )
mean_dice = total_dice / len(test_images)
print("Evaluation Finished: Dice Score is ", mean_dice)

/root/autodl-tmp/data/imgs/OCHADNGK.jpg

Original Size Results for OCHADNGK.jpg:
Dice Score: 0.9559
/root/autodl-tmp/data/imgs/OCHADNP4.jpg

Original Size Results for OCHADNP4.jpg:
Dice Score: 0.9073
/root/autodl-tmp/data/imgs/OCHALB1C.jpg

Original Size Results for OCHALB1C.jpg:
Dice Score: 0.9391
/root/autodl-tmp/data/imgs/OCHALB1G.jpg

Original Size Results for OCHALB1G.jpg:
Dice Score: 0.9483
/root/autodl-tmp/data/imgs/OCHALB1K.jpg

Original Size Results for OCHALB1K.jpg:
Dice Score: 0.9444
/root/autodl-tmp/data/imgs/OCHF8JA8.jpg

Original Size Results for OCHF8JA8.jpg:
Dice Score: 0.8199
/root/autodl-tmp/data/imgs/OCHF8JAC.jpg

Original Size Results for OCHF8JAC.jpg:
Dice Score: 0.9099
/root/autodl-tmp/data/imgs/OCHF8JAI.jpg

Original Size Results for OCHF8JAI.jpg:
Dice Score: 0.9041
/root/autodl-tmp/data/imgs/OCHF8JAM.jpg

Original Size Results for OCHF8JAM.jpg:
Dice Score: 0.9419
/root/autodl-tmp/data/imgs/OCHFHDQU.jpg

Original Size Results for OCHFHDQU.jpg:
Dice Score: 0.8970


In [7]:
import numpy as np
import scipy.io as sio
import matplotlib.pyplot as plt
import os
from matplotlib.colors import LinearSegmentedColormap

def load_rf_data(rfimg_path):
    """加载原始RF数据（不应用任何处理）"""
    mat_data = sio.loadmat(rfimg_path)
    frameRF = mat_data['frameRF']
    return frameRF[0]['data'][0]

def process_rf_data(rf_data, use_log=True):
    """处理RF数据（可选择是否应用log变换）"""
    rf_data = np.abs(rf_data)
    if use_log:
        rf_data = np.log(1+rf_data + 1e-6)
    rf_data = (rf_data - rf_data.min()) / (rf_data.max() - rf_data.min() + 1e-6)
    return rf_data

def create_red_green_cmap():
    """创建红(负值)-绿(正值)颜色映射"""

    colors = [(1,0.2,0.2), (0.8,0.8,0.8), (0.2,1,0.2)]      # 深绿  # 红 -> 黑 -> 绿
    return LinearSegmentedColormap.from_list('red_green', colors)

# 选择示例RF文件路径
rfimg_path = "/root/autodl-tmp/data/rfs/OCHADNGK_RF.mat"
output_dir = '/root/TransUNet_fusion/rf_data'
os.makedirs(output_dir, exist_ok=True)

# 加载原始数据
raw_rf = load_rf_data(rfimg_path)

# 设置全局绘图参数
plt.rcParams['figure.dpi'] = 300  # 提高分辨率
plt.rcParams['font.size'] = 10

# 计算长宽比
aspect_ratio = raw_rf.shape[0] / raw_rf.shape[1]

def plot_square(data, filename, use_log_scale=False, is_red_green=False):
    """通用绘图函数（无坐标、无标题）"""
    plt.figure(figsize=(8, 8))
    
    # 动态范围设置
    if is_red_green:
        vmax = np.max(np.abs(data))
        vmin = -vmax
        cmap = create_red_green_cmap()
        norm = None
    else:
        vmin = np.percentile(data, 1)
        vmax = np.percentile(data, 99)
        cmap = 'gray'
        norm = LogNorm(vmin=max(vmin, 1e-6), vmax=vmax) if use_log_scale else plt.Normalize(vmin=vmin, vmax=vmax)
    
    plt.imshow(data,
              cmap=cmap,
              norm=norm,
              aspect=1/aspect_ratio,
              extent=[0, data.shape[1], data.shape[0], 0])
    
    plt.axis('off')  # 完全关闭坐标轴
    plt.savefig(os.path.join(output_dir, filename), 
               bbox_inches='tight', 
               pad_inches=0,
               dpi=300)
    plt.close()

# # 1. 原始数据（绝对值）
# plot_square(np.abs(raw_rf), 'raw_rf_abs.png')

# 2. 原始数据（红绿正负）
plot_square(raw_rf, 'raw_rf_red_green.png', is_red_green=True)

# 3. 处理后数据（带log）
# rf_with_log = process_rf_data(raw_rf, use_log=True)
# plot_square(rf_with_log, 'processed_with_log.png', use_log_scale=True)

# 4. 处理后数据（不带log）
# rf_without_log = process_rf_data(raw_rf, use_log=False)
# plot_square(rf_without_log, 'processed_without_log.png')

IndentationError: unexpected indent (419573415.py, line 24)

In [2]:
import numpy as np
import scipy.io as sio
import matplotlib.pyplot as plt
import os
from matplotlib.colors import LinearSegmentedColormap
from PIL import Image  # 用于调整图像大小

def load_rf_data(rfimg_path):
    """加载原始RF数据（不应用任何处理）"""
    mat_data = sio.loadmat(rfimg_path)
    frameRF = mat_data['frameRF']
    return frameRF[0]['data'][0]

def resize_to_224x224(data):
    """使用PIL将数据调整为224x224"""
    img = Image.fromarray(data)
    return img.resize((224, 224), Image.LANCZOS)

def plot_rf_224x224(data, title, filename, output_dir):
    """绘制并保存224x224的RF图像"""
    # 调整大小
    resized_data = resize_to_224x224(data)
    
    plt.figure(figsize=(5, 5))
    plt.imshow(resized_data, cmap='gray', aspect='auto')
    plt.axis('off')  # 关闭坐标轴
    plt.savefig(
        os.path.join(output_dir, filename),
        bbox_inches='tight',
        pad_inches=0,
        dpi=300
    )
    plt.close()

def plot_rf_pos_neg_colored(data, title, filename, output_dir):
    """绘制正值为绿色、负值为红色的RF图像"""
    # 创建自定义颜色映射
    colors = [(1, 0, 0), (0, 0, 0), (0, 1, 0)]  # 红 -> 黑 -> 绿
    cmap = LinearSegmentedColormap.from_list('pos_neg', colors)
    
    # 调整大小
    resized_data = resize_to_224x224(data)
    
    plt.figure(figsize=(5, 5))
    plt.imshow(
        resized_data,
        cmap=cmap,
        aspect='auto',
        vmin=-np.max(np.abs(data)),  # 对称范围
        vmax=np.max(np.abs(data))
    )
    plt.axis('off')
    plt.savefig(
        os.path.join(output_dir, filename),
        bbox_inches='tight',
        pad_inches=0,
        dpi=300
    )
    plt.close()

# 选择示例RF文件路径
rfimg_path = "/root/autodl-tmp/data/rfs/OCHADNGK_RF.mat"
output_dir = '/root/TransUNet_fusion/rf_data'

# 确保输出目录存在
os.makedirs(output_dir, exist_ok=True)

# 加载原始数据
raw_rf = load_rf_data(rfimg_path)

# 1. 保存224x224的原始RF图像（灰度）
plot_rf_224x224(
    np.abs(raw_rf),  # 取绝对值避免负值影响
    "Original RF (224x224)",
    "raw_rf_224x224.png",
    output_dir
)

# 2. 保存正负值彩色图（红-绿）
plot_rf_pos_neg_colored(
    raw_rf,  # 使用原始数据（包含正负值）
    "RF (Positive=Green, Negative=Red)",
    "raw_rf_pos_neg_colored.png",
    output_dir
)

In [6]:
print("\n=== 原始信号值分析 ===")
print(f"数据类型: {raw_rf.dtype}")
print(f"数组形状: {raw_rf.shape}")
print(f"绝对值范围: [{np.abs(raw_rf).min():.3e}, {np.abs(raw_rf).max():.3e}]")
print(f"实部范围: [{raw_rf.real.min():.3e}, {raw_rf.real.max():.3e}]")
print(f"虚部范围: [{raw_rf.imag.min():.3e}, {raw_rf.imag.max():.3e}]")


=== 原始信号值分析 ===
数据类型: float64
数组形状: (8512, 280)
绝对值范围: [9.063e-09, 1.000e+00]
实部范围: [-8.448e-01, 1.000e+00]
虚部范围: [0.000e+00, 0.000e+00]
