# ProtSpace: Interactive protein embedding visualization
### About ProtSpace ([github](https://github.com/tsenoner/protspace))
ProtSpace is a tool for interactive visualization of protein embeddings that:
- Converts high-dimensional protein embeddings into 2D/3D visualizations
- Supports multiple dimension reduction methods (PCA, UMAP, t-SNE, PaCMAP)
- Allows annotation-based coloring and shaping of data points
- Integrates protein structure visualization alongside embedding space
- Enables publication-quality exports and sharing of visualization sessions

### Basic Workflow:
1. Upload protein embeddings (H5 file)
2. Upload feature annotations (CSV file)
3. Choose visualization methods
4. Explore your protein space interactively

In [None]:
#@title Install Dependencies and Import Libraries (~2min)
%%capture
!pip install -q protspace[frontend]

import os
import requests
import sys
from google.colab import files
from IPython.display import clear_output
from pathlib import Path
from urllib.parse import urlparse

import h5py
import pandas as pd

from protspace.app import ProtSpace

## Data preparation

In [None]:
#@title 📤 Data upload and file analysis system
#@markdown Choose your data source by selecting one of the following options:

class DataAnalysisSystem:
    def __init__(self):
        self.h5_allowed_extensions = {'.h5', '.hdf5', '.hdf'}
        self.csv_allowed_extensions = {'.csv'}

    def validate_file_extension(self, filename, allowed_extensions):
        """Check if file has an allowed extension."""
        ext = os.path.splitext(filename)[1].lower()
        if ext not in allowed_extensions:
            print(f"❌ Invalid file type. Please upload one of: {', '.join(allowed_extensions)}")
            return False
        return True

    def save_file(self, content, filename):
        """Save file to current directory."""
        with open(filename, 'wb') as f:
            f.write(content)
        return filename

    def handle_file_upload(self, file_type, allowed_extensions):
        """Upload and validate a file, saving it locally."""
        print(f"\n=== Upload {file_type} ===")
        print(f"Allowed extensions: {', '.join(allowed_extensions)}")

        uploaded = files.upload()
        if not uploaded:
            return None, None

        filename = list(uploaded.keys())[0]
        content = uploaded[filename]

        if not self.validate_file_extension(filename, allowed_extensions):
            return None, None

        # Save the uploaded file locally
        filepath = self.save_file(content, filename)
        print(f"✅ File '{filename}' uploaded and saved")
        return filepath, filename

    def get_github_filename(self, url):
        """Extract original filename from GitHub URL."""
        return os.path.basename(urlparse(url).path)

    def download_and_save_github_file(self, url):
        """Download a file from GitHub and save it with original filename."""
        # Get original filename from GitHub URL
        filename = self.get_github_filename(url)

        # Download the file
        raw_url = url.replace('github.com', 'raw.githubusercontent.com').replace('/blob/', '/')
        response = requests.get(raw_url, timeout=30)
        response.raise_for_status()

        # Save with original filename
        filepath = self.save_file(response.content, filename)
        return filepath

    def analyze_data_structure(self, h5_filepath, csv_filepath):
        """Analyze both files and show their structure using local file paths."""
        # Read CSV file directly from disk
        df = pd.read_csv(csv_filepath)

        # Read H5 file directly from disk
        with h5py.File(h5_filepath, 'r') as hdf:
            keys = list(hdf.keys())
            first_embedding = hdf[keys[0]]

            h5_ids = set(keys)
            csv_ids = set(df['identifier'])
            matching_ids = h5_ids.intersection(csv_ids)

            print("\n=== Data Analysis Results ===")
            print(f"\nCSV File ({csv_filepath}):")
            print(f"- Entries: {len(df):,}")
            print(f"- Columns: {', '.join(df.columns)}")

            print(f"\nH5 File ({h5_filepath}):")
            print(f"- Embeddings: {len(keys):,}")
            print(f"- Dimensions: {first_embedding.shape}")
            print(f"- Data type: {first_embedding.dtype}")

            print(f"\nAlignment:")
            print(f"- Matching entries: {len(matching_ids):,}")
            print(f"- Only in CSV: {len(csv_ids - h5_ids):,}")
            print(f"- Only in H5: {len(h5_ids - csv_ids):,}")

data_source = "Use example data from GitHub" #@param ["Use example data from GitHub", "Upload my own files"]

try:
    analyzer = DataAnalysisSystem()

    if data_source == "Use example data from GitHub":
        print("Downloading example files...")

        # GitHub URLs with original filenames
        h5_url = "https://github.com/tsenoner/protspace/blob/main/data/Pla2g2/embs/esm2_3b.h5"
        csv_url = "https://github.com/tsenoner/protspace/blob/main/data/Pla2g2/Pla2g2.csv"

        h5_filepath = analyzer.download_and_save_github_file(h5_url)
        csv_filepath = analyzer.download_and_save_github_file(csv_url)
        print("✅ Downloads complete and saved locally")
    else:
        # Handle H5 file upload
        h5_filepath, h5_filename = None, None
        while h5_filepath is None:
            h5_filepath, h5_filename = analyzer.handle_file_upload("Embedding File", analyzer.h5_allowed_extensions)

        # Handle CSV file upload
        csv_filepath, csv_filename = None, None
        while csv_filepath is None:
            csv_filepath, csv_filename = analyzer.handle_file_upload("Feature File", analyzer.csv_allowed_extensions)

    # Analyze the files using their local paths
    analyzer.analyze_data_structure(h5_filepath, csv_filepath)

except Exception as e:
    print(f"\n❌ Error: {str(e)}")
    print("Please try again or write a GitHub issue if the problem persists.")

In [None]:
#@title 📂 File Paths
#@markdown ### Enter paths to your files

#@markdown #### Path to embedding file (.h5/.hdf5/.hdf):
embedding_file = "" #@param {type:"string"}

#@markdown #### Path to feature file (.csv):
feature_file = "" #@param {type:"string"}

# Validate files
valid_files = True
if not embedding_file:
    print("⚠️ Please enter path to embedding file")
    valid_files = False
elif not os.path.exists(embedding_file):
    print(f"⚠️ Embedding file not found: {embedding_file}")
    valid_files = False
elif not any(embedding_file.endswith(ext) for ext in ['.h5', '.hdf5', '.hdf']):
    print(f"⚠️ Embedding file must be .h5, .hdf5, or .hdf format")
    valid_files = False

if not feature_file:
    print("⚠️ Please enter path to feature file")
    valid_files = False
elif not os.path.exists(feature_file):
    print(f"⚠️ Feature file not found: {feature_file}")
    valid_files = False
elif not feature_file.endswith('.csv'):
    print(f"⚠️ Feature file must be .csv format")
    valid_files = False

if valid_files:
    print("✅ Files validated successfully")
    output_file = str(Path(embedding_file).with_suffix('.json'))

In [None]:
#@title 🔧 Configure Visualization
#@markdown ### Choose dimension reduction methods:
#@markdown - **PCA**: Fast, linear reduction
#@markdown - **UMAP**: Preserves global and local structure
#@markdown - **t-SNE**: Focuses on local structure
#@markdown - **PaCMAP**: Balances global and local structure

use_pca = True #@param {type:"boolean"}
use_umap = False #@param {type:"boolean"}
use_tsne = False #@param {type:"boolean"}
use_pacmap = False #@param {type:"boolean"}

#@markdown ### Choose dimensions:
dimensions = "2D only" #@param ["2D only", "3D only", "2D and 3D"]

# Build methods string
methods = []

def add_method(use_method, method_name):
    if use_method:
        if dimensions == "2D only":
            methods.append(f"{method_name}2")
        elif dimensions == "3D only":
            methods.append(f"{method_name}3")
        else:
            methods.extend([f"{method_name}2", f"{method_name}3"])

add_method(use_pca, "pca")
add_method(use_umap, "umap")
add_method(use_tsne, "tsne")
add_method(use_pacmap, "pacmap")

methods_str = " ".join(methods)
params_str = ""

## Advanced settings (optional)

In [None]:
#@title ## Advanced Parameters

#@markdown #### UMAP Parameters:
if use_umap:
    #@markdown - Number of neighbors influences locality preservation
    umap_n_neighbors = 50 #@param {type:"slider", min:2, max:200, step:1}
    #@markdown - Minimum distance between points
    umap_min_dist = 0.5 #@param {type:"slider", min:0.0, max:1.0, step:0.01}
    #@markdown - Distance metric
    umap_metric = "euclidean" #@param ["euclidean", "cosine"]

# @markdown ---
#@markdown #### t-SNE Parameters:
if use_tsne:
    #@markdown - Perplexity balances local and global structure
    tsne_perplexity = 30 #@param {type:"slider", min:5, max:100, step:5}
    #@markdown - Learning rate influences optimization
    tsne_learning_rate = 200 #@param {type:"number"}

# @markdown ---
#@markdown #### PaCMAP Parameters:
if use_pacmap:
    #@markdown - Number of neighbors
    pacmap_n_neighbors = 25 #@param {type:"slider", min:2, max:100, step:1}
    #@markdown - MN ratio (Mid-Near pairs ratio): Controls local structure preservation
    #@markdown   - Higher values (→1.0): Better preserves local structure
    #@markdown   - Lower values (→0.1): Allows more global structure influence
    pacmap_mn_ratio = 0.5 #@param {type:"slider", min:0.1, max:1.0, step:0.1}
    #@markdown - FP ratio (Further Pairs ratio): Controls global structure preservation
    #@markdown   - Higher values (→5.0): Better preserves global structure, more separation between clusters
    #@markdown   - Lower values (→0.1): Focuses more on local relationships
    pacmap_fp_ratio = 2.0 #@param {type:"slider", min:0.1, max:5.0, step:0.1}
    #@markdown
    #@markdown Recommended combinations:
    #@markdown - Balanced view: MN=0.5, FP=2.0
    #@markdown - Local focus: MN=0.8, FP=1.0
    #@markdown - Global focus: MN=0.3, FP=3.0

# Build parameter string
params = []

if use_umap:
    params.extend([
        f"--n_neighbors {umap_n_neighbors}",
        f"--min_dist {umap_min_dist}",
        f"--metric {umap_metric}"
    ])

if use_tsne:
    params.extend([
        f"--perplexity {tsne_perplexity}",
        f"--learning_rate {tsne_learning_rate}"
    ])

if use_pacmap:
    params.extend([
        f"--n_neighbors {pacmap_n_neighbors}",
        f"--mn_ratio {pacmap_mn_ratio}",
        f"--fp_ratio {pacmap_fp_ratio}"
    ])

params_str = " ".join(params)

## 📊 Generate JSON File

In [None]:
#@title Generate the visualization data file

if not valid_files:
    print("⚠️ Please fix file path issues before continuing")
else:
    print(f"Generating visualization data...")
    !protspace-json -i {embedding_file} -m {feature_file} -o {output_file} --methods {methods_str} {params_str}
    print(f"✅ JSON file saved as: {output_file}")

## 🚀 Launch ProtSpace

In [None]:
#@title 🌟 Visualization Settings and Launch {display-mode: "form"}
#@markdown ### Configure and launch ProtSpace visualization

#@markdown #### Path to JSON file:
json_file = "" #@param {type:"string"}

#@markdown #### Visualization Height (pixels):
jupyter_height = 800 #@param {type:"slider", min:400, max:1200, step:50}

#@markdown #### Jupyter Display Mode:
jupyter_mode = "inline" #@param ["inline", "external"]

import os
import sys
from IPython.display import display, HTML, clear_output

def show_launch_info(json_file, height, mode):
    """Display launch information in a formatted box separate from the dashboard."""
    info_html = f"""
    <div style="background-color: #666666; padding: 10px; margin-bottom: 20px; border-radius: 5px;">
        <p style="margin: 0;"><strong>📊 Launching ProtSpace with:</strong></p>
        <ul style="margin: 10px 0;">
            <li>File: {os.path.basename(json_file)}</li>
            <li>Height: {height}px</li>
            <li>Mode: {mode}</li>
        </ul>
    </div>
    """
    display(HTML(info_html))

def launch_protspace(json_file, height=800, mode="inline"):
    """
    Launch ProtSpace with customized visualization settings.

    Args:
        json_file: Path to the JSON configuration file
        height: Height of the visualization in pixels
        mode: Jupyter display mode ('inline', 'external', 'jupyterlab', 'tab')
    """
    # Use generated JSON file if no other file specified
    if not json_file and 'output_file' in globals():
        json_file = output_file

    if not json_file:
        display(HTML(
            '<div style="color: #dc3545; padding: 10px;">'
            '⚠️ No JSON file specified. Please enter a file path or ensure a default output file exists.'
            '</div>'
        ))
        return False

    if not os.path.exists(json_file):
        display(HTML(
            '<div style="color: #dc3545; padding: 10px;">'
            f'⚠️ JSON file not found: {json_file}'
            '</div>'
        ))
        return False

    # Clear any previous output
    clear_output(wait=True)

    # Show launch information
    show_launch_info(json_file, height, mode)

    original_stdout, original_stderr = sys.stdout, sys.stderr
    try:
        sys.stdout = sys.stderr = open(os.devnull, 'w')
        protspace = ProtSpace(default_json_file=json_file)
        app = protspace.create_app()
        app.run(
            jupyter_mode=mode,
            jupyter_height=height,  # For inline mode, apply height restriction
            jupyter_width='100%',
            dev_tools_silence_routes_logging=True,
            dev_tools_prune_errors=True
        )
        return True
    except Exception as e:
        display(HTML(
            '<div style="color: #dc3545; padding: 10px;">'
            f'❌ Error launching ProtSpace: {str(e)}'
            '</div>'
        ))
        return False
    finally:
        sys.stdout, sys.stderr = original_stdout, original_stderr

# Launch ProtSpace with the configured settings
launch_protspace(json_file, height=jupyter_height, mode=jupyter_mode)