In [1]:
import torch
data = torch.load("/root/zqs_project/drug/MolCRAFT/small_data/test_set.pth")

In [4]:
from torch_scatter import scatter_mean


def build_local_coordinate_system(atom_positions, batch_indices, offset):
    """
    针对不同 batch 的分子构建局部坐标系。
    
    Args:
        atom_positions: 原子的坐标张量，shape=(N, 3)，包含所有分子的原子。
        batch_indices: 每个原子所属的 batch，shape=(N,)。
        
    Returns:
        transform_matrices: 每个 batch 的局部坐标系变换矩阵，shape=(B, 3, 3)。
        offset: 每个 batch 的局部坐标系原点（质心），shape=(B, 3)。
    """
    # Step 1: 计算每个 batch 的质心
    batch_size = batch_indices.max() + 1
    expanded_offsets = offset[batch_indices]
    distances = torch.norm(atom_positions - expanded_offsets, dim=1) 
    sorted_indices = torch.argsort(distances)  # 按距离排序
    sorted_batch_indices = batch_indices[sorted_indices]  # 对应 batch 排序

    # 利用分组方式找到每个 batch 中的前两个最近原子
    unique_batches, inverse_indices, counts = torch.unique(sorted_batch_indices, return_inverse=True, return_counts=True)
    cumsum_counts = torch.cumsum(counts, dim=0)
    starts = torch.cat([torch.tensor([0], device=cumsum_counts.device), cumsum_counts[:-1]])  # 每个 batch 的起点索引

    # 获取前两个原子索引
    nearest_indices = torch.stack([
        sorted_indices[starts + 0],  # 最近的第一个原子
        sorted_indices[starts + 1]  # 最近的第二个原子
    ], dim=1)

    # Step 4: 构建局部坐标系
    pos_A = offset  # 质心
    pos_B = atom_positions[nearest_indices[:, 0]]  # 最近的第一个原子
    pos_C = atom_positions[nearest_indices[:, 1]]  # 最近的第二个原子

    x_axis = (pos_B - pos_A) / torch.norm(pos_B - pos_A, dim=1, keepdim=True)
    temp_vector = (pos_C - pos_A)
    z_axis = torch.cross(x_axis, temp_vector, dim=1)
    z_axis /= torch.norm(z_axis, dim=1, keepdim=True)
    y_axis = torch.cross(z_axis, x_axis, dim=1)

    transform_matrices = torch.stack([x_axis, y_axis, z_axis], dim=1)  # shape=(B, 3, 3)

    return transform_matrices

def center_pos(protein_pos, ligand_pos, batch_protein, batch_ligand, mode="protein"):
    if mode == "none":
        offset = 0.0
        pass
    elif mode == "protein":
        offset = scatter_mean(protein_pos, batch_protein, dim=0)
        protein_pos = protein_pos - offset[batch_protein]
        ligand_pos = ligand_pos - offset[batch_ligand]
    else:
        raise NotImplementedError
    return protein_pos, ligand_pos, offset

example = {
    "protein_pos": torch.rand(442, 3),  # 442 个蛋白质原子
    "ligand_pos": torch.rand(31, 3),   # 31 个配体原子
    "surface_pos": torch.rand(362, 3), # 362 个表面原子
    "batch_protein": torch.randint(0, 2, (442,)),  # 假设 10 个 batch
    "batch_ligand": torch.randint(0, 2, (31,)),
    "batch_surface": torch.randint(0, 2, (362,)),
    "offset": torch.rand(10, 3)  # 每个 batch 的质心
}

# 配置
class Config:
    class Dynamics:
        center_pos_mode = "protein"
    class Data:
        normalizer = 2.0

    dynamics = Dynamics()
    data = Data()

cfg = Config()
gt_protein_pos = example["protein_pos"]
ligand_pos = example["ligand_pos"]
surface_pos = example["surface_pos"]
batch_protein = example["batch_protein"]
batch_ligand = example["batch_ligand"]
batch_surface = example["batch_surface"]

# 执行代码
gt_protein_pos, ligand_pos, offset = center_pos(
    gt_protein_pos,
    ligand_pos,
    batch_protein,
    batch_ligand,
    mode=cfg.dynamics.center_pos_mode
)

surface_pos = surface_pos - offset[batch_surface]
transform_matrix = build_local_coordinate_system(gt_protein_pos, batch_protein, offset)
gt_protein_pos = torch.matmul(gt_protein_pos.unsqueeze(1), transform_matrix[batch_protein]).squeeze(1)
ligand_pos = torch.matmul(ligand_pos.unsqueeze(1), transform_matrix[batch_ligand]).squeeze(1)
surface_pos = torch.matmul(surface_pos.unsqueeze(1), transform_matrix[batch_surface]).squeeze(1)
gt_protein_pos = gt_protein_pos / cfg.data.normalizer

In [11]:
import plotly.graph_objects as go
import torch

def plot_3d_interactive(original_positions, rotated_positions, title="3D Positions"):
    """
    使用 Plotly 绘制原始和旋转后坐标的交互式 3D 可视化。
    
    Args:
        original_positions: 原始坐标，形状 (N, 3)。
        rotated_positions: 旋转后的坐标，形状 (N, 3)。
        title: 图的标题。
    """
    fig = go.Figure()

    # 添加原始坐标点
    fig.add_trace(go.Scatter3d(
        x=original_positions[:, 0].cpu().numpy(),
        y=original_positions[:, 1].cpu().numpy(),
        z=original_positions[:, 2].cpu().numpy(),
        mode='markers',
        marker=dict(size=5, color='blue', opacity=0.8),
        name='Original Positions'
    ))

    # 添加旋转后的坐标点
    fig.add_trace(go.Scatter3d(
        x=rotated_positions[:, 0].cpu().numpy(),
        y=rotated_positions[:, 1].cpu().numpy(),
        z=rotated_positions[:, 2].cpu().numpy(),
        mode='markers',
        marker=dict(size=5, color='red', opacity=0.8),
        name='Rotated Positions'
    ))

    # 设置布局
    fig.update_layout(
        title=title,
        scene=dict(
            xaxis_title='X-axis',
            yaxis_title='Y-axis',
            zaxis_title='Z-axis'
        ),
        margin=dict(l=0, r=0, b=0, t=40)
    )

    # 显示图形
    fig.show()


# 保存原始坐标
original_protein_pos = example["protein_pos"].clone()
original_ligand_pos = example["ligand_pos"].clone()

# 执行中心化与旋转
gt_protein_pos, ligand_pos, offset = center_pos(
    original_protein_pos,
    original_ligand_pos,
    batch_protein,
    batch_ligand,
    mode=cfg.dynamics.center_pos_mode
)
transform_matrix = build_local_coordinate_system(gt_protein_pos, batch_protein, offset)
gt_protein_pos_rotated = torch.matmul(gt_protein_pos.unsqueeze(1), transform_matrix[batch_protein]).squeeze(1)
ligand_pos_rotated = torch.matmul(ligand_pos.unsqueeze(1), transform_matrix[batch_ligand]).squeeze(1)

# 绘制交互式 3D 可视化
plot_3d_interactive(gt_protein_pos, ligand_pos, title="Protein Positions: Original vs Rotated")
plot_3d_interactive(gt_protein_pos_rotated, ligand_pos_rotated, title="Ligand Positions: Original vs Rotated")


In [5]:
import  torch 
pt = torch.load("/root/zqs_project/drug/MolCRAFT/logs/root_bfn_sbdd/add_cluster_mean/3/test_outputs_v2/20241121-102845/generated.pt")


In [7]:
pt[0]

{'mol': <rdkit.Chem.rdchem.Mol at 0x7f7325ed4a40>,
 'ligand_filename': 'BSD_ASPTE_1_130_0/2z3h_A_rec_1wn6_bst_lig_tt_docked_3.sdf',
 'pred_pos': array([[ 19.07535744,  34.89083862, 106.89019775],
        [ 24.17471313,  35.18234634, 105.21728516],
        [ 26.1620903 ,  34.51578903, 103.89488983],
        [ 23.72534561,  33.87718582, 105.25028229],
        [ 29.57939529,  33.21274948, 102.87207794],
        [ 15.53934956,  34.09431839, 102.69537354],
        [ 21.09123993,  34.18183899, 105.84368134],
        [ 20.57048607,  34.76352692, 107.14292908],
        [ 18.2828083 ,  35.71738434, 104.71507263],
        [ 19.93759155,  33.862957  , 104.9485321 ],
        [ 23.64234543,  31.64995003, 104.0189743 ],
        [ 14.83694363,  35.21371078, 102.26055145],
        [ 16.29716873,  36.64857101, 103.52635956],
        [ 22.44967079,  33.5345192 , 105.79893494],
        [ 33.13725281,  33.74886703, 105.21717072],
        [ 15.20415211,  36.48918533, 102.6739502 ],
        [ 32.00395203,  