# Feature Engineering Pipeline

This notebook demonstrates the complete feature engineering pipeline for protein interactor prediction, including:

1. **UniProt Data Extraction**: Subcellular locations, GO terms, protein domains, and PTMs
2. **Ensembl Mapping**: Converting UniProt IDs to Ensembl gene IDs for HPA data
3. **HPA Integration**: Brain tissue expression data from Human Protein Atlas (in this example, but you can change this)
4. **ESM-2 Embeddings**: 1280-dimensional protein sequence embeddings
5. **Multi-hot Encoding**: Converting categorical features to ML-ready format

The pipeline processes protein IDs and generates comprehensive feature matrices suitable for machine learning.

In [None]:
# Import necessary libraries
import pandas as pd
import numpy as np
import requests
from tqdm.notebook import tqdm
import time
import json
from bs4 import BeautifulSoup
import torch
from transformers import AutoTokenizer, AutoModel
from sklearn.preprocessing import MultiLabelBinarizer
import os

print("Libraries imported successfully!")

## 1. UniProt Data Extraction

The `get_uniprot_data()` function extracts comprehensive protein annotations from UniProt, including:
- **Subcellular locations** from comments, features, keywords, and GO terms
- **GO terms** for functional annotations
- **Protein domains** from InterPro cross-references
- **Post-translational modifications (PTMs)** from features and keywords

In [None]:
def robust_get(url, headers=None, max_retries=5, sleep_time=1):
    """Robust HTTP GET with retry logic for API calls."""
    for attempt in range(max_retries):
        try:
            response = requests.get(url, headers=headers)
            if response.status_code == 200:
                return response
            elif response.status_code == 429:
                print(f"Rate limited! Sleeping and retrying... (attempt {attempt+1})")
                time.sleep(5 * (attempt + 1))
            else:
                print(f"Status {response.status_code} for {url}")
                time.sleep(sleep_time)
        except Exception as e:
            print(f"Error on attempt {attempt+1} for {url}: {e}")
            time.sleep(sleep_time)
    print(f"Failed to get {url} after {max_retries} attempts.")
    return None

def get_uniprot_data(protein_id):
    """Extract comprehensive protein annotations from UniProt."""
    accession = protein_id
    if '_HUMAN' in protein_id:
        url = f"https://rest.uniprot.org/uniprotkb/search?query={protein_id}&fields=accession&format=json"
        r = robust_get(url)
        if r:
            data = r.json()
            if data.get('results'):
                accession = data['results'][0]['primaryAccession']
            else:
                return {'locations': [], 'go_terms': [], 'domains': [], 'ptms': []}
        else:
            return {'locations': [], 'go_terms': [], 'domains': [], 'ptms': []}
    
    base_url = f"https://rest.uniprot.org/uniprotkb/{accession}.json"
    response = robust_get(base_url)
    if response:
        data = response.json()
        
        # --- Subcellular locations ---
        locations = set()
        # From comments
        for comment in data.get("comments", []):
            if comment.get("commentType") == "SUBCELLULAR LOCATION":
                for loc in comment.get("subcellularLocations", []):
                    if "location" in loc:
                        locations.add(loc["location"].get("value"))
        # From features
        for feature in data.get("features", []):
            if feature.get("type", "").upper() == "SUBCELLULAR LOCATION":
                desc = feature.get("description")
                if desc:
                    locations.add(desc)
        # From keywords
        for kw in data.get("keywords", []):
            val = kw.get("value", "").lower()
            if (
                "subcellular location" in val or
                "membrane" in val or
                "cytoplasm" in val or
                "nucleus" in val or
                "mitochondrion" in val
            ):
                locations.add(kw.get("value"))
        # From GO terms: add GO:CC (Cellular Component) terms to locations
        for xref in data.get("uniProtKBCrossReferences", []):
            if xref.get("database") == "GO":
                properties = xref.get("properties", [])
                for prop in properties:
                    if prop.get("term") == "C":  # Cellular Component
                        go_loc = prop.get("value")
                        if go_loc:
                            locations.add(go_loc)
        
        # --- GO terms ---
        go_terms = []
        for xref in data.get("uniProtKBCrossReferences", []):
            if xref.get("database") == "GO":
                term = xref.get("properties", [{}])[0].get("value", "").lower()
                go_terms.append(term)
        
        # --- Domains ---
        domains = []
        for xref in data.get("uniProtKBCrossReferences", []):
            if xref.get("database") == "InterPro":
                domains.append(xref.get("id"))
        
        # --- PTMs (expanded logic) ---
        ptms = set()
        # From features: include any feature with PTM-related type or description
        for feature in data.get("features", []):
            ftype = feature.get("type", "").upper()
            desc = feature.get("description", "")
            # If the type or description suggests a PTM, include it
            if (
                "PTM" in ftype or
                "MOD" in ftype or
                "MOD_RES" in ftype or
                "CARBOHYD" in ftype or
                "LIPID" in ftype or
                "GLYCOSYLATION" in ftype or
                "PHOSPHO" in ftype or
                "UBIQUITIN" in ftype or
                "SUMO" in ftype or
                "ACETYL" in ftype or
                "METHYL" in ftype or
                "DISULFID" in ftype or
                "CROSSLNK" in ftype or
                "GPI_ANCHOR" in ftype or
                "PEPTIDE" in ftype or
                "PROPEP" in ftype or
                "SIGNAL" in ftype or
                "TRANSIT" in ftype or
                "CHAIN" in ftype or
                "SE_CYS" in ftype or
                "NON_STD" in ftype or
                "VARIANT" in ftype or
                "VAR_SEQ" in ftype or
                "MUTAGEN" in ftype or
                "TOPO_DOM" in ftype or
                "ZN_FING" in ftype or
                "COILED" in ftype or
                "COMPBIAS" in ftype or
                "REGION" in ftype or
                "REPEAT" in ftype or
                "MOTIF" in ftype or
                "BINDING" in ftype or
                "SITE" in ftype or
                "METAL" in ftype or
                "NP_BIND" in ftype or
                "DNA_BIND" in ftype or
                "CA_BIND" in ftype or
                # Also check description for PTM-related words
                "ptm" in desc.lower() or
                "modification" in desc.lower() or
                "glycosylation" in desc.lower() or
                "phospho" in desc.lower() or
                "ubiquitin" in desc.lower() or
                "sumo" in desc.lower() or
                "acetyl" in desc.lower() or
                "methyl" in desc.lower() or
                "disulfide" in desc.lower() or
                "cross-link" in desc.lower()
            ):
                if desc:
                    ptms.add(desc)
        # From keywords
        for kw in data.get("keywords", []):
            val = kw.get("value", "").lower()
            if (
                "ptm" in val or
                "modification" in val or
                "glycosylation" in val or
                "phospho" in val or
                "ubiquitin" in val or
                "sumoylation" in val or
                "acetylation" in val or
                "methylation" in val or
                "disulfide" in val or
                "cross-link" in val
            ):
                ptms.add(kw.get("value"))
        
        return {
            'locations': list(locations),
            'go_terms': list(set(go_terms)),
            'domains': list(set(domains)),
            'ptms': list(ptms)
        }
    else:
        print(f"Failed to fetch UniProt data for {protein_id}")
        return {'locations': [], 'go_terms': [], 'domains': [], 'ptms': []}

In [None]:
def manual_map_uniprot_to_ensembl(uniprot_id):
    """Maximally robust mapping: try all strategies with both original and cleaned IDs."""
    def clean_protein_id(protein_id):
        protein_id = str(protein_id)
        base_id = protein_id.split('-')[0]
        base_id = base_id.replace('_HUMAN', '')
        return base_id

    # Helper for transcript-to-gene mapping
    def get_ensembl_id_from_transcript(transcript_id):
        url = f"https://rest.ensembl.org/lookup/id/{transcript_id}?"
        headers = {"Content-Type": "application/json"}
        response = robust_get(url, headers=headers)
        if response:
            data = response.json()
            return data.get('Parent')
        return None

    # Helper for gene name lookup
    def get_ensembl_id_from_gene_name(gene_name):
        if not gene_name:
            return None
        url = f"https://rest.ensembl.org/xrefs/symbol/homo_sapiens/{gene_name}?"
        headers = {"Content-Type": "application/json"}
        response = robust_get(url, headers=headers)
        if response:
            results = response.json()
            if results and len(results) > 0:
                for result in results:
                    gene_id = result.get('id')
                    if gene_id and gene_id.startswith('ENSG'):
                        return gene_id
        return None

    for test_id in [uniprot_id, clean_protein_id(uniprot_id)]:
        print(f"\nProcessing {test_id}")
        uniprot_url = f"https://rest.uniprot.org/uniprotkb/{test_id}.json"
        try:
            response = requests.get(uniprot_url)
            if response.status_code == 200:
                data = response.json()
                # Method 1: Direct Ensembl cross-reference
                for xref in data.get("uniProtKBCrossReferences", []):
                    if xref.get("database") == "Ensembl":
                        for prop in xref.get("properties", []):
                            if prop.get("key") == "GeneId":
                                gene_id = prop.get("value")
                                if gene_id and gene_id.startswith('ENSG'):
                                    print(f"Found Ensembl gene ID (xref): {gene_id}")
                                    return gene_id
                # Method 2: Transcript-to-gene mapping
                for xref in data.get("uniProtKBCrossReferences", []):
                    if xref.get("database") == "Ensembl":
                        transcript_id = xref.get("id")
                        if transcript_id and transcript_id.startswith("ENST"):
                            gene_id = get_ensembl_id_from_transcript(transcript_id)
                            if gene_id and gene_id.startswith('ENSG'):
                                print(f"Found Ensembl gene ID (transcript): {gene_id}")
                                return gene_id
                # Method 3: Gene name lookup
                genes = data.get('genes', [])
                if genes and genes[0].get('geneName'):
                    gene_name = genes[0]['geneName'].get('value')
                    print(f"Found gene name: {gene_name}")
                    gene_id = get_ensembl_id_from_gene_name(gene_name)
                    if gene_id and gene_id.startswith('ENSG'):
                        print(f"Found Ensembl gene ID (gene name): {gene_id}")
                        return gene_id
        except Exception as e:
            print(f"Error mapping {test_id} to Ensembl: {str(e)}")
    print(f"No Ensembl Gene ID found for {uniprot_id} after all mapping strategies")
    return None

## 3. Human Protein Atlas Integration

The `get_hpa_brain_ntpm_from_xml()` function extracts brain tissue expression data from HPA XML files, including:
- nTPM values for all brain regions (nTPM > 1)
- Brain expression summary

In [None]:
def get_hpa_brain_ntpm_from_xml(ensembl_id, sleep_time=1):
    """Fetches all human brain region nTPM values (nTPM > 1) and the summary field from HPA XML."""
    core_ensembl_id = ensembl_id.split('.')[0]  # Always use the core Ensembl ID (no version) for HPA queries!
    
    url = f"https://www.proteinatlas.org/{core_ensembl_id}.xml"
    r = requests.get(url)
    if r.status_code != 200:
        print(f"Failed to fetch XML for {ensembl_id}")
        return {}, None
    
    soup = BeautifulSoup(r.content, "xml")
    ntpm_dict = {}
    summary = None
    
    # Find the <rnaExpression> block for human brain
    for rna_expr in soup.find_all("rnaExpression"):
        if rna_expr.get("assayType") == "humanBrain":
            # Get all <data> blocks (one per region)
            for data in rna_expr.find_all("data"):
                tissue_tag = data.find("tissue")
                if tissue_tag:
                    region = tissue_tag.text.strip()
                    # Find the nTPM value
                    for level in data.find_all("level"):
                        if level.get("type") == "normalizedRNAExpression" and level.get("unitRNA") == "nTPM":
                            try:
                                ntpm_val = float(level.get("expRNA"))
                                if ntpm_val > 1:
                                    ntpm_dict[region] = ntpm_val
                            except Exception:
                                continue
            # Get the summary field if present
            summary_tag = rna_expr.find("rnaDistribution")
            if summary_tag:
                summary = summary_tag.text.strip()
            break  # Only one humanBrain block
    
    time.sleep(sleep_time)
    return ntpm_dict, summary

## 4. ESM-2 Protein Language Model

The ESM-2 model generates 1280-dimensional embeddings

In [None]:
# Load ESM2 model and tokenizer
model_name = "facebook/esm2_t33_650M_UR50D"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name)

# Move model to GPU if available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)

def get_sequence_embedding(sequence, tokenizer, model):
    """Generate ESM-2 embedding for a protein sequence."""
    # Tokenize sequence
    inputs = tokenizer(sequence, return_tensors="pt", padding=True)
    inputs = {k: v.to(device) for k, v in inputs.items()}

    # Get embeddings
    with torch.no_grad():
        outputs = model(**inputs)

    # Get mean of last hidden states (excluding padding tokens)
    last_hidden_states = outputs.last_hidden_state
    attention_mask = inputs['attention_mask']

    # Calculate mean embedding (excluding padding tokens)
    mean_embedding = (last_hidden_states * attention_mask.unsqueeze(-1)).sum(dim=1) / attention_mask.sum(dim=1, keepdim=True)
    return mean_embedding.cpu().numpy()[0]

print(f"ESM-2 model loaded on {device}")
print(f"Embedding dimension: 1280")

## 5. Multi-hot Encoding Pipeline

Conversion into binary ML features 

In [None]:
def multi_hot(df, col):
    """Convert list columns to multi-hot encoded binary features."""
    mlb = MultiLabelBinarizer()
    values = df[col].apply(lambda x: x if isinstance(x, list) else [])
    expanded = pd.DataFrame(mlb.fit_transform(values), columns=[f"{col}_{c}" for c in mlb.classes_], index=df.index)
    return expanded

def process_features_to_ml_format(features_df):
    """Convert JSON features to ML-ready format with multi-hot encoding."""
    print("Converting to DataFrame...")
    features_df = pd.DataFrame.from_dict(features_df, orient='index')
    features_df.index.name = 'protein_id'
    features_df = features_df.reset_index()
    print("Created features_df:", features_df.shape)

    print("Flattening uniprot_data and hpa_data...")
    uniprot_df = pd.json_normalize(features_df['uniprot_data'])
    uniprot_df.columns = [f'uniprot_{col}' for col in uniprot_df.columns]
    hpa_df = pd.json_normalize(features_df['hpa_data'])
    hpa_df.columns = [f'hpa_{col}' for col in hpa_df.columns]
    flat_df = pd.concat([features_df[['protein_id', 'ensembl_id', 'status']], uniprot_df, hpa_df], axis=1)
    print("Flattened DataFrame:", flat_df.shape)

    # --- Multi-hot encode list features ---
    multi_hot_cols = ['uniprot_locations', 'uniprot_go_terms', 'uniprot_domains', 'uniprot_ptms']
    multi_hot_dfs = []
    for col in multi_hot_cols:
        if col in flat_df.columns:
            print(f"Multi-hot encoding {col}...")
            mh = multi_hot(flat_df, col)
            print(f"  {col}: {mh.shape[1]} columns")
            multi_hot_dfs.append(mh)
            flat_df = flat_df.drop(columns=[col])

    if multi_hot_dfs:
        flat_df = pd.concat([flat_df] + multi_hot_dfs, axis=1)
    print("After multi-hot encoding:", flat_df.shape)

    # --- Rename nTPM columns to match training data format ---
    print("Renaming nTPM columns...")
    ntpm_columns = {}
    for col in flat_df.columns:
        if col.startswith('hpa_all_tissue_expression.'):
            # Extract region name
            region = col.split('.', 1)[1]
            # Convert to training data format (spaces, parentheses, etc.)
            region = region.replace('_', ' ').replace('  ', ' ')  # Handle double spaces
            # Add log2_nTPM prefix
            new_col = f'log2_nTPM_{region}'
            ntpm_columns[col] = new_col

    flat_df = flat_df.rename(columns=ntpm_columns)
    print("After renaming nTPM columns:", flat_df.shape)

    # --- Apply log2 transformation to nTPM values ---
    print("Applying log2 transformation to nTPM values...")
    log2_ntpm_cols = [col for col in flat_df.columns if col.startswith('log2_nTPM_')]
    for col in log2_ntpm_cols:
        # Apply log2 transformation, handle NaN and zeros
        flat_df[col] = flat_df[col].apply(lambda x: np.log2(x) if pd.notna(x) and x > 0 else 0)

    print("After log2 transformation:", flat_df.shape)

    # --- Binary encode brain_expressed robustly ---
    if 'hpa_brain_expressed' in flat_df.columns:
        flat_df['hpa_brain_expressed'] = flat_df['hpa_brain_expressed'].fillna(False).astype(int)
        print("Binary encoded hpa_brain_expressed.")

    print("Final DataFrame shape:", flat_df.shape)
    print("Memory usage (MB):", flat_df.memory_usage(deep=True).sum() / 1e6)
    
    return flat_df

In [None]:
def complete_feature_engineering_pipeline(protein_ids, batch_size=1000, checkpoint_interval=100):
    """Complete feature engineering pipeline with batch processing and checkpoints."""
    print(f"Starting feature engineering for {len(protein_ids)} proteins...")
    
    # Convert to DataFrame
    df = pd.DataFrame({'protein_id': list(set(protein_ids))})
    print(f"Loaded DataFrame with {len(df)} proteins")
    
    results = {}
    num_proteins = len(df)
    
    for batch_start in range(0, num_proteins, batch_size):
        batch_end = min(batch_start + batch_size, num_proteins)
        batch_df = df.iloc[batch_start:batch_end]
        print(f"\nProcessing batch {batch_start} to {batch_end-1}...")
        
        # Define checkpoint path
        batch_checkpoint_path = f"checkpoint_{batch_start}_{batch_end-1}.json"
        if os.path.exists(batch_checkpoint_path):
            print(f"Checkpoint for batch {batch_start}-{batch_end-1} exists, loading...")
            with open(batch_checkpoint_path, 'r') as f:
                batch_results = json.load(f)
        else:
            batch_results = {}
        
        checkpoint_counter = 0
        
        try:
            for i, protein_id in enumerate(tqdm(batch_df['protein_id'])):
                if str(protein_id) in batch_results:
                    continue  # Skip already processed
                
                print(f"Processing {batch_start + i}: {protein_id}")
                try:
                    # Get UniProt data
                    uniprot_data = get_uniprot_data(protein_id)
                    
                    # Get Ensembl mapping
                    ensembl_id = manual_map_uniprot_to_ensembl(protein_id)
                    
                    # Get HPA data
                    ntpm_dict, brain_summary = {}, None
                    if ensembl_id:
                        try:
                            ntpm_dict, brain_summary = get_hpa_brain_ntpm_from_xml(ensembl_id)
                        except Exception as hpa_e:
                            print(f"Error extracting HPA XML for {protein_id}: {hpa_e}")
                    
                    # Determine brain expression
                    brain_expressed = (len(ntpm_dict) > 0) or (
                        brain_summary and (
                            brain_summary.lower().startswith("detected in all") or
                            brain_summary.lower().startswith("detected in many")
                        )
                    )
                    
                    hpa_data = {
                        'brain_expressed': brain_expressed,
                        'all_tissue_expression': ntpm_dict,
                        'brain_expression_summary': brain_summary
                    }
                    
                    if ensembl_id:
                        status = 'mapped'
                    else:
                        status = 'no_ensembl_mapping'
                    
                    batch_results[str(protein_id)] = {
                        'ensembl_id': ensembl_id,
                        'hpa_data': hpa_data,
                        'uniprot_data': uniprot_data,
                        'status': status
                    }
                    
                except Exception as inner_e:
                    print(f"Error processing {protein_id}: {inner_e}")
                    batch_results[str(protein_id)] = {
                        'ensembl_id': None,
                        'hpa_data': None,
                        'uniprot_data': None,
                        'status': f'error: {inner_e}'
                    }
                
                checkpoint_counter += 1
                if checkpoint_counter % checkpoint_interval == 0:
                    print(f"\nSaving checkpoint for batch {batch_start}-{batch_end-1} after {checkpoint_counter} proteins...")
                    with open(batch_checkpoint_path, 'w') as f:
                        json.dump(batch_results, f, indent=2)
                    print("Checkpoint saved!")
                    
        except Exception as e:
            print(f"\nError encountered in batch {batch_start}-{batch_end-1}: {str(e)}")
            print("Saving progress before exit...")
            with open(batch_checkpoint_path, 'w') as f:
                json.dump(batch_results, f, indent=2)
            print("Progress saved!")
            raise e
        
        # Save batch results
        with open(batch_checkpoint_path, 'w') as f:
            json.dump(batch_results, f, indent=2)
        print(f"Batch {batch_start}-{batch_end-1} complete and saved.")
        
        # Merge into overall results
        results.update(batch_results)
    
    return results

# Example usage (commented out to avoid running)
# protein_ids = ['P12345', 'Q67890', 'P11111']  # Example protein IDs
# results = complete_feature_engineering_pipeline(protein_ids)
# ml_ready_df = process_features_to_ml_format(results)

## Summary


The pipeline can process thousands of proteins efficiently and generates features suitable for machine learning models.