# 🔬 Interactive Pfam & Clan Explorer

Welcome! This notebook lets you explore protein families (Pfam) and clans using protein embeddings and visualization with [ProtSpace](https://github.com/tsenoner/protspace).

**Features:**

1. **Setup:** Install libraries and set up environment
2. **Download:** Get necessary data files
3. **Select:** Choose up to 10 Pfam families or clans
4. **Extract:** Get protein embeddings
5. **Generate:** Create visualization with PCA, UMAP, or PaCMAP
6. **Visualize:** Launch ProtSpace to explore the embedding space interactively

**Usage:**

- Run cells sequentially (▶️ or Shift+Enter)
- Code is hidden by default for clarity. Double-click cell titles to view code if needed.
- Follow instructions in each step


In [None]:
# @title 1. 📦 Setup Environment (~ 2 min) { display-mode: "form" }
# @markdown Run this cell to install required libraries and create data directories.
# @markdown Output is hidden, but a success message will appear upon completion.
# ---

# --- Import ProtSpace ---
import importlib.util
import subprocess
import sys

if importlib.util.find_spec("protspace") is None:
    subprocess.run(
        # [sys.executable, "-m", "pip", "install", "-q", "protspace[frontend]"],
        [
            sys.executable,
            "-m",
            "pip",
            "install",
            "-q",
            "git+https://github.com/tsenoner/protspace.git@b54e0b3b248d360729c5434977fb36715765039d#egg=protspace[frontend]",
        ],
        stdout=subprocess.DEVNULL,
        stderr=subprocess.DEVNULL,
    )
from protspace import ProtSpace

# --- Other Imports ---
import base64
import csv
import json
import os
import traceback
from pathlib import Path

import h5py
import matplotlib.pyplot as plt
import pandas as pd
import polars as pl
from IPython.display import display, clear_output
from matplotlib.colors import to_hex
from tqdm.notebook import tqdm
from ipywidgets import (
    Accordion,
    Button,
    Checkbox,
    Dropdown,
    FloatSlider,
    IntSlider,
    Layout,
    HBox,
    HTML,
    Output,
    Select,
    Text,
    VBox,
    widgets,
)

# --- Create Data Directories using Pathlib ---
parent_data_dir = Path("../data")
if parent_data_dir.exists():
    data_dir = parent_data_dir
else:
    data_dir = Path("data")
raw_data_dir = data_dir / "raw"
explore_data_dir = data_dir / "explore"
raw_data_dir.mkdir(parents=True, exist_ok=True)
explore_data_dir.mkdir(parents=True, exist_ok=True)


# --- Global variable to store path of the latest generated/styled JSON ---
generated_path = None

print("✅ Setup complete. Libraries installed and directories created.")

In [None]:
# @title 2. 📂 Download Data Files (~ 1 min) { display-mode: "form" }
# @markdown This cell downloads these required files if they don't exist:
# @markdown - UniProt-Pfam Mapping
# @markdown - Pfam Clan Information
# @markdown - UniProt Swiss-Prot Embeddings (large file)
# ---

# Assumes raw_data_dir and explore_data_dir are defined in Cell 1
pfam_file_path = raw_data_dir / "uniprot_pfam.tsv"
sprot_file_path = raw_data_dir / "uniprot_sprot.h5"
pfam_clans_path = raw_data_dir / "Pfam-A.clans.tsv"


# Function to download a file using wget with progress
def download_file_wget(url, output_path: Path, description):
    if not output_path.exists():
        print(f"⏳ Downloading {description}...")
        try:
            # Use wget for efficient download and progress bar
            process = subprocess.run(
                [
                    "wget",
                    url,
                    "-O",
                    str(output_path),
                ],
                check=True,
                # Capture stderr to check progress, suppress stdout
                stderr=subprocess.PIPE,
                stdout=subprocess.DEVNULL,
                text=True,
                encoding="utf-8",  # Specify encoding
            )
            # Check file existence as primary success indicator after command runs
            if output_path.exists() and output_path.stat().st_size > 0:
                print(f"✅ Successfully downloaded {description}")
                return True
            else:
                # Provide more specific error based on stderr if available
                err_msg = process.stderr if process.stderr else "Unknown reason."
                print(
                    f"❌ Error downloading {description}: 'wget' completed but file not found or empty."
                )
                print(f"Stderr: {err_msg}")
                # Attempt cleanup if file exists but is empty or invalid
                if output_path.exists():
                    try:
                        output_path.unlink()
                    except OSError:
                        pass
                return False

        except FileNotFoundError:
            print(
                "❌ Error: 'wget' command not found. Please ensure wget is installed and in your system's PATH."
            )
            return False
        except subprocess.CalledProcessError as e:
            print(
                f"❌ Error downloading {description} with wget (return code {e.returncode}):"
            )
            print(f"Stderr:\n{e.stderr}")
            # Clean up potentially incomplete file
            if output_path.exists():
                try:
                    output_path.unlink()
                except OSError:
                    pass
            return False
        except Exception as e:
            print(f"❌ An unexpected error occurred during download: {e}")
            return False
    else:
        print(f"✅ {description} already exists locally.")
        return True


# List of files to download
files_to_download = [
    {
        "url": "https://nextcloud.in.tum.de/index.php/s/9q3kefKeND8rkQ8/download",
        "path": pfam_file_path,
        "description": "UniProt-Pfam mapping",
    },
    {
        "url": "https://nextcloud.in.tum.de/index.php/s/P2icGkHy5Msgw2y/download",
        "path": pfam_clans_path,
        "description": "Pfam clans information",
    },
    {
        "url": "https://nextcloud.in.tum.de/index.php/s/CsBBsz4cJ3zwtHa/download",
        "path": sprot_file_path,
        "description": "UniProt Swiss-Prot embeddings (large file)",
    },
]

# Download all files
all_successful = True
print("--- Starting Data Download ---")
for file_info in files_to_download:
    file_path_obj = Path(file_info["path"])
    success = download_file_wget(
        file_info["url"], file_path_obj, file_info["description"]
    )
    if not success:
        all_successful = False

# Verify completion and show file sizes
if all_successful:
    print("\n--- Data Verification ---")
    total_size_mb = 0
    all_files_found = True
    for file_info in files_to_download:
        path_obj = Path(file_info["path"])
        if path_obj.exists():
            try:
                size_bytes = path_obj.stat().st_size
                if size_bytes > 0:
                    size_mb = size_bytes / (1024 * 1024)
                    print(f"  • {path_obj.name}: {size_mb:.2f} MB")
                    total_size_mb += size_mb
                else:
                    print(f"  • {path_obj.name}: ⚠️ File exists but is empty!")
                    all_files_found = False
            except OSError as e:
                print(f"  • {path_obj.name}: ⚠️ Error checking file size: {e}")
                all_files_found = False
        else:
            print(f"  • {path_obj.name}: ⚠️ File not found after download attempt!")
            all_files_found = False

    if all_files_found:
        print(
            f"\n✅ All required data files are present and valid. Total size: {total_size_mb:.2f} MB."
        )
    else:
        print(
            "\n⚠️ Some critical data files are missing or invalid. Please check download errors above and try running this cell again."
        )
else:
    print("\n⚠️ Some files could not be downloaded. Please check the errors above.")

In [None]:
# @title 3. ✨ Initialize Data Lookup & Select Families/Clans { display-mode: "form" }
# @markdown Lookup system for Pfam families/clans selection.
# @markdown
# @markdown **Quick Guide:**
# @markdown 1. Select search type (Pfam/Clans)
# @markdown 2. Search results (optional)
# @markdown 3. Select from list *(three families preselected)*
# @markdown 4. Add to your selections (max 10)
# @markdown 5. Use "Clear All" to reset
# @markdown 6. View your selections in the table below
# ---

# --- ProteinPfamClanLookup Class Definition ---
class ProteinPfamClanLookup:
    """Handles efficient lookups between proteins, Pfam families, and clans."""

    def __init__(self, protein_pfam_path, pfam_clan_path=None):
        protein_pfam_file = Path(protein_pfam_path)
        pfam_clan_file = Path(pfam_clan_path) if pfam_clan_path else None

        # --- Initialize protein-Pfam relationships ---
        if not protein_pfam_file.exists():
            raise FileNotFoundError(
                f"Could not find protein-Pfam mapping file at {protein_pfam_file}. Ensure it was downloaded correctly."
            )
        try:
            schema_overrides = {"Entry": pl.Utf8, "Pfam": pl.Utf8}
            df = pl.read_csv(
                protein_pfam_file, separator="\t", schema_overrides=schema_overrides
            )
        except Exception as e:
            raise RuntimeError(
                f"Could not read protein-Pfam mapping file {protein_pfam_file}. Error: {e}"
            )

        # Process protein-Pfam pairs using Polars
        pairs_df = (
            df.with_columns(pl.col("Pfam").fill_null("").cast(pl.Utf8))
            .with_columns(
                pl.col("Pfam").str.strip_chars(";").str.split(";").alias("Pfam_list")
            )
            .select(["Entry", "Pfam_list"])
            .explode("Pfam_list")
            .filter(pl.col("Pfam_list") != "")
        )

        # Create protein -> {pfam1, pfam2,...}
        protein_groups = pairs_df.group_by("Entry").agg(pl.col("Pfam_list").unique())
        self.protein_to_pfams = {
            row["Entry"]: set(row["Pfam_list"])
            for row in protein_groups.iter_rows(named=True)
        }

        # Create pfam -> {protein1, protein2,...}
        pfam_groups = pairs_df.group_by("Pfam_list").agg(pl.col("Entry").unique())
        self.pfam_to_proteins = {
            row["Pfam_list"]: set(row["Entry"])
            for row in pfam_groups.iter_rows(named=True)
        }

        # --- Initialize Pfam-clan relationships ---
        self.clan_to_pfams = {}
        self.pfam_to_clan = {}
        self.clan_details = {}

        if pfam_clan_file and pfam_clan_file.exists():
            try:
                clan_schema_overrides = {
                    "pfam_id": pl.Utf8,
                    "clan_id": pl.Utf8,
                    "clan_name": pl.Utf8,
                }
                clan_df = pl.read_csv(
                    pfam_clan_file,
                    separator="\t",
                    has_header=False,
                    new_columns=[
                        "pfam_id",
                        "clan_id",
                        "clan_name",
                        "pfam_name",
                        "pfam_description",
                    ],
                    schema_overrides=clan_schema_overrides,
                ).filter(pl.col("clan_id").is_not_null() & (pl.col("clan_id") != ""))

                # Create clan -> {pfam1, pfam2,...}
                clan_groups = clan_df.group_by("clan_id").agg(
                    pl.col("pfam_id").unique().alias("pfams"),
                    pl.first("clan_name").alias("name"),
                )
                self.clan_to_pfams = {
                    row["clan_id"]: set(row["pfams"])
                    for row in clan_groups.iter_rows(named=True)
                }
                self.clan_details = {
                    row["clan_id"]: {
                        "name": row["name"],
                        "pfam_count": len(row["pfams"]),
                    }
                    for row in clan_groups.iter_rows(named=True)
                }

                # Create pfam -> clan dictionary
                self.pfam_to_clan = {
                    row["pfam_id"]: row["clan_id"]
                    for row in clan_df.select(["pfam_id", "clan_id"]).iter_rows(
                        named=True
                    )
                    if row["clan_id"]
                }
            except Exception as e:
                print(
                    f"Warning: Could not read or process Pfam clan file at {pfam_clan_file}. Clan features might be incomplete. Error: {e}"
                )
        elif pfam_clan_file:
            print(
                f"Warning: Pfam clan file specified but not found at {pfam_clan_file}. Clan features will be unavailable."
            )

    # --- Lookup Methods ---
    def get_pfams_for_protein(self, protein):
        return self.protein_to_pfams.get(protein, set())

    def get_proteins_for_pfam(self, pfam):
        return self.pfam_to_proteins.get(pfam, set())

    def get_all_pfams(self):
        return set(self.pfam_to_proteins.keys())

    def get_all_clans(self):
        return list(self.clan_to_pfams.keys())

    def get_pfams_for_clan(self, clan_id):
        return self.clan_to_pfams.get(clan_id, set())

    def get_clan_for_pfam(self, pfam_id):
        return self.pfam_to_clan.get(pfam_id)

    def get_clan_details(self, clan_id):
        return self.clan_details.get(
            clan_id,
            {"name": "Unknown", "pfam_count": len(self.get_pfams_for_clan(clan_id))},
        )

    def get_proteins_for_clan(self, clan_id):
        proteins = set()
        for pfam_id in self.get_pfams_for_clan(clan_id):
            proteins.update(self.get_proteins_for_pfam(pfam_id))
        return proteins

    def get_clans_for_protein(self, protein):
        clans = set()
        for pfam_id in self.get_pfams_for_protein(protein):
            clan_id = self.get_clan_for_pfam(pfam_id)
            if clan_id:
                clans.add(clan_id)
        return clans


# --- Initialize Lookup ---
lookup_status_output = Output()
pfam_options = {}
clan_options = {}
lookup = None
initialization_message = ""  # Store the final message

with lookup_status_output:  # Capture tqdm output here temporarily
    try:
        lookup = ProteinPfamClanLookup(pfam_file_path, pfam_clans_path)

        # Prepare options for Pfam dropdown/selector with counts
        all_pfams = sorted(list(lookup.get_all_pfams()))
        # Run tqdm outside the dict comprehension for cleaner progress
        pfam_proteins_counts = {
            pfam: len(lookup.get_proteins_for_pfam(pfam))
            for pfam in tqdm(
                all_pfams, desc="Loading Pfam options", leave=False, mininterval=0.5
            )
        }
        pfam_options = {
            f"{pfam} ({count} proteins)": pfam
            for pfam, count in pfam_proteins_counts.items()
        }

        # Prepare options for Clan dropdown with counts
        all_clans = sorted(lookup.get_all_clans())
        clan_proteins_counts = {
            clan: len(lookup.get_proteins_for_clan(clan))
            for clan in tqdm(
                all_clans, desc="Loading Clan options", leave=False, mininterval=0.5
            )
        }
        clan_options = {}
        for clan, protein_count in clan_proteins_counts.items():
            details = lookup.get_clan_details(clan)
            pfam_count = details["pfam_count"]
            clan_options[f"{clan} ({pfam_count} Pfams, {protein_count} proteins)"] = (
                clan
            )

        initialization_message = f"✅ Lookup initialized. Found {len(all_pfams)} Pfam families and {len(all_clans)} clans."

    except FileNotFoundError as e:
        initialization_message = f"❌ Error initializing lookup: {e}\n   Please ensure the data files were downloaded successfully in Step 2."
    except Exception as e:
        initialization_message = (
            f"❌ An unexpected error occurred during lookup initialization: {e}"
        )

# --- Define Widgets ---
search_type = Dropdown(
    options=["Pfam Families", "Clans"],
    value="Pfam Families",
    description="Search by:",
    style={"description_width": "initial"},
    layout={"width": "max-content"},
)
search_field = Text(
    placeholder="Filter by ID (e.g., PF00067 or CL0001)...",
    description="Filter:",
    style={"description_width": "initial"},
    layout={"width": "60%"},
)
selector = Select(
    options=[],
    value=None,
    description="Select:",
    rows=10,
    style={"description_width": "initial"},
    layout={"width": "95%"},
)
add_button = Button(description="Add Selected", button_style="info", icon="plus")
clear_button = Button(description="Clear All", button_style="warning", icon="trash")
selected_items_display = HTML(value="<p>No items selected yet.</p>")
selected_items_summary = HTML(value="")
status_display = HTML()

# --- Global state for selections ---
selected_pfams = ["PF00061", "PF00067", "PF00077"]
selected_clans = ["CL0001", "CL0092", "CL0240"]
current_mode = "Pfam Families"


# --- Widget Logic Functions ---
def get_total_selected_proteins():
    """Calculate total unique proteins for the current selection."""
    if not lookup:
        return 0
    total_proteins = set()
    items = selected_pfams if current_mode == "Pfam Families" else selected_clans
    getter = (
        lookup.get_proteins_for_pfam
        if current_mode == "Pfam Families"
        else lookup.get_proteins_for_clan
    )
    for item_id in items:
        total_proteins.update(getter(item_id))
    return len(total_proteins)


def update_display():
    """Update the HTML table showing selected items and the summary."""
    if not lookup:
        return

    target_list = selected_pfams if current_mode == "Pfam Families" else selected_clans
    item_type_plural = "Pfam Families" if current_mode == "Pfam Families" else "Clans"
    limit = 10

    # --- Build DataFrame for HTML Table ---
    if target_list:
        data = []
        if current_mode == "Pfam Families":
            for pfam in target_list:
                proteins = lookup.get_proteins_for_pfam(pfam)
                clan = lookup.get_clan_for_pfam(pfam) or "N/A"
                data.append(
                    {
                        "Pfam Family": pfam,
                        "Proteins": f"{len(proteins):,}",
                        "Clan": clan,
                    }
                )
            df = pd.DataFrame(data)
            display_title = f"Selected Pfam Families ({len(target_list)}/{limit})"
        else:  # Clans mode
            for clan in target_list:
                proteins = lookup.get_proteins_for_clan(clan)
                details = lookup.get_clan_details(clan)
                pfam_count = details["pfam_count"]
                data.append(
                    {
                        "Clan ID": clan,
                        "Pfam Families": pfam_count,
                        "Proteins": f"{len(proteins):,}",
                    }
                )
            df = pd.DataFrame(data)
            display_title = f"Selected Clans ({len(target_list)}/{limit})"

        # Convert DataFrame to HTML
        selected_items_display.value = f"<h4>{display_title}:</h4>{df.to_html(index=False, classes='table table-striped table-sm', border=0, justify='left')}"
    else:
        selected_items_display.value = f"<p>No {item_type_plural.lower()} selected.</p>"

    # --- Update Summary (Total proteins + Warning) ---
    total_proteins = get_total_selected_proteins()
    summary_html = f"<p style='margin-top: 10px;'><b>Total unique proteins in selection: {total_proteins:,}</b></p>"
    if total_proteins > 20000:
        summary_html += '<p style="color: orange;">⚠️ <b>Warning:</b> Selecting over 20,000 proteins may lead to slow performance in subsequent steps (Extraction, Generation, Visualization).</p>'
    selected_items_summary.value = summary_html

    status_display.value = ""


# --- Other Widget Callbacks (Largely Unchanged Logic) ---
def update_selector_options():
    if not lookup:
        return
    global current_mode
    search_term = search_field.value.strip().lower()
    options_dict = pfam_options if current_mode == "Pfam Families" else clan_options
    desc = "Select Pfam:" if current_mode == "Pfam Families" else "Select Clan:"

    if search_term:
        filtered_options = {
            k: v for k, v in options_dict.items() if search_term in k.lower()
        }
        selector.options = list(filtered_options.keys())
    else:
        selector.options = list(options_dict.keys())
    selector.description = desc
    selector.value = None  # Reset selection after filtering


def on_search_type_change(change):
    global current_mode  # Ensure we modify the global variable
    if change["new"] != change["old"]:
        current_mode = change["new"]  # Update mode state
        search_field.value = ""
        update_selector_options()
        update_display()  # Update selection list title and summary


def filter_items(change):
    update_selector_options()


def add_selected(button):
    global selected_pfams, selected_clans  # Ensure modification of globals
    selected_key = selector.value
    if not lookup or not selected_key:
        status_display.value = (
            '<p style="color: orange;">⚠️ Please select an item from the list first.</p>'
        )
        return

    options_dict = pfam_options if current_mode == "Pfam Families" else clan_options
    item_id = options_dict[selected_key]

    target_list = selected_pfams if current_mode == "Pfam Families" else selected_clans
    item_type = "Pfam family" if current_mode == "Pfam Families" else "clan"
    limit = 10

    if item_id in target_list:
        status_display.value = f'<p style="color: orange;">⚠️ This {item_type} ({item_id}) is already selected.</p>'
        return

    if len(target_list) >= limit:
        status_display.value = f'<p style="color: red;">❌ Maximum limit of {limit} {item_type}s reached.</p>'
        return

    target_list.append(item_id)
    status_display.value = (
        f'<p style="color: green;">✅ Added {item_type}: {item_id}</p>'
    )
    update_display()  # Refresh list and summary


def clear_selections(button):
    global selected_pfams, selected_clans  # Ensure modification of globals
    item_type = "Pfam family" if current_mode == "Pfam Families" else "clan"
    cleared = False
    if current_mode == "Pfam Families":
        if selected_pfams:  # Only clear if list is not empty
            selected_pfams = []
            cleared = True
    else:
        if selected_clans:  # Only clear if list is not empty
            selected_clans = []
            cleared = True

    if cleared:
        status_display.value = (
            f'<p style="color: blue;">ℹ️ All selected {item_type}s cleared.</p>'
        )
        update_display()


# --- Connect Widgets ---
if lookup:  # Only connect if lookup initialized successfully
    search_type.observe(on_search_type_change, names="value")
    search_field.observe(filter_items, names="value")
    add_button.on_click(add_selected)
    clear_button.on_click(clear_selections)

    # --- Initial Population ---
    update_selector_options()  # Populate selector initially
    update_display()  # Update display with defaults and summary
else:
    status_display.value = "<p style='color: red;'>❌ Lookup initialization failed. Cannot create selection widgets.</p>"

# --- Display Widgets ---
# Display the final initialization message clearly *after* potential tqdm output
print(initialization_message)

display(
    VBox(
        [
            HBox([search_type, search_field]),
            selector,
            HBox([add_button, clear_button]),
            status_display,  # For add/clear/error messages
            HTML("<hr>"),  # Separator
            selected_items_display,  # The HTML table for the list
            selected_items_summary,  # HTML for total count and warning
        ]
    )
)

In [None]:
# @title 4. 📈 Extract Embeddings & Generate Metadata { display-mode: "form" }
# @markdown Extracts protein embeddings for selected Pfam families/clans and creates metadata for ProtSpace.
# @markdown
# @markdown **Process:**
# @markdown 1. Finds proteins for your selection
# @markdown 2. Extracts embeddings from source file
# @markdown 3. Creates metadata CSV with annotations
# @markdown
# @markdown **Note:** May take time for large selections. Progress shown. Previously processed selections will be reused.
# @markdown
# @markdown ---
# @markdown Click to extract:
# ---

# --- Global variables to store results of the latest successful extraction ---
last_extraction_dir: Path | None = None
last_embedding_path: Path | None = None
last_metadata_path: Path | None = None
last_selection_mode: str | None = None  # To track if it was Pfam or Clan
last_selected_items: list | None = None  # To store the specific list used


# --- Helper Functions ---
def get_proteins_for_selection(mode, selected_items_list, lookup_obj):
    """Gets the set of unique protein IDs based on the current selection mode and items."""
    # This function remains the same as previous version
    if not lookup_obj:
        return set(), "Lookup system not initialized."
    if not selected_items_list:
        return set(), f"No {mode.replace(' Families', '')[:-1]} selected."
    all_proteins = set()
    getter = (
        lookup_obj.get_proteins_for_pfam
        if mode == "Pfam Families"
        else lookup_obj.get_proteins_for_clan
    )
    for item_id in selected_items_list:
        all_proteins.update(getter(item_id))
    count = len(all_proteins)
    item_type_plural = mode.replace(" Families", "")
    msg = f"Found {count:,} unique proteins across {len(selected_items_list)} selected {item_type_plural}."
    return all_proteins, msg


def save_embeddings_and_metadata(
    proteins, selection_mode, selected_items_list, lookup_obj, input_h5_path: Path
):
    """
    Extracts embeddings, saves them to a new HDF5 file, and generates the metadata CSV.
    Returns a dictionary with paths and summary, or None on error.
    Updates global variables on success.
    """
    global \
        last_extraction_dir, \
        last_embedding_path, \
        last_metadata_path, \
        last_selection_mode, \
        last_selected_items

    if not lookup_obj:
        print("❌ Error: Lookup object is not available.")
        return None
    if not proteins:
        print("❌ Error: No proteins provided for extraction.")
        return None
    if not input_h5_path.exists():
        print(
            f"❌ Error: Input embedding file not found at {input_h5_path}. Please ensure Step 2 completed."
        )
        return None

    # --- Define Paths ---
    identifier = "_".join(sorted(selected_items_list))
    prefix = "pfam" if selection_mode == "Pfam Families" else "clan"
    # explore_data_dir is assumed defined globally (Cell 1)
    output_dir = explore_data_dir / f"{prefix}_{identifier}"
    output_h5_path = output_dir / "embeddings.h5"
    output_csv_path = output_dir / "metadata.csv"
    output_meta_path = output_dir / "run_metadata.txt"

    # --- Check for Existing Output ---
    if output_dir.exists() and output_h5_path.exists() and output_csv_path.exists():
        print(f"✅ Output for this exact selection already exists in: {output_dir}")
        last_extraction_dir = output_dir
        last_embedding_path = output_h5_path
        last_metadata_path = output_csv_path
        last_selection_mode = selection_mode
        last_selected_items = selected_items_list.copy()
        found_count, total_requested_count = (
            -1,
            -1,
        )  # Use -1 to indicate not found in metadata
        try:
            with open(output_meta_path, "r", encoding="utf-8") as f:
                for line in f:
                    if "Found proteins with embeddings:" in line:
                        found_count = int(line.split(":", 1)[1].strip())
                    if "Total proteins requested:" in line:
                        total_requested_count = int(line.split(":", 1)[1].strip())
        except Exception:
            pass  # Ignore errors reading metadata

        missing_count_str = "N/A"
        if found_count != -1 and total_requested_count != -1:
            missing_count_str = f"{total_requested_count - found_count:,}"
        found_count_str = f"{found_count:,}" if found_count != -1 else "N/A"

        return {
            "output_dir": output_dir,
            "embedding_path": output_h5_path,
            "metadata_path": output_csv_path,
            "total_proteins_requested": len(proteins),  # Current request size
            "found_proteins": found_count_str,
            "missing_proteins": missing_count_str,
            "already_exists": True,
        }

    # --- Create Output Directory ---
    try:
        output_dir.mkdir(parents=True, exist_ok=True)
    except OSError as e:
        print(f"❌ Error creating output directory {output_dir}: {e}")
        return None

    # --- Extract Embeddings ---
    found_count = 0
    missing_proteins_count = 0
    print(f"🔄 Extracting embeddings for {len(proteins):,} proteins...")
    # Use a set to efficiently track proteins we actually save embeddings for
    saved_protein_ids = set()
    try:
        with (
            h5py.File(input_h5_path, "r") as in_file,
            h5py.File(output_h5_path, "w") as out_file,
        ):
            # Iterate directly over the input protein set for potential efficiency
            for protein_id in tqdm(
                proteins, desc="Reading embeddings", unit="protein", mininterval=0.5
            ):
                if protein_id in in_file:
                    embedding = in_file[protein_id][:]
                    out_file.create_dataset(protein_id, data=embedding)
                    found_count += 1
                    saved_protein_ids.add(protein_id)  # Track saved proteins
                else:
                    missing_proteins_count += 1
        print(
            f"➡️ Found embeddings for {found_count:,} out of {len(proteins):,} requested proteins."
        )
        if missing_proteins_count > 0:
            print(f"⚠️ Missing embeddings for {missing_proteins_count:,} proteins.")

    except Exception as e:
        print(f"❌ Error during embedding extraction: {e}")
        if output_h5_path.exists():
            try:
                output_h5_path.unlink()
            except OSError:
                pass
        if output_dir.exists() and not any(output_dir.iterdir()):
            try:
                output_dir.rmdir()
            except OSError:
                pass
        return None

    # --- Generate Metadata CSV ---
    # Use the set of proteins actually saved in the H5 file
    print("🔄 Generating metadata CSV file...")
    try:
        with open(output_csv_path, "w", newline="", encoding="utf-8") as csvfile:
            csv_writer = csv.writer(csvfile)

            if selection_mode == "Pfam Families":
                csv_writer.writerow(
                    ["identifier", "Pfam_Primary", "Pfam_Extended", "Clan"]
                )
                for protein_id in tqdm(
                    saved_protein_ids,
                    desc="Writing Pfam CSV",
                    unit="protein",
                    mininterval=0.5,
                ):
                    all_pfams_for_protein = lookup_obj.get_pfams_for_protein(protein_id)
                    protein_selected_pfams = [
                        pfam
                        for pfam in selected_items_list
                        if pfam in all_pfams_for_protein
                    ]
                    if not protein_selected_pfams:
                        continue  # Should not happen if logic is correct
                    primary_pfam = protein_selected_pfams[0]
                    primary_clan = lookup_obj.get_clan_for_pfam(primary_pfam) or "N/A"
                    other_pfams = [
                        pfam
                        for pfam in all_pfams_for_protein
                        if pfam not in selected_items_list
                    ]
                    num_other = len(other_pfams)
                    if num_other == 0:
                        pfam_extended = primary_pfam
                    elif num_other == 1:
                        pfam_extended = f"{primary_pfam} + 1 other"
                    elif num_other == 2:
                        pfam_extended = f"{primary_pfam} + 2 others"
                    else:
                        pfam_extended = f"{primary_pfam} + >2 others"
                    csv_writer.writerow(
                        [protein_id, primary_pfam, pfam_extended, primary_clan]
                    )

            elif selection_mode == "Clans":
                csv_writer.writerow(["identifier", "Clan_Primary", "Pfam_Label"])
                clan_top_pfams = {}
                # Pre-calculate top Pfams based on *saved* proteins
                for clan in selected_items_list:
                    pfams_in_clan = lookup_obj.get_pfams_for_clan(clan)
                    pfam_counts = {}
                    for protein_id_saved in saved_protein_ids:
                        protein_pfams = lookup_obj.get_pfams_for_protein(
                            protein_id_saved
                        )
                        for pfam in protein_pfams:
                            if pfam in pfams_in_clan:
                                pfam_counts[pfam] = pfam_counts.get(pfam, 0) + 1
                    sorted_pfams = sorted(
                        pfam_counts.items(), key=lambda item: item[1], reverse=True
                    )
                    clan_top_pfams[clan] = [pfam for pfam, _ in sorted_pfams[:3]]

                for protein_id in tqdm(
                    saved_protein_ids,
                    desc="Writing Clan CSV",
                    unit="protein",
                    mininterval=0.5,
                ):
                    all_pfams_for_protein = lookup_obj.get_pfams_for_protein(protein_id)
                    protein_clans_involved = set()
                    pfam_clan_map = {}
                    for pfam in all_pfams_for_protein:
                        clan = lookup_obj.get_clan_for_pfam(pfam)
                        if clan in selected_items_list:
                            protein_clans_involved.add(clan)
                            pfam_clan_map[pfam] = clan
                    if not protein_clans_involved:
                        continue
                    primary_clan = sorted(list(protein_clans_involved))[0]
                    pfam_label = f"{primary_clan}: Other"  # Default
                    for pfam, clan in pfam_clan_map.items():
                        if clan == primary_clan and pfam in clan_top_pfams.get(
                            primary_clan, []
                        ):
                            pfam_label = f"{primary_clan}: {pfam}"
                            break
                    csv_writer.writerow([protein_id, primary_clan, pfam_label])
            else:
                raise ValueError(
                    f"Invalid selection mode '{selection_mode}' for CSV generation."
                )

        print("✅ Metadata CSV generated successfully.")

    except Exception as e:
        print(f"❌ Error generating metadata CSV: {e}")
        # Clean up files if CSV fails
        if output_h5_path.exists():
            try:
                output_h5_path.unlink()
            except OSError:
                pass
        if output_csv_path.exists():
            try:
                output_csv_path.unlink()
            except OSError:
                pass
        if output_dir.exists() and not any(output_dir.iterdir()):
            try:
                output_dir.rmdir()
            except OSError:
                pass
        return None

    # --- Save Run Metadata ---
    try:
        with open(output_meta_path, "w", encoding="utf-8") as f:
            f.write(f"Selection Mode: {selection_mode}\n")
            f.write(f"Selected Items: {', '.join(selected_items_list)}\n")
            f.write(
                f"Total proteins requested: {len(proteins)}\n"
            )  # Original request size
            f.write(
                f"Found proteins with embeddings: {found_count}\n"
            )  # Actual found in H5
            f.write(f"Missing proteins: {missing_proteins_count}\n")
            f.write(f"Input Embedding File: {input_h5_path.name}\n")
            f.write(f"Output Embedding File: {output_h5_path.name}\n")
            f.write(f"Output Metadata File: {output_csv_path.name}\n")
            f.write(f"Output Directory: {output_dir}\n")
            f.write(f"Created on: {pd.Timestamp.now()}\n")
    except Exception as e:
        print(f"⚠️ Warning: Could not save run metadata file {output_meta_path}: {e}")

    # --- Update Global Variables on Success ---
    last_extraction_dir = output_dir
    last_embedding_path = output_h5_path
    last_metadata_path = output_csv_path
    last_selection_mode = selection_mode
    last_selected_items = selected_items_list.copy()  # Store a copy

    # --- Return Summary ---
    return {
        "output_dir": output_dir,
        "embedding_path": output_h5_path,
        "metadata_path": output_csv_path,
        "total_proteins_requested": len(proteins),
        "found_proteins": found_count,  # Return integer
        "missing_proteins": missing_proteins_count,  # Return integer
        "already_exists": False,
    }


# --- Widget Setup ---
extract_button = Button(
    description="Extract Embeddings & Create Metadata",
    button_style="primary",
    icon="cogs",
)
extraction_output = Output()


# --- Button Click Handler ---
def on_extract_button_click(b):
    with extraction_output:
        clear_output()  # Clear previous output
        print(f"--- Starting Extraction for {current_mode} ---")

        if not lookup:
            print(
                "❌ Cannot proceed: Data lookup system failed to initialize in Step 3."
            )
            return

        target_list = (
            selected_pfams if current_mode == "Pfam Families" else selected_clans
        )
        mode_str = current_mode

        if not target_list:
            item_type = mode_str.replace(" Families", "")[:-1]  # Pfam or Clan
            print(
                f"⚠️ Please select at least one {item_type} in Step 3 before extracting."
            )
            return

        print("1. Identifying target proteins...")
        all_selected_proteins, message = get_proteins_for_selection(
            mode_str, target_list, lookup
        )
        print(f"   {message}")

        if not all_selected_proteins:
            print("⚠️ No proteins found for the current selection. Cannot proceed.")
            return
        elif len(all_selected_proteins) > 100000:
            # Warning printed by get_proteins_for_selection, maybe add time estimate?
            print(
                f"   ⚠️ Note: Processing {len(all_selected_proteins):,} proteins may take significant time."
            )

        print("\n2. Processing embeddings and metadata...")
        # sprot_file_path is assumed global (Cell 2)
        process_summary = save_embeddings_and_metadata(
            all_selected_proteins, mode_str, target_list, lookup, sprot_file_path
        )

        print("\n--- Extraction Summary ---")
        if process_summary:
            # Format numbers from summary for display
            found_proteins_val = process_summary["found_proteins"]
            requested_val = process_summary["total_proteins_requested"]
            missing_val = process_summary["missing_proteins"]
            # Use the formatted strings for N/A cases when already_exists is True
            found_str = (
                found_proteins_val
                if isinstance(found_proteins_val, str)
                else f"{found_proteins_val:,}"
            )
            missing_str = (
                missing_val if isinstance(missing_val, str) else f"{missing_val:,}"
            )
            requested_str = f"{requested_val:,}"  # requested should always be int here

            if process_summary.get("already_exists", False):
                print("ℹ️ Result: Found existing data for this selection.")
                print(f"   Directory: {process_summary['output_dir']}")
                print(f"   Embeddings: {process_summary['embedding_path'].name}")
                print(f"   Metadata: {process_summary['metadata_path'].name}")
                # Report counts based on the previously found data
                print(
                    f"   Proteins Found (in existing data): {found_str} / {process_summary.get('total_proteins_requested', 'N/A'):,} originally requested."
                )
            else:
                print("✅ Result: Extraction and metadata generation successful!")
                print(f"   Directory: {process_summary['output_dir']}")
                print(f"   Embeddings: {process_summary['embedding_path'].name}")
                print(f"   Metadata: {process_summary['metadata_path'].name}")
                print(f"   Proteins Found: {found_str} / {requested_str} requested.")
                print(f"   Proteins Missing: {missing_str}")

            print("\n🎉 Ready for Step 5: Generate Visualization JSON.")
        else:
            print("❌ Extraction failed. Please check the error messages above.")


# --- Connect Button ---
extract_button.on_click(on_extract_button_click)

# --- Display ---
display(extract_button)
display(extraction_output)

In [None]:
# @title 5. 📊 Generate ProtSpace Visualization JSON { display-mode: "form" }
# @markdown Runs dimensionality reduction on embeddings and creates visualization files for ProtSpace.
# @markdown
# @markdown **Instructions:**
# @markdown 1. Select dimensionality reduction methods
# @markdown 2. Click **"Generate ProtSpace JSON"**
# @markdown 3. Wait for `🎉 Generation and Styling Complete.` message (time varies with protein count)
# @markdown 4. Continue to Step 6 or download files to upload at https://protspace.rostlab.org/
# @markdown
# @markdown **Note:** Uses data from Step 4. Re-run Step 4 if you change selections.
# ---

# --- Define Methods and Widgets ---
dim_reduction_methods = [
    ("PCA 2D", "pca2"),
    ("PCA 3D", "pca3"),
    ("UMAP 2D", "umap2"),
    ("UMAP 3D", "umap3"),
    ("PaCMAP 2D", "pacmap2"),
    ("PaCMAP 3D", "pacmap3"),
]
# method_checkboxes = [
#     Checkbox(description=name, value=(i % 2 == 0))  # Default to 2D methods
#     for i, (name, _) in enumerate(dim_reduction_methods)
# ]
method_checkboxes = [
    Checkbox(description=name, value=(i == 0))
    for i, (name, _) in enumerate(dim_reduction_methods)
]

# UMAP Parameters
umap_n_neighbors = IntSlider(
    value=25,
    min=5,
    max=200,
    step=1,
    description="n_neighbors:",
    style={"description_width": "initial"},
)
umap_min_dist = FloatSlider(
    value=0.5,
    min=0.01,
    max=1,
    step=0.01,
    description="min_dist:",
    style={"description_width": "initial"},
)

# PaCMAP Parameters
pacmap_mn_ratio = FloatSlider(
    value=0.5,
    min=0.1,
    max=1.0,
    step=0.1,
    description="MN ratio:",
    style={"description_width": "initial"},
)
pacmap_fp_ratio = FloatSlider(
    value=2.0,
    min=0.1,
    max=5.0,
    step=0.1,
    description="FP ratio:",
    style={"description_width": "initial"},
)

# Accordion for Parameters
params_accordion = Accordion(
    children=[
        VBox([umap_n_neighbors, umap_min_dist]),
        VBox([pacmap_mn_ratio, pacmap_fp_ratio]),
    ]
)
params_accordion.set_title(0, "UMAP Parameters")
params_accordion.set_title(1, "PaCMAP Parameters")
params_accordion.selected_index = None  # Start collapsed

# Buttons
generate_button = Button(
    description="Generate ProtSpace JSON",
    button_style="success",
    icon="check",
    layout=Layout(width="200px"),
)
force_generate_button = Button(
    description="Force Regenerate",
    button_style="warning",
    icon="refresh",
    layout=Layout(width="200px"),
)
download_button = Button(
    description="Download JSON",
    button_style="info",
    icon="download",
    layout=Layout(width="200px"),
    disabled=True,
)
gen_output = Output()


# --- Styling Functions ---
def _apply_styling_base(input_json_path: Path, csv_path: Path, style_logic_func):
    """Base function to read JSON, apply styling logic, and write styled JSON."""
    if not input_json_path or not input_json_path.exists():
        print(f"❌ Error applying styling: Input JSON '{input_json_path}' not found.")
    if not csv_path or not csv_path.exists():
        print(f"❌ Error applying styling: Metadata CSV '{csv_path}' not found.")

    try:
        # Define colors and shapes
        # Ensure colormaps are accessed correctly
        try:
            tab20b = plt.get_cmap("tab20b")
            tab20c = plt.get_cmap("tab20c")
        except ValueError:
            print(
                "⚠️ Warning: Colormaps 'tab20b' or 'tab20c' not found. Using fallback 'viridis'."
            )
            # Provide a fallback if specific maps aren't available (less likely with standard matplotlib)
            tab20b = plt.get_cmap("viridis")
            tab20c = plt.get_cmap("plasma")  # Different fallback

        colors_b = [
            to_hex(tab20b(i / 19.0)) for i in range(20)
        ]  # Normalize index for colormap lookup
        colors_c = [to_hex(tab20c(i / 19.0)) for i in range(20)]
        all_colors = colors_b + colors_c  # 40 colors total
        shapes = [
            "circle",
            "square",
            "diamond",
            "cross",
            "x",
            "circle-open",
            "square-open",
            "diamond-open",
            "triangle-up",
            "star",
        ]  # 10 shapes

        # Read metadata CSV
        df = pd.read_csv(csv_path)

        # Apply the specific styling logic
        visualization_state = style_logic_func(df, all_colors, shapes)

        if not visualization_state:  # Check if styling logic failed internally
            raise ValueError("Styling logic function returned None or empty state.")

        # Read the original JSON data
        with open(input_json_path, "r", encoding="utf-8") as f:
            data = json.load(f)

        # Add visualization state
        data["visualization_state"] = visualization_state

        # Write the updated (styled) JSON
        with open(input_json_path, "w", encoding="utf-8") as f:
            json.dump(data, f, indent=2)

        print(f"✅ Custom styling applied. Styled JSON saved to: {input_json_path}")

    except Exception as e:
        print(f"❌ Error applying styling: {e}")
        print("   Proceeding with unstyled JSON.")


def _pfam_style_logic(df, all_colors, shapes):
    """Generates Pfam-specific color/shape mapping."""
    # This function remains the same as previous version
    if "Pfam_Primary" not in df.columns or "Pfam_Extended" not in df.columns:
        print(
            "❌ Pfam styling error: Required columns ('Pfam_Primary', 'Pfam_Extended') not found in CSV."
        )
        return None
    base_pfams = sorted(df["Pfam_Primary"].dropna().unique().tolist())
    num_base_pfams = len(base_pfams)
    if num_base_pfams == 0:
        return None
    feature_colors = {"Pfam_Primary": {}, "Pfam_Extended": {}, "Clan": {}}
    marker_shapes = {"Pfam_Primary": {}, "Pfam_Extended": {}, "Clan": {}}
    clan_colors = {}
    for i, base_pfam in enumerate(base_pfams):
        current_shape = shapes[i % len(shapes)]
        base_color_idx = (i * 4) % len(all_colors)
        base_color = all_colors[base_color_idx]
        feature_colors["Pfam_Primary"][base_pfam] = base_color
        marker_shapes["Pfam_Primary"][base_pfam] = current_shape
        feature_colors["Pfam_Extended"][base_pfam] = base_color
        marker_shapes["Pfam_Extended"][base_pfam] = current_shape
        variants = [
            f"{base_pfam} + 1 other",
            f"{base_pfam} + 2 others",
            f"{base_pfam} + >2 others",
        ]
        unique_extended = df["Pfam_Extended"].unique()
        for j, variant in enumerate(variants):
            if variant in unique_extended:
                variant_color_idx = (base_color_idx + j + 1) % len(all_colors)
                feature_colors["Pfam_Extended"][variant] = all_colors[variant_color_idx]
                marker_shapes["Pfam_Extended"][variant] = current_shape
        if "Clan" in df.columns:
            clan_series = df.loc[df["Pfam_Primary"] == base_pfam, "Clan"]
            if not clan_series.empty:
                clan = clan_series.iloc[0]
                if pd.notna(clan) and clan != "N/A":
                    if clan not in clan_colors:
                        clan_colors[clan] = base_color
                        feature_colors["Clan"][clan] = base_color
                        marker_shapes["Clan"][clan] = current_shape
                    else:
                        feature_colors["Clan"][clan] = clan_colors[clan]
                        marker_shapes["Clan"][clan] = current_shape
    return {"feature_colors": feature_colors, "marker_shapes": marker_shapes}


def _clan_style_logic(df, all_colors, shapes):
    """Generates Clan-specific color/shape mapping with corrected color logic."""
    # This function uses the corrected logic from the previous iteration
    if "Clan_Primary" not in df.columns or "Pfam_Label" not in df.columns:
        print(
            "❌ Clan styling error: Required columns ('Clan_Primary', 'Pfam_Label') not found in CSV."
        )
        return None
    unique_clans = sorted(df["Clan_Primary"].dropna().unique().tolist())
    num_clans = len(unique_clans)
    if num_clans == 0:
        return None
    feature_colors = {"Clan_Primary": {}, "Pfam_Label": {}}
    marker_shapes = {"Clan_Primary": {}, "Pfam_Label": {}}
    max_clans_to_style = 10
    for i, clan in enumerate(unique_clans):
        if i >= max_clans_to_style:
            print(
                f"⚠️ Warning: More than {max_clans_to_style} clans selected. Styling might reuse colors/shapes for clans beyond the {max_clans_to_style}th."
            )
        current_shape = shapes[i % len(shapes)]
        base_color_idx = (i * 4) % len(all_colors)
        clan_color = all_colors[base_color_idx]
        feature_colors["Clan_Primary"][clan] = clan_color
        marker_shapes["Clan_Primary"][clan] = current_shape
        clan_pfam_labels = sorted(
            df[df["Clan_Primary"] == clan]["Pfam_Label"].dropna().unique().tolist()
        )
        other_label = f"{clan}: Other"
        color_offset = 1
        # Assign colors to specific Pfam labels first
        for pfam_label in clan_pfam_labels:
            if pfam_label == other_label:
                continue
            if color_offset <= 2:  # Assign colors 1, 2 relative to base
                label_color_idx = (base_color_idx + color_offset) % len(all_colors)
                feature_colors["Pfam_Label"][pfam_label] = all_colors[label_color_idx]
                marker_shapes["Pfam_Label"][pfam_label] = current_shape
                color_offset += 1
            else:  # Assign 'other' color (base+3) to any further specific Pfams
                label_color_idx = (base_color_idx + 3) % len(all_colors)
                feature_colors["Pfam_Label"][pfam_label] = all_colors[label_color_idx]
                marker_shapes["Pfam_Label"][pfam_label] = current_shape
        # Assign 'other' label color (base+3)
        if other_label in clan_pfam_labels:
            other_color_idx = (base_color_idx + 3) % len(all_colors)
            feature_colors["Pfam_Label"][other_label] = all_colors[other_color_idx]
            marker_shapes["Pfam_Label"][other_label] = current_shape
    return {"feature_colors": feature_colors, "marker_shapes": marker_shapes}


# --- Download Function ---
def on_download_button_click(b):
    """Handles the download button click event."""
    if not generated_json_path or not generated_json_path.exists():
        with gen_output:
            print("❌ No valid JSON file available for download.")
        return

    try:
        with open(generated_json_path, "r", encoding="utf-8") as f:
            json_content = f.read()

        # Create downloadable link
        filename = generated_json_path.name
        b64 = base64.b64encode(json_content.encode()).decode()
        payload = f"<a download='{filename}' href='data:application/json;base64,{b64}' target='_blank'>Click to download {filename}</a>"

        with gen_output:
            clear_output()
            print(f"✅ JSON file ready for download: {generated_json_path}")
            print("📥 Click the link below to download:")
            print(
                "You can upload this file to https://protspace.rostlab.org/ for visualization."
            )
            display(HTML(payload))
    except Exception as e:
        with gen_output:
            print(f"❌ Error preparing download: {e}")


# --- Main Generation Function ---
def run_protspace_generation(force=False):
    """Handles the logic for generating the ProtSpace JSON."""
    global generated_json_path  # Allow updating the global variable
    download_button.disabled = (
        True  # Disable download button until generation completes
    )

    with gen_output:
        clear_output()
        print("--- Starting ProtSpace JSON Generation ---")

        # 1. Get Selected Methods
        selected_methods = [
            method_code
            for checkbox, (_, method_code) in zip(
                method_checkboxes, dim_reduction_methods
            )
            if checkbox.value
        ]
        if not selected_methods:
            print("⚠️ Please select at least one dimensionality reduction method.")
            return
        print(f"Selected methods: {', '.join(selected_methods)}")

        # 2. Check if Extraction was Run and Get Paths
        if (
            not last_extraction_dir
            or not last_embedding_path
            or not last_metadata_path
            or not last_selection_mode
            or not last_selected_items
        ):
            print(
                "❌ Error: Embedding extraction (Step 4) has not been run successfully for the current session, or its results are missing."
            )
            print("   Please run Step 4 for your desired selection first.")
            generated_json_path = None  # Ensure path is invalid
            return

        # --- Verify current selection matches last extraction ---
        # current_mode is assumed global (from Cell 3)
        current_selection_list = (
            selected_pfams if current_mode == "Pfam Families" else selected_clans
        )
        if last_selection_mode != current_mode or set(last_selected_items) != set(
            current_selection_list
        ):
            print(
                "⚠️ Warning: Your current selections in Step 3 do not match the data extracted in the last run of Step 4."
            )
            print(
                f"   Last extraction was for {last_selection_mode}: {', '.join(sorted(last_selected_items))}"
            )
            print(
                f"   Current selection is for {current_mode}: {', '.join(sorted(current_selection_list))}"
            )
            print(
                "   Generation will proceed using the *previously extracted data*. If this is not intended, please re-run Step 4 with your current selections."
            )
        # --- End Check ---

        embedding_path = last_embedding_path
        metadata_path = last_metadata_path
        output_dir = last_extraction_dir

        # 3. Define Output JSON Paths using Pathlib
        output_json_dir = output_dir / "protspace"
        output_json_file = output_json_dir / "selected_features_projections.json"

        # 4. Check for Existing Files (Unless Force=True)
        final_existing_path = None
        if output_json_file.exists():
            final_existing_path = output_json_file

        if final_existing_path and not force:
            print("\nℹ️ ProtSpace JSON already exists for this selection:")
            print(f"   {final_existing_path}")
            print("   Using existing file. To regenerate, click 'Force Regenerate'.")
            generated_json_path = final_existing_path  # Update global path
            download_button.disabled = False  # Enable download button
            print("\n🎉 Ready for Step 6: Launch Visualization.")
            print(
                "   You can also download the JSON file to upload to https://protspace.rostlab.org/"
            )
            return

        # 5. Build protspace-local Command (convert Paths to strings for subprocess)
        cmd = [
            "protspace-local",
            "-i",
            str(embedding_path),
            "-f",
            str(metadata_path),
            "-o",
            str(output_json_dir),
            "--methods",
            *selected_methods,
            "--non-binary",
        ]
        if any(m.startswith("umap") for m in selected_methods):
            cmd.extend(
                [
                    "--n_neighbors",
                    str(umap_n_neighbors.value),
                    "--min_dist",
                    str(umap_min_dist.value),
                ]
            )
        if any(m.startswith("pacmap") for m in selected_methods):
            cmd.extend(
                [
                    "--mn_ratio",
                    str(pacmap_mn_ratio.value),
                    "--fp_ratio",
                    str(pacmap_fp_ratio.value),
                ]
            )

        # 6. Execute Command
        print("\n🔄 Running protspace-local command...")
        try:
            subprocess.run(
                cmd, check=True, capture_output=True, text=True, encoding="utf-8"
            )
            print(f"✅ Base ProtSpace JSON generated successfully: {output_json_file}")

            # 7. Apply Styling
            print("\n🔄 Applying custom styling...")
            # Use the correct styling logic based on the *data that was processed*
            style_logic = (
                _pfam_style_logic
                if last_selection_mode == "Pfam Families"
                else _clan_style_logic
            )
            _apply_styling_base(output_json_file, metadata_path, style_logic)
            generated_json_path = output_json_file

            if output_json_file and output_json_file.exists():
                print(f"   Final JSON ready at: {generated_json_path}")
                download_button.disabled = False  # Enable download button
                print("\nReady for Step 6: Launch Visualization.")
                print(
                    "   You can also download the JSON file to upload to https://protspace.rostlab.org/"
                )
            else:
                print(
                    "\n❌ Generation failed or styling prevented file finalization. Check errors above."
                )
                generated_json_path = (
                    None  # Ensure invalid path if styling failed badly
                )

        except FileNotFoundError:
            print("❌ Error: 'protspace-local' command not found.")
            print(
                "   Please ensure ProtSpace is correctly installed in the environment."
            )
            generated_json_path = None
        except subprocess.CalledProcessError as e:
            print(
                f"❌ Error running protspace-local command (return code {e.returncode}):"
            )
            print("--- Command stdout ---")
            print(e.stdout)
            print("--- Command stderr ---")
            print(e.stderr)
            generated_path = None
            if output_json_file.exists():
                try:
                    output_json_file.unlink(missing_ok=True)
                except OSError:
                    pass
        except Exception as e:
            print(f"❌ An unexpected error occurred during generation: {e}")
            generated_json_path = None
            if output_json_file.exists():
                try:
                    output_json_file.unlink(missing_ok=True)
                except OSError:
                    pass


# --- Button Handlers ---
def on_generate_click(b):
    run_protspace_generation(force=False)


def on_force_generate_click(b):
    run_protspace_generation(force=True)


# --- Connect Buttons ---
generate_button.on_click(on_generate_click)
force_generate_button.on_click(on_force_generate_click)
download_button.on_click(on_download_button_click)

# --- Display Widgets ---
display(
    VBox(
        [
            widgets.HTML("<b>Select dimensionality reduction methods:</b>"),
            widgets.VBox(method_checkboxes),
            params_accordion,
            HBox([generate_button, force_generate_button, download_button]),
            gen_output,
        ]
    )
)

In [None]:
# @title 6. 🚀 Launch ProtSpace Visualization { display-mode: "form" }
# @markdown Configure and launch ProtSpace visualization.
# @markdown
# @markdown **Instructions:**
# @markdown 1. Enter JSON file path or leave empty to use last generated file
# @markdown 2. Adjust height
# @markdown 3. Select display mode: `inline` or `external` (external mode works only with Chrome)
# @markdown 4. Click **"Launch ProtSpace"**
# @markdown
# @markdown ---

# --- Widget Definitions ---
json_file_input = Text(
    value="",  # Start empty - will use last generated file if left empty
    placeholder="Leave empty to use last generated file, or enter path to JSON file",
    description="JSON File:",
    style={"description_width": "initial"},
    layout=Layout(width="90%"),
)
height_slider = IntSlider(
    value=800,
    min=400,
    max=1200,
    step=50,
    description="Height (px):",
    style={"description_width": "initial"},
    layout=Layout(width="50%"),
)
# Removed 'jupyterlab' from options
mode_dropdown = Dropdown(
    options=["inline", "external"],
    value="inline",
    description="Display Mode:",
    style={"description_width": "initial"},
    layout=Layout(width="50%"),
)
launch_button = Button(
    description="Launch ProtSpace", button_style="primary", icon="rocket"
)
launch_output = Output()


# --- Launch Handler ---
def on_launch_button_click(b):
    """Handles the launch button click event."""
    with launch_output:
        clear_output()

        # 1. Determine JSON file path (use user input or fall back to generated_json_path)
        json_file_str = json_file_input.value.strip()
        if not json_file_str:
            # Use the global generated_json_path if available
            global generated_json_path
            if generated_json_path is not None:
                json_file_str = str(generated_json_path)
                print(f"ℹ️ Using generated JSON file: {json_file_str}")
            else:
                print(
                    "❌ Error: Please enter the path to the ProtSpace JSON file generated in Step 5."
                )
                return

        json_file = Path(json_file_str)
        if not json_file.exists():
            print(f"❌ Error: Specified JSON file not found at '{json_file}'.")
            print("   Please ensure the path is correct and the file exists.")
            return
        # Verify it's likely a JSON file
        if not (json_file.suffix == ".json"):
            print(
                f"⚠️ Warning: The specified file '{json_file.name}' does not have a .json extension. Proceeding anyway."
            )

        # 2. Get Parameters
        height = height_slider.value
        mode = mode_dropdown.value
        port = 8051  # Fixed port

        # 3. Display Launch Info
        launch_info_html = f"""
        <div style="border: 1px solid #ccc; background-color: #dddddd; padding: 10px; margin-bottom: 15px; border-radius: 4px; color: black;">
            <h4 style="margin-top: 0; color: black;">🚀 Launching ProtSpace...</h4>
            <ul style="color: black;">
                <li><b>File:</b> {json_file}</li>
                <li><b>Mode:</b> {mode}</li>
                {"<li><b>Height:</b> " + str(height) + "px</li>" if mode == "inline" else ""}
                {'<li><b>URL:</b> <a href="http://127.0.0.1:' + str(port) + '" target="_blank" style="color: blue;">http://127.0.0.1:' + str(port) + "</a></li>" if mode in ["tab", "external"] else ""}
            </ul>
        </div>
        """
        display(HTML(launch_info_html))

        # 4. Launch ProtSpace App
        try:
            protspace_instance = ProtSpace(default_json_file=str(json_file))
            app = protspace_instance.create_app()

            if mode == "inline":
                print("🔄 Loading visualization in 'inline' mode...")
                app.run(
                    port=port,
                    jupyter_mode=mode,
                    jupyter_height=height,
                    jupyter_width="100%",
                    dev_tools_silence_routes_logging=True,
                    dev_tools_prune_errors=True,
                )
            elif mode == "external":
                print("🔄 Starting server for 'external' mode...")
                print("⚠️ Note: External mode works only with Google Chrome.")
                original_stdout, original_stderr = sys.stdout, sys.stderr
                sys.stdout = sys.stderr = open(os.devnull, "w")
                app.run(
                    port=port,
                    jupyter_mode=mode,
                    debug=False,
                    dev_tools_silence_routes_logging=True,
                    dev_tools_prune_errors=True,
                    serve_kernel_port_as_iframe=True,
                )
                sys.stdout, sys.stderr = original_stdout, original_stderr
                print(f"✅ Server running at http://127.0.0.1:{port}")
            else:
                print(f"❌ Error: Unknown display mode '{mode}'.")

        except ImportError as e:
            if "werkzeug" in str(e).lower():
                print("❌ Error: A dependency issue occurred (likely with 'werkzeug').")
                print(
                    "   Consider installing a specific version via pip if problems persist: pip install werkzeug==2.3.7"
                )
            else:
                print(f"❌ An import error occurred: {e}")
                print(
                    "   Please ensure all dependencies from Step 1 are installed correctly."
                )
        except Exception as e:
            print(f"❌ An unexpected error occurred while launching ProtSpace: {e}")
            print("--- Traceback ---")
            print(traceback.format_exc())


# --- Connect Button ---
launch_button.on_click(on_launch_button_click)

# --- Display Widgets ---
display(
    VBox(
        [
            widgets.HTML("<h3>🚀 Visualization Settings</h3>"),
            json_file_input,
            HBox([height_slider, mode_dropdown]),
            launch_button,
            launch_output,
        ]
    )
)

---

## 🧠 Tips for Exploring the Visualization

Once the ProtSpace visualization loads:

- **Navigation:** Switch between projection methods (PCA, UMAP, etc.) using the "Method" dropdown
- **Interaction:**
  - Hover over points to see details
  - Use lasso/box tools to select groups of points
  - Zoom/pan with mouse wheel or toolbar
  - Search for specific proteins in the search box

### Interpreting Results

- **Clustering:** Proteins from the same Pfam family or clan often cluster together, indicating similarity in their embedding space (which often relates to sequence/functional similarity).
- **Multi-Domain Proteins:** Proteins belonging to multiple families (e.g., those labeled "+ N others" in Pfam mode) might appear between clusters or in unexpected locations.
- **Clan Structure:** In Clan mode, coloring by "Pfam_Label" can reveal how different Pfam families within the same clan are distributed relative to each other. The styling attempts to use similar color tones for Pfams belonging to the same clan. Using the same shape for all points within a clan (by selecting "Clan_Primary" for Shape) helps visually group them.
- **Projection Methods:** PCA shows global patterns; UMAP/PaCMAP reveal local structures

### Iterating Your Analysis

- **Step 3:** Return to Step 3 to select different Pfam families/clans
- **Step 4:** Re-run extraction
- **Step 5:** Regenerate visualization JSON
- **Step 6:** Launch new visualization with updated JSON path

Happy exploring! 🌌
