In [None]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
from Bio.PDB import PDBParser, PPBuilder, DSSP
from Bio.SeqUtils import seq1
import pymol
from pymol import cmd
from scipy.spatial.distance import cosine
import torch
from transformers import AutoTokenizer, AutoModel
from bio_embeddings.embed import ProtTransT5XLU50Embedder
import subprocess
import logging
import warnings
import pkg_resources
import shutil

# Set up logging
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)

# Define directories
BASE_DIR = Path("pdb_files")
TASK_DIRS = {
    "homologous_pairs": BASE_DIR / "homologous_pairs",
    "mutants": BASE_DIR / "mutants",
    "single_domain": BASE_DIR / "single_domain",
    "multi_domain": BASE_DIR / "multi_domain",
    "disentanglement": BASE_DIR / "disentanglement"
}
OUTPUT_DIR = Path("analysis_outputs")
OUTPUT_DIR.mkdir(exist_ok=True)

def check_environment():
    """Verify key dependencies are available."""
    try:
        import bio_embeddings
        from bio_embeddings.embed import ProtTransT5XLU50Embedder
        import transformers
        import scipy
        bio_embeddings_version = pkg_resources.get_distribution("bio-embeddings").version
        transformers_version = getattr(transformers, "__version__", "unknown")
        scipy_version = getattr(scipy, "__version__", "unknown")
        logger.info(f"Dependencies checked: bio-embeddings={bio_embeddings_version}, "
                    f"transformers={transformers_version}, scipy={scipy_version}")
        # Check for mkdssp
        if not shutil.which("mkdssp"):
            logger.warning("mkdssp binary not found. Install with 'conda install -c salilab dssp' or use fallback secondary structure.")
    except Exception as e:
        logger.error(f"Dependency check failed: {e}")
        raise

def check_directory_files(task_dir, task_name, min_files=1):
    """Check if directory has enough PDB files for analysis."""
    pdb_files = list(task_dir.glob("*.pdb"))
    if len(pdb_files) < min_files:
        logger.warning(f"Insufficient files in {task_name} directory ({len(pdb_files)} files).")
        return False, pdb_files
    logger.info(f"Found {len(pdb_files)} files in {task_name} directory.")
    return True, pdb_files

def load_esm2_model():
    """Load ESM-2 model and tokenizer."""
    try:
        model_name = "facebook/esm2_t33_650M_UR50D"
        tokenizer = AutoTokenizer.from_pretrained(model_name)
        model = AutoModel.from_pretrained(model_name)
        model.eval()
        return tokenizer, model
    except Exception as e:
        logger.error(f"Failed to load ESM-2 model: {e}")
        raise

def load_prott5_model():
    """Load ProtT5 model via bio-embeddings."""
    try:
        return ProtTransT5XLU50Embedder()
    except Exception as e:
        logger.error(f"Failed to load ProtT5 model: {e}")
        raise

def get_sequence_from_pdb(pdb_file):
    """Extract sequence from PDB file."""
    try:
        parser = PDBParser(QUIET=True)
        structure = parser.get_structure("protein", pdb_file)
        ppb = PPBuilder()
        seq = ""
        for pp in ppb.build_peptides(structure):
            seq += seq1(pp.get_sequence())
        if not seq:
            raise ValueError("No sequence extracted from PDB")
        return seq
    except Exception as e:
        logger.error(f"Failed to extract sequence from {pdb_file}: {e}")
        return None

def compute_esm2_embedding(sequence, tokenizer, model, per_residue=False):
    """Compute ESM-2 embedding (mean or per-residue)."""
    try:
        inputs = tokenizer(sequence, return_tensors="pt", padding=True, truncation=True)
        with torch.no_grad():
            outputs = model(**inputs)
        if per_residue:
            embedding = outputs.last_hidden_state[0, 1:len(sequence)+1].numpy()
        else:
            embedding = outputs.last_hidden_state.mean(dim=1).squeeze().numpy()
        logger.debug(f"ESM-2 embedding shape: {embedding.shape}")
        return embedding
    except Exception as e:
        logger.error(f"Failed to compute ESM-2 embedding: {e}")
        return None

def compute_prott5_embedding(sequence, embedder):
    """Compute ProtT5 embedding for a sequence."""
    try:
        embedding = embedder.embed(sequence)
        embedding = np.squeeze(embedding)
        if embedding.ndim > 1:
            embedding = embedding.mean(axis=0)
        logger.debug(f"ProtT5 embedding shape: {embedding.shape}")
        return embedding
    except Exception as e:
        logger.error(f"Failed to compute ProtT5 embedding: {e}")
        return None

def compute_rmsd(pdb1, pdb2, per_residue=False):
    """Compute RMSD (global or per-residue) using PyMOL."""
    try:
        cmd.reinitialize()
        cmd.load(pdb1, "struct1")
        cmd.load(pdb2, "struct2")
        cmd.align("struct1 and name CA", "struct2 and name CA")
        if per_residue:
            stored_dists = []
            cmd.iterate_state(1, "struct1 and name CA", "stored_dists.append([resi, x, y, z])", space={"stored_dists": stored_dists})
            stored_dists2 = []
            cmd.iterate_state(1, "struct2 and name CA", "stored_dists2.append([resi, x, y, z])", space={"stored_dists2": stored_dists2})
            rmsd_per_res = []
            for (resi1, x1, y1, z1), (resi2, x2, y2, z2) in zip(stored_dists, stored_dists2):
                dist = np.sqrt((x1-x2)**2 + (y1-y2)**2 + (z1-z2)**2)
                rmsd_per_res.append((resi1, dist))
            return rmsd_per_res
        else:
            rmsd = cmd.align("struct1 and name CA", "struct2 and name CA")[0]
            return rmsd
    except Exception as e:
        logger.error(f"Failed to compute RMSD for {pdb1} and {pdb2}: {e}")
        return None

def compute_foldx_score(pdb_file, mutation):
    """Compute FoldX stability score (placeholder for real implementation)."""
    try:
        foldx_path = "/path/to/foldx"  # Update with FoldX binary path
        result = subprocess.run(
            [foldx_path, "--command=Stability", f"--pdb={pdb_file}", f"--mutant={mutation}"],
            capture_output=True, text=True
        )
        energy = float(result.stdout.split("\n")[-1].split()[-1])
        return energy
    except Exception:
        logger.warning(f"Using placeholder FoldX score for {pdb_file}")
        return 2.5  # Placeholder value

def compute_dssp_features(pdb_file):
    """Extract secondary structure using Bio.PDB.DSSP or fallback."""
    try:
        parser = PDBParser(QUIET=True)
        structure = parser.get_structure("protein", pdb_file)
        model = structure[0]
        dssp = DSSP(model, pdb_file, dssp="mkdssp")  # Requires mkdssp binary
        ss = []
        for key in dssp.keys():
            ss_type = dssp[key][2]
            ss.append(1 if ss_type in ["H", "G", "I"] else 0)  # Helix=1, Other=0
        return np.array(ss)
    except Exception as e:
        logger.warning(f"Failed to compute DSSP features for {pdb_file}: {e}. Using fallback.")
        # Fallback: Simplified secondary structure from PyMOL
        try:
            cmd.reinitialize()
            cmd.load(str(pdb_file), "protein")
            stored_ss = []
            cmd.iterate("name CA", "stored_ss.append([resi, ss])", space={"stored_ss": stored_ss})
            ss = [1 if item[1] == "H" else 0 for item in stored_ss]
            return np.array(ss) if ss else None
        except Exception as e:
            logger.error(f"Fallback secondary structure failed for {pdb_file}: {e}")
            return None

def find_differing_residue(seq1, seq2):
    """Identify the first differing residue between two sequences."""
    if len(seq1) != len(seq2):
        logger.warning("Sequences have different lengths")
        return None
    differences = [(i + 1, aa1, aa2) for i, (aa1, aa2) in enumerate(zip(seq1, seq2)) if aa1 != aa2]
    if len(differences) == 1:
        return differences[0]  # (res_num, aa1, aa2)
    logger.warning(f"Expected one differing residue, found {len(differences)}")
    return None

def visualize_protein(pdb_file, output_file, color="blue", highlight_residue=None):
    """Visualize protein structure using PyMOL with white background."""
    try:
        cmd.reinitialize()
        cmd.load(str(pdb_file), "protein")
        cmd.show("cartoon")
        cmd.color(color, "protein")
        if highlight_residue:
            cmd.select("highlight", f"resi {highlight_residue}")
            cmd.color("red", "highlight")
        cmd.bg_color("white")
        cmd.set("ray_opaque_background", 1)
        cmd.png(str(output_file), width=800, height=600, ray=1)
        logger.info(f"Saved visualization to {output_file}")
    except Exception as e:
        logger.error(f"Failed to visualize {pdb_file}: {e}")

def case_1_homologous_pairs():
    """Analyze one homologous pair from real data."""
    task_dir = TASK_DIRS["homologous_pairs"]
    has_files, pdb_files = check_directory_files(task_dir, "homologous_pairs", min_files=2)
    if not has_files:
        return
    
    # Use first two PDBs from real data
    pdb1, pdb2 = pdb_files[0], pdb_files[1]
    
    tokenizer, esm2_model = load_esm2_model()
    prott5_embedder = load_prott5_model()
    
    # Compute global RMSD and embeddings
    rmsd = compute_rmsd(pdb1, pdb2)
    seq1 = get_sequence_from_pdb(pdb1)
    seq2 = get_sequence_from_pdb(pdb2)
    if not seq1 or not seq2 or rmsd is None:
        logger.warning(f"Skipping case 1 due to invalid sequence or RMSD for {pdb1.name}, {pdb2.name}")
        return
    
    esm2_emb1 = compute_esm2_embedding(seq1, tokenizer, esm2_model)
    esm2_emb2 = compute_esm2_embedding(seq2, tokenizer, esm2_model)
    prott5_emb1 = compute_prott5_embedding(seq1, prott5_embedder)
    prott5_emb2 = compute_prott5_embedding(seq2, prott5_embedder)
    if any(emb is None for emb in [esm2_emb1, esm2_emb2, prott5_emb1, prott5_emb2]):
        return
    
    esm2_dist = cosine(esm2_emb1, esm2_emb2)
    prott5_dist = cosine(prott5_emb1, prott5_emb2)
    
    # Residue-based analysis
    esm2_res_emb1 = compute_esm2_embedding(seq1, tokenizer, esm2_model, per_residue=True)
    esm2_res_emb2 = compute_esm2_embedding(seq2, tokenizer, esm2_model, per_residue=True)
    ss1 = compute_dssp_features(pdb1)
    ss2 = compute_dssp_features(pdb2)
    rmsd_per_res = compute_rmsd(pdb1, pdb2, per_residue=True)
    
    if esm2_res_emb1 is None or esm2_res_emb2 is None or ss1 is None or ss2 is None or rmsd_per_res is None:
        logger.warning(f"Skipping residue-based analysis for {pdb1.name}, {pdb2.name}")
        return
    
    # Compute per-residue embedding distances
    res_results = []
    for i in range(min(len(esm2_res_emb1), len(esm2_res_emb2), len(ss1), len(ss2))):
        res_dist = cosine(esm2_res_emb1[i], esm2_res_emb2[i])
        ss_match = 1 if ss1[i] == ss2[i] else 0
        res_rmsd = next((dist for resi, dist in rmsd_per_res if int(resi) == i+1), np.nan)
        res_results.append({
            "residue": i+1,
            "embedding_distance": res_dist,
            "ss_match": ss_match,
            "rmsd": res_rmsd
        })
    
    # Save results
    df = pd.DataFrame(res_results)
    df.to_csv(OUTPUT_DIR / "case_1_homologous_results.csv", index=False)
    
    # Plot residue-based comparison
    plt.figure(figsize=(10, 6))
    plt.plot(df["residue"], df["embedding_distance"], label="Embedding Distance")
    plt.plot(df["residue"], df["rmsd"], label="Per-Residue RMSD")
    plt.title(f"Homologous Pair: {pdb1.name} vs. {pdb2.name}")
    plt.xlabel("Residue Position")
    plt.ylabel("Distance/RMSD")
    plt.legend()
    plt.savefig(OUTPUT_DIR / "case_1_homologous_plot.png")
    plt.close()
    
    # Visualize aligned structures
    visualize_protein(pdb1, OUTPUT_DIR / "case_1_homologous_pdb1.png", color="blue")
    visualize_protein(pdb2, OUTPUT_DIR / "case_1_homologous_pdb2.png", color="green")

def case_2_mutants():
    """Analyze a wild-type vs. mutant pair, focusing on the differing residue."""
    task_dir = TASK_DIRS["mutants"]
    has_files, pdb_files = check_directory_files(task_dir, "mutants", min_files=2)
    if not has_files:
        return
    
    # Find a valid wild-type/mutant pair with one differing residue
    wt_pdb, mut_pdb, diff_res = None, None, None
    for i in range(len(pdb_files)):
        for j in range(i + 1, len(pdb_files)):
            pdb1, pdb2 = pdb_files[i], pdb_files[j]
            seq1 = get_sequence_from_pdb(pdb1)
            seq2 = get_sequence_from_pdb(pdb2)
            if seq1 and seq2:
                diff = find_differing_residue(seq1, seq2)
                if diff:
                    wt_pdb, mut_pdb = pdb1, pdb2
                    diff_res = diff
                    break
        if wt_pdb:
            break
    
    if not wt_pdb or not diff_res:
        logger.warning("No valid wild-type/mutant pair with one differing residue found")
        return
    
    res_num, wt_aa, mut_aa = diff_res
    logger.info(f"Selected pair: {wt_pdb.name} (WT) vs. {mut_pdb.name} (Mutant). Differing residue: {wt_aa}{res_num}{mut_aa}")
    
    tokenizer, esm2_model = load_esm2_model()
    prott5_embedder = load_prott5_model()
    
    # Get sequences
    wt_seq = get_sequence_from_pdb(wt_pdb)
    mut_seq = get_sequence_from_pdb(mut_pdb)
    if not wt_seq or not mut_seq:
        logger.warning(f"Skipping case 2 due to invalid sequences")
        return
    
    # Compute global RMSD and embeddings
    rmsd = compute_rmsd(wt_pdb, mut_pdb)
    if rmsd is None:
        logger.warning(f"Skipping case 2 due to invalid RMSD")
        return
    
    wt_esm2_emb = compute_esm2_embedding(wt_seq, tokenizer, esm2_model)
    mut_esm2_emb = compute_esm2_embedding(mut_seq, tokenizer, esm2_model)
    wt_prott5_emb = compute_prott5_embedding(wt_seq, prott5_embedder)
    mut_prott5_emb = compute_prott5_embedding(mut_seq, prott5_embedder)
    if any(emb is None for emb in [wt_esm2_emb, mut_esm2_emb, wt_prott5_emb, mut_prott5_emb]):
        return
    
    esm2_dist = cosine(wt_esm2_emb, mut_esm2_emb)
    prott5_dist = cosine(wt_prott5_emb, mut_prott5_emb)
    
    # Residue-based analysis
    wt_res_emb = compute_esm2_embedding(wt_seq, tokenizer, esm2_model, per_residue=True)
    mut_res_emb = compute_esm2_embedding(mut_seq, tokenizer, esm2_model, per_residue=True)
    wt_ss = compute_dssp_features(wt_pdb)
    mut_ss = compute_dssp_features(mut_pdb)
    rmsd_per_res = compute_rmsd(wt_pdb, mut_pdb, per_residue=True)
    
    if wt_res_emb is None or mut_res_emb is None or wt_ss is None or mut_ss is None or rmsd_per_res is None:
        logger.warning(f"Skipping residue-based analysis for {wt_pdb.name}, {mut_pdb.name}")
        return
    
    # Compute per-residue embedding distances
    res_results = []
    for i in range(min(len(wt_res_emb), len(mut_res_emb), len(wt_ss), len(mut_ss))):
        res_dist = cosine(wt_res_emb[i], mut_res_emb[i])
        ss_match = 1 if wt_ss[i] == mut_ss[i] else 0
        res_rmsd = next((dist for resi, dist in rmsd_per_res if int(resi) == i+1), np.nan)
        is_diff_res = 1 if (i + 1) == res_num else 0
        res_results.append({
            "residue": i+1,
            "embedding_distance": res_dist,
            "ss_match": ss_match,
            "rmsd": res_rmsd,
            "is_differing_residue": is_diff_res
        })
    
    # Save results
    df = pd.DataFrame(res_results)
    df.to_csv(OUTPUT_DIR / "case_2_mutant_results.csv", index=False)
    
    # Plot residue-based comparison, highlighting differing residue
    plt.figure(figsize=(10, 6))
    plt.plot(df["residue"], df["embedding_distance"], label="ESM-2 Embedding Distance")
    plt.plot(df["residue"], df["rmsd"], label="Per-Residue RMSD")
    plt.axvline(x=res_num, color="red", linestyle="--", label=f"Mutation: {wt_aa}{res_num}{mut_aa}")
    plt.title(f"Wild-Type vs. Mutant: {wt_pdb.name} vs. {mut_pdb.name}")
    plt.xlabel("Residue Position")
    plt.ylabel("Distance/RMSD")
    plt.legend()
    plt.savefig(OUTPUT_DIR / "case_2_mutant_plot.png")
    plt.close()
    
    # Visualize structures, highlighting differing residue
    visualize_protein(wt_pdb, OUTPUT_DIR / "case_2_wt.png", color="blue", highlight_residue=res_num)
    visualize_protein(mut_pdb, OUTPUT_DIR / "case_2_mutant.png", color="green", highlight_residue=res_num)

def case_3_domains():
    """Analyze one single-domain and one multi-domain structure."""
    single_dir = TASK_DIRS["single_domain"]
    multi_dir = TASK_DIRS["multi_domain"]
    has_single, single_files = check_directory_files(single_dir, "single_domain", min_files=1)
    has_multi, multi_files = check_directory_files(multi_dir, "multi_domain", min_files=1)
    if not (has_single and has_multi):
        return
    
    single_pdb = single_files[0]
    multi_pdb = multi_files[0]
    
    tokenizer, esm2_model = load_esm2_model()
    prott5_embedder = load_prott5_model()
    
    seq_single = get_sequence_from_pdb(single_pdb)
    seq_multi = get_sequence_from_pdb(multi_pdb)
    if not seq_single or not seq_multi:
        logger.warning(f"Skipping case 3 due to invalid sequences")
        return
    
    esm2_emb_single = compute_esm2_embedding(seq_single, tokenizer, esm2_model)
    esm2_emb_multi = compute_esm2_embedding(seq_multi, tokenizer, esm2_model)
    prott5_emb_single = compute_prott5_embedding(seq_single, prott5_embedder)
    prott5_emb_multi = compute_prott5_embedding(seq_multi, prott5_embedder)
    if any(emb is None for emb in [esm2_emb_single, esm2_emb_multi, prott5_emb_single, prott5_emb_multi]):
        return
    
    # Residue-based analysis for single-domain
    esm2_res_emb = compute_esm2_embedding(seq_single, tokenizer, esm2_model, per_residue=True)
    ss = compute_dssp_features(single_pdb)
    if esm2_res_emb is None or ss is None:
        logger.warning(f"Skipping residue-based analysis for {single_pdb.name}")
        return
    
    res_results = []
    for i in range(min(len(esm2_res_emb), len(ss))):
        res_norm = np.linalg.norm(esm2_res_emb[i])
        ss_val = ss[i]
        res_results.append({
            "residue": i+1,
            "embedding_norm": res_norm,
            "secondary_structure": ss_val
        })
    
    # Save results
    df = pd.DataFrame(res_results)
    df.to_csv(OUTPUT_DIR / "case_3_single_domain_results.csv", index=False)
    
    # Plot residue-based comparison
    plt.figure(figsize=(10, 6))
    plt.plot(df["residue"], df["embedding_norm"], label="Embedding Norm")
    plt.plot(df["residue"], df["secondary_structure"], label="Secondary Structure (Helix=1)")
    plt.title(f"Single-Domain: {single_pdb.name}")
    plt.xlabel("Residue Position")
    plt.ylabel("Value")
    plt.legend()
    plt.savefig(OUTPUT_DIR / "case_3_single_domain_plot.png")
    plt.close()
    
    # Visualize structures
    visualize_protein(single_pdb, OUTPUT_DIR / "case_3_single_domain.png", color="blue")
    visualize_protein(multi_pdb, OUTPUT_DIR / "case_3_multi_domain.png", color="green")

def case_4_disentanglement():
    """Analyze one structure for disentanglement."""
    task_dir = TASK_DIRS["disentanglement"]
    has_files, pdb_files = check_directory_files(task_dir, "disentanglement", min_files=1)
    if not has_files:
        return
    
    pdb_file = pdb_files[0]
    tokenizer, esm2_model = load_esm2_model()
    
    seq = get_sequence_from_pdb(pdb_file)
    if not seq:
        logger.warning(f"Skipping case 4 due to invalid sequence for {pdb_file.name}")
        return
    
    esm2_res_emb = compute_esm2_embedding(seq, tokenizer, esm2_model, per_residue=True)
    ss = compute_dssp_features(pdb_file)
    if esm2_res_emb is None or ss is None:
        logger.warning(f"Skipping residue-based analysis for {pdb_file.name}")
        return
    
    res_results = []
    for i in range(min(len(esm2_res_emb), len(ss))):
        res_norm = np.linalg.norm(esm2_res_emb[i])
        ss_val = ss[i]
        res_results.append({
            "residue": i+1,
            "embedding_norm": res_norm,
            "secondary_structure": ss_val
        })
    
    # Save results
    df = pd.DataFrame(res_results)
    df.to_csv(OUTPUT_DIR / "case_4_disentanglement_results.csv", index=False)
    
    # Plot residue-based comparison
    plt.figure(figsize=(10, 6))
    plt.plot(df["residue"], df["embedding_norm"], label="Embedding Norm")
    plt.plot(df["residue"], df["secondary_structure"], label="Secondary Structure (Helix=1)")
    plt.title(f"Disentanglement: {pdb_file.name}")
    plt.xlabel("Residue Position")
    plt.ylabel("Value")
    plt.legend()
    plt.savefig(OUTPUT_DIR / "case_4_disentanglement_plot.png")
    plt.close()
    
    # Visualize structure
    visualize_protein(pdb_file, OUTPUT_DIR / "case_4_disentanglement.png", color="blue")

def main():
    """Run focused analysis for one case per objective."""
    check_environment()
    
    logger.info("Running Case 1: Homologous Pairs...")
    case_1_homologous_pairs()
    
    logger.info("Running Case 2: Mutants...")
    case_2_mutants()
    
    logger.info("Running Case 3: Domains...")
    case_3_domains()
    
    logger.info("Running Case 4: Disentanglement...")
    case_4_disentanglement()

if __name__ == "__main__":
    main()