In [None]:
# Required imports
import os
import subprocess

# Define target protein and the residue to center triangle attention on
PROT = "6KWC"
TRI_RESIDUE_IDX = 18

# Define all relevant directories
BASE_DATA_DIR = "/ime/hdd/rhaas/SUP-5301/database" # path to AlphaFold database

# Local paths for saving results (these probably can remain unchanged)
ATTN_MAP_DIR = f"./outputs/attention_files_{PROT}_demo_tri_{TRI_RESIDUE_IDX}" # directory for saving text files with top-k attention scores
ALIGNMENT_DIR = "./examples/monomer/alignments" # directory containing pre-computed alignment files (and MSAs)
OUTPUT_DIR = f"./outputs/my_outputs_align_{PROT}_demo_tri_{TRI_RESIDUE_IDX}" # directory to save outputs
IMAGE_OUTPUT_DIR = f"./outputs/attention_images_{PROT}_demo_tri_{TRI_RESIDUE_IDX}"
FASTA_DIR = f"./examples/monomer/fasta_dir_{PROT}"

# Note: If this is a new protein, the ALIGNMENT_DIR does not need to be specified here or in the next cell
# In this case, the code will compute MSAs and alignments, which can take several hours


In [None]:
# Run OpenFold inference and save top attention scores to text files 
inference_cmd = f"""
python3 run_pretrained_openfold.py \
    {FASTA_DIR} \
    {BASE_DATA_DIR}/pdb_mmcif/mmcif_files \
    --use_precomputed_alignments {ALIGNMENT_DIR} \
    --output_dir {OUTPUT_DIR} \
    --config_preset model_1_ptm \
    --uniref90_database_path {BASE_DATA_DIR}/uniref90/uniref90.fasta \
    --mgnify_database_path {BASE_DATA_DIR}/mgnify/mgy_clusters_2022_05.fa \
    --pdb70_database_path {BASE_DATA_DIR}/pdb70/pdb70 \
    --uniclust30_database_path {BASE_DATA_DIR}/uniclust30/uniclust30_2018_08 \
    --bfd_database_path {BASE_DATA_DIR}/bfd/bfd_metaclust_clu_complete_id30_c90_final_seq.sorted_opt \
    --save_outputs \
    --model_device "cuda:0" \
    --attn_map_dir {ATTN_MAP_DIR} \
    --num_recycles_save 1 \
    --triangle_residue_idx {TRI_RESIDUE_IDX} \
    --demo_attn
"""

subprocess.run(inference_cmd, shell=True, check=True)


In [None]:
# Render predicted 3D structure and save as PNG image
from visualize_attention_general_utils import render_pdb_to_image

PDB_FILE = os.path.join(OUTPUT_DIR, f"predictions/{PROT}_1_model_1_ptm_relaxed.pdb")
FNAME = f"predicted_structure_{PROT}_tri_{TRI_RESIDUE_IDX}.png"

render_pdb_to_image(PDB_FILE, IMAGE_OUTPUT_DIR, FNAME)


In [None]:
# Import visualization utilities
from visualize_attention_3d_demo_utils import plot_pymol_attention_heads
from visualize_attention_arc_diagram_demo_utils import generate_arc_diagrams, parse_fasta_sequence

# Setup visualization output directories
output_dir_msa = os.path.join(IMAGE_OUTPUT_DIR, 'msa_row_attention_plots') # directory for saving msa attention 3D visuals
output_dir_tri = os.path.join(IMAGE_OUTPUT_DIR, 'tri_start_attention_plots') # directory for saving triangle attention 3D visuals
FASTA_PATH = f"/u/thayes/vizfold/examples/monomer/fasta_dir_{PROT}/{PROT}.fasta"
LAYER_IDX = 47 # selected layer for attention evaluation
TOP_K = 50 # show top-k attention links (limit to 500)

# Generate 3D attention plots for MSA row attention
plot_pymol_attention_heads(
    pdb_file=PDB_FILE,
    attention_dir=ATTN_MAP_DIR,
    output_dir=output_dir_msa,
    protein=PROT,
    attention_type="msa_row",
    top_k=TOP_K,
    layer_idx=LAYER_IDX
)

# Generate 3D attention plots for triangle start attention
plot_pymol_attention_heads(
    pdb_file=PDB_FILE,
    attention_dir=ATTN_MAP_DIR,
    output_dir=output_dir_tri,
    protein=PROT,
    attention_type="triangle_start",
    residue_indices=[TRI_RESIDUE_IDX],
    top_k=TOP_K,
    layer_idx=LAYER_IDX
)

# Parse FASTA for arc diagrams
residue_seq = parse_fasta_sequence(FASTA_PATH)

# Generate arc diagrams for MSA row attention
generate_arc_diagrams(
    attention_dir=ATTN_MAP_DIR,
    residue_sequence=residue_seq,
    output_dir=output_dir_msa,
    protein=PROT,
    attention_type="msa_row",
    top_k=TOP_K,
    layer_idx=LAYER_IDX
)

# Generate arc diagrams for triangle start attention
generate_arc_diagrams(
    attention_dir=ATTN_MAP_DIR,
    residue_sequence=residue_seq,
    output_dir=output_dir_tri,
    protein=PROT,
    attention_type="triangle_start",
    residue_indices=[TRI_RESIDUE_IDX],
    top_k=TOP_K,
    layer_idx=LAYER_IDX
)


In [None]:
# Import function for combining attention plots
from visualize_attention_general_utils import generate_combined_attention_panels

# Combine MSA row plots
generate_combined_attention_panels(
    attention_type="msa_row",
    protein=PROT,
    layer_idx=LAYER_IDX,
    output_dir_3d=output_dir_msa,
    output_dir_arc=output_dir_msa,
    combined_output_dir=IMAGE_OUTPUT_DIR,
)

# Combine triangle start plots
generate_combined_attention_panels(
    attention_type="triangle_start",
    protein=PROT,
    layer_idx=LAYER_IDX,
    output_dir_3d=output_dir_tri,
    output_dir_arc=output_dir_tri,
    combined_output_dir=IMAGE_OUTPUT_DIR,
    residue_indices=[TRI_RESIDUE_IDX]
)


In [None]:
# Generate contact map with attention overlay
from visualize_attention_general_utils import compute_contact_map, plot_contact_map_with_attention
from Bio.PDB import PDBParser
import numpy as np

# Parse PDB file to extract CA coordinates
parser = PDBParser(QUIET=True)
structure = parser.get_structure('protein', PDB_FILE)

# Extract CA atom coordinates
ca_coords = []
for model in structure:
    for chain in model:
        for residue in chain:
            if 'CA' in residue:
                ca_coords.append(residue['CA'].coord)

ca_coords = np.array(ca_coords)

# Compute contact map (threshold in Angstroms)
contact_map = compute_contact_map(ca_coords, threshold=8.0)

# Load attention edges from attention files
# For now, using empty list - populate with actual attention data if needed
attn_edges = []

# Create contact map visualization with attention overlay
fig, ax = plt.subplots(figsize=(8, 8))
plot_contact_map_with_attention(contact_map, attn_edges, len(ca_coords), max_edges=500, ax=ax)

# Save figure
contact_map_path = os.path.join(IMAGE_OUTPUT_DIR, f'contact_map_{PROT}.png')
plt.savefig(contact_map_path, dpi=100, bbox_inches='tight')
plt.show()

print(f"Saved contact map to {contact_map_path}")