In [None]:
import os
import torch
from torch_geometric.data import Data
import numpy as np
from atomic_utils import readcif, lattice_params_to_matrix_torch, read_cif_bonds, compute_image_flag, get_atomic_graph, compute_distance_matrix, frac2cart
from data_utils import frac_to_cart_coords

# 假设已经定义的函数
# pyg_graph_from_cif() 和 mof_criterion()

def process_bb_files(root_dir, max_atoms=200, max_cps=20):
    for subfolder in os.listdir(root_dir):
        subfolder_path = os.path.join(root_dir, subfolder)
        
        if os.path.isdir(subfolder_path):
            for file_name in os.listdir(subfolder_path):
                # 检查文件名是否符合 bb_{i}.cif 格式
                if file_name.startswith("bb_") and file_name.endswith(".cif"):
                    cif_path = os.path.join(subfolder_path, file_name)
                    
                    # 将 cif 文件转换为 Data 数据结构
                    bb = pyg_graph_from_cif(cif_path)
                    
                    # 使用 bb_criterion 进行筛选
                    _, success = bb_criterion(bb, max_atoms=max_atoms, max_cps=max_cps)
                    
                    # 如果筛选失败，删除该文件并打印路径
                    if not success:
                        # os.remove(cif_path)
                        print(f"Deleted invalid CIF file: {cif_path}")

def pyg_graph_from_cif(cif, graph_provided=False, Hbond=False):
    lattice_parameters, atom_types, frac_coords, atom_symbols = readcif(cif)
    frac_coords = torch.FloatTensor(frac_coords)
    atom_types = torch.LongTensor(atom_types)
    num_atoms = len(atom_types)
    lengths = torch.FloatTensor(lattice_parameters[:3]).view(1, -1)
    angles = torch.FloatTensor(lattice_parameters[3:]).view(1, -1)
    cell = lattice_params_to_matrix_torch(lengths, angles).squeeze()
    scaled_lattice = torch.cat(
        [lengths / float(frac_coords.shape[0]) ** (1 / 3), angles], dim=1
    )

    if graph_provided:        
        from_index, to_index, _ = read_cif_bonds(cif)
        edge_index = np.stack([from_index, to_index]).T
        if len(edge_index) > 0:
            edge_index = torch.LongTensor(edge_index).T.contiguous()
            reverse_edge_index = torch.stack([edge_index[1], edge_index[0]])
            edge_index = torch.cat([edge_index, reverse_edge_index], dim=1)
            to_jimages = compute_image_flag(
                cell, frac_coords[edge_index[0]], frac_coords[edge_index[1]]
            )
        else:
            to_jimages = torch.FloatTensor([])
    else:
        edge_index, to_jimages = get_atomic_graph(frac_coords, atom_types, cell, Hbond = Hbond, cif = cif)
    num_bonds = len(edge_index.T)

    return Data(
        frac_coords=frac_coords,
        atom_types=atom_types,
        lengths=lengths,
        angles=angles,
        edge_index=edge_index,
        to_jimages=to_jimages,
        num_atoms=num_atoms,
        num_bonds=num_bonds,
        num_nodes=num_atoms,
        cell=cell,
        scaled_lattice=scaled_lattice,
    )


def bb_criterion(bb, max_atoms=200, max_cps=20):
    # 转换分数坐标为笛卡尔坐标
    cart_coords = frac_to_cart_coords(
        bb.frac_coords, bb.lengths, bb.angles, bb.num_atoms
    )
    pdist = torch.cdist(cart_coords, cart_coords).fill_diagonal_(5.0)

    # 检查边的连接信息是否为空
    edge_index = bb.edge_index
    if edge_index.numel() == 0:
        print("Invalid BB: No bond information (empty edge_index).")
        return None, False

    # 计算键长
    j, i = edge_index
    dist_mat = compute_distance_matrix(bb.cell, cart_coords)
    bond_dist = dist_mat[i, j]
    # bond_dist = (cart_coords[i] - cart_coords[j]).pow(2).sum(dim=-1).sqrt()
    # print("bond_dist:", bond_dist)
    # print("bond_dist:", bond_dist)
    # 检查是否存在键长
    if bond_dist.numel() == 0:
        print("Invalid BB: No valid bond distances (empty bond_dist).")
        return None, False

    # 检查最小原子距离
    if pdist.min() <= 0.25:
        print("Invalid BB: Atom distance is below the minimum threshold of 0.25.")
        return cart_coords, False

    # 检查最大键长
    if bond_dist.max() >= 5.0:
        print("Invalid BB: Bond distance exceeds the maximum threshold of 5.0.")
        return cart_coords, False

    # 检查原子数量上限
    if bb.num_atoms > max_atoms:
        print(f"Invalid BB: Number of atoms ({bb.num_atoms}) exceeds the maximum limit of {max_atoms}.")
        return cart_coords, False

    # 如果满足所有条件
    return cart_coords, True

# 调用函数并指定根目录
root_directory = "/data/user2/wty/HOF/MOFDiff/hofdiff/data/hof_data"
process_bb_files(root_directory)


In [None]:
import os
import torch
from torch_geometric.data import Data
import numpy as np
from atomic_utils import readcif, lattice_params_to_matrix_torch, read_cif_bonds, compute_image_flag, get_atomic_graph, compute_distance_matrix, frac2cart
from data_utils import frac_to_cart_coords

# 假设已经定义的函数
# pyg_graph_from_cif() 和 mof_criterion()

def process_bb_files(root_dir, max_atoms=200, max_cps=20):
    for file_name in os.listdir(root_dir):
        # 检查文件名是否符合 bb_{i}.cif 格式
        if file_name.startswith("bb_") and file_name.endswith(".cif"):
            cif_path = os.path.join(root_dir, file_name)
            
            # 将 cif 文件转换为 Data 数据结构
            bb = pyg_graph_from_cif(cif_path)
            
            # 使用 bb_criterion 进行筛选
            _, success = bb_criterion(bb, max_atoms=max_atoms, max_cps=max_cps)
            
            # 如果筛选失败，删除该文件并打印路径
            if not success:
                # os.remove(cif_path)
                print(f"Deleted invalid CIF file: {cif_path}")

def pyg_graph_from_cif(cif, graph_provided=False, Hbond=False):
    lattice_parameters, atom_types, frac_coords, atom_symbols = readcif(cif)
    frac_coords = torch.FloatTensor(frac_coords)
    atom_types = torch.LongTensor(atom_types)
    num_atoms = len(atom_types)
    lengths = torch.FloatTensor(lattice_parameters[:3]).view(1, -1)
    angles = torch.FloatTensor(lattice_parameters[3:]).view(1, -1)
    cell = lattice_params_to_matrix_torch(lengths, angles).squeeze()
    scaled_lattice = torch.cat(
        [lengths / float(frac_coords.shape[0]) ** (1 / 3), angles], dim=1
    )

    if graph_provided:        
        from_index, to_index, _ = read_cif_bonds(cif)
        edge_index = np.stack([from_index, to_index]).T
        if len(edge_index) > 0:
            edge_index = torch.LongTensor(edge_index).T.contiguous()
            reverse_edge_index = torch.stack([edge_index[1], edge_index[0]])
            edge_index = torch.cat([edge_index, reverse_edge_index], dim=1)
            to_jimages = compute_image_flag(
                cell, frac_coords[edge_index[0]], frac_coords[edge_index[1]]
            )
        else:
            to_jimages = torch.FloatTensor([])
    else:
        edge_index, to_jimages = get_atomic_graph(frac_coords, atom_types, cell, Hbond = Hbond, cif = cif)
    num_bonds = len(edge_index.T)

    return Data(
        frac_coords=frac_coords,
        atom_types=atom_types,
        lengths=lengths,
        angles=angles,
        edge_index=edge_index,
        to_jimages=to_jimages,
        num_atoms=num_atoms,
        num_bonds=num_bonds,
        num_nodes=num_atoms,
        cell=cell,
        scaled_lattice=scaled_lattice,
    )

def compute_distance_matrix(cell, cart_coords, num_cells=1):
    pos = torch.arange(-num_cells, num_cells + 1, 1).to(cell.device)
    combos = (
        torch.stack(torch.meshgrid(pos, pos, pos, indexing="xy"))
        .permute(3, 2, 1, 0)
        .reshape(-1, 3)
        .to(cell.device)
    )
    shifts = torch.sum(cell.unsqueeze(0) * combos.unsqueeze(-1), dim=1)
    shifted = cart_coords.unsqueeze(1) + shifts.unsqueeze(0)
    dist = cart_coords.unsqueeze(1).unsqueeze(1) - shifted.unsqueeze(0)
    # +eps to avoid nan in differentiation
    dist = (dist.pow(2).sum(dim=-1) + 1e-32).sqrt()
    distance_matrix = dist.min(dim=-1)[0]
    return distance_matrix


# def bb_criterion(bb, max_atoms=200, max_cps=20):
#     # bb.num_cps = bb.is_anchor.long().sum()
#     # if (bb.num_atoms > max_atoms) or (bb.num_cps > max_cps):
#     #     return None, False

#     cart_coords = frac_to_cart_coords(
#         bb.frac_coords, bb.lengths, bb.angles, bb.num_atoms
#     )
#     pdist = torch.cdist(cart_coords, cart_coords).fill_diagonal_(5.0)

#     # detect BBs with problematic bond info.
#     edge_index = bb.edge_index
#     print("edge_index:", edge_index)
#     if edge_index.numel() == 0:
#         return None, False
#     j, i = edge_index
#     bond_dist = (cart_coords[i] - cart_coords[j]).pow(2).sum(dim=-1).sqrt()
#     if bond_dist.numel() == 0:
#         success = False
#     else:
#         success = (
#             pdist.min() > 0.25
#             and bond_dist.max() < 5.0
#             and (bb.num_atoms <= max_atoms)
#             # and (bb.num_cps <= max_cps)
#         )
#     return cart_coords, success

def bb_criterion(bb, max_atoms=200, max_cps=20):
    bb.num_cps = bb.is_anchor.long().sum()
    print(bb.num_cps)
    # 转换分数坐标为笛卡尔坐标
    # print(bb.lengths)
    cart_coords = frac_to_cart_coords(
        bb.frac_coords, bb.lengths, bb.angles, bb.num_atoms
    )
    pdist = torch.cdist(cart_coords, cart_coords).fill_diagonal_(5.0)

    # 检查边的连接信息是否为空
    edge_index = bb.edge_index
    if edge_index.numel() == 0:
        print("Invalid BB: No bond information (empty edge_index).")
        return None, False

    # 计算键长
    j, i = edge_index
    print("edge_index:", edge_index)
    print("cart_coords:", cart_coords)
    dist_mat = compute_distance_matrix(bb.cell, cart_coords)
    bond_dist = dist_mat[i, j]
    # bond_dist = (cart_coords[i] - cart_coords[j]).pow(2).sum(dim=-1).sqrt()
    print("bond_dist:", bond_dist)
    # 检查是否存在键长
    if bond_dist.numel() == 0:
        print("Invalid BB: No valid bond distances (empty bond_dist).")
        return None, False

    # 检查最小原子距离
    if pdist.min() <= 0.25:
        print("Invalid BB: Atom distance is below the minimum threshold of 0.25.")
        return cart_coords, False

    # 检查最大键长
    if bond_dist.max() >= 5.0:
        print("Invalid BB: Bond distance exceeds the maximum threshold of 5.0.")
        return cart_coords, False

    # 检查原子数量上限
    if bb.num_atoms > max_atoms:
        print(f"Invalid BB: Number of atoms ({bb.num_atoms}) exceeds the maximum limit of {max_atoms}.")
        return cart_coords, False

    # 如果满足所有条件
    return cart_coords, True

# 调用函数并指定根目录
root_directory = "/data/user2/wty/HOF/MOFDiff/hofdiff/data/hof_data/1_1"
process_bb_files(root_directory)


In [None]:
from ase.io import read, write
from ase.geometry import distance
import numpy as np
from openbabel import pybel


def remove_overlapping_atoms(structure, threshold=0.001):
    positions = structure.get_positions()
    unique_indices = []
    unique_positions = []

    for i, pos in enumerate(positions):
        if all(np.linalg.norm(pos - other) > threshold for other in unique_positions):
            unique_indices.append(i)  # 存储索引
            unique_positions.append(pos)

    # 创建去除重叠原子的结构
    new_structure = structure[unique_indices]
    return new_structure

# 读取 CIF 文件
structure = read("/data/user2/wty/HOF/MOFDiff/hofdiff/data/hof_data/218-hof-tcpb-298/bb_1.cif")

# 去除重叠原子
structure = remove_overlapping_atoms(structure)

# 保存结果（可选）
write("/data/user2/wty/HOF/MOFDiff/hofdiff/data/hof_data/218-hof-tcpb-298/bb_1_mod.cif", structure)

mol = next(pybel.readfile("cif", "/data/user2/wty/HOF/MOFDiff/hofdiff/data/hof_data/218-hof-tcpb-298/bb_1_mod.cif"))
# 保存结果（可选）
mol.write("cif", "/data/user2/wty/HOF/MOFDiff/hofdiff/data/hof_data/218-hof-tcpb-298/bb_1_mod.cif", overwrite=True)


In [None]:
from openbabel import pybel
import numpy as np

def remove_overlapping_atoms_pybel(mol, threshold=0.001):
    unique_atoms = []
    unique_coords = []

    for atom in mol:
        pos = np.array(atom.coords)
        if all(np.linalg.norm(pos - other) > threshold for other in unique_coords):
            unique_atoms.append(atom)
            unique_coords.append(pos)
    
    # 创建新分子结构并添加非重叠原子
    new_mol = pybel.ob.OBMol()
    for atom in unique_atoms:
        new_atom = new_mol.NewAtom()
        new_atom.SetAtomicNum(atom.atomicnum)
        x, y, z = atom.coords
        new_atom.SetVector(x, y, z)

    return pybel.Molecule(new_mol)

# 读取 CIF 文件
mol = next(pybel.readfile("cif", "/data/user2/wty/HOF/MOFDiff/hofdiff/data/hof_data/218-hof-tcpb-298/bb_1.cif"))

# 去除重叠原子
mol_filtered = remove_overlapping_atoms_pybel(mol)

# 保存结果（可选）
mol_filtered.write("cif", "/data/user2/wty/HOF/MOFDiff/hofdiff/data/hof_data/218-hof-tcpb-298/bb_1_mod.cif", overwrite=True)

In [None]:
import os
import torch
from torch_geometric.data import Data
import numpy as np
from atomic_utils import readcif, lattice_params_to_matrix_torch, read_cif_bonds, compute_image_flag, get_atomic_graph, compute_distance_matrix, frac2cart
from data_utils import frac_to_cart_coords

# 假设已经定义的函数
# pyg_graph_from_cif() 和 mof_criterion()

def process_bb_files(root_dir, max_atoms=200, max_cps=20):
    for subfolder in os.listdir(root_dir):
        subfolder_path = os.path.join(root_dir, subfolder)
        
        if os.path.isdir(subfolder_path):
            for file_name in os.listdir(subfolder_path):
                # 检查文件名是否符合 bb_{i}.cif 格式
                if file_name.startswith("bb_") and file_name.endswith(".cif"):
                    cif_path = os.path.join(subfolder_path, file_name)
                    remove_overlapping_atoms_from_cif(cif_path, cif_path, threshold=0.001)


def remove_overlapping_atoms_from_cif(input_cif, output_cif, threshold=0.001):
    """
    读取一个 .cif 文件，去除位置完全重叠的原子，并保存处理后的 .cif 文件。
    
    参数:
    - input_cif (str): 输入 .cif 文件路径。
    - output_cif (str): 输出 .cif 文件路径。
    - threshold (float): 判断原子重叠的距离阈值，单位为 Å，默认值为 0.001。
    """
    
    # 读取 CIF 文件
    structure = read(input_cif)
    
    # 去除重叠原子
    positions = structure.get_positions()
    unique_indices = []
    unique_positions = []

    for i, pos in enumerate(positions):
        if all(np.linalg.norm(pos - other) > threshold for other in unique_positions):
            unique_indices.append(i)  # 存储索引
            unique_positions.append(pos)

    # 创建去除重叠原子的结构
    new_structure = structure[unique_indices]
    
    # 保存去重后的结构到新的 CIF 文件
    write(output_cif, new_structure)
    
    # 使用 Pybel 再次读取并保存以确保格式的兼容性
    mol = next(pybel.readfile("cif", output_cif))
    mol.write("cif", output_cif, overwrite=True)

# 调用函数并指定根目录
root_directory = "/data/user2/wty/HOF/MOFDiff/hofdiff/data/hof_data"
process_bb_files(root_directory)


In [None]:
import os
import torch

def process_bb_files(root_dir, max_atoms=200, max_cps=20):
    for subfolder in os.listdir(root_dir):
        subfolder_path = os.path.join(root_dir, subfolder)
        
        if os.path.isdir(subfolder_path):
            for file_name in os.listdir(subfolder_path):
                # 检查文件名是否和子文件夹名称相同
                if file_name == f"{subfolder}.ckf":
                    cif_path = os.path.join(subfolder_path, file_name)
                    print(f"Processing CIF file: {cif_path}")
                    os.remove(cif_path)
                    # remove_overlapping_atoms_from_cif(cif_path, cif_path, threshold=0.001)

process_bb_files("/data/user2/wty/HOF/MOFDiff/hofdiff/data/hof_data")

In [None]:
from pathlib import Path
import argparse
import pickle
import pandas as pd
from p_tqdm import p_umap
from openbabel import openbabel as ob
import numpy as np
from atomic_utils import readcif, lattice_params_to_matrix_torch, read_cif_bonds, compute_image_flag, get_atomic_graph, compute_distance_matrix, frac2cart
from data_utils import frac_to_cart_coords
import re
import torch

import sys
import os
# sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..')))

from hofdiff.common.atomic_utils import pyg_graph_from_cif, assemble_local_struct
import multiprocessing as mp
from hofdiff.common.constants import COVALENT_RADII

def bb_criterion(bb, max_atoms=200, max_cps=20):
    bb.num_cps = bb.is_anchor.long().sum()
    print(bb.is_anchor)
    # 转换分数坐标为笛卡尔坐标
    # print(bb.lengths)
    cart_coords = frac_to_cart_coords(
        bb.frac_coords, bb.lengths, bb.angles, bb.num_atoms
    )
    pdist = torch.cdist(cart_coords, cart_coords).fill_diagonal_(5.0)

    # 检查边的连接信息是否为空
    edge_index = bb.edge_index
    if edge_index.numel() == 0:
        print("Invalid BB: No bond information (empty edge_index).")
        return None, False

    # 计算键长
    j, i = edge_index
    # print("edge_index:", edge_index)
    # print("cart_coords:", cart_coords)
    dist_mat = compute_distance_matrix(bb.cell, cart_coords)
    bond_dist = dist_mat[i, j]
    # bond_dist = (cart_coords[i] - cart_coords[j]).pow(2).sum(dim=-1).sqrt()
    # print("bond_dist:", bond_dist)
    # 检查是否存在键长
    if bond_dist.numel() == 0:
        print("Invalid BB: No valid bond distances (empty bond_dist).")
        return None, False

    # 检查最小原子距离
    if pdist.min() <= 0.25:
        print("Invalid BB: Atom distance is below the minimum threshold of 0.25.")
        return cart_coords, False

    # 检查最大键长
    if bond_dist.max() >= 5.0:
        print("Invalid BB: Bond distance exceeds the maximum threshold of 5.0.")
        return cart_coords, False

    # 检查原子数量上限
    if bb.num_atoms > max_atoms:
        print(f"Invalid BB: Number of atoms ({bb.num_atoms}) exceeds the maximum limit of {max_atoms}.")
        return cart_coords, False

    # 如果满足所有条件
    return cart_coords, True

def has_no_overlapping_atoms(cif_path, threshold=0.7):
    """
    判断给定的 CIF 文件中是否有重叠的原子。如果没有重叠原子则返回 True，否则返回 False。

    :param cif_path: CIF 文件路径
    :param threshold: 判定原子是否重叠的阈值，默认为 0.7
    :return: 没有重叠原子返回 True，有重叠原子返回 False
    """
    print(f"Checking {cif_path} for overlapping atoms.")
    obConversion = ob.OBConversion()
    obConversion.SetInFormat("cif")
    mol = ob.OBMol()

    if not obConversion.ReadFile(mol, cif_path):
        print(f"Failed to read {cif_path} file.")
        return False

    # 分离出所有连通分支
    fragments = mol.Separate()

    for frag in fragments:
        frag_mol = ob.OBMol(frag)
        other_atoms = []

        # 遍历分子中的每个原子，检查原子间是否有重叠
        for atom in ob.OBMolAtomIter(frag_mol):
            pos = np.array([atom.GetX(), atom.GetY(), atom.GetZ()])
            e1 = atom.GetType()
            
            for other_atom in other_atoms:
                other_pos = np.array([other_atom.GetX(), other_atom.GetY(), other_atom.GetZ()])
                e2 = other_atom.GetType()
                
                # 去掉e1e2的数字，只留下字母
                e1 = ''.join([i for i in e1 if not i.isdigit()])
                e2 = ''.join([i for i in e2 if not i.isdigit()])
                # 根据原子类型，计算它们的共价半径
                try:
                    min_threshold = min(COVALENT_RADII[e1], COVALENT_RADII[e2])
                except KeyError as e:
                    # print(f"Warning: Unrecognized atom type '{e.args[0]}' encountered.")
                    continue  # Skip or handle the unrecognized atom type
                if np.linalg.norm(pos - other_pos) < threshold * min_threshold:
                    return False  # 找到重叠的原子，直接返回 False

            other_atoms.append(atom)

    return True  # 没有重叠原子，返回 True

def assign_cif_files(base_path):
    # 正则表达式，用于匹配符合 "molecule_{i}.cif" 格式的文件名
    cif_pattern = re.compile(r"^molecule_\d+\.cif$")
    
    # 用于存储不符合特定格式的cif文件路径
    non_conforming_cifs = []
    hid = base_path.parts[-1]
    # 遍历base_path目录下的所有文件
    for file in os.listdir(base_path):
        if file.endswith('.cif') and not cif_pattern.match(file) and file != f'{hid}.cif':
            bb_path = os.path.join(base_path, file)
            if has_no_overlapping_atoms(bb_path):
                print(f"{bb_path} has no overlapping")
                non_conforming_cifs.append(os.path.join(base_path, file))
    
    return non_conforming_cifs

def assemble_mof(m_id, use_asr=True):
    try:
        base_path = Path(f'/data/user2/wty/HOF/MOFDiff/hofdiff/data/hof_data/{m_id}')
        # print(base_path)
        # use the metaloxo algorithm for deconstruction.
        # g_nodes_path, g_linkers_path, g_node_bridges_path = assign_cif_files(base_path)
        bb_cifs = assign_cif_files(base_path)
        # print("bb_cifs:",bb_cifs)
        g_bb_cifs = []
        for bb_cif in bb_cifs:
            g_bb_cifs.append(pyg_graph_from_cif(Path(bb_cif), Hbond=False))
        # print("g_bb_cifs:", len(g_bb_cifs))
        if use_asr:
            g_asr = pyg_graph_from_cif(Path(base_path / f'{m_id}.cif'), Hbond=True)
        else:
            g_asr = None
        # print("g_node_bridges",g_node_bridges)
        # print("g_asr",g_asr)
        data = assemble_local_struct(
            g_bb_cifs, g_asr, device='cpu', 
        )
    except FileNotFoundError:
        print(f"FileNotFoundError: {m_id}")
        return None
    except UnboundLocalError:
        print(f"UnboundLocalError: {m_id}")
        return None
    except IndexError:
        print(f"IndexError: {m_id}")
        return None
    except ValueError:
        print(f"ValueError: {m_id}")
        return None
    return data


data = assemble_mof('218-hof-tcpb-298')
bbs = []
print("length of data.bbs:", len(data.bbs))
for bb in data.bbs:
    cart_coords, success = bb_criterion(bb)
    if success:
        bb.num_nodes = bb.num_atoms
        bb.diameter = torch.pdist(cart_coords).max()
        # print(bb.diameter)
        bbs.append(bb)
print("length of bbs:", len(bbs))



In [None]:
from pathlib import Path
import argparse
import pickle
import pandas as pd
from p_tqdm import p_umap
from openbabel import openbabel as ob
import numpy as np
from atomic_utils import readcif, lattice_params_to_matrix_torch, read_cif_bonds, compute_image_flag, get_atomic_graph, compute_distance_matrix, \
    frac2cart, pyg_graph_from_cif, mof2cif_with_bonds
from data_utils import frac_to_cart_coords
import re
import torch

import sys
import os
import time
# sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..')))

from hofdiff.common.atomic_utils import pyg_graph_from_cif, assemble_local_struct
import multiprocessing as mp
from hofdiff.common.constants import COVALENT_RADII
from hofdiff.common.optimization import (
    annealed_optimization,
    assemble_mof,
    feasibility_check,
)

def assemble_one(
    mof,
    verbose=True,
    rounds=3,
    sigma_start=3.0,
    sigma_end=0.3,
    max_neighbors_start=30,
    max_neighbors_end=1,
):
    # if not feasibility_check(mof):
    #     return None
    # print("mof:",mof)
    sigma_schedule = np.linspace(sigma_start, sigma_end, rounds)
    max_neighbors_schedule = (
        np.linspace(max_neighbors_start, max_neighbors_end, rounds).round().astype(int)
    )

    now = time.time()
    results, v = annealed_optimization(
        mof,
        0,
        optimize=False,
        sigma_schedule=sigma_schedule,
        max_neighbors_schedule=max_neighbors_schedule,
        maxiter=1000,
        verbose=verbose,
    )
    elapsed = time.time() - now
    vecs = torch.from_numpy(results["x"]).view(mof.num_atoms, 3).float()
    # print("vecs:",vecs)
    # print(" ")
    # for bb in mof.bbs:
    #     print("bb.atom_types:", bb.atom_types)
    bb_local_vectors = [bb.local_vectors for bb in mof.bbs]
    # print(mof)
    assembled_rec = assemble_mof(mof, vecs, bb_local_vectors=bb_local_vectors)
    # print("assembled_rec:",assembled_rec)
    assembled_rec.opt_v = v
    assembled_rec.assemble_time = elapsed
    return assembled_rec

def has_no_overlapping_atoms(cif_path, threshold=0.7):
    """
    判断给定的 CIF 文件中是否有重叠的原子。如果没有重叠原子则返回 True，否则返回 False。

    :param cif_path: CIF 文件路径
    :param threshold: 判定原子是否重叠的阈值，默认为 0.7
    :return: 没有重叠原子返回 True，有重叠原子返回 False
    """
    print(f"Checking {cif_path} for overlapping atoms.")
    obConversion = ob.OBConversion()
    obConversion.SetInFormat("cif")
    mol = ob.OBMol()

    if not obConversion.ReadFile(mol, cif_path):
        print(f"Failed to read {cif_path} file.")
        return False

    # 分离出所有连通分支
    fragments = mol.Separate()

    for frag in fragments:
        frag_mol = ob.OBMol(frag)
        other_atoms = []

        # 遍历分子中的每个原子，检查原子间是否有重叠
        for atom in ob.OBMolAtomIter(frag_mol):
            pos = np.array([atom.GetX(), atom.GetY(), atom.GetZ()])
            e1 = atom.GetType()
            
            for other_atom in other_atoms:
                other_pos = np.array([other_atom.GetX(), other_atom.GetY(), other_atom.GetZ()])
                e2 = other_atom.GetType()
                
                # 去掉e1e2的数字，只留下字母
                e1 = ''.join([i for i in e1 if not i.isdigit()])
                e2 = ''.join([i for i in e2 if not i.isdigit()])
                # 根据原子类型，计算它们的共价半径
                try:
                    min_threshold = min(COVALENT_RADII[e1], COVALENT_RADII[e2])
                except KeyError as e:
                    # print(f"Warning: Unrecognized atom type '{e.args[0]}' encountered.")
                    continue  # Skip or handle the unrecognized atom type
                if np.linalg.norm(pos - other_pos) < threshold * min_threshold:
                    return False  # 找到重叠的原子，直接返回 False

            other_atoms.append(atom)

    return True  # 没有重叠原子，返回 True

def assign_cif_files(base_path):
    # 正则表达式，用于匹配符合 "molecule_{i}.cif" 格式的文件名
    cif_pattern = re.compile(r"^molecule_\d+\.cif$")
    
    # 用于存储不符合特定格式的cif文件路径
    non_conforming_cifs = []
    hid = base_path.parts[-1]
    # 遍历base_path目录下的所有文件
    for file in os.listdir(base_path):
        if file.endswith('.cif') and not cif_pattern.match(file) and file != f'{hid}.cif':
            bb_path = os.path.join(base_path, file)
            if has_no_overlapping_atoms(bb_path):
                print(f"{bb_path} has no overlapping")
                non_conforming_cifs.append(os.path.join(base_path, file))
    
    return non_conforming_cifs
    # 检查是否正好有三个非特定格式的cif文件
    # if len(non_conforming_cifs) == 3:
    #     g_nodes, g_linkers, g_node_bridges = non_conforming_cifs
    #     return g_nodes, g_linkers, g_node_bridges
    # else:
    #     raise ValueError("There are not exactly three non-conforming .cif files in the directory.")


def assemble_mof(m_id, use_asr=True):
    try:
        base_path = Path(f'/data/user2/wty/HOF/MOFDiff/hofdiff/data/hof_data/{m_id}')
        # print(base_path)
        # use the metaloxo algorithm for deconstruction.
        # g_nodes_path, g_linkers_path, g_node_bridges_path = assign_cif_files(base_path)
        bb_cifs = assign_cif_files(base_path)
        print("bb_cifs:",bb_cifs)
        g_bb_cifs = []
        for bb_cif in bb_cifs:
            g_bb_cifs.append(pyg_graph_from_cif(Path(bb_cif), Hbond=False))
        # print("g_bb_cifs:", len(g_bb_cifs))
        if use_asr:
            g_asr = pyg_graph_from_cif(Path(base_path / f'{m_id}.cif'), Hbond=True)
        else:
            g_asr = None
        # print("g_node_bridges",g_node_bridges)
        print("g_asr",g_asr)
        data = assemble_local_struct(
            g_bb_cifs, g_asr, device='cpu', 
        )
    except FileNotFoundError:
        print(f"FileNotFoundError: {m_id}")
        return None
    except UnboundLocalError:
        print(f"UnboundLocalError: {m_id}")
        return None
    except IndexError:
        print(f"IndexError: {m_id}")
        return None
    except ValueError:
        print(f"ValueError: {m_id}")
        return None
    return data

def preprocess_mof():
    df = pd.read_csv('/data/user2/wty/HOF/MOFDiff/hofdiff/data/hof.csv')
    m_id_list = df['hof_id']

    # 对每个 m_id 调用 assemble_mof 函数
    for m_id in m_id_list:
        data = assemble_mof(m_id)
        if data:
            for bb in data.bbs:
                print(bb.atom_types)
            # data = assemble_one(data)
            mof2cif_with_bonds(data, f'/data/user2/wty/HOF/MOFDiff/hofdiff/data/test/{m_id}.cif')

# m_id = "85_1888690_ecut-hof-30"
# data = assemble_mof(m_id)
# index = 1
# for bb in data.bbs: 
#     print(bb.atom_types)
#     print(bb.frac_coords)
    # mof2cif_with_bonds(bb, f'/data/user2/wty/HOF/MOFDiff/hofdiff/data/test/{index}.cif')
    # index += 1
# mof2cif_with_bonds(data, f'/data/user2/wty/HOF/MOFDiff/hofdiff/data/test/{m_id}.cif')
preprocess_mof()

In [3]:
import torch

# Define the cell (cubic cell with side length of 10 Å)
cell = torch.tensor([[10.0, 0.0, 0.0], 
                     [0.0, 10.0, 0.0], 
                     [0.0, 0.0, 10.0]], dtype=torch.float32)

# Atom positions (two atoms placed at opposite ends of the unit cell)
cart_coords = torch.tensor([[0.0, 0.0, 0.0],  # Atom 1 at origin
                            [9.0, 9.0, 9.0]], dtype=torch.float32)  # Atom 2 near the opposite corner

# Compute the distance matrix considering periodic boundary conditions
def compute_distance_matrix(cell, cart_coords, num_cells=1):
    pos = torch.arange(-num_cells, num_cells + 1, 1).to(cell.device)
    combos = (
        torch.stack(torch.meshgrid(pos, pos, pos, indexing="xy"))
        .permute(3, 2, 1, 0)
        .reshape(-1, 3)
        .to(cell.device)
    )
    shifts = torch.sum(cell.unsqueeze(0) * combos.unsqueeze(-1), dim=1)
    shifted = cart_coords.unsqueeze(1) + shifts.unsqueeze(0)
    dist = cart_coords.unsqueeze(1).unsqueeze(1) - shifted.unsqueeze(0)
    dist = (dist.pow(2).sum(dim=-1) + 1e-32).sqrt()
    distance_matrix = dist.min(dim=-1)[0]
    return distance_matrix

# Call the function
distance_matrix = compute_distance_matrix(cell, cart_coords, num_cells=2)

# Print the distance matrix
print("Distance Matrix:")
print(distance_matrix)


Distance Matrix:
tensor([[1.0000e-16, 1.7321e+00],
        [1.7321e+00, 1.0000e-16]])


In [4]:
import numpy as np

# 输入晶格常数和角度（度 -> 弧度）
a = 12.33672
b = 15.67333
c = 13.73446
alpha = 91.57824  # 角度
beta = 97.85991   # 角度
gamma = 92.64503  # 角度

alpha = np.radians(alpha)
beta = np.radians(beta)
gamma = np.radians(gamma)

# 计算晶格矩阵A
A = np.array([
    [a, 0, 0],
    [b * np.cos(gamma), b * np.sin(gamma), 0],
    [c * np.cos(beta), c * (np.cos(alpha) - np.cos(gamma) * np.cos(beta) / np.sin(gamma)), c * np.sin(alpha)]
])

# 分数坐标
frac_coords = np.array([0.23122752, 0.8597536, 0.41181207])

# 计算笛卡尔坐标
cartesian_coords = np.dot(A, frac_coords)
print("A:", A)
print("笛卡尔坐标:", cartesian_coords)


A: [[12.33672     0.          0.        ]
 [-0.72329419 15.65663178  0.        ]
 [-1.87820733 -0.46504263 13.72924979]]
笛卡尔坐标: [ 2.85258917 13.29360002  4.81975548]


In [6]:
import numpy as np

# 输入分数坐标和对应的笛卡尔坐标
frac_coords = np.array([
    [0.23123, 0.85975, 0.41181],  # 第一个分数坐标
    [0.39614, 0.89225, 0.49441],  # 第二个分数坐标
    [0.02337, 0.92821, 0.18190]   # 第三个分数坐标
])

cartesian_coords = np.array([
    [1.45727, 13.26917, 5.59960],  # 第一个笛卡尔坐标
    [3.31305, 13.73944, 6.72269],    # 第二个笛卡尔坐标
    [-0.72472, 14.44803, 2.47333]     # 第三个笛卡尔坐标
])

# 计算反解的晶格矩阵 A
A = np.linalg.inv(frac_coords) @ cartesian_coords
print("反解出的晶格矩阵 A:\n", A)

# 用这个A再对另外的一个分数坐标进行变换
frac_coords = np.array([0.43463, 0.54112, 0.40885])
cartesian_coords = frac_coords @ A
print("笛卡尔坐标:", cartesian_coords)


反解出的晶格矩阵 A:
 [[ 1.23365694e+01 -1.30954541e-03 -8.31390272e-04]
 [-7.23315348e-01  1.56564941e+01 -2.31031764e-04]
 [-1.87815882e+00 -4.64286943e-01  1.35984820e+01]]
笛卡尔坐标: [4.20255754 8.28164921 5.559253  ]


In [3]:
from openbabel import openbabel as ob
import numpy as np
import re
from pathlib import Path
import argparse
import pickle
import pandas as pd
import sys
import os
from hofdiff.common.constants import COVALENT_RADII

def has_no_overlapping_atoms(cif_path, threshold=0.7):
    """
    判断给定的 CIF 文件中是否有重叠的原子。如果没有重叠原子则返回 True，否则返回 False。

    :param cif_path: CIF 文件路径
    :param threshold: 判定原子是否重叠的阈值，默认为 0.7
    :return: 没有重叠原子返回 True，有重叠原子返回 False
    """
    print(f"Checking {cif_path} for overlapping atoms.")
    obConversion = ob.OBConversion()
    obConversion.SetInFormat("cif")
    mol = ob.OBMol()

    if not obConversion.ReadFile(mol, cif_path):
        print(f"Failed to read {cif_path} file.")
        return False

    # 分离出所有连通分支
    fragments = mol.Separate()

    for frag in fragments:
        frag_mol = ob.OBMol(frag)
        other_atoms = []

        # 遍历分子中的每个原子，检查原子间是否有重叠
        for atom in ob.OBMolAtomIter(frag_mol):
            pos = np.array([atom.GetX(), atom.GetY(), atom.GetZ()])
            e1 = atom.GetType()
            
            for other_atom in other_atoms:
                other_pos = np.array([other_atom.GetX(), other_atom.GetY(), other_atom.GetZ()])
                e2 = other_atom.GetType()
                
                # 去掉e1e2的数字，只留下字母
                e1 = ''.join([i for i in e1 if not i.isdigit()])
                e2 = ''.join([i for i in e2 if not i.isdigit()])
                # 根据原子类型，计算它们的共价半径
                try:
                    min_threshold = min(COVALENT_RADII[e1], COVALENT_RADII[e2])
                except KeyError as e:
                    # print(f"Warning: Unrecognized atom type '{e.args[0]}' encountered.")
                    continue  # Skip or handle the unrecognized atom type
                if np.linalg.norm(pos - other_pos) < threshold * min_threshold:
                    return False  # 找到重叠的原子，直接返回 False

            other_atoms.append(atom)

    return True  # 没有重叠原子，返回 True

cif_path = "/data/user2/wty/HOF/MOFDiff/hofdiff/data/hof_models/hof_models/bwdb_hoff/samples_4096_seed_8/cif/sample_57.cif"
print(has_no_overlapping_atoms(cif_path))

ModuleNotFoundError: No module named 'hofdiff'