In [None]:
#!/usr/bin/env python
# Protein Embedding Analysis for WT-Mutant Pairs

import os
import glob
import re
import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from Bio import SeqIO
from Bio.PDB import PDBParser
from scipy.spatial.distance import pdist, squareform
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE, MDS
from sklearn.metrics import pairwise_distances
from typing import Dict, List, Tuple, Optional
from transformers import AutoTokenizer, AutoModel, pipeline
from tqdm import tqdm

# Set parameters directly in the code
# Modify these variables as needed
data_dir = './pdb_mutant_pairs'  # Directory containing protein pair folders
model_name = 'facebook/esm2_t33_650M_UR50D'  # Protein language model to use
output_dir = './results'  # Directory to save results
visualize = True  # Whether to generate visualizations
device = 'cuda' if torch.cuda.is_available() else 'cpu'  # Device for model inference

# Ensure output directory exists
os.makedirs(output_dir, exist_ok=True)

# Set up logger
import logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
    handlers=[
        logging.FileHandler(os.path.join(output_dir, 'analysis.log')),
        logging.StreamHandler()
    ]
)
logger = logging.getLogger(__name__)

class ProteinPairAnalyzer:
    def __init__(self, model_name: str, device: str = 'cpu'):
        """Initialize the protein pair analyzer with a specific protein language model."""
        logger.info(f"Initializing ProteinPairAnalyzer with model: {model_name}")
        self.model_name = model_name
        self.device = device
        
        # Load model and tokenizer
        logger.info("Loading model and tokenizer...")
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModel.from_pretrained(model_name).to(device)
        logger.info("Model and tokenizer loaded successfully")
        
        # Initialize PDB parser
        self.pdb_parser = PDBParser(QUIET=True)
        
    def load_protein_pairs(self, data_dir: str) -> Dict[str, Dict[str, str]]:
        """
        Load protein pairs from the specified directory structure.
        
        Expected structure:
        data_dir/
            pair_{PROTEIN_ID_1}_{PROTEIN_ID_2}/
                PROTEIN_ID_1.pdb (or protein_id_1.pdb)
                PROTEIN_ID_2.pdb (or protein_id_2.pdb)
                PROTEIN_ID_1.fasta (or protein_id_1.fasta)
                PROTEIN_ID_2.fasta (or protein_id_2.fasta)
        
        Returns:
            Dictionary of protein pairs with their file paths
        """
        logger.info(f"Loading protein pairs from {data_dir}")
        protein_pairs = {}
        
        # Find all pair directories
        pair_dirs = glob.glob(os.path.join(data_dir, "pair_*_*"))
        logger.info(f"Found {len(pair_dirs)} protein pair directories")
        
        for pair_dir in pair_dirs:
            # Extract protein IDs from directory name
            dir_name = os.path.basename(pair_dir)
            match = re.match(r"pair_(.+)_(.+)", dir_name)
            
            if not match:
                logger.warning(f"Directory name does not match expected pattern: {dir_name}")
                continue
                
            protein_id_1, protein_id_2 = match.groups()
            
            # Define possible filenames (both uppercase and lowercase)
            possible_files = {
                "wt_pdb_upper": os.path.join(pair_dir, f"{protein_id_1}.pdb"),
                "wt_pdb_lower": os.path.join(pair_dir, f"{protein_id_1.lower()}.pdb"),
                "mut_pdb_upper": os.path.join(pair_dir, f"{protein_id_2}.pdb"),
                "mut_pdb_lower": os.path.join(pair_dir, f"{protein_id_2.lower()}.pdb"),
                "wt_fasta_upper": os.path.join(pair_dir, f"{protein_id_1}.fasta"),
                "wt_fasta_lower": os.path.join(pair_dir, f"{protein_id_1.lower()}.fasta"),
                "mut_fasta_upper": os.path.join(pair_dir, f"{protein_id_2}.fasta"),
                "mut_fasta_lower": os.path.join(pair_dir, f"{protein_id_2.lower()}.fasta")
            }
            
            # Find the actual files that exist
            file_paths = {}
            for key, filepath in possible_files.items():
                if os.path.exists(filepath):
                    if key.startswith("wt_pdb"):
                        file_paths["wt_pdb"] = filepath
                    elif key.startswith("mut_pdb"):
                        file_paths["mut_pdb"] = filepath
                    elif key.startswith("wt_fasta"):
                        file_paths["wt_fasta"] = filepath
                    elif key.startswith("mut_fasta"):
                        file_paths["mut_fasta"] = filepath
            
            # Check if we found all required files
            required_keys = ["wt_pdb", "mut_pdb", "wt_fasta", "mut_fasta"]
            all_files_exist = all(key in file_paths for key in required_keys)
            
            if all_files_exist:
                pair_id = f"{protein_id_1}_{protein_id_2}"
                protein_pairs[pair_id] = {
                    "wt_pdb": file_paths["wt_pdb"],
                    "mut_pdb": file_paths["mut_pdb"],
                    "wt_fasta": file_paths["wt_fasta"],
                    "mut_fasta": file_paths["mut_fasta"],
                    "wt_id": protein_id_1,
                    "mut_id": protein_id_2
                }
            else:
                missing_keys = [key for key in required_keys if key not in file_paths]
                logger.warning(f"Missing files for pair {protein_id_1}_{protein_id_2}: {missing_keys}")
        
        logger.info(f"Successfully loaded {len(protein_pairs)} protein pairs")
        return protein_pairs
    
    def get_sequence_from_fasta(self, fasta_path: str) -> str:
        """Extract protein sequence from a FASTA file."""
        with open(fasta_path, "r") as f:
            for record in SeqIO.parse(f, "fasta"):
                return str(record.seq)
        return ""
    
    def get_structure_from_pdb(self, pdb_path: str):
        """Load protein structure from a PDB file."""
        structure_id = os.path.basename(pdb_path).split('.')[0]
        structure = self.pdb_parser.get_structure(structure_id, pdb_path)
        return structure
    
    def extract_mutations(self, wt_seq: str, mut_seq: str) -> List[Tuple[str, int, str]]:
        """Identify mutations between wild-type and mutant sequences."""
        mutations = []
        
        # Ensure sequences are aligned (simple case: same length)
        if len(wt_seq) != len(mut_seq):
            logger.warning("Sequences have different lengths. Simple mutation detection may not be accurate.")
            
        # Find mismatches
        for i, (wt_aa, mut_aa) in enumerate(zip(wt_seq, mut_seq)):
            if wt_aa != mut_aa:
                mutations.append((wt_aa, i+1, mut_aa))  # 1-indexed position
        
        return mutations
    
    def compute_embedding(self, sequence: str) -> np.ndarray:
        """Compute protein embedding using the loaded model."""
        # Tokenize sequence
        inputs = self.tokenizer(sequence, return_tensors="pt").to(self.device)
        
        # Get model output
        with torch.no_grad():
            outputs = self.model(**inputs)
        
        # Use CLS token embedding (first token) as protein representation
        # or average all token embeddings (excluding special tokens)
        embedding = outputs.last_hidden_state[:, 1:-1].mean(dim=1).cpu().numpy()
        
        return embedding.squeeze()
    
    def analyze_pair(self, pair_data: Dict[str, str]) -> Dict:
        """Analyze a single protein pair."""
        # Load sequences
        wt_seq = self.get_sequence_from_fasta(pair_data["wt_fasta"])
        mut_seq = self.get_sequence_from_fasta(pair_data["mut_fasta"])
        
        # Load structures
        wt_structure = self.get_structure_from_pdb(pair_data["wt_pdb"])
        mut_structure = self.get_structure_from_pdb(pair_data["mut_pdb"])
        
        # Extract mutations
        mutations = self.extract_mutations(wt_seq, mut_seq)
        
        # Compute embeddings
        wt_embedding = self.compute_embedding(wt_seq)
        mut_embedding = self.compute_embedding(mut_seq)
        
        # Calculate embedding distance
        embedding_diff = wt_embedding - mut_embedding
        embedding_distance = np.linalg.norm(embedding_diff)
        
        # Calculate structural distance (RMSD) using BioPython's Superimposer
        from Bio.PDB import Superimposer
        
        # Extract CA atoms for superimposition
        wt_ca_atoms = [atom for atom in wt_structure.get_atoms() if atom.get_name() == 'CA']
        mut_ca_atoms = [atom for atom in mut_structure.get_atoms() if atom.get_name() == 'CA']
        
        # Check if we have the same number of CA atoms in both structures
        structural_distance = float('nan')
        if len(wt_ca_atoms) == len(mut_ca_atoms) and len(wt_ca_atoms) > 0:
            try:
                # Create Superimposer object
                sup = Superimposer()
                
                # Set atom lists for alignment
                sup.set_atoms(wt_ca_atoms, mut_ca_atoms)
                
                # Get RMSD
                structural_distance = sup.rms
                
                # Apply the transformation to the movable atoms
                # This step is optional but useful for visualization
                # sup.apply(mut_structure.get_atoms())
                
                logger.info(f"RMSD between {pair_data['wt_id']} and {pair_data['mut_id']}: {structural_distance:.3f} Å")
            except Exception as e:
                logger.error(f"Error calculating RMSD: {str(e)}")
        else:
            logger.warning(f"Cannot calculate RMSD: Different number of CA atoms ({len(wt_ca_atoms)} vs {len(mut_ca_atoms)})")
        
        return {
            "wt_id": pair_data["wt_id"],
            "mut_id": pair_data["mut_id"],
            "wt_embedding": wt_embedding,
            "mut_embedding": mut_embedding,
            "embedding_diff": embedding_diff,
            "embedding_distance": embedding_distance,
            "mutations": mutations,
            "num_mutations": len(mutations),
            "wt_seq_length": len(wt_seq),
            "mut_seq_length": len(mut_seq),
            "structural_distance": structural_distance
        }
    
    def analyze_all_pairs(self, protein_pairs: Dict[str, Dict[str, str]]) -> Dict[str, Dict]:
        """Analyze all protein pairs."""
        results = {}
        
        for pair_id, pair_data in tqdm(protein_pairs.items(), desc="Analyzing protein pairs"):
            logger.info(f"Analyzing pair: {pair_id}")
            try:
                results[pair_id] = self.analyze_pair(pair_data)
            except Exception as e:
                logger.error(f"Error analyzing pair {pair_id}: {str(e)}")
        
        return results
    
    def disentangle_embeddings(self, results: Dict[str, Dict]) -> Dict:
        """
        Perform disentanglement analysis on the embeddings.
        Identify directions in the embedding space that correlate with mutations.
        """
        # Collect all embedding differences
        embedding_diffs = np.stack([r["embedding_diff"] for r in results.values()])
        
        # Perform PCA on the embedding differences
        pca = PCA(n_components=min(10, len(embedding_diffs)))
        pca_result = pca.fit_transform(embedding_diffs)
        
        # Collect mutation counts and structural distances
        mutation_counts = np.array([r["num_mutations"] for r in results.values()])
        structural_distances = np.array([r["structural_distance"] for r in results.values()])
        
        # Calculate correlations between PCA components and mutation properties
        correlations = {
            "mutation_count": [],
            "structural_distance": []
        }
        
        for i in range(pca_result.shape[1]):
            corr_mut = np.corrcoef(pca_result[:, i], mutation_counts)[0, 1]
            correlations["mutation_count"].append(corr_mut)
            
            # Filter out NaN values for structural distance correlation
            valid_idx = ~np.isnan(structural_distances)
            if np.sum(valid_idx) > 1:
                corr_struct = np.corrcoef(pca_result[valid_idx, i], structural_distances[valid_idx])[0, 1]
                correlations["structural_distance"].append(corr_struct)
            else:
                correlations["structural_distance"].append(np.nan)
        
        # Perform additional analysis for interpretability
        # Identify top mutational directions
        top_mutation_component = np.argmax(np.abs(correlations["mutation_count"]))
        top_structure_component = np.argmax(np.abs(correlations["structural_distance"]))
        
        # Get top features contributing to these components
        top_mutation_loadings = pca.components_[top_mutation_component]
        top_structure_loadings = pca.components_[top_structure_component]
        
        disentanglement_results = {
            "pca_explained_variance": pca.explained_variance_ratio_,
            "pca_components": pca.components_,
            "correlations": correlations,
            "pca_result": pca_result,
            "mutation_counts": mutation_counts,
            "structural_distances": structural_distances,
            "top_mutation_component": top_mutation_component,
            "top_structure_component": top_structure_component,
            "top_mutation_loadings": top_mutation_loadings,
            "top_structure_loadings": top_structure_loadings
        }
        
        # Print some summary information for quick analysis in the notebook
        print(f"Top component correlated with mutations: PC{top_mutation_component+1} " +
              f"(r = {correlations['mutation_count'][top_mutation_component]:.3f})")
        print(f"Top component correlated with structural changes: PC{top_structure_component+1} " +
              f"(r = {correlations['structural_distance'][top_structure_component]:.3f})")
        print(f"Total variance explained by top 3 components: " +
              f"{np.sum(pca.explained_variance_ratio_[:3]):.2%}")
        
        return disentanglement_results
    
    def visualize_results(self, results: Dict[str, Dict], disentanglement: Dict, output_dir: str):
        """Generate visualizations of the analysis results."""
        os.makedirs(output_dir, exist_ok=True)
        
        # 1. Plot embedding distances vs number of mutations
        plt.figure(figsize=(10, 6))
        x = [r["num_mutations"] for r in results.values()]
        y = [r["embedding_distance"] for r in results.values()]
        plt.scatter(x, y)
        plt.xlabel("Number of Mutations")
        plt.ylabel("Embedding Distance")
        plt.title("Embedding Distance vs Number of Mutations")
        plt.grid(True, alpha=0.3)
        plt.savefig(os.path.join(output_dir, "embedding_vs_mutations.png"), dpi=300)
        plt.show()  # Added for Jupyter notebook
        
        # 2. Plot embedding distances vs structural distances
        plt.figure(figsize=(10, 6))
        x = [r["structural_distance"] for r in results.values()]
        y = [r["embedding_distance"] for r in results.values()]
        valid_idx = ~np.isnan(np.array(x))
        plt.scatter(np.array(x)[valid_idx], np.array(y)[valid_idx])
        plt.xlabel("Structural Distance (RMSD)")
        plt.ylabel("Embedding Distance")
        plt.title("Embedding Distance vs Structural Distance")
        plt.grid(True, alpha=0.3)
        plt.savefig(os.path.join(output_dir, "embedding_vs_structure.png"), dpi=300)
        plt.show()  # Added for Jupyter notebook
        
        # 3. PCA visualization of embedding differences
        plt.figure(figsize=(10, 6))
        plt.scatter(disentanglement["pca_result"][:, 0], disentanglement["pca_result"][:, 1], 
                    c=disentanglement["mutation_counts"], cmap="viridis")
        plt.colorbar(label="Number of Mutations")
        plt.xlabel(f"PC1 ({disentanglement['pca_explained_variance'][0]:.2%} variance)")
        plt.ylabel(f"PC2 ({disentanglement['pca_explained_variance'][1]:.2%} variance)")
        plt.title("PCA of Embedding Differences")
        plt.grid(True, alpha=0.3)
        plt.savefig(os.path.join(output_dir, "pca_embedding_diffs.png"), dpi=300)
        plt.show()  # Added for Jupyter notebook
        
        # 4. Correlation heatmap
        plt.figure(figsize=(12, 6))
        correlation_data = np.array([
            disentanglement["correlations"]["mutation_count"],
            disentanglement["correlations"]["structural_distance"]
        ])
        sns.heatmap(correlation_data, cmap="coolwarm", center=0, 
                    annot=True, fmt=".2f", 
                    xticklabels=[f"PC{i+1}" for i in range(correlation_data.shape[1])],
                    yticklabels=["Mutation Count", "Structural Distance"])
        plt.title("Correlations between PCA Components and Protein Properties")
        plt.tight_layout()
        plt.savefig(os.path.join(output_dir, "correlation_heatmap.png"), dpi=300)
        plt.show()  # Added for Jupyter notebook
        
        # 5. Top principal components
        top_components = 3
        fig, axes = plt.subplots(top_components, 1, figsize=(12, 8), sharex=True)
        for i in range(top_components):
            axes[i].bar(range(len(disentanglement["pca_components"][i])), 
                      disentanglement["pca_components"][i])
            axes[i].set_ylabel(f"PC{i+1} Loading")
            axes[i].set_title(f"PC{i+1} - Var: {disentanglement['pca_explained_variance'][i]:.2%}, " +
                            f"Corr(Mut): {disentanglement['correlations']['mutation_count'][i]:.2f}, " +
                            f"Corr(Struct): {disentanglement['correlations']['structural_distance'][i]:.2f}")
        
        axes[-1].set_xlabel("Embedding Dimension")
        plt.tight_layout()
        plt.savefig(os.path.join(output_dir, "top_pca_components.png"), dpi=300)
        plt.show()  # Added for Jupyter notebook
        
        # Return the figure handles for further customization in the notebook
        return {
            "mutation_vs_embedding": plt.figure(1),
            "structure_vs_embedding": plt.figure(2),
            "pca_visualization": plt.figure(3),
            "correlation_heatmap": plt.figure(4),
            "top_components": plt.figure(5)
        }
    
    def save_results(self, results: Dict[str, Dict], disentanglement: Dict, output_dir: str):
        """Save analysis results to files."""
        os.makedirs(output_dir, exist_ok=True)
        
        # Save summary table
        summary_data = []
        for pair_id, pair_result in results.items():
            summary_data.append({
                "pair_id": pair_id,
                "wt_id": pair_result["wt_id"],
                "mut_id": pair_result["mut_id"],
                "num_mutations": pair_result["num_mutations"],
                "embedding_distance": pair_result["embedding_distance"],
                "structural_distance": pair_result["structural_distance"],
                "mutations": ",".join([f"{m[0]}{m[1]}{m[2]}" for m in pair_result["mutations"]])
            })
        
        summary_df = pd.DataFrame(summary_data)
        summary_df.to_csv(os.path.join(output_dir, "summary_results.csv"), index=False)
        
        # Save PCA results
        pca_df = pd.DataFrame(disentanglement["pca_result"])
        pca_df.columns = [f"PC{i+1}" for i in range(pca_df.shape[1])]
        pca_df["pair_id"] = list(results.keys())
        pca_df["num_mutations"] = disentanglement["mutation_counts"]
        pca_df["structural_distance"] = disentanglement["structural_distances"]
        pca_df.to_csv(os.path.join(output_dir, "pca_results.csv"), index=False)
        
        # Save correlation results
        corr_df = pd.DataFrame({
            "component": [f"PC{i+1}" for i in range(len(disentanglement["correlations"]["mutation_count"]))],
            "explained_variance": disentanglement["pca_explained_variance"],
            "corr_mutation_count": disentanglement["correlations"]["mutation_count"],
            "corr_structural_distance": disentanglement["correlations"]["structural_distance"]
        })
        corr_df.to_csv(os.path.join(output_dir, "correlation_results.csv"), index=False)
        
        # Save full embeddings
        embedding_dir = os.path.join(output_dir, "embeddings")
        os.makedirs(embedding_dir, exist_ok=True)
        
        for pair_id, pair_result in results.items():
            np.save(os.path.join(embedding_dir, f"{pair_id}_wt_embedding.npy"), pair_result["wt_embedding"])
            np.save(os.path.join(embedding_dir, f"{pair_id}_mut_embedding.npy"), pair_result["mut_embedding"])
            np.save(os.path.join(embedding_dir, f"{pair_id}_embedding_diff.npy"), pair_result["embedding_diff"])

def run_analysis():
    """Function to run the analysis in a Jupyter notebook."""
    logger.info(f"Starting protein embedding analysis using model: {model_name}")
    
    # Initialize analyzer
    analyzer = ProteinPairAnalyzer(model_name, device)
    
    # Load protein pairs
    protein_pairs = analyzer.load_protein_pairs(data_dir)
    
    if not protein_pairs:
        logger.error("No valid protein pairs found. Exiting.")
        return None, None
    
    # Analyze all pairs
    results = analyzer.analyze_all_pairs(protein_pairs)
    
    # Perform disentanglement analysis
    disentanglement = analyzer.disentangle_embeddings(results)
    
    # Save results
    analyzer.save_results(results, disentanglement, output_dir)
    
    # Generate visualizations if requested
    if visualize:
        logger.info("Generating visualizations")
        analyzer.visualize_results(results, disentanglement, os.path.join(output_dir, "visualizations"))
    
    logger.info(f"Analysis completed. Results saved to {output_dir}")
    
    return results, disentanglement

# Example Jupyter notebook usage:

# Run the complete analysis pipeline
results, disentanglement = run_analysis()

# If you want to run just parts of the pipeline:
analyzer = ProteinPairAnalyzer(model_name, device)
protein_pairs = analyzer.load_protein_pairs(data_dir)

# Example: Analyze just one specific pair
# first_pair_id = list(protein_pairs.keys())[0]
# pair_result = analyzer.analyze_pair(protein_pairs[first_pair_id])

# Example: Additional visualization of a specific result
if results and disentanglement:
    # Get top mutations sorted by embedding distance
    sorted_pairs = sorted(results.items(), key=lambda x: x[1]['embedding_distance'], reverse=True)
    
    print("Top 5 pairs with largest embedding distance:")
    for pair_id, pair_data in sorted_pairs[:5]:
        print(f"Pair {pair_id}: {pair_data['num_mutations']} mutations, " + 
              f"Embedding distance: {pair_data['embedding_distance']:.4f}, " +
              f"Structural distance: {pair_data['structural_distance']:.4f}")
    
    # Plot mutation patterns
    mutation_counts = {}
    for pair_data in results.values():
        for mut in pair_data['mutations']:
            key = f"{mut[0]}->{mut[2]}"  # e.g., "A->G"
            mutation_counts[key] = mutation_counts.get(key, 0) + 1
    
    # Create mutation type distribution plot
    if mutation_counts:
        plt.figure(figsize=(12, 6))
        sorted_mutations = sorted(mutation_counts.items(), key=lambda x: x[1], reverse=True)
        labels, values = zip(*sorted_mutations)
        plt.bar(labels, values)
        plt.xlabel('Mutation Type')
        plt.ylabel('Frequency')
        plt.title('Distribution of Mutation Types Across All Protein Pairs')
        plt.xticks(rotation=45)
        plt.tight_layout()
        plt.show()


2025-04-29 01:42:06,138 - __main__ - INFO - Starting protein embedding analysis using model: facebook/esm2_t33_650M_UR50D
2025-04-29 01:42:06,434 - __main__ - INFO - Initializing ProteinPairAnalyzer with model: facebook/esm2_t33_650M_UR50D
2025-04-29 01:42:06,671 - __main__ - INFO - Loading model and tokenizer...
Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm2_t33_650M_UR50D and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
2025-04-29 01:42:07,464 - __main__ - INFO - Model and tokenizer loaded successfully
2025-04-29 01:42:07,466 - __main__ - INFO - Loading protein pairs from ./pdb_mutant_pairs
2025-04-29 01:42:07,468 - __main__ - INFO - Found 150 protein pair directories
2025-04-29 01:42:07,617 - __main__ - INFO - Successfully loaded 139 protein pairs
Analyzing protein pairs:   0%|          | 0/139 [00:00<?, ?it/s