In [None]:
import sys
import os
sys.path.append('')  # 添加项目根目录到Python路径
import numpy as np
import torch
import matplotlib.pyplot as plt
import cv2
import scipy.io as sio
from PIL import Image
from networks.vit_seg_modeling import VisionTransformer_mixRf as ViT_seg_mixRf
from networks.vit_seg_modeling import CONFIGS as CONFIGS_ViT_seg
from torchvision import transforms
# %% [markdown]
"""
## 1. 后处理函数定义
"""
# %%
def remove_small_regions(mask, min_area=500):
    """去除小面积不连通区域"""
    mask = (mask * 255).astype(np.uint8)
    num_labels, labels, stats, _ = cv2.connectedComponentsWithStats(mask, connectivity=8)
    cleaned_mask = np.zeros_like(mask)
    for i in range(1, num_labels):
        if stats[i, cv2.CC_STAT_AREA] >= min_area:
            cleaned_mask[labels == i] = 255
    return cleaned_mask.astype(float) / 255.0

def fill_holes(mask):
    """填充掩码中的孔洞"""
    mask = (mask * 255).astype(np.uint8)
    contours, hierarchy = cv2.findContours(mask, cv2.RETR_CCOMP, cv2.CHAIN_APPROX_SIMPLE)
    if contours:
        for i, contour in enumerate(contours):
            if hierarchy[0][i][3] != -1:  # 孔洞检测
                cv2.drawContours(mask, [contour], 0, 255, -1)
    return mask.astype(float) / 255.0
    
def make_convex(mask):
    """
    将掩码转换为凸形。

    """
    mask = (mask * 255).astype(np.uint8)
    
    contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    
    convex_mask = np.zeros_like(mask)
    
    for contour in contours:
        hull = cv2.convexHull(contour)
        cv2.drawContours(convex_mask, [hull], 0, 255, -1)
    
    return convex_mask.astype(float) / 255.0
    
def postprocess_mask(pred_mask, min_area=500):

    mask = (pred_mask > 0.5).astype(float)
    mask = remove_small_regions(mask, min_area)
    mask = fill_holes(mask)
    mask = make_convex(mask)
    return mask

def load_transunet_model(model_path, img_size=224):
    """加载预训练的TransUNet模型"""
    config = CONFIGS_ViT_seg['R50-ViT-B_16']
    config.n_classes = 1
    config.n_skip = 3
    config.patches.grid = (int(img_size / 16), int(img_size / 16))
    
    model = ViT_seg_mixRf(config, img_size=img_size, num_classes=config.n_classes)
    model.load_state_dict(torch.load(model_path))
    model.cuda()
    model.eval()
    return model

# 配置路径
MODEL_PATH = ".."
DATA_DIR = ".."
MASK_DIR = ".."
RF_DIR = ".."
OUTPUT_DIR = ".."
os.makedirs(OUTPUT_DIR, exist_ok=True)

def preprocess_image(image_path, img_size=224):
    img_transform = transforms.Compose([
        transforms.ToPILImage(),
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5], std=[0.5])
    ])
    img = Image.open(image_path).convert('L')
    img = np.array(img)
    img = img_transform(img)
    img = img.unsqueeze(0).cuda()  # (1, 1, H, W)
    return img

def load_ground_truth(mask_path, img_size=224):
    mask = Image.open(mask_path).convert('L')
    mask = mask.resize((img_size, img_size))
    mask = np.array(mask) > 0
    return mask.astype(np.float32)


def visualize_results(image, true_mask, pred_mask, processed_mask, save_path=None):

    plt.figure(figsize=(18, 6))
    
    plt.subplot(1, 4, 1)
    plt.imshow(image, cmap='gray')
    plt.title("Input Image")
    plt.axis('off')
    
    plt.subplot(1, 4, 2)
    plt.imshow(true_mask, cmap='gray')
    plt.title("Ground Truth")
    plt.axis('off')
    
    plt.subplot(1, 4, 3)
    plt.imshow(pred_mask, cmap='gray')
    plt.title("Raw Prediction")
    plt.axis('off')
    
    plt.subplot(1, 4, 4)
    plt.imshow(processed_mask, cmap='gray')
    plt.title("Post-processed")
    plt.axis('off')
    
    if save_path:
        plt.savefig(save_path, bbox_inches='tight', dpi=300)
    plt.show()

# 加载模型
model = load_transunet_model(MODEL_PATH)
    
# 获取测试样本
test_images = sorted([f for f in os.listdir(DATA_DIR) if f.endswith('.jpg')])
test_masks = sorted([f for f in os.listdir(MASK_DIR) if f.endswith('.png')])
test_rfs = sorted([f for f in os.listdir(RF_DIR) if f.endswith('.mat')]
                 )
def load_rf_image(rfimg_path, img_size=224):
    """加载并处理RF图像（保持单通道），返回归一化后的RF数据和tensor"""
    rf_transform = transforms.Compose([
        transforms.ToPILImage(),
        transforms.Resize((224, 224)),
        transforms.ToTensor()
    ])
    mat_data = sio.loadmat(rfimg_path)
    frameRF = mat_data['frameRF']
    rf_data = frameRF[0]['data'][0]
    
    rf_data = np.abs(rf_data)
    rf_data = np.log(rf_data + 1e-6)
    
    rf_data = (rf_data - rf_data.min()) / (rf_data.max() - rf_data.min() + 1e-6)
    rf_data = torch.from_numpy(rf_data).float()
    rf_data = rf_data.unsqueeze(0)
    rf_tensor = rf_transform(rf_data)
    rf_tensor = rf_tensor.unsqueeze(0).cuda()
    return rf_tensor
    
def load_original_image_mask(img_path, mask_path, rf_path):
    """加载原始尺寸的图像和掩码"""
    img = Image.open(img_path).convert('L')
    mask = Image.open(mask_path).convert('L')
    original_size = img.size
    
    img_np = np.array(img)
    mask_np = (np.array(mask) > 0).astype(np.float32)
    
    return img_np, mask_np, original_size

def resize_mask(mask, target_size):
    mask_img = Image.fromarray((mask * 255).astype(np.uint8))
    
    resized_img = mask_img.resize(target_size, Image.LANCZOS)
    
    return (np.array(resized_img) > 127).astype(np.float32)

for img_name, mask_name, rf_name in zip(test_images, test_masks, test_rfs):
    img_path = os.path.join(DATA_DIR, img_name)
    mask_path = os.path.join(MASK_DIR, mask_name)
    rf_path = os.path.join(RF_DIR, rf_name)
    print(img_path)
    original_img, original_mask, original_size = load_original_image_mask(img_path, mask_path, rf_path)
    
    img_tensor = preprocess_image(img_path)
    rf_tensor = load_rf_image(rf_path)

    with torch.no_grad():
        pred = model(img_tensor, None)
        # rf_aligned = model.rf_stn(rf_tensor).cpu().squeeze()
        pred_mask = torch.sigmoid(pred).squeeze().cpu().numpy()
    
    pred_mask_original_size = resize_mask(pred_mask > 0.5, original_size)
    processed_mask_original_size = postprocess_mask(pred_mask_original_size)
    
    # 可视化（原始尺寸）
    save_path = os.path.join(OUTPUT_DIR, f"result_{os.path.splitext(img_name)[0]}.png")
    # visualize_results(original_img, 
    #                 original_mask, 
    #                 pred_mask_original_size, 
    #                 processed_mask_original_size,
    #                 save_path)
    
    print(f"\nOriginal Size Results for {img_name}:")

    # 保存原始尺寸预测结果
    cv2.imwrite(
        os.path.join(OUTPUT_DIR, f"pred_{img_name}"), 
        (pred_mask_original_size * 255).astype(np.uint8)
    )
    