In [None]:
import os
import nibabel as nib
import numpy as np
from sklearn.cluster import KMeans, MiniBatchKMeans
from sklearn.metrics import silhouette_score
from sklearn.preprocessing import StandardScaler
from scipy.ndimage import gaussian_filter
import matplotlib.pyplot as plt
from tqdm import tqdm
import pandas as pd
from skimage.measure import label, regionprops
from joblib import Parallel, delayed
import warnings
import psutil
import logging
from datetime import datetime

# 初始化日志
logging.basicConfig(
    filename=f'segmentation_log_{datetime.now().strftime("%Y%m%d_H%M%S")}.log',
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s'
)

# 常量定义
MIN_CLUSTERS = 3
MAX_CLUSTERS = 100
CLUSTER_SEARCH_RANGE = range(5, 51, 5)  # 用于寻找最佳聚类数的搜索范围
MIN_SAMPLE_POINTS = 50  # 用于聚类数搜索的最小采样点数
MAX_FEATURES = 250000  # 最大特征点数
BATCH_SIZE = 64  # 分块处理大小

def check_memory_safety(required_mb):
    """检查内存是否足够"""
    free_mem = psutil.virtual_memory().available / (1024 ** 2)
    if free_mem < required_mb * 1.2:
        raise MemoryError(f"需要{required_mb}MB内存，当前可用仅{free_mem:.1f}MB")

def get_voxel_size(img):
    """获取体素物理尺寸(mm)"""
    return img.header.get_zooms()

def normalize_data(data, roi_mask):
    """基于ROI区域的智能归一化 (内存优化版)"""
    if np.count_nonzero(roi_mask) == 0:
        return data
    
    # 分块计算统计量
    sum_total = 0.0
    sum_sq = 0.0
    count = 0
    
    for z in range(data.shape[2]):
        slice_mask = roi_mask[:, :, z]
        if np.any(slice_mask):
            slice_data = data[:, :, z][slice_mask]
            sum_total += np.sum(slice_data, dtype=np.float64)
            sum_sq += np.sum(slice_data**2, dtype=np.float64)
            count += np.count_nonzero(slice_mask)
    
    if count == 0:
        return data
    
    mean_val = sum_total / count
    std_val = np.sqrt((sum_sq / count) - (mean_val ** 2))
    
    # 分块归一化
    for z in range(data.shape[2]):
        data[:, :, z] = (data[:, :, z] - mean_val) / (std_val + 1e-8)
    
    np.nan_to_num(data, copy=False, nan=0.0, posinf=0.0, neginf=0.0)
    return data

def calculate_silhouette(features, k):
    """计算单个k值的轮廓系数"""
    if len(features) < k:
        return -1
    kmeans = MiniBatchKMeans(
        n_clusters=k,
        batch_size=BATCH_SIZE,
        n_init=3,
        random_state=42
    )
    labels = kmeans.fit_predict(features)
    return silhouette_score(features, labels) if k > 1 else 0

def find_optimal_clusters(features):
    """
    使用轮廓系数寻找最佳聚类数
    返回: (最佳聚类数, 轮廓系数)
    """
    try:
        # 限制特征点数量以提高速度
        sample_size = min(MAX_FEATURES, len(features))
        if sample_size < MIN_SAMPLE_POINTS:
            return min(MAX_CLUSTERS, max(MIN_CLUSTERS, sample_size//10)), 0
        
        sample_idx = np.random.choice(len(features), sample_size, replace=False)
        sample_features = features[sample_idx]
        
        # 并行计算轮廓系数
        silhouettes = Parallel(n_jobs=-1)(
            delayed(calculate_silhouette)(sample_features, k) 
            for k in CLUSTER_SEARCH_RANGE
        )
        
        optimal_k = CLUSTER_SEARCH_RANGE[np.argmax(silhouettes)]
        return optimal_k, max(silhouettes)
    except Exception as e:
        logging.warning(f"寻找最佳聚类数失败: {str(e)}")
        return MIN_CLUSTERS, 0

def memory_safe_clustering(points, intensities, optimal_k):
    """内存安全的聚类实现 (优化版)"""
    try:
        # 使用32位浮点数
        features = np.hstack([
            points.astype(np.float32), 
            intensities.astype(np.float32)
        ])
        
        # 分批次训练
        kmeans = MiniBatchKMeans(
            n_clusters=optimal_k,
            batch_size=BATCH_SIZE,
            max_iter=100,
            n_init=3,
            random_state=42
        )
        kmeans.fit(features)
        return kmeans.labels_
    except MemoryError:
        logging.warning("内存不足，使用简化聚类")
        return np.zeros(len(points), dtype=int)

def postprocess_segmentation(segments, min_voxels=20):
    """优化后的后处理"""
    # 使用高斯滤波（更快且保留边缘）
    segments = gaussian_filter(segments.astype(np.float32), sigma=0.5).astype(np.int16)
    
    # 移除小区域
    labeled = label(segments > 0)
    regions = regionprops(labeled)
    for reg in regions:
        if reg.area < min_voxels:
            segments[labeled == reg.label] = 0
    return segments

def process_single_image(args):
    """处理单个图像的核心函数（内存优化版）"""
    image_path, mask_path, output_dir, min_lesion_size_mm3 = args
    try:
        # 1. 验证输入文件
        if not all(os.path.exists(p) for p in [image_path, mask_path]):
            raise FileNotFoundError("输入文件不存在")
            
        # 2. 加载数据（使用float32节省内存）
        img = nib.load(image_path)
        data = img.get_fdata().astype(np.float32)
        mask = nib.load(mask_path).get_fdata().astype(bool)
        roi = mask > 0
        
        # 内存检查
        voxel_count = np.prod(data.shape)
        required_mb = voxel_count * 4 / (1024 ** 2)  # float32占4字节
        check_memory_safety(required_mb)
        
        if not np.any(roi):
            logging.info(f"{os.path.basename(image_path)}: 无有效ROI")
            return None, 0, image_path

        # 3. 预处理
        voxel_size = get_voxel_size(img)
        voxel_volume = np.prod(voxel_size)
        data = normalize_data(data, roi)
        points = np.argwhere(roi).astype(np.float32)
        
        # 4. 特征工程（限制采样数量）
        sample_step = max(1, len(points)//MIN_SAMPLE_POINTS)
        sample_points = points[::sample_step]
        sample_intensities = data[tuple(sample_points.T.astype(int))].reshape(-1, 1)
        features = np.hstack([sample_points, sample_intensities])
        
        # 5. 确定最佳聚类数
        optimal_k, sil_score = find_optimal_clusters(features)
        logging.info(f"{os.path.basename(image_path)}: 使用{optimal_k}个聚类 (轮廓系数={sil_score:.2f})")
        
        # 6. 全量数据聚类
        intensities = data[roi].reshape(-1, 1).astype(np.float32)
        labels = memory_safe_clustering(points, intensities, optimal_k)
        
        segments = np.zeros(data.shape, dtype=np.int16)
        segments[tuple(points.T.astype(int))] = labels + 1
        
        # 7. 后处理
        min_voxels = max(1, int(min_lesion_size_mm3 / voxel_volume))
        segments = postprocess_segmentation(segments, min_voxels)
        
        # 8. 保存结果
        if output_dir:
            os.makedirs(output_dir, exist_ok=True)
            output_path = os.path.join(output_dir, os.path.basename(image_path))
            nib.save(nib.Nifti1Image(segments, img.affine), output_path)
            
        return segments, optimal_k, image_path
        
    except Exception as e:
        logging.error(f"处理失败 {os.path.basename(image_path)}: {str(e)}", exc_info=True)
        return None, 0, image_path

def batch_process(images_dir, masks_dir, output_dir, min_lesion_size_mm3=50, n_jobs=4):
    """并行批量处理"""
    # 验证输入目录
    if not all(os.path.isdir(d) for d in [images_dir, masks_dir]):
        raise ValueError("输入目录无效")
    
    # 获取文件列表
    image_files = sorted([f for f in os.listdir(images_dir) if f.endswith('.nii.gz')])
    mask_files = sorted([f for f in os.listdir(masks_dir) if f.endswith('.nii.gz')])
    
    if len(image_files) != len(mask_files):
        raise ValueError("图像与掩膜数量不匹配")
    
    # 准备参数
    args_list = [(os.path.join(images_dir, img), 
                 os.path.join(masks_dir, mask),
                 output_dir,
                 min_lesion_size_mm3) for img, mask in zip(image_files, mask_files)]
    
    # 内存检查
    if psutil.virtual_memory().percent > 90:
        logging.warning("内存使用过高，建议减少并行任务数")
        n_jobs = max(1, n_jobs//2)
    
    # 并行处理
    logging.info(f"开始处理 {len(image_files)} 个样本 (并行数={n_jobs})")
    results = Parallel(n_jobs=n_jobs, max_nbytes=None)(
        delayed(process_single_image)(args) for args in tqdm(args_list))
    
    # 收集统计信息
    stats = []
    for seg, k, img_path in results:
        if seg is not None:
            unique, counts = np.unique(seg[seg > 0], return_counts=True)
            stats.append({
                'sample': os.path.basename(img_path),
                'optimal_clusters': k,
                'total_voxels': np.sum(seg > 0),
                'min_region_size': np.min(counts) if len(counts) > 0 else 0,
                'median_region_size': np.median(counts) if len(counts) > 0 else 0,
                'max_region_size': np.max(counts) if len(counts) > 0 else 0
            })
    
    # 保存统计结果
    stats_df = pd.DataFrame(stats)
    stats_path = os.path.join(output_dir, "segmentation_stats.csv")
    stats_df.to_csv(stats_path, index=False)
    
    # 打印摘要
    logging.info("\n===== 处理完成 =====")
    logging.info(f"成功处理: {len(stats)}/{len(image_files)}")
    logging.info(f"平均聚类数: {stats_df['optimal_clusters'].mean():.1f}")
    logging.info(f"区域大小分布 (voxels):")
    logging.info(f"- 最小: {stats_df['min_region_size'].min()}")
    logging.info(f"- 中位数: {stats_df['median_region_size'].median()}")
    logging.info(f"- 最大: {stats_df['max_region_size'].max()}")
    
    return stats_df

if __name__ == "__main__":
    # 配置参数
    IMAGES_DIR = r'E:\00000wwc\000houai\isurvival'
    MASKS_DIR = r'E:\00000wwc\000houai\msurvival'
    OUTPUT_DIR = r'E:\00000wwc\000houai\habitat_optimized_fina2'
    MIN_LESION_SIZE_MM3 = 50  # 最小病灶体积(mm³)
    N_JOBS = 6  # 并行任务数
    
    # 创建输出目录
    os.makedirs(OUTPUT_DIR, exist_ok=True)
    
    # 运行处理
    try:
        stats = batch_process(
            images_dir=IMAGES_DIR,
            masks_dir=MASKS_DIR,
            output_dir=OUTPUT_DIR,
            min_lesion_size_mm3=MIN_LESION_SIZE_MM3,
            n_jobs=N_JOBS
        )
        print("处理完成！结果保存在:", OUTPUT_DIR)
    except Exception as e:
        logging.critical(f"主程序失败: {str(e)}", exc_info=True)
        print(f"运行失败: {str(e)}")

 82%|███████████████████████████████████████████████████████████████████▋              | 66/80 [01:47<00:24,  1.76s/it]