In [None]:
import os
# RXT2060禁用cuequivariance
os.environ["SHOULD_USE_CUEQUIVARIANCE"] = "0"
os.environ["CUEQUIVARIANCE_USE_FALLBACK"] = "1"          
os.environ["DISABLE_CUEQUIVARIANCE"] = "1"
# 强制fp16
os.environ["TORCH_DTYPE"] = "fp16"                       
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
print(os.environ.get("SHOULD_USE_CUEQUIVARIANCE"))
print(os.environ.get("DISABLE_CUEQUIVARIANCE"))

# 目录
folders = ['out_cdr', 'out_cif']
for folder in folders:
    os.makedirs(os.path.join(os.getcwd(), folder), exist_ok=True)

# 量化ram/gpu/time
%load_ext memory_profiler
%load_ext autotime

# %%gpumem
from IPython.core.magic import register_cell_magic
import pynvml
import time

pynvml.nvmlInit()

@register_cell_magic
def gpumem(line, cell):
    handle = pynvml.nvmlDeviceGetHandleByIndex(0) 

    def get_mem():
        info = pynvml.nvmlDeviceGetMemoryInfo(handle)
        return info.used / 1024**3, info.total / 1024**3  # GB

    used_before, total = get_mem()
    print(f"Before: {used_before:.2f} GB / {total:.2f} GB")

    start = time.time()
    get_ipython().run_cell(cell)
    elapsed = time.time() - start

    used_after, _ = get_mem()
    print(f"After : {used_after:.2f} GB / {total:.2f} GB")
    print(f"Δ     : {used_after - used_before:+.2f} GB")
    print(f"Time  : {elapsed:.2f} s")



# AtomWorks可视化模块
from atomworks.io.utils.visualize import view

%%gpumem
%%memit

from lightning.fabric import seed_everything
from rfd3.engine import RFD3InferenceConfig, RFD3InferenceEngine
seed_everything(0)

import json
from rfd3.engine import RFD3InferenceConfig, RFD3InferenceEngine

spec = json.load(open('/home/alex/aidd/PDL1-4ZQK/protein_binder_design.json'))['pdl1_clean'] 


config = RFD3InferenceConfig(
    specification=spec,      
    diffusion_batch_size=2,
)
engine = RFD3InferenceEngine(**config)

# 生成10条backbone
outputs = engine.run(
    inputs='/home/alex/aidd/PDL1-4ZQK/pd_l1_clean.pdb', 
    out_dir='/home/alex/aidd/PDL1-4ZQK/out_cdr',
    n_batches=1,            
)

from pathlib import Path
from atomworks.io.utils.io_utils import load_any
import biotite.structure as struc

# 1. 列出所有生成的 CIF
cif_files = sorted(Path('out_cdr').glob('pd_l1_clean*_model_*.cif.gz'))

# 2. 取第一条
first_cif = cif_files[0]
atom_array = load_any(str(first_cif), model=1)

# 3. （可选）打印
print(f'Loaded {first_cif.name}  ({atom_array.shape[0]} atoms)')

# 4. 可视化
view(atom_array)

### ProteinMPNN序列设计

%%gpumem

from mpnn.inference_engines.mpnn import MPNNInferenceEngine

# Configure MPNN inference engine
# See mpnn.utils.inference.MPNN_GLOBAL_INFERENCE_DEFAULTS for all options
engine_config = {
    "model_type": "protein_mpnn",  # or "protein_mpnn" for vanilla ProteinMPNN
    "is_legacy_weights": True,    # Required for now for ligand_mpnn and protein_mpnn
    "out_directory": None,        # Return results in memory
    "write_structures": False,
    "write_fasta": False,
}

# Configure per-input inference options
# See mpnn.utils.inference.MPNN_PER_INPUT_INFERENCE_DEFAULTS for all options
input_configs = [
    {
        "batch_size": 10,         # Generate 10 sequences per structure
        "remove_waters": True,
    }
]

# Run sequence design on the RFD3-generated backbone
model = MPNNInferenceEngine(**engine_config)
mpnn_outputs = model.run(input_dicts=input_configs, atom_arrays=[atom_array])

%%gpumem
%%memit

from biotite.structure import get_residue_starts
from biotite.sequence import ProteinSequence

# Extract and display the designed sequences
print(f"Generated {len(mpnn_outputs)} designed sequences:\n")

for i, item in enumerate(mpnn_outputs):
    res_starts = get_residue_starts(item.atom_array)
    # Convert 3-letter codes to 1-letter using Biotite
    seq_1letter = ''.join(
        ProteinSequence.convert_letter_3to1(res_name)
        for res_name in item.atom_array.res_name[res_starts]
    )
    print(f"Sequence {i+1}: {seq_1letter}")

%%gpumem
%%memit

from rf3.inference_engines.rf3 import RF3InferenceEngine
from rf3.utils.inference import InferenceInput


# Initialize RF3 inference engine
# inference_engine = RF3InferenceEngine(ckpt_path='rf3', verbose=False)
inference_engine = RF3InferenceEngine(
    ckpt_path='rf3',
    verbose=False,
    n_recycles=1,              # 默认10 → 降到1，大幅减内存（recycle 是主要吃内存点）
    diffusion_batch_size=1,    # 默认5 → 降到1（batch=1 最省）
    num_steps=20               # 默认50 → 减半
)

# Create input from the MPNN-designed structure (first design)
# This re-folds the sequence to validate it adopts the intended structure
input_structure = InferenceInput.from_atom_array(atom_array, example_id="pdl1_clean")
rf3_outputs = inference_engine.run(inputs=input_structure)

# Outputs: dict mapping example_id -> list[RF3Output] (multiple models per input)
print(f"Output keys: {rf3_outputs.keys()}")
print(f"Number of models for 'example_protein': {len(rf3_outputs['pdl1_clean'])}")


### RF3结构预测

# Extract the top-ranked prediction
rf3_output = rf3_outputs["pdl1_clean"][0]

# Inspect RF3Output structure
print(f"RF3Output contains:")
print(f"  - atom_array: {len(rf3_output.atom_array)} atoms")
print(f"  - summary_confidences: {list(rf3_output.summary_confidences.keys())}")
print(f"  - confidences: {list(rf3_output.confidences.keys()) if rf3_output.confidences else None}")

# Visualize the predicted structure
view(rf3_output.atom_array)

# Summary confidences: overall model quality metrics
summary = rf3_output.summary_confidences

print("=== Summary Confidences ===")
print(f"  Overall pLDDT:    {summary['overall_plddt']:.3f}")
print(f"  Overall PAE:      {summary['overall_pae']:.2f} A")
print(f"  Overall PDE:      {summary['overall_pde']:.3f}")
print(f"  pTM:              {summary['ptm']:.3f}")
print(f"  ipTM:             {summary.get('iptm', 'N/A (single chain)')}")
print(f"  Ranking score:    {summary['ranking_score']:.3f}")
print(f"  Has clash:        {summary['has_clash']}")

# Detailed per-atom/residue confidences
conf = rf3_output.confidences

print("=== Per-Atom/Residue Confidences ===")
print(f"  atom_plddts:      {len(conf['atom_plddts'])} values (one per atom)")
print(f"  atom_chain_ids:   {len(conf['atom_chain_ids'])} values")
print(f"  token_chain_ids:  {len(conf['token_chain_ids'])} values (one per residue)")
print(f"  token_res_ids:    {len(conf['token_res_ids'])} values")
print(f"  PAE matrix:       {len(conf['pae'])}x{len(conf['pae'][0])}")

# Preview first 10 atom pLDDT scores
import numpy as np
print(f"\nFirst 10 atom pLDDTs: {np.round(conf['atom_plddts'][:10], 2).tolist()}")

## RMSD验证和导出cif

from biotite.structure import rmsd, superimpose
from atomworks.constants import PROTEIN_BACKBONE_ATOM_NAMES
import numpy as np

# Get structures for comparison
aa_generated = atom_array              # Original RFD3 backbone (from Section 1)
aa_refolded = rf3_output.atom_array    # RF3-predicted structure

# Filter to backbone atoms (N, CA, C, O)
bb_generated = aa_generated[np.isin(aa_generated.atom_name, PROTEIN_BACKBONE_ATOM_NAMES)]
bb_refolded = aa_refolded[np.isin(aa_refolded.atom_name, PROTEIN_BACKBONE_ATOM_NAMES)]

# Superimpose structures and calculate RMSD
bb_refolded_fitted, _ = superimpose(bb_generated, bb_refolded)
rmsd_value = rmsd(bb_generated, bb_refolded_fitted)

print(f"Backbone RMSD: {rmsd_value:.2f} A")
print(f"\nInterpretation: {'Excellent' if rmsd_value < 1.0 else 'Good' if rmsd_value < 2.0 else 'Moderate'} designability")

from biotite.structure import rmsd, superimpose
from atomworks.constants import PROTEIN_BACKBONE_ATOM_NAMES
import numpy as np

# Get structures for comparison
aa_generated = atom_array              # Original RFD3 backbone (from Section 1)
aa_refolded = rf3_output.atom_array    # RF3-predicted structure

# 只取 chain A 的结构（CDR-H3）
mask_a = (atom_array.chain_id == "A")
mask_a_refolded = (rf3_output.atom_array.chain_id == "A")

aa_generated_cdr = atom_array[mask_a]
aa_refolded_cdr = rf3_output.atom_array[mask_a_refolded]

# 进一步过滤 backbone 原子（N, CA, C, O）
bb_generated = aa_generated_cdr[np.isin(aa_generated_cdr.atom_name, PROTEIN_BACKBONE_ATOM_NAMES)]
bb_refolded  = aa_refolded_cdr[np.isin(aa_refolded_cdr.atom_name, PROTEIN_BACKBONE_ATOM_NAMES)]

# 叠合 & 计算 RMSD（只用 CDR 部分）
bb_refolded_fitted, _ = superimpose(bb_generated, bb_refolded)
rmsd_value = rmsd(bb_generated, bb_refolded_fitted)

print(f"CDR-H3 Backbone RMSD (chain A only): {rmsd_value:.2f} Å")
print(f"Interpretation: {'Excellent' if rmsd_value < 1.2 else 'Good' if rmsd_value < 2.0 else 'Moderate' if rmsd_value < 3.0 else 'Poor'} designability")

from atomworks.io.utils.io_utils import to_cif_file

# Export structures to CIF format for visualization in PyMOL/ChimeraX
to_cif_file(aa_generated, "/home/alex/aidd/PDL1-4ZQK/out_cif/pdl1_clean_generated.cif")
to_cif_file(aa_refolded, "/home/alex/aidd/PDL1-4ZQK/out_cif/pdl1_clean_refolded.cif")

to_cif_file(aa_generated_cdr, "/home/alex/aidd/PDL1-4ZQK/out_cif/pdl1_clean_generated_cdr.cif")
to_cif_file(aa_refolded_cdr, "/home/alex/aidd/PDL1-4ZQK/out_cif/pdl1_clean_refolded_cdr.cif")











