In [None]:
import os
import multiprocessing
import json
from pathlib import Path
import numpy as np
from openbabel import pybel
import xgboost as xgb
from sklearn.metrics import classification_report
import warnings
import pickle
import pandas as pd
import sys
from contextlib import contextmanager
from concurrent.futures import ProcessPoolExecutor
import concurrent.futures

@contextmanager
def suppress_stdout_stderr():
    # 打开空设备
    with open(os.devnull, 'w') as fnull:
        stdout, stderr = sys.stdout, sys.stderr
        sys.stdout, sys.stderr = fnull, fnull
        try:
            yield
        finally:
            sys.stdout, sys.stderr = stdout, stderr

def calculate_fingerprint(cif_path, hof_name):
    """
    计算分子指纹，返回指纹位向量列表
    """
    print(f"开始计算指纹: {hof_name}")
    with suppress_stdout_stderr():
        mol = next(pybel.readfile("cif", cif_path))
        fp = mol.calcfp(fptype='FP2')  # 使用FP2指纹

    bits_fp = [0] * 1024  # 1024位的位向量
    for bit in fp.bits:
        if bit < 1024:  # 确保 bit 不超过索引范围
            bits_fp[bit] = 1
    print(f"完成计算指纹: {hof_name}")
    return bits_fp

def worker(cif_path, hof_name):
    """
    Worker 函数用于计算分子指纹，不做超时控制。
    """
    if not os.path.exists(cif_path):
        print(f"文件不存在: {cif_path}")
        return hof_name, None
    
    try:
        fingerprint = calculate_fingerprint(cif_path, hof_name)
        return hof_name, fingerprint
    except Exception as e:
        print(f"计算分子指纹时出错 ({hof_name}): {e}")
        return hof_name, None

def update_fingerprints(input_file, new_keys_file, output_file, timeout=5):
    """
    更新现有分子指纹 JSON 文件，将新的分子指纹追加到文件中。
    """
    # 加载现有的分子指纹数据
    if os.path.exists(output_file):
        with open(output_file, 'r') as f:
            fingerprints = json.load(f)
    else:
        fingerprints = {}

    # 读取新 JSON 文件中的 key 列表
    with open(new_keys_file, 'r') as f:
        new_keys_data = json.load(f)
    new_keys = set(new_keys_data.keys())
    
    # 找出需要计算的文件（即新文件中存在，但不在旧指纹文件中的文件）
    keys_to_process = new_keys - set(fingerprints.keys())
    
    # 并行计算新的分子指纹
    def collect_result(result):
        hof_name, fingerprint = result
        if fingerprint is not None:
            fingerprints[hof_name] = fingerprint

    with multiprocessing.Pool(processes=4) as pool:
        results = []
        for hof_name in keys_to_process:
            cif_path = f'/data/user2/wty/HOF/moftransformer/data/HOF_solvent/cifs/{hof_name}.cif'
            result = pool.apply_async(worker, (cif_path, hof_name), callback=collect_result)
            results.append(result)

        # 等待所有任务完成
        for result in results:
            try:
                result.get(timeout=timeout)
            except multiprocessing.TimeoutError:
                print(f"计算超时: {result}")
    
    # 保存更新后的分子指纹数据
    with open(output_file, 'w') as f:
        json.dump(fingerprints, f, indent=4)
    print(f"分子指纹已更新并保存到 {output_file}")

# 使用示例
existing_fp_file = "/data/user2/wty/HOF/moftransformer/data/HOF_pretrain/all_fp.json"  # 现有分子指纹 JSON 文件
new_keys_file = "/data/user2/wty/HOF/moftransformer/data/HOF_time/foldall/train_time.json"  # 包含新键的 JSON 文件
output_file = existing_fp_file  # 将输出文件路径设置为与现有文件相同

update_fingerprints(existing_fp_file, new_keys_file, output_file, timeout=5)


In [None]:
import os
import subprocess
import json
from pathlib import Path

def run_command_on_cif_files(folder_path, command_template, json_file_path):
    # 读取JSON文件中的所有key值
    with open(json_file_path, 'r', encoding='utf-8') as f:
        cif_keys = set(json.load(f).keys())  # 获取所有的 key
    
    # 遍历指定文件夹中的所有 .cif 文件
    for root, _, files in os.walk(folder_path):
        for file in files:
            # 检查文件是否以 .cif 结尾，并且文件名在 cif_keys 中
            if file.endswith('.cif') and file.rsplit('.', 1)[0] in cif_keys:
                cif_file_path = Path(root) / file
                # 构造命令
                command = command_template.format(cif_file_path=cif_file_path)
                # 运行命令
                subprocess.run(command, shell=True)
                print(f"运行命令: {command}")

# 示例调用
folder_path = '/data/user2/wty/HOF/moftransformer/data/HOF_solvent/cifs'  # 替换为你的实际文件夹路径
command_template = './network -vol 1.5 1.5 1000000 {cif_file_path}'
json_file_path = '/data/user2/wty/HOF/moftransformer/data/HOF_pretrain/hofs.json'  # 替换为实际的 JSON 文件路径

run_command_on_cif_files(folder_path, command_template, json_file_path)


In [3]:
import os
import json
import re
from pathlib import Path

def remove_empty_lists_from_json(src_json_path, dest_json_path):
    # 读取源 JSON 文件
    with open(src_json_path, 'r') as src_file:
        data = json.load(src_file)
    
    # 删除值为空列表的键值对
    cleaned_data = {k: v for k, v in data.items() if v != []}
    
    # 将修改后的数据写入目标 JSON 文件
    with open(dest_json_path, 'w') as dest_file:
        json.dump(cleaned_data, dest_file, indent=4)
    print(f"已删除空列表的键值对，并保存到 {dest_json_path}")

class HbondExtractor:
    def __init__(self, cifs_path):
        self.cifs_path = cifs_path

    def get_Hbond_lists(self, cif_id):
        donors, hs, acceptors = [], [], []
        lis_path = os.path.join(self.cifs_path, f"{cif_id}.lis")
        # 假如没有lis文件直接返回空list
        if not os.path.exists(lis_path):
            print(f"No LIS file found for CIF ID {cif_id}.")
            return donors, hs, acceptors
        with open(lis_path, 'r') as file:
            content = file.read()
            # 找到"H....Acceptor"到"Translation of ARU-Code to CIF and Equivalent Position Code"之间的数据块
            data_block_match = re.search(r"(Nr Typ Res Donor.*?)(?=\n[A-Z])", content, re.DOTALL | re.MULTILINE)
        if data_block_match:
            data_block = data_block_match.group(0)
            lines = data_block.splitlines()
            for idx, line in enumerate(lines):
                # 假如line中有？则直接跳过
                if "?" in line:
                    continue
                line = re.sub(r'Intra', ' ', line)
                # 把形如“数字*”的字串替换为“数字 ”
                line = re.sub(r'\d\*', '1 ', line)
                # 替换形如 "_[a-z]" 的后缀
                line = re.sub(r'_[a-z*]', ' ', line)
                line = re.sub(r'_[0-9*]', ' ', line)
                line = re.sub(r'_', ' ', line)
                line = re.sub(r'>', ' ', line)
                line = re.sub(r'<', ' ', line)
                columns = line.split()
                if len(columns) > 1 and (columns[0].isdigit() or columns[0].startswith('**')) and columns[1].isdigit():  # 检查每行是否以数字开头
                    # 提取“元素符号+数字”格式
                    donor = re.search(r'[A-Za-z]+\d+[A-Z]*$', columns[2])
                    h = re.search(r'[A-Za-z]+\d+[A-Z]*$', columns[3])
                    acceptor = re.search(r'[A-Za-z]+\d+[A-Z]*$', columns[4])
                    # 将匹配到的结果添加到列表中, 并且donor不以C开头
                    if donor and not donor.group().startswith('C'):
                        donors.append((donor.group(), idx))
                        if h:
                            hs.append((h.group(), idx))
                        if acceptor:
                            acceptors.append((acceptor.group(), idx))
            # 假如三个list 的长度不相等，则输出cif_id并打印错误信息
            if len(donors) != len(acceptors):
                print('donors:', donors)
                print('hs:', hs)
                print('acceptors:', acceptors)
                print(f"Error in {cif_id}: Donor, H, Acceptor lists have different lengths.")
        return donors, hs, acceptors

    def get_atom_indices(self, cif_id, atoms):
        cif_path = os.path.join(self.cifs_path, f"{cif_id}.cif")
        # print("cif_path:", cif_path)
        # print("atoms:", atoms)
        if not os.path.exists(cif_path):
            print(f"No CIF file found for CIF ID {cif_id}.")
            return []
        atom_indices = []
        with open(cif_path, 'r') as file:
            lines = file.readlines()
            atom_block = False
            atom_list_start_index = None
            for idx, line in enumerate(lines):
                if line.strip() == "_atom_site_occupancy":
                    atom_block = True
                    atom_list_start_index = idx + 1
                elif atom_block and line.strip() == "loop_":
                    break
                elif atom_block:
                    columns = line.split()
                    if len(columns) > 1 and columns[0] in atoms:
                        atom_indices.append(idx - atom_list_start_index)
        return atom_indices
    
    def get_atom_binary_list(self, cif_id, atoms):
        cif_path = os.path.join(self.cifs_path, f"{cif_id}.cif")
        if not os.path.exists(cif_path):
            print(f"No CIF file found for CIF ID {cif_id}.")
            return []
        
        binary_list = []
        with open(cif_path, 'r') as file:
            lines = file.readlines()
            atom_block = False
            for line in lines:
                if line.strip() == "_atom_site_occupancy":
                    atom_block = True
                elif atom_block and line.strip() == "loop_":
                    break
                elif atom_block:
                    columns = line.split()
                    if len(columns) > 1:
                        if columns[1] in atoms:
                            binary_list.append(1)
                        else:
                            binary_list.append(0)
        return binary_list

    def create_json_from_cifs(self, output_json_path):
        hbond_data = {}
        # 遍历文件夹中的所有 .cif 文件
        for filename in os.listdir(self.cifs_path):
            if filename.endswith(".cif"):
                cif_id = os.path.splitext(filename)[0]
                donors, hs, acceptors = self.get_Hbond_lists(cif_id)
                # 将donors, hs, acceptors合并为一个列表，包含元素符号和行号
                all_atoms = list(set(donors + hs + acceptors))
                # print("all_atoms:", all_atoms)
                atom_symbols = [atom[0] for atom in all_atoms]
                atom_indices = self.get_atom_indices(cif_id, atom_symbols)
                hbond_data[cif_id] = atom_indices
        # 将结果写入JSON文件
        with open(output_json_path, 'w') as json_file:
            json.dump(hbond_data, json_file, indent=4)
        print(f"JSON file created at {output_json_path}")

# 使用示例
if __name__ == "__main__":
    cifs_path = "/data/user2/wty/HOF/MOFDiff/mofdiff/data/mof_models/mof_models/bwdb_hoff/hofchecker"  # 替换为你的cif文件夹路径
    output_json_path = "/data/user2/wty/HOF/moftransformer/data/HOF_pretrain_new/fold8/all_tobacco_hbond.json"  # 输出JSON文件的路径
    extractor = HbondExtractor(cifs_path)
    extractor.create_json_from_cifs(output_json_path)
    
    # src_json_path = Path('/data/user2/wty/HOF/moftransformer/data/HOF_pretrain/all_tobacco_hbond.json')  # 替换为你的源 JSON 文件路径
    # dest_json_path = Path('/data/user2/wty/HOF/moftransformer/data/HOF_pretrain/all_tobacco_hbond.json')  # 替换为你的目标 JSON 文件路径
    # remove_empty_lists_from_json(src_json_path, dest_json_path)

    

JSON file created at /data/user2/wty/HOF/moftransformer/data/HOF_pretrain_new/all_tobacco_hbond.json


In [None]:
import csv
import json
import ast

# 文件路径
csv_file_path = '/data/user2/wty/HOF/logs/HOF_pretrain/fold3/nh_na/pretrained_mof_seed0_from_pmtransformer/version_0/val_prediction.csv'
keys_json_path = '/data/user2/wty/HOF/moftransformer/data/HOF_pretrain/fold3/val_mtp.json'
output_json_path = '/data/user2/wty/HOF/moftransformer/data/HOF_pretrain/fold3/val_mtp.json'

# 从CSV中读取mtp_logits并找出最大置信度的分类索引
def extract_max_confidence_indices(csv_path):
    max_indices = []
    with open(csv_path, mode='r', encoding='utf-8') as csvfile:
        reader = csv.DictReader(csvfile)
        for row in reader:
            logits = ast.literal_eval(row['mtp_logits'])  # 将字符串转换为列表
            max_index = logits.index(max(logits))  # 获取最大值的索引
            max_indices.append(max_index)
    return max_indices

# 读取keys并生成输出json
def create_output_json(keys_path, max_indices, output_path):
    with open(keys_path, 'r', encoding='utf-8') as keyfile:
        keys_dict = json.load(keyfile)
    if len(keys_dict) != len(max_indices):
        raise ValueError("keys文件中的键数量与CSV文件中的行数不匹配")
    
    keys = list(keys_dict.keys())
    output_data = {keys[i]: max_indices[i] for i in range(len(keys))}
    
    with open(output_path, 'w', encoding='utf-8') as outfile:
        json.dump(output_data, outfile, indent=4, ensure_ascii=False)

if __name__ == "__main__":
    # 读取CSV文件并提取最大置信度索引
    max_indices = extract_max_confidence_indices(csv_file_path)
    
    # 生成输出JSON文件
    create_output_json(keys_json_path, max_indices, output_json_path)
    
    print(f"结果已保存到 {output_json_path}")

In [None]:
import json

def reset_values_to_zero(input_file, output_file):
    # 读取原始JSON文件
    with open(input_file, 'r', encoding='utf-8') as f:
        data = json.load(f)
    
    # 将所有value置为0
    new_data = {key: 0 for key in data.keys()}
    
    # 将新的数据写入输出文件
    with open(output_file, 'w', encoding='utf-8') as f:
        json.dump(new_data, f, indent=4, ensure_ascii=False)
    
    print("新的JSON文件已保存到", output_file)

# 示例用法
input_file = '/data/user2/wty/HOF/ML_method/all_fp.json'
output_file = '/data/user2/wty/HOF/moftransformer/data/HOF_pretrain/fold2/mtp_train.json'
reset_values_to_zero(input_file, output_file)


In [None]:
import json
import pandas as pd

def update_json_with_csv(csv_file, json_file, output_file):
    # 读取JSON文件中的key
    with open(json_file, 'r', encoding='utf-8') as f:
        json_data = json.load(f)
    keys = list(json_data.keys())

    # 读取CSV文件中的vfp_logits列
    df = pd.read_csv(csv_file)
    logits = df['vfp_logits'].tolist()

    # 检查JSON和CSV文件的行数是否匹配
    if len(keys) != len(logits):
        print("错误: JSON文件中的键数量和CSV文件中的vfp_logits数量不匹配")
        return

    # 将CSV中的vfp_logits作为新的value生成新的JSON数据
    # 如果value 值为负数，则将其设置为0

    updated_data = {key: (logits[i] if logits[i] >= 0 else 0.0) for i, key in enumerate(keys)}

    # 保存新的JSON文件
    with open(output_file, 'w', encoding='utf-8') as f:
        json.dump(updated_data, f, indent=4, ensure_ascii=False)
    print(f"已保存新的JSON文件到 {output_file}")

# 示例用法
csv_file = '/data/user2/wty/HOF/logs/HOF_pretrain/fold3/nh_na/pretrained_mof_seed0_from_pmtransformer/version_0/test_prediction.csv'  # CSV文件路径
json_file = '/data/user2/wty/HOF/moftransformer/data/HOF_pretrain/fold3/test_vfp.json'  # 原始JSON文件路径
output_file = '/data/user2/wty/HOF/moftransformer/data/HOF_pretrain/fold3/test_vfp.json'  # 生成的JSON文件路径

update_json_with_csv(csv_file, json_file, output_file)


In [None]:
import json

def replace_values_from_reference(reference_file, input_files, output_files, missing_keys_file):
    # 加载参考JSON文件
    with open(reference_file, 'r', encoding='utf-8') as f:
        reference_data = json.load(f)

    # 用于保存所有缺失的键
    all_missing_keys = {}

    # 遍历每个输入文件
    for input_file, output_file in zip(input_files, output_files):
        # 加载当前输入JSON文件
        with open(input_file, 'r', encoding='utf-8') as f:
            input_data = json.load(f)
        
        # 创建用于存储更新后的数据字典
        updated_data = {}
        missing_keys = []

        # 遍历输入文件的每个key
        for key in input_data.keys():
            if key in reference_data:
                # 如果参考JSON中存在该key，使用参考JSON的value
                updated_data[key] = reference_data[key]
            else:
                # 如果参考JSON中不存在该key，记录为缺失
                missing_keys.append(key)
                all_missing_keys[key] = 0  # 将缺失的key保存到总的缺失字典中，value设置为0
        
        # 打印缺失的key
        if missing_keys:
            print(f"{input_file} 缺失的key: {missing_keys}")
        
        # 将更新后的数据保存到新的输出文件
        with open(output_file, 'w', encoding='utf-8') as f:
            json.dump(updated_data, f, indent=4, ensure_ascii=False)
        print(f"已保存更新后的文件: {output_file}")

    # 将所有缺失的key保存到指定的缺失键JSON文件
    with open(missing_keys_file, 'w', encoding='utf-8') as f:
        json.dump(all_missing_keys, f, indent=4, ensure_ascii=False)
    print(f"缺失的key已保存到 {missing_keys_file}")

# 示例用法
reference_file = '/data/user2/wty/HOF/moftransformer/data/HOF_pretrain/all_fp.json'  # 参考JSON文件路径
input_files = [
    '/data/user2/wty/HOF/moftransformer/data/HOF_solvent/foldall/train_solvent.json'
]  # 输入JSON文件路径列表
output_files = [
    "/data/user2/wty/HOF/moftransformer/data/HOF_solvent/foldall/train_fp.json"
]  # 输出JSON文件路径列表
missing_keys_file = '/data/user2/wty/HOF/moftransformer/data/HOF_solvent/foldall/missing_keys.json'  # 缺失key的JSON文件路径

replace_values_from_reference(reference_file, input_files, output_files, missing_keys_file)


In [None]:
import os
import json
import shutil

def process_json_files(source_folder, destination_folder):
    # 确保目标文件夹存在
    os.makedirs(destination_folder, exist_ok=True)
    
    # 遍历 source_folder 中的所有文件
    for filename in os.listdir(source_folder):
        if filename.endswith(".json"):
            file_path = os.path.join(source_folder, filename)
            
            # 读取 JSON 文件
            with open(file_path, 'r', encoding='utf-8') as file:
                data = json.load(file)
            
            # 删除以 "sample" 和 "tobacco" 开头的键值对
            keys_to_delete = [key for key in data.keys() if key.startswith("sample") or key.startswith("tobacco")]
            for key in keys_to_delete:
                del data[key]
            
            # 将修改后的数据写回 JSON 文件
            with open(file_path, 'w', encoding='utf-8') as file:
                json.dump(data, file, ensure_ascii=False, indent=4)
            
            # 移动修改后的文件到目标文件夹
            shutil.move(file_path, os.path.join(destination_folder, filename))
            print(f"Processed and moved file: {filename}")

# 使用示例
source_folder = "/data/user2/wty/HOF/moftransformer/data/HOF_pretrain/fold2"
destination_folder = "/data/user2/wty/HOF/moftransformer/data/HOF_pretrain/fold3"
process_json_files(source_folder, destination_folder)


In [None]:
import os
import multiprocessing
import json
from pathlib import Path
import numpy as np
from openbabel import pybel
import xgboost as xgb
from sklearn.metrics import classification_report
import warnings
import pickle
import pandas as pd
import sys
from contextlib import contextmanager
from concurrent.futures import ProcessPoolExecutor
import concurrent.futures

@contextmanager
def suppress_stdout_stderr():
    # 打开空设备
    with open(os.devnull, 'w') as fnull:
        stdout, stderr = sys.stdout, sys.stderr
        sys.stdout, sys.stderr = fnull, fnull
        try:
            yield
        finally:
            sys.stdout, sys.stderr = stdout, stderr

def calculate_fingerprint(cif_path, hof_name):
    """
    计算分子指纹，返回指纹位向量列表
    """
    print(f"开始计算指纹: {hof_name}")
    with suppress_stdout_stderr():
        mol = next(pybel.readfile("cif", cif_path))
        fp = mol.calcfp(fptype='FP2')  # 使用FP2指纹

    bits_fp = [0] * 1024  # 1024位的位向量
    for bit in fp.bits:
        if bit < 1024:  # 确保 bit 不超过索引范围
            bits_fp[bit] = 1
    print(f"完成计算指纹: {hof_name}")
    return bits_fp

def worker(cif_path, hof_name):
    """
    Worker 函数用于计算分子指纹，不做超时控制。
    """
    if not os.path.exists(cif_path):
        print(f"文件不存在: {cif_path}")
        return hof_name, None
    
    try:
        fingerprint = calculate_fingerprint(cif_path, hof_name)
        return hof_name, fingerprint
    except Exception as e:
        print(f"计算分子指纹时出错 ({hof_name}): {e}")
        return hof_name, None

def update_fingerprints(input_file, new_keys_file, output_file, timeout=5):
    """
    更新现有分子指纹 JSON 文件，将新的分子指纹追加到文件中。
    """
    # 加载现有的分子指纹数据
    if os.path.exists(output_file):
        with open(output_file, 'r') as f:
            fingerprints = json.load(f)
    else:
        fingerprints = {}

    # 读取新 JSON 文件中的 key 列表
    with open(new_keys_file, 'r') as f:
        new_keys_data = json.load(f)
    new_keys = set(new_keys_data.keys())
    
    # 找出需要计算的文件（即新文件中存在，但不在旧指纹文件中的文件）
    keys_to_process = new_keys - set(fingerprints.keys())
    
    # 并行计算新的分子指纹
    def collect_result(result):
        hof_name, fingerprint = result
        if fingerprint is not None:
            fingerprints[hof_name] = fingerprint

    with multiprocessing.Pool(processes=40) as pool:
        results = []
        for hof_name in keys_to_process:
            cif_path = f'/data/user2/wty/HOF/moftransformer/data/HOF_solvent/cifs/{hof_name}.cif'
            if not os.path.exists(cif_path):
                print(f"文件不存在: {cif_path}")
                continue
            result = pool.apply_async(worker, (cif_path, hof_name), callback=collect_result)
            results.append(result)

        # 等待所有任务完成
        for result in results:
            try:
                result.get(timeout=timeout)
            except multiprocessing.TimeoutError:
                print(f"计算超时: {result}")
    
    # 保存更新后的分子指纹数据
    with open(output_file, 'w') as f:
        json.dump(fingerprints, f, indent=4)
    print(f"分子指纹已更新并保存到 {output_file}")

# 使用示例
existing_fp_file = "/data/user2/wty/HOF/moftransformer/data/HOF_pretrain/all_fp.json"  # 现有分子指纹 JSON 文件
new_keys_file = "/data/user2/wty/HOF/moftransformer/data/HOF_temperature/foldfake/test_temperature.json"  # 包含新键的 JSON 文件
output_file = existing_fp_file  # 将输出文件路径设置为与现有文件相同

update_fingerprints(existing_fp_file, new_keys_file, output_file, timeout=5)


In [None]:
import json

def delete_keys_from_json(target_json_path, specified_json_path):
    # 读取目标 JSON 文件，获取需要删除的键
    with open(target_json_path, 'r', encoding='utf-8') as target_file:
        target_data = json.load(target_file)
        keys_to_delete = set(target_data.keys())  # 将目标 JSON 文件中的键存储为集合
    
    # 读取指定的 JSON 文件
    with open(specified_json_path, 'r', encoding='utf-8') as specified_file:
        specified_data = json.load(specified_file)
    
    # 从指定 JSON 文件中删除目标 JSON 文件中的键
    for key in keys_to_delete:
        if key in specified_data:
            del specified_data[key]
    
    # 将修改后的数据写回指定的 JSON 文件
    with open(specified_json_path, 'w', encoding='utf-8') as specified_file:
        json.dump(specified_data, specified_file, ensure_ascii=False, indent=4)
    
    print(f"已删除 {len(keys_to_delete)} 个键值对，结果已保存到 {specified_json_path}")

# 使用示例
target_json_path = '/data/user2/wty/HOF/moftransformer/data/HOF_solvent/foldall/missing_keys.json'      # 包含要删除的键的目标 JSON 文件
specified_json_path = '/data/user2/wty/HOF/moftransformer/data/HOF_solvent/foldall/train_solvent.json' # 指定的 JSON 文件
delete_keys_from_json(target_json_path, specified_json_path)


In [None]:
import json
import os

def extract_fp(reference_file, fp_all_file, output_file):
    # 加载参考文件
    with open(reference_file, 'r', encoding='utf-8') as ref_file:
        ref_data = json.load(ref_file)

    # 加载 fp_all 文件
    with open(fp_all_file, 'r', encoding='utf-8') as fp_file:
        fp_data = json.load(fp_file)

    # 初始化输出数据
    output_data = {}

    # 遍历参考文件中的键并从 fp_all 中提取对应的键值对
    for key in ref_data.keys():
        if key in fp_data:
            output_data[key] = fp_data[key]
        else:
            print(f"Error: Key '{key}' not found in {fp_all_file}.")

    # 保存提取的键值对到输出文件
    with open(output_file, 'w', encoding='utf-8') as out_file:
        json.dump(output_data, out_file, ensure_ascii=False, indent=4)
    print(f"提取结果已保存到 {output_file}")

def main():
    base_folder = "/data/user2/wty/HOF/moftransformer/data/HOF_pretrain/fold2"  # 目标文件夹路径
    fp_all_file = '/data/user2/wty/HOF/moftransformer/data/HOF_pretrain/all_fp.json'

    # 定义文件路径
    files = {
        "train": "/data/user2/wty/HOF/moftransformer/data/HOF_pretrain/fold2/train_mtp.json",
        "test": "/data/user2/wty/HOF/moftransformer/data/HOF_pretrain/fold2/test_mtp.json",
        "val": "/data/user2/wty/HOF/moftransformer/data/HOF_pretrain/fold2/val_mtp.json"
    }

    # 对每个文件执行提取操作
    for file_type, filename in files.items():
        reference_file = os.path.join(base_folder, filename)
        output_file = os.path.join(base_folder, f"{file_type}_fp.json")
        extract_fp(reference_file, fp_all_file, output_file)

if __name__ == "__main__":
    main()


In [None]:
import json
from pathlib import Path

def count_key_value_pairs(json_path):
    # 读取 JSON 文件
    with open(json_path, 'r', encoding='utf-8') as json_file:
        data = json.load(json_file)
    
    # 计算键值对数量
    num_pairs = len(data)
    
    print(f"JSON 文件中有 {num_pairs} 个键值对。")

# 示例调用
json_path = Path('/data/user2/wty/HOF/moftransformer/data/HOF_pretrain_new/all_fp.json')  # 替换为你的 JSON 文件路径

count_key_value_pairs(json_path)

In [None]:
import json
import random
from pathlib import Path

def split_json(json_path, train_path, val_path, test_path, train_ratio=0.67, val_ratio=0.17, test_ratio=0.16):
    # 读取 JSON 文件
    with open(json_path, 'r', encoding='utf-8') as json_file:
        data = json.load(json_file)
    
    # 获取所有键值对
    items = list(data.items())
    
    # 打乱顺序
    random.shuffle(items)
    
    # 计算每个集合的大小
    total = len(items)
    train_size = int(total * train_ratio)
    val_size = int(total * val_ratio)
    test_size = total - train_size - val_size
    
    # 划分数据集
    train_items = items[:train_size]
    val_items = items[train_size:train_size + val_size]
    test_items = items[train_size + val_size:]
    
    # 转换为字典
    train_data = dict(train_items)
    val_data = dict(val_items)
    test_data = dict(test_items)
    
    # 写入 JSON 文件
    with open(train_path, 'w', encoding='utf-8') as train_file:
        json.dump(train_data, train_file, indent=4, ensure_ascii=False)
    
    with open(val_path, 'w', encoding='utf-8') as val_file:
        json.dump(val_data, val_file, indent=4, ensure_ascii=False)
    
    with open(test_path, 'w', encoding='utf-8') as test_file:
        json.dump(test_data, test_file, indent=4, ensure_ascii=False)
    
    print(f"数据集已拆分并保存到 {train_path}, {val_path}, {test_path}")

# 示例调用
json_path = '/data/user2/wty/HOF/moftransformer/data/HOF_pretrain/av_volume_fractions.json'  # 替换为你的输入 JSON 文件路径
train_path = '/data/user2/wty/HOF/moftransformer/data/HOF_pretrain/fold3/train_vfp.json'  # 替换为你的输出 train JSON 文件路径
val_path = '/data/user2/wty/HOF/moftransformer/data/HOF_pretrain/fold3/val_vfp.json'  # 替换为你的输出 val JSON 文件路径
test_path = '/data/user2/wty/HOF/moftransformer/data/HOF_pretrain/fold3/test_vfp.json'  # 替换为你的输出 test JSON 文件路径

split_json(json_path, train_path, val_path, test_path)

In [5]:
import torch
import pickle
import os

def get_graph(cif_id, cifs_path):
    file_graph = os.path.join(cifs_path, f"{cif_id}.graphdata")

    graphdata = pickle.load(open(file_graph, "rb"))
    # graphdata = ["cif_id", "atom_num", "nbr_idx", "nbr_dist", "uni_idx", "uni_count"]
    atom_num = torch.LongTensor(graphdata[1].copy())
    # print("atom_num:", atom_num)
    nbr_idx = torch.LongTensor(graphdata[2].copy()).view(len(atom_num), -1)
    uni_idx = graphdata[4]
    uni_count = graphdata[5]

    return {
        "atom_num": atom_num,
        "nbr_idx": nbr_idx,
        "uni_idx": uni_idx,
        "uni_count": uni_count,
    }
cifs_path = "/data/user2/wty/HOF/MOFDiff/mofdiff/data/mof_models/mof_models/bwdb_hoff/hofchecker/total"
cif_id = "sample_822_47"
result = get_graph(cif_id, cifs_path)
print(len(result["atom_num"]))

75


In [3]:
import json
source_json = '/data/user2/wty/HOF/moftransformer/data/HOF_pretrain_new/fold4/train_hbond.json'
des_json = '/data/user2/wty/HOF/moftransformer/data/HOF_pretrain_new/fold7/train_hbond.json'

source_data = json.load(open(source_json, 'r'))
des_data = json.load(open(des_json, 'r'))

for key in source_data.keys():
    if key not in des_data:
        des_data[key] = source_data[key]

with open(des_json, 'w', encoding='utf-8') as f:
    json.dump(des_data, f, indent=4, ensure_ascii=False)