In [1]:
import pandas as pd
import numpy as np
from tqdm import tqdm
from maml.sampling.direct import BirchClustering, DIRECTSampler, SelectKFromClusters
from dscribe.descriptors import SOAP
from ase.io import write
import os

N_cluster = 5
output_dir = f"C{N_cluster}"
os.makedirs(output_dir, exist_ok=True)

def get_soap(atom):
    species = ['W']
    r_cut = 5
    n_max = 10
    l_max = 8
    
    # Setting up the SOAP descriptor
    average_soap = SOAP(
        species=species,
        r_cut=r_cut,
        n_max=n_max,
        l_max=l_max,
        average="inner",
        periodic=True,
        sparse=False
    )
    soap_matrix = []
    for i, iatom in enumerate(tqdm(atom, desc="Processing training atoms")):
        soap_iatom = average_soap.create(iatom)  # Compute the SOAP vector
        soap_matrix.append(soap_iatom)  # Store the SOAP vector in the matrix
        
    return soap_matrix

df = pd.read_pickle("PRM2020.pckl.gzip", compression="gzip")
# atom  = df["ase_atoms"]
# my_soap = get_soap(atom)
# np.save('my_soap.npy',my_soap)
my_soap = np.load('my_soap.npy')

# =========================
# 3. 用 DIRECTSampler 找到每个簇的“中心”索引
# =========================
DIRECT_partitioner = DIRECTSampler(
    structure_encoder=None,
    clustering=BirchClustering(n=N_cluster, threshold_init=0.001),
    select_k_from_clusters=SelectKFromClusters(selection_criteria="center", n_sites=None)
)
DIRECT_partition = DIRECT_partitioner.fit_transform(my_soap)
selected_indexes = DIRECT_partition['selected_indexes']  # 中心点索引（array）

# 把中心对应的原子结构写到一个 extxyz 轨迹文件里
df_center = df.iloc[selected_indexes]
atoms_center_list = df_center["ase_atoms"].tolist()
outfile_center_xyz = os.path.join(output_dir, "selected.extxyz")
write(outfile_center_xyz, atoms_center_list, format="extxyz")

# 保存中心结构对应的 SOAP 向量到 npy
# 注意：my_soap 已经是二维数组了，可以这样索引
center_soap = my_soap[selected_indexes, :]
np.save(os.path.join(output_dir, "center_soap.npy"), center_soap)

print(f"已将 {len(selected_indexes)} 个中心结构写入：{outfile_center_xyz}")
print(f"已将 {len(selected_indexes)} 行 SOAP 向量写入：{output_dir}/center_soap.npy\n")

# =========================
# 4. 再次用 DIRECTSampler，只做聚类标记（不选中心），提取 labels
# =========================
DIRECT_partitioner1 = DIRECTSampler(
    structure_encoder=None,
    clustering=BirchClustering(n=N_cluster, threshold_init=0.001),
    select_k_from_clusters=None
)
DIRECT_partition1 = DIRECT_partitioner1.fit_transform(my_soap)

labels = DIRECT_partition1['labels']   # 形如 array([0,1,4,1,2,0,...]), 长度 = 样本数
unique_vals = np.unique(labels)        # 得到 [0,1,2,3,4]（因为 N_cluster=5）

# 创建 一个 {簇标签: 对应索引列表} 的字典
indices_dict = {val: np.where(labels == val)[0] for val in unique_vals}

# 简单打印每个簇内包含多少个样本
for val, idx_array in indices_dict.items():
    print(f"Cluster {val}: {len(idx_array)} 个样本")

print("中心索引数组：", selected_indexes)  # 只是做个对比输出

# =========================
# 5. 对每个簇分别输出一个 .xyz 轨迹，以及一个 .pckl.gzip 的子 DataFrame
# =========================
for val in sorted(indices_dict.keys()):
    selected_idx = indices_dict[val]              # ndarray，例如 array([12, 19, 37, ...])
    df_sel = df.iloc[selected_idx]                # 提取属于该簇的所有行
    atoms_list = df_sel["ase_atoms"].tolist()     # ASE Atoms 对象列表

    # 5.1 生成该簇的 .xyz 文件名：e.g. "C5/group_0.xyz"
    filename_xyz = os.path.join(output_dir, f"group_{val}.xyz")
    write(filename_xyz, atoms_list, format="extxyz")
    print(f"Finished writing cluster {val}  ->  {filename_xyz}")

    # 5.2 生成该簇对应子 DataFrame 的 pickle 文件： e.g. "C5/group_0.pckl.gzip"
    filename_pkl = os.path.join(output_dir, f"group_{val}.pckl.gzip")
    df_sel.to_pickle(filename_pkl, compression='gzip', protocol=4)
    print(f"Finished writing cluster {val} DataFrame  ->  {filename_pkl}\n")

  from .autonotebook import tqdm as notebook_tqdm
INFO:maml.sampling.pca:Selected first 15 PCs, explaining 99.26% variance
INFO:maml.sampling.clustering:BirchClustering with threshold_init=0.001 and n=5 gives 5 clusters.
INFO:maml.sampling.stratified_sampling:Finally selected 5 configurations.


已将 5 个中心结构写入：C5/selected.extxyz
已将 5 行 SOAP 向量写入：C5/center_soap.npy



INFO:maml.sampling.pca:Selected first 15 PCs, explaining 99.26% variance
INFO:maml.sampling.clustering:BirchClustering with threshold_init=0.001 and n=5 gives 5 clusters.


Cluster 0: 3081 个样本
Cluster 1: 100 个样本
Cluster 2: 208 个样本
Cluster 3: 85 个样本
Cluster 4: 439 个样本
中心索引数组： [np.int64(495), np.int64(2844), np.int64(3473), np.int64(2792), np.int64(3014)]
Finished writing cluster 0  ->  C5/group_0.xyz
Finished writing cluster 0 DataFrame  ->  C5/group_0.pckl.gzip

Finished writing cluster 1  ->  C5/group_1.xyz
Finished writing cluster 1 DataFrame  ->  C5/group_1.pckl.gzip

Finished writing cluster 2  ->  C5/group_2.xyz
Finished writing cluster 2 DataFrame  ->  C5/group_2.pckl.gzip

Finished writing cluster 3  ->  C5/group_3.xyz
Finished writing cluster 3 DataFrame  ->  C5/group_3.pckl.gzip

Finished writing cluster 4  ->  C5/group_4.xyz
Finished writing cluster 4 DataFrame  ->  C5/group_4.pckl.gzip

