In [1]:
from pathlib import Path
import os

from IPython import get_ipython
from IPython.core.magic import register_cell_magic

ipython = get_ipython()


@register_cell_magic
def pybash(line, cell):
    ipython.run_cell_magic("bash", "", cell.format(**globals()))

In [2]:
import gemmi
from pathlib import Path


def extract_domain(
    struct_file: Path, chain: str, start_res: int, end_res: int, output_file: Path
) -> None:
    """
    Extract a domain from a structure file (PDB or mmCIF) using gemmi.

    Args:
        struct_file: Path to the input structure file (PDB or mmCIF)
        chain: Chain identifier
        start_res: Start residue number
        end_res: End residue number
        output_file: Path to save the extracted domain (always as mmCIF)

    Raises:
        ValueError: If the chain is not found or no residues are extracted.
        Exception: For other file reading or processing errors.
    """
    try:
        structure = gemmi.read_structure(
            str(struct_file), merge_chain_parts=True, format=gemmi.CoorFormat.Detect
        )

        # Create a new structure for the domain
        domain = gemmi.Structure()
        domain.name = f"{struct_file.stem}_{chain}_{start_res}_{end_res}"

        # Create a new model
        model = gemmi.Model("1")

        # Find and copy the specified chain
        found_chain = False
        extracted_residues = 0
        original_chain_instance = None

        for ch in structure[0]:  # Assuming model 0
            if ch.name == chain:
                original_chain_instance = ch
                found_chain = True
                break

        if not found_chain:
            raise ValueError(f"Chain {chain} not found in {struct_file}")

        new_chain = gemmi.Chain(chain)

        # Copy residues in the specified range
        for residue in original_chain_instance:
            seq_id = residue.seqid.num
            if start_res <= seq_id <= end_res:
                # Ensure residue is cloned to avoid modifying original structure if needed elsewhere
                new_chain.add_residue(residue.clone())
                extracted_residues += 1

        if extracted_residues == 0:
            raise ValueError(
                f"No residues in range {start_res}-{end_res} found in chain {chain} of {struct_file}"
            )

        model.add_chain(new_chain)
        domain.add_model(model)
        domain.make_mmcif_document().write_file(str(output_file))

        print(
            f"Successfully extracted domain from {struct_file}: chain {chain}, "
            f"residues {start_res}-{end_res} ({extracted_residues} residues) to {output_file}"
        )

    except ValueError as ve:  # Re-raise specific errors
        raise ve
    except Exception as e:
        # Raise a more informative general exception
        raise Exception(f"Failed to extract domain from {struct_file}: {e}") from e

In [3]:
structure_id = "1kt0"
tmp_dir = Path("../tmp/domain_separation")
data_dir = Path(f"../data/{structure_id}")

structure_file = data_dir / f"{structure_id}.cif"
structure_svg = tmp_dir / f"{structure_id}.svg"
domains_svg = tmp_dir / f"{structure_id}-domains.svg"

os.makedirs(tmp_dir, exist_ok=True)

## Generate Structure SVG without Domain Separation


In [4]:
%%pybash

uv run flatprot project {structure_file} -o {structure_svg}

[2;36m2025-04-01 11:23:24[0m[2;36m [0m[34mINFO    [0m Using default styles                               
[2;36m2025-04-01 11:23:24[0m[2;36m [0m[34mINFO    [0m [1mSVG saved to ..[0m[1;35m/tmp/domain_separation/[0m[1;95m1kt0.svg[0m     
[2;36m2025-04-01 11:23:24[0m[2;36m [0m[34mINFO    [0m [1mSuccessfully processed structure:[0m                  
[2;36m2025-04-01 11:23:24[0m[2;36m [0m[34mINFO    [0m   Structure file: ..[35m/data/1kt0/[0m[95m1kt0.cif[0m            
[2;36m2025-04-01 11:23:24[0m[2;36m [0m[34mINFO    [0m   Output file: ..[35m/tmp/domain_separation/[0m[95m1kt0.svg[0m   
[2;36m2025-04-01 11:23:24[0m[2;36m [0m[34mINFO    [0m   Transformation: Inertia-based                    


## Load Domains generated by Chainsaw [(Wells et al. 2024)](https://doi.org/10.1093/bioinformatics/btae296) | [GitHub](https://github.com/JudeWells/chainsaw)


In [5]:
import polars as pl

# Define the path to the chainsaw domains file
chainsaw_file = data_dir / f"{structure_id.lower()}-chainsaw-domains.tsv"

# Read the domains file
domains_df = pl.read_csv(chainsaw_file, separator="\t")

# Parse the chopping column which contains residue ranges
chopping = domains_df["chain_id" == structure_id]["chopping"][0]

# Get all residues from the ranges
domains = []
for range_str in chopping.split(","):
    start, end = map(int, range_str.split("-"))
    domains.append((start, end))

## Split the structure into domains


In [11]:
# Cell 9: Refactor combine_domain_transformations and its call

import numpy as np
from pathlib import Path
from typing import List, Tuple, Dict, Any
from flatprot.io import GemmiStructureParser
from flatprot.utils.coordinate_manger import (
    create_coordinate_manager,
    apply_projection,
    CoordinateType,
)
from flatprot.utils.style import create_style_manager
from flatprot.utils.svg import generate_svg, save_svg
from flatprot.core.components import Structure, Chain
from flatprot.core import CoordinateManager


def combine_domain_transformations(
    structure_id: str,
    domains: List[Tuple[int, int]],
    tmp_dir: Path,
    output_svg: Path,
    chain_id: str,  # Added: Specify which chain the domains belong to
    structure_file: Path,  # Added: Original structure file path needed for extraction
    arrangement_direction: str = "horizontal",  # 'horizontal' or 'vertical'
    spacing_pixels: float = 50.0,  # Spacing in output SVG pixels
) -> None:
    """
    Extracts domains, aligns them independently, and combines them into a single SVG.

    Args:
        structure_id: ID of the structure (used for naming temp files).
        domains: List of (start, end) tuples defining domain boundaries.
        tmp_dir: Directory to store temporary domain structure files.
        output_svg: Path to save the combined SVG.
        chain_id: The chain identifier from which to extract domains.
        structure_file: Path to the original full structure file.
        arrangement_direction: How to arrange domains ('horizontal' or 'vertical').
        spacing_pixels: Spacing between domains in pixels.

    Raises:
        ValueError: If arrangement_direction is invalid or domain extraction fails.
        FileNotFoundError: If the structure_file does not exist.
    """
    if arrangement_direction not in ["horizontal", "vertical"]:
        raise ValueError("arrangement_direction must be 'horizontal' or 'vertical'")
    if not structure_file.exists():
        raise FileNotFoundError(f"Structure file not found: {structure_file}")

    parser = GemmiStructureParser()
    style_manager = create_style_manager(None)  # Using default styles

    domain_data: List[Dict[str, Any]] = []
    temp_files: List[Path] = []

    try:
        # 1. Extract and process each domain independently
        for i, (start, end) in enumerate(domains):
            domain_file = tmp_dir / f"{structure_id}_{chain_id}_{start}_{end}.cif"
            temp_files.append(domain_file)  # Keep track for cleanup

            # Extract domain (will raise error if fails)
            extract_domain(structure_file, chain_id, start, end, domain_file)

            # Run DSSP on the extracted domain file
            # Note: Ensure mkdssp is in PATH or provide full path
            # Error handling for mkdssp might be needed
            os.system(f"mkdssp {domain_file} {domain_file}")

            # Process domain: parse, align (inertia), project
            domain_structure = parser.parse_structure(domain_file)
            # Assign a unique ID to avoid conflicts if chain names are the same
            # Though extract_domain currently uses the original chain_id,
            # the Structure object might benefit from unique chain IDs internally
            # For now, we assume the parser handles structure/chain identity correctly.
            # We'll combine chains later. Give structure a unique name.
            domain_structure.id = f"domain_{i}_{start}_{end}"

            coord_manager = create_coordinate_manager(domain_structure)
            # apply_projection performs inertia alignment and projection
            coord_manager = apply_projection(coord_manager, style_manager)

            # Calculate bounding box from CANVAS coordinates for arrangement
            if CoordinateType.CANVAS in coord_manager.coordinates:
                all_canvas_coords = np.vstack(
                    list(coord_manager.coordinates[CoordinateType.CANVAS].values())
                )
                min_xy = np.min(all_canvas_coords, axis=0)
                max_xy = np.max(all_canvas_coords, axis=0)
                bbox = {
                    "min_x": min_xy[0],
                    "min_y": min_xy[1],
                    "max_x": max_xy[0],
                    "max_y": max_xy[1],
                }
            else:
                # Should not happen if apply_projection worked, but handle defensively
                print(
                    f"Warning: CANVAS coordinates not found for domain {start}-{end}. Cannot determine layout."
                )
                bbox = {"min_x": 0, "min_y": 0, "max_x": 0, "max_y": 0}

            domain_data.append(
                {"structure": domain_structure, "coords": coord_manager, "bbox": bbox}
            )

        # 2. Combine structures and calculate translations for arrangement
        combined_structure_chains: List[Chain] = []
        combined_coords = CoordinateManager()
        current_offset_x = 0.0
        current_offset_y = 0.0

        for i, data in enumerate(domain_data):
            domain_structure: Structure = data["structure"]
            domain_coords: CoordinateManager = data["coords"]
            bbox = data["bbox"]

            # Calculate translation needed for this domain
            # For the first domain (i=0), translation is zero
            translation_x = 0.0
            translation_y = 0.0
            if i > 0:
                prev_bbox = domain_data[i - 1]["bbox"]
                if arrangement_direction == "horizontal":
                    # Place to the right of the previous domain's max_x
                    current_offset_x += (
                        prev_bbox["max_x"] - prev_bbox["min_x"] + spacing_pixels
                    )
                    # We want the *start* (min_x) of the current bbox to be at current_offset_x
                    translation_x = current_offset_x - bbox["min_x"]

                elif arrangement_direction == "vertical":
                    # Place below the previous domain's max_y
                    current_offset_y += (
                        prev_bbox["max_y"] - prev_bbox["min_y"] + spacing_pixels
                    )
                    # We want the *start* (min_y) of the current bbox to be at current_offset_y
                    translation_y = current_offset_y - bbox["min_y"]

            # Add chains to the combined list (ensure unique chain instances if needed)
            for chain in domain_structure:
                combined_structure_chains.append(chain)

            # Add coordinates, applying translation ONLY to CANVAS coords
            for coord_type, ranges_coords in domain_coords.coordinates.items():
                for (start, end), coords_array in ranges_coords.items():
                    translated_coords = coords_array.copy()
                    if coord_type == CoordinateType.CANVAS:
                        translated_coords[:, 0] += translation_x
                        translated_coords[:, 1] += translation_y
                    # Add potentially translated coords, including the domain's structure ID
                    combined_coords.add(
                        start,
                        end,
                        translated_coords,
                        coord_type,
                    )

        # Create the final combined structure object (using the version that caused the TypeError)
        final_combined_structure = Structure(chains=combined_structure_chains)

        # --- DEBUG ---
        print("--- Debug: Combined CANVAS Coordinates ---")
        if CoordinateType.CANVAS in combined_coords.coordinates:
            for (start, end), coords in combined_coords.coordinates[
                CoordinateType.CANVAS
            ].items():
                # Fetch associated structure ID if stored (depends on CoordinateManager implementation)
                struct_id_info = (
                    f" (Struct ID: {combined_coords.get_structure_id(CoordinateType.CANVAS, start, end)})"
                    if hasattr(combined_coords, "get_structure_id")
                    else ""
                )

                print(f"  Range ({start}-{end}){struct_id_info}:")
                print(f"    Shape: {coords.shape}")
                if coords.size > 0:
                    min_xy = np.min(coords, axis=0)
                    max_xy = np.max(coords, axis=0)
                    print(f"    Min XY: {min_xy}")
                    print(f"    Max XY: {max_xy}")
                else:
                    print("    No coordinates")
        else:
            print("  No CANVAS coordinates found in combined_coords.")
        print("--- End Debug ---")

        # 3. Generate final SVG
        svg_content = generate_svg(
            final_combined_structure, combined_coords, style_manager
        )
        save_svg(svg_content, output_svg)
        print(f"Combined SVG saved to {output_svg}")

    finally:
        # 4. Clean up temporary files
        # Commented out for debugging, uncomment for production
        # for f in temp_files:
        #    if f.exists():
        #        os.remove(f)
        #        print(f"Removed temporary file: {f}")
        pass  # Keep files for inspection for now


# Example Usage Update:
# Assuming 'structure_id', 'domains', 'tmp_dir', 'structure_file' are defined as before
domains_svg = tmp_dir / f"{structure_id}-domains-combined.svg"

# Make sure to define the chain_id relevant for the chainsaw file
# For 1kt0, the chainsaw file seems to use chain 'A' based on the notebook context
chain_id_for_domains = "A"

combine_domain_transformations(
    structure_id=structure_id,
    domains=domains,  # List of (start, end) tuples from cell 7
    tmp_dir=tmp_dir,
    output_svg=domains_svg,
    chain_id=chain_id_for_domains,  # Specify the chain
    structure_file=structure_file,  # Provide original structure path
    arrangement_direction="horizontal",  # Or "vertical"
    spacing_pixels=50.0,  # Adjust as needed
)


Successfully extracted domain from ../data/1kt0/1kt0.cif: chain A, residues 34-141 (89 residues) to ../tmp/domain_separation/1kt0_A_34_141.cif
Successfully extracted domain from ../data/1kt0/1kt0.cif: chain A, residues 148-251 (104 residues) to ../tmp/domain_separation/1kt0_A_148_251.cif
Successfully extracted domain from ../data/1kt0/1kt0.cif: chain A, residues 256-411 (152 residues) to ../tmp/domain_separation/1kt0_A_256_411.cif
--- Debug: Combined CANVAS Coordinates ---
  Range (0-89):
    Shape: (89, 2)
    Min XY: [-256.61654924 -273.9778134 ]
    Max XY: [272.18737812 266.0221866 ]
  Range (0-104):
    Shape: (104, 2)
    Min XY: [ 578.80392736 -238.10459304]
    Max XY: [1298.80392736  233.68946882]
  Range (0-152):
    Shape: (152, 2)
    Min XY: [1348.80392736 -243.5166494 ]
    Max XY: [1946.0835796  296.4833506]
--- End Debug ---
Combined SVG saved to ../tmp/domain_separation/1kt0-domains-combined.svg
