### ZscoreSustain中有这么一条：

    def subtype_and_stage_individuals_newData(self, data_new, samples_sequence, samples_f, N_samples):

        numStages_new                   = self.__sustainData.getNumStages() #data_new.shape[1]
        sustainData_newData             = ZScoreSustainData(data_new, numStages_new)

        ml_subtype,         \
        prob_ml_subtype,    \
        ml_stage,           \
        prob_ml_stage,      \
        prob_subtype,       \
        prob_stage,         \
        prob_subtype_stage          = self.subtype_and_stage_individuals(sustainData_newData, samples_sequence, samples_f, N_samples)

        return ml_subtype, prob_ml_subtype, ml_stage, prob_ml_stage, prob_subtype, prob_stage, prob_subtype_stage

    # ********************* STATIC METHODS
    @staticmethod
    def linspace_local2(a, b, N, arange_N):
        return a + (b - a) / (N - 1.) * arange_N

    @staticmethod

In [10]:
# Load libraries

import os
import pandas
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pySuStaIn
import statsmodels.formula.api as smf
from scipy import stats
import sklearn.model_selection


In [20]:
# Load new data
data = pandas.read_excel('/Users/qiuyuyue/Documents/MRI/data/SuStaIn/mri_from4500_renamed_with_zscores_17zcomined.xlsx')

# store our biomarker labels as a variable
biomarkers = [col for col in data.columns if col.endswith("_zcombined")]

# 对除了 "ChP_sumv_w" 之外的列乘以 -1
for col in biomarkers:
    if col != "ChP_zcombined":
        data[col] = data[col] * (-1)

print(biomarkers)
 

['Amygdala_zcombined', 'Hippo_zcombined', 'Parahippo_zcombined', 'Entorhinal_zcombined', 'Temp_Neo_zcombined', 'PCC_zcombined', 'SPL_zcombined', 'IPL_zcombined', 'Precuneus_zcombined', 'OFC_zcombined', 'PFC_zcombined', 'Lat_Occi_zcombined', 'Med_Occi_zcombined', 'Insula_zcombined', 'BG_zcombined', 'CC_zcombined', 'ChP_zcombined']


In [21]:
# 检查并去除含有 NaN 或 Inf 值的行

# 1. 检查 NaN 或 Inf
# 使用 .isin() 检查是否有 NaN 或 Inf，然后使用 .any(axis=1) 找出包含这些值的行
# ~ 符号用于取反，即保留那些不包含 NaN 或 Inf 的行
rows_to_keep = ~data[biomarkers].replace([np.inf, -np.inf], np.nan).isna().any(axis=1)

# 2. 筛选数据
data_cleaned = data[rows_to_keep].copy()

# 3. 打印结果
num_removed = len(data) - len(data_cleaned)
print(f"\n--- 数据清洗结果 ---")
print(f"原始行数: {len(data)}")
print(f"移除的行数 (含 NaN/Inf): {num_removed}")
print(f"保留的行数 (干净数据): {len(data_cleaned)}")

# 将 data 变量更新为清洗后的数据，以便后续步骤使用
data = data_cleaned


--- 数据清洗结果 ---
原始行数: 3373
移除的行数 (含 NaN/Inf): 0
保留的行数 (干净数据): 3373


In [22]:
# 基本统计
print(data[biomarkers].describe().T)

                       count      mean       std        min       25%  \
Amygdala_zcombined    3373.0  1.186248  2.347180 -32.163653 -0.153969   
Hippo_zcombined       3373.0  3.621499  1.887754 -26.614790  2.536439   
Parahippo_zcombined   3373.0  1.240936  1.553245 -10.317692  0.257861   
Entorhinal_zcombined  3373.0  1.161498  1.442170 -25.951521  0.359245   
Temp_Neo_zcombined    3373.0  1.507117  1.250126  -3.193455  0.630121   
PCC_zcombined         3373.0  1.208086  1.135703  -5.666528  0.429761   
SPL_zcombined         3373.0  1.124454  1.453967 -16.961349  0.237877   
IPL_zcombined         3373.0  0.996915  1.192973  -4.667249  0.224958   
Precuneus_zcombined   3373.0  1.028417  1.721010 -19.039533 -0.086182   
OFC_zcombined         3373.0  1.279857  1.513672 -24.459872  0.449333   
PFC_zcombined         3373.0  0.925015  1.066287  -4.205288  0.252755   
Lat_Occi_zcombined    3373.0  1.138883  1.307613 -28.392792  0.414276   
Med_Occi_zcombined    3373.0  0.327286  1.154182 -1

In [23]:
import numpy as np
import pandas as pd
# ----------------- 数据截断（Clipping）代码 -----------------

clip_min = -4.0
clip_max = 7.0

for col in biomarkers:
    # 截断小于 clip_min 的值
    original_min = data[col].min()
    data[col] = np.maximum(data[col], clip_min)
    
    # 截断大于 clip_max 的值
    original_max = data[col].max()
    data[col] = np.minimum(data[col], clip_max)
    
    # 打印受影响的列信息 (可选)
    if original_min < clip_min or original_max > clip_max:
         print(f"列: {col:<30} | 原始 Min/Max: {original_min:.4f} / {original_max:.4f}")

# 打印新的基本统计，确认截断生效
print("\n--- 截断后的数据基本统计 ---")
print(data[biomarkers].describe().T)

列: Amygdala_zcombined             | 原始 Min/Max: -32.1637 / 10.8185
列: Hippo_zcombined                | 原始 Min/Max: -26.6148 / 11.9342
列: Parahippo_zcombined            | 原始 Min/Max: -10.3177 / 8.1492
列: Entorhinal_zcombined           | 原始 Min/Max: -25.9515 / 6.6882
列: Temp_Neo_zcombined             | 原始 Min/Max: -3.1935 / 9.5991
列: PCC_zcombined                  | 原始 Min/Max: -5.6665 / 8.2118
列: SPL_zcombined                  | 原始 Min/Max: -16.9613 / 10.0913
列: IPL_zcombined                  | 原始 Min/Max: -4.6672 / 9.3989
列: Precuneus_zcombined            | 原始 Min/Max: -19.0395 / 12.0559
列: OFC_zcombined                  | 原始 Min/Max: -24.4599 / 13.3159
列: PFC_zcombined                  | 原始 Min/Max: -4.2053 / 9.6094
列: Lat_Occi_zcombined             | 原始 Min/Max: -28.3928 / 9.3537
列: Med_Occi_zcombined             | 原始 Min/Max: -14.2121 / 6.2627
列: Insula_zcombined               | 原始 Min/Max: -20.3825 / 14.4250
列: BG_zcombined                   | 原始 Min/Max: -12.6490 / 12.8767
列: ChP_

In [24]:
# 准备new data矩阵
import numpy as np
# 选出 biomarker 数据
new_data = data[biomarkers].to_numpy()  # shape = n_samples × n_ROI

# 1️⃣ 准备新样本数据
new_data_matrix = data[biomarkers].values
numStages_new = len(biomarkers)  # 或用字典里已有的数据



In [25]:
N = len(biomarkers)         # number of biomarkers

SuStaInLabels = biomarkers
Z_vals = np.array([[1,2,3]]*N)     # Z-scores for each biomarker
Z_max  = np.array([5]*N)           # maximum z-score

#### 加载已有的model

In [26]:
pk = pandas.read_pickle('/Users/qiuyuyue/Documents/MRI/data/SuStaIn/Output8_原/pickle_files/Output_subtype3.pickle')
# let's take a look at all of the things that exist in SuStaIn's output (pickle) file
pk.keys()

dict_keys(['samples_sequence', 'samples_f', 'samples_likelihood', 'ml_subtype', 'prob_ml_subtype', 'ml_stage', 'prob_ml_stage', 'prob_subtype', 'prob_stage', 'prob_subtype_stage', 'ml_sequence_EM', 'ml_sequence_prev_EM', 'ml_f_EM', 'ml_f_prev_EM'])

In [27]:
 
from pySuStaIn import ZScoreSustainData, ZscoreSustain
# 1. 从 pickle 文件中提取 MCMC 样本
samples_sequence = pk["samples_sequence"]
samples_f = pk["samples_f"]
N_samples = samples_sequence.shape[0] # 使用 MCMC 采样的次数作为样本数

# 2. 初始化 ZscoreSustain 对象

# 注意：这些参数需要与你运行 MCMC 训练模型时使用的参数一致
N_startpoints = 50     # 随便设一个值，此处不跑MCMC，但初始化需要
N_S_max = 4            # 你的pickle文件名是 Output_subtype3，所以这里设置为3
N_iterations_MCMC = 10000 # 随便设一个值
output_folder = '/Users/qiuyuyue/Documents/MRI/data/SuStaIn/Output8_原/'
dataset_name = 'Output'
use_parallel_startpoints = False

# 实例化 ZscoreSustain 模型
sustain_object = pySuStaIn.ZscoreSustain(
    # 传入要预测的数据
    new_data, 
    Z_vals,
    Z_max,
    SuStaInLabels,
    N_startpoints,
    N_S_max, 
    N_iterations_MCMC, 
    output_folder, 
    dataset_name, 
    use_parallel_startpoints
)

# 3. 运行预测函数，获取每个样本的分型分期结果
# N_samples: 这里我们使用 MCMC 样本的总数，通常是可靠的选择。
ml_subtype, \
prob_ml_subtype, \
ml_stage, \
prob_ml_stage, \
prob_subtype, \
prob_stage, \
prob_subtype_stage = sustain_object.subtype_and_stage_individuals_newData(
    new_data, 
    samples_sequence, 
    samples_f, 
    N_samples
)

# 4. 将分型和分期结果添加到原始 DataFrame
data['ML_Subtype'] = ml_subtype
data['ML_Stage'] = ml_stage
data['ML_Subtype'] = data['ML_Subtype'] + 1


# 查看结果的分布
print("\n--- 最大似然子类型 (ML_Subtype) 分布 ---")
print(data['ML_Subtype'].value_counts())
print("\n--- 最大似然阶段 (ML_Stage) 分布 ---")
print(data['ML_Stage'].value_counts())


# 如果你想保存带有结果的数据集
data.to_excel('/Users/qiuyuyue/Documents/MRI/data/SuStaIn/mri_from4500_with_sustain_results_17ROI.xlsx', index=False)



--- 最大似然子类型 (ML_Subtype) 分布 ---
ML_Subtype
1.0    2394
3.0     348
2.0     331
4.0     300
Name: count, dtype: int64

--- 最大似然阶段 (ML_Stage) 分布 ---
ML_Stage
2.0     279
1.0     226
3.0     180
4.0     126
8.0     123
16.0    121
11.0    117
6.0     116
7.0     114
17.0    114
15.0    107
9.0     106
5.0     106
12.0    103
10.0    103
13.0    101
0.0      98
21.0     94
20.0     86
18.0     82
14.0     81
19.0     77
22.0     73
23.0     63
26.0     60
24.0     53
27.0     51
30.0     44
25.0     36
29.0     35
33.0     35
35.0     30
28.0     27
31.0     26
37.0     25
32.0     23
34.0     20
41.0     17
38.0     15
36.0     15
39.0     14
42.0     11
45.0      9
44.0      9
40.0      7
43.0      6
46.0      5
48.0      3
49.0      1
Name: count, dtype: int64
