# GET项目 - 分子图可视化

本notebook用于可视化GET项目中PDBbind identity30数据集的分子复合物图结构。

## 功能包括：
1. 3D分子结构可视化
2. 2D图结构可视化
3. 连接矩阵热图
4. 交互式数据探索

In [None]:
# 导入必要的库
import os
import sys
import pickle
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import seaborn as sns
import pandas as pd
from IPython.display import display, HTML
import ipywidgets as widgets
from ipywidgets import interact, interactive, fixed, interact_manual

# 添加项目路径
PROJ_DIR = os.path.abspath('.')
sys.path.append(PROJ_DIR)

from visualize_graph import GraphVisualizer
from data.pdb_utils import VOCAB
from utils.logger import print_log

# 设置matplotlib
plt.rcParams['font.sans-serif'] = ['SimHei', 'Arial', 'DejaVu Sans']
plt.rcParams['axes.unicode_minus'] = False
%matplotlib inline

## 1. 数据集加载和基本信息

In [None]:
# 数据集路径配置
dataset_paths = {
    'train': './datasets/PDBBind/processed/identity30/train.pkl',
    'valid': './datasets/PDBBind/processed/identity30/valid.pkl',
    'test': './datasets/PDBBind/processed/identity30/test.pkl'
}

# 检查数据集文件
available_datasets = {}
for name, path in dataset_paths.items():
    if os.path.exists(path):
        available_datasets[name] = path
        print(f"✓ Found {name} dataset: {path}")
    else:
        print(f"✗ Dataset not found: {path}")

if not available_datasets:
    print("\n⚠️ No dataset files found. Please check the paths.")
    print("Expected paths:")
    for name, path in dataset_paths.items():
        print(f"  {name}: {path}")
else:
    print(f"\n✓ Found {len(available_datasets)} datasets")

In [None]:
# 加载数据集并显示基本统计信息
def load_and_analyze_dataset(dataset_path):
    """加载数据集并分析基本信息"""
    with open(dataset_path, 'rb') as f:
        data = pickle.load(f)
    
    print(f"Dataset size: {len(data)}")
    
    # 统计信息
    n_blocks = [len(item['B']) for item in data]
    n_atoms = [len(item['A']) for item in data]
    
    stats = {
        'Total samples': len(data),
        'Blocks - Min': min(n_blocks),
        'Blocks - Max': max(n_blocks),
        'Blocks - Mean': np.mean(n_blocks),
        'Atoms - Min': min(n_atoms),
        'Atoms - Max': max(n_atoms),
        'Atoms - Mean': np.mean(n_atoms)
    }
    
    return data, stats

# 分析第一个可用数据集
if available_datasets:
    first_dataset = list(available_datasets.keys())[0]
    first_path = available_datasets[first_dataset]
    
    print(f"Analyzing {first_dataset} dataset...")
    dataset, stats = load_and_analyze_dataset(first_path)
    
    # 显示统计信息
    stats_df = pd.DataFrame(list(stats.items()), columns=['Metric', 'Value'])
    display(stats_df)
    
    # 绘制分布图
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
    
    n_blocks = [len(item['B']) for item in dataset]
    n_atoms = [len(item['A']) for item in dataset]
    
    ax1.hist(n_blocks, bins=50, alpha=0.7, edgecolor='black')
    ax1.set_xlabel('Number of Blocks')
    ax1.set_ylabel('Frequency')
    ax1.set_title('Distribution of Block Numbers')
    ax1.grid(True, alpha=0.3)
    
    ax2.hist(n_atoms, bins=50, alpha=0.7, edgecolor='black')
    ax2.set_xlabel('Number of Atoms')
    ax2.set_ylabel('Frequency')
    ax2.set_title('Distribution of Atom Numbers')
    ax2.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()

## 2. 交互式可视化工具

In [None]:
# 创建可视化器
visualizer = GraphVisualizer(figsize=(12, 8))

# 交互式可视化函数
def interactive_visualize(dataset_name, sample_index, k_neighbors, plot_type):
    """交互式可视化函数"""
    if dataset_name not in available_datasets:
        print(f"Dataset {dataset_name} not available")
        return
    
    dataset_path = available_datasets[dataset_name]
    
    try:
        # 加载样本
        sample = visualizer.load_data(dataset_path, sample_index)
        info = visualizer.parse_sample_data(sample)
        
        print(f"Sample {sample_index} from {dataset_name} dataset:")
        print(f"  ID: {info['id']}")
        print(f"  Affinity: {info['affinity']}")
        print(f"  Blocks: {info['n_blocks']}, Atoms: {info['n_atoms']}")
        
        # 根据选择的图类型进行可视化
        if plot_type == '3D Structure':
            visualizer.plot_3d_structure(sample)
        elif plot_type == '2D Graph':
            visualizer.plot_2d_graph(sample, k_neighbors)
        elif plot_type == 'Connectivity Matrix':
            visualizer.plot_connectivity_matrix(sample, k_neighbors)
        elif plot_type == 'All':
            visualizer.plot_3d_structure(sample)
            visualizer.plot_2d_graph(sample, k_neighbors)
            visualizer.plot_connectivity_matrix(sample, k_neighbors)
            
    except Exception as e:
        print(f"Error: {e}")

# 创建交互式控件
dataset_dropdown = widgets.Dropdown(
    options=list(available_datasets.keys()),
    value=list(available_datasets.keys())[0] if available_datasets else None,
    description='Dataset:',
    disabled=False,
)

sample_slider = widgets.IntSlider(
    value=0,
    min=0,
    max=min(100, len(dataset)-1) if 'dataset' in locals() else 10,
    step=1,
    description='Sample Index:',
    disabled=False,
    continuous_update=False,
    orientation='horizontal',
    readout=True,
    readout_format='d'
)

k_neighbors_slider = widgets.IntSlider(
    value=9,
    min=3,
    max=20,
    step=1,
    description='K Neighbors:',
    disabled=False,
    continuous_update=False,
    orientation='horizontal',
    readout=True,
    readout_format='d'
)

plot_type_dropdown = widgets.Dropdown(
    options=['3D Structure', '2D Graph', 'Connectivity Matrix', 'All'],
    value='3D Structure',
    description='Plot Type:',
    disabled=False,
)

if available_datasets:
    # 显示交互式控件
    interactive_plot = interactive(interactive_visualize, 
                                 dataset_name=dataset_dropdown,
                                 sample_index=sample_slider,
                                 k_neighbors=k_neighbors_slider,
                                 plot_type=plot_type_dropdown)
    
    display(interactive_plot)
else:
    print("No datasets available for interactive visualization")

## 3. 批量可视化和保存

In [None]:
# 批量可视化函数
def batch_visualize(dataset_name, num_samples=5, k_neighbors=9):
    """批量可视化多个样本"""
    if dataset_name not in available_datasets:
        print(f"Dataset {dataset_name} not available")
        return
    
    dataset_path = available_datasets[dataset_name]
    output_dir = f'./visualization_output/{dataset_name}_batch'
    os.makedirs(output_dir, exist_ok=True)
    
    print(f"Batch visualizing {num_samples} samples from {dataset_name} dataset...")
    print(f"Output directory: {output_dir}")
    
    for i in range(num_samples):
        try:
            print(f"\nProcessing sample {i}...")
            visualizer.visualize_complete(
                dataset_path=dataset_path,
                index=i,
                k_neighbors=k_neighbors,
                output_dir=output_dir
            )
        except Exception as e:
            print(f"Error processing sample {i}: {e}")
            continue
    
    print(f"\nBatch visualization complete! Check {output_dir} for results.")

# 批量可视化控件
batch_dataset_dropdown = widgets.Dropdown(
    options=list(available_datasets.keys()),
    value=list(available_datasets.keys())[0] if available_datasets else None,
    description='Dataset:',
    disabled=False,
)

batch_num_slider = widgets.IntSlider(
    value=5,
    min=1,
    max=20,
    step=1,
    description='Num Samples:',
    disabled=False,
    continuous_update=False,
    orientation='horizontal',
    readout=True,
    readout_format='d'
)

batch_k_slider = widgets.IntSlider(
    value=9,
    min=3,
    max=20,
    step=1,
    description='K Neighbors:',
    disabled=False,
    continuous_update=False,
    orientation='horizontal',
    readout=True,
    readout_format='d'
)

if available_datasets:
    batch_button = widgets.Button(
        description='Start Batch Visualization',
        disabled=False,
        button_style='success',
        tooltip='Click to start batch visualization',
    )
    
    def on_batch_button_clicked(b):
        batch_visualize(batch_dataset_dropdown.value, 
                       batch_num_slider.value, 
                       batch_k_slider.value)
    
    batch_button.on_click(on_batch_button_clicked)
    
    display(widgets.VBox([
        widgets.HTML("<h3>Batch Visualization Settings</h3>"),
        batch_dataset_dropdown,
        batch_num_slider,
        batch_k_slider,
        batch_button
    ]))
else:
    print("No datasets available for batch visualization")

## 4. 数据分析和统计

In [None]:
# 分析数据集中的原子和块类型分布
def analyze_dataset_composition(dataset_path, dataset_name):
    """分析数据集组成"""
    with open(dataset_path, 'rb') as f:
        data = pickle.load(f)
    
    print(f"\nAnalyzing {dataset_name} dataset composition...")
    
    # 统计原子类型
    atom_counts = {}
    block_counts = {}
    affinity_values = []
    
    for item in data:
        # 原子类型统计
        for atom_type in item['A']:
            atom_counts[atom_type] = atom_counts.get(atom_type, 0) + 1
        
        # 块类型统计
        for block_type in item['B']:
            block_counts[block_type] = block_counts.get(block_type, 0) + 1
        
        # 亲和力值
        if 'affinity' in item:
            affinity_values.append(item['affinity'])
    
    # 绘制分布图
    fig, axes = plt.subplots(2, 2, figsize=(15, 10))
    
    # 原子类型分布
    atom_types = list(atom_counts.keys())
    atom_counts_list = list(atom_counts.values())
    
    ax1 = axes[0, 0]
    ax1.bar(range(len(atom_types)), atom_counts_list)
    ax1.set_title('Atom Type Distribution')
    ax1.set_xlabel('Atom Type Index')
    ax1.set_ylabel('Count')
    
    # 块类型分布
    block_types = list(block_counts.keys())
    block_counts_list = list(block_counts.values())
    
    ax2 = axes[0, 1]
    ax2.bar(range(len(block_types)), block_counts_list)
    ax2.set_title('Block Type Distribution')
    ax2.set_xlabel('Block Type Index')
    ax2.set_ylabel('Count')
    
    # 亲和力分布
    if affinity_values:
        ax3 = axes[1, 0]
        ax3.hist(affinity_values, bins=50, alpha=0.7, edgecolor='black')
        ax3.set_title('Affinity Distribution')
        ax3.set_xlabel('Affinity Value')
        ax3.set_ylabel('Frequency')
        ax3.grid(True, alpha=0.3)
    
    # 分子大小分布
    molecule_sizes = [len(item['A']) for item in data]
    ax4 = axes[1, 1]
    ax4.hist(molecule_sizes, bins=50, alpha=0.7, edgecolor='black')
    ax4.set_title('Molecule Size Distribution')
    ax4.set_xlabel('Number of Atoms')
    ax4.set_ylabel('Frequency')
    ax4.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()
    
    # 打印统计信息
    print(f"\n{dataset_name} Dataset Statistics:")
    print(f"  Total samples: {len(data)}")
    print(f"  Unique atom types: {len(atom_types)}")
    print(f"  Unique block types: {len(block_types)}")
    if affinity_values:
        print(f"  Affinity range: {min(affinity_values):.2f} to {max(affinity_values):.2f}")
        print(f"  Affinity mean: {np.mean(affinity_values):.2f}")
    
    return {
        'atom_counts': atom_counts,
        'block_counts': block_counts,
        'affinity_values': affinity_values,
        'molecule_sizes': molecule_sizes
    }

# 分析所有可用数据集
dataset_stats = {}
for dataset_name, dataset_path in available_datasets.items():
    stats = analyze_dataset_composition(dataset_path, dataset_name)
    dataset_stats[dataset_name] = stats

## 5. 使用说明和总结

### 使用说明

1. **交互式可视化**：使用上面的交互式控件来探索不同的样本和可视化选项
2. **批量可视化**：使用批量可视化功能将多个样本的可视化结果保存到文件
3. **数据分析**：查看数据集的组成和统计信息

### 可视化类型说明

- **3D Structure**: 显示分子的三维结构，不同颜色表示不同的段(segment)
- **2D Graph**: 显示块级别的图连接关系和统计信息
- **Connectivity Matrix**: 显示块之间连接关系的热图矩阵

### 输出文件

可视化结果将保存在 `./visualization_output/` 目录中，包含：
- `*_3d_structure.png`: 3D分子结构图
- `*_2d_graph.png`: 2D图结构和统计信息
- `*_connectivity_matrix.png`: 块连接矩阵热图

In [None]:
# 显示可用的命令行工具使用方法
print("\n" + "="*60)
print("命令行工具使用方法")
print("="*60)
print("\n1. 基本可视化:")
print("   python visualize_graph.py --dataset ./datasets/PDBBind/processed/identity30/test.pkl --index 0")
print("\n2. 指定输出目录:")
print("   python visualize_graph.py --dataset ./datasets/PDBBind/processed/identity30/test.pkl --index 0 --output_dir ./output")
print("\n3. 调整k邻居数:")
print("   python visualize_graph.py --dataset ./datasets/PDBBind/processed/identity30/test.pkl --index 0 --k_neighbors 15")
print("\n4. 快速演示:")
print("   python visualize_pdbbind_example.py --demo")
print("\n5. 完整可视化:")
print("   python visualize_pdbbind_example.py")
print("\n" + "="*60)