In [7]:
import os
import torch
import esm
import numpy as np
import pandas as pd
from pathlib import Path
from sklearn.decomposition import PCA
from sklearn.metrics.pairwise import cosine_distances, euclidean_distances

# Paths
df_path = Path("pdb_mutant_pairs/pairs_analysis_summary.csv")
root_dir = Path("pdb_mutant_pairs")

# Load metadata
meta = pd.read_csv(df_path)

# Gather sequence pairs
import re

def load_sequence(pair_name):
    pair_dir = root_dir / pair_name
    fasta_files = list(pair_dir.glob("*.fasta"))
    if len(fasta_files) != 2:
        raise ValueError(f"Expected 2 FASTA files in {pair_dir}, found {len(fasta_files)}")
    seqs = []
    for fasta in fasta_files:
        with open(fasta) as f:
            lines = [l.strip() for l in f if not l.startswith('>')]
        seq = ''.join(lines)
        # Uppercase and filter to valid amino acids
        seq = seq.upper()
        seq = re.sub(r'[^ACDEFGHIKLMNPQRSTVWY]', '', seq)
        if not seq:
            raise ValueError(f"No valid amino acids in sequence from {fasta}")
        seqs.append(seq)
    if len(seqs[0]) != len(seqs[1]):
        raise ValueError(f"Sequence length mismatch in {pair_name}")
    return seqs

pairs = []
for _, row in meta.iterrows():
    name = row['pair']
    try:
        seq1, seq2 = load_sequence(name)
    except Exception:
        continue
    pairs.append({
        'pair': name,
        'seq1': seq1,
        'seq2': seq2,
        'rmsd': row['rmsd'],
        'mutations': row['mutations'],
        'identity_pct': row['identity_pct']
    })

# Load ESM-2 model
device = torch.device("mps" if torch.mps.is_available() else "cpu")
model, alphabet = esm.pretrained.esm2_t33_650M_UR50D()
model = model.to(device).eval()
batch_converter = alphabet.get_batch_converter()

# Embed unique sequences
unique_seqs = list({p['seq1'] for p in pairs} | {p['seq2'] for p in pairs})
embeddings = {}
batch_size = 32
for i in range(0, len(unique_seqs), batch_size):
    batch = unique_seqs[i:i+batch_size]
    # Prepare batch for conversion: (label, seq)
    raw_batch = [(str(idx), seq) for idx, seq in enumerate(batch)]
    labels, seq_strs, tokens = batch_converter(raw_batch)
    with torch.no_grad():
        results = model(tokens.to(device), repr_layers=[model.num_layers], return_contacts=False)
    reps = results['representations'][model.num_layers]
    for j, seq in enumerate(batch):
        length = len(seq)
        emb = reps[j, 1: length+1].mean(0).cpu().numpy()
        embeddings[seq] = emb

# Compute embedding distances
cosines, euclids = [], []
for p in pairs:
    print(f"Processing pair {p['pair']}")
    e1, e2 = embeddings[p['seq1']], embeddings[p['seq2']]
    cosines.append(cosine_distances([e1], [e2])[0,0])
    euclids.append(euclidean_distances([e1], [e2])[0,0])
meta['embed_cosine'] = cosines
meta['embed_euclid'] = euclids

# Correlation analysis
corrs = {
    'cosine_vs_rmsd': np.corrcoef(meta['embed_cosine'], meta['rmsd'])[0,1],
    'cosine_vs_mut': np.corrcoef(meta['embed_cosine'], meta['mutations'])[0,1],
    'euclid_vs_rmsd': np.corrcoef(meta['embed_euclid'], meta['rmsd'])[0,1],
    'euclid_vs_mut': np.corrcoef(meta['embed_euclid'], meta['mutations'])[0,1]
}
print("Embedding-to-metric correlations:", corrs)

# PCA disentanglement
diffs = np.stack([embeddings[p['seq1']] - embeddings[p['seq2']] for p in pairs])
pca = PCA(n_components=5)
pcs = pca.fit_transform(diffs)
results = []
for idx in range(pcs.shape[1]):
    pc = pcs[:, idx]
    results.append({
        'PC': idx+1,
        'explained_variance': pca.explained_variance_ratio_[idx],
        'corr_rmsd': np.corrcoef(pc, meta['rmsd'])[0,1],
        'corr_mut': np.corrcoef(pc, meta['mutations'])[0,1]
    })
df_pca = pd.DataFrame(results)
print("\nPCA Disentanglement:")
print(df_pca)

# Save outputs
meta.to_csv(root_dir / 'pairs_embedding_analysis.csv', index=False)
df_pca.to_csv(root_dir / 'pairs_pca_disentanglement.csv', index=False)
print("Results saved to CSV.")


: 

In [6]:
import torch
torch.mps.is_available()

True