**Code that reads in models in rmf3 format and reads in crosslinking data, and then maps the crosslinks onto the structure in the rmf3 file, or creates a script that will be read in by chimerax to map the crosslinks and visualize it there.**

In [46]:
import os
import IMP
import RMF
import IMP.pmi.output
from Bio.PDB import PDBParser, Superimposer, PDBExceptions
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import logging
from multiprocessing import Pool, cpu_count
from functools import partial
import shutil
import IMP.pmi.analysis

In [60]:
# read in rmf3 file and extract hierarchy
# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

def write_pdb(hier, pdb_head):
    """
    Write PDB files for each chain in the hierarchy and combine them into a single PDB file.

    Args:
        hier (IMP.atom.Hierarchy): The hierarchy containing the chains.
        pdb_head (str): The prefix for the output PDB file names.

    Returns:
        str: Path to the combined PDB file.
    """
    output_dir = 'output_pdbs'
    os.makedirs(output_dir, exist_ok=True)
    print("level1", hier.get_children())
    print("level2", hier.get_children()[0].get_children())
    print("level3", hier.get_children()[0].get_child(0).get_children())
    print("level4", hier.get_children()[0].get_child(0).get_child(1).get_children())
    print("level5", hier.get_children()[0].get_child(0).get_child(1).get_child(0).get_children())
    #print("children : ", hier.get_children()[0].get_child(0).get_child(0).get_child(0).get_child(0))
    
    
    chains = hier.get_children()[0].get_children()
    print(f"Total chains in hierarchy: {len(chains)} and chains are {chains}")
    pdb_files = []
    
    # Write individual PDB files for each chain
    for i, ch in enumerate(chains):
        if i < 60:
            continue        
        output_pdb_path = os.path.join(output_dir, f'{pdb_head}_{i}.pdb')
        o = IMP.pmi.output.Output()
        o.init_pdb(output_pdb_path, ch)
        o.write_pdb(output_pdb_path)
        pdb_files.append(output_pdb_path)

    # Combine individual PDB files into a single PDB file
    final_pdb_path = os.path.join(output_dir, f'{pdb_head}.pdb')
    with open(final_pdb_path, 'w') as outfile:
        for fname in pdb_files:
            with open(fname) as infile:
                for line in infile:
                    if line.startswith('ATOM'):
                        outfile.write(line)

    # Delete individual chain PDB files
    for fname in pdb_files:
        os.remove(fname)
    
    return final_pdb_path

def get_moldict_coorddict(hier, molnames):
    """return data structure for the RMSD calculation"""
    moldict = {}
    mol_coords = {}
    mol_XYZs = {}
    for mol in IMP.pmi.tools.get_molecules(hier):
        name = mol.get_name()
        if name not in molnames:
            continue
        parts = True
        mol_coords[mol] = []
        mol_XYZs[mol] = []
        i = 1
        while parts:
            sel = IMP.atom.Selection(
                mol, residue_index=i,
                representation_type=IMP.atom.BALLS,
                resolution=1)
            parts = sel.get_selected_particles()
            if parts:
                mol_coords[mol].append(
                    IMP.core.XYZ(parts[0]).get_coordinates())
                mol_XYZs[mol].append(IMP.core.XYZ(parts[0]))
                i = i+1
        if name in moldict:
            moldict[name].append(mol)
        else:
            moldict[name] = [mol]
    return moldict, mol_coords, mol_XYZs
    
def process_rmf3_frame(rmf_filename, frame_number):
    imp_model = IMP.Model()
    with RMF.open_rmf_file_read_only(rmf_filename) as rmf_file:
        hier = IMP.rmf.create_hierarchies(rmf_file, imp_model)
        IMP.rmf.load_frame(rmf_file, RMF.FrameID(frame_number))
        moldict, mol_coords, mol_XYZs = get_moldict_coorddict(hier, ['DDI1', 'DDI2'])
        for key, value in mol_coords.items():
            print(f"Mol: {key.get_name()}, Num coords: {len(value)}")
        print("Moldict: ", len(moldict['DDI1']))
        print("Mol_coords: ", mol_coords)
        print("Mol_XYZs: ", mol_XYZs.keys())
        print(f"Processing frame {frame_number} from RMF file {rmf_filename}")
        pdb_file = write_pdb(hier[0], f'test_frame_{frame_number}')

# Example usage
rmf_file = './test/6.rmf3'
frame_number = 100
process_rmf3_frame(rmf_file, frame_number)


Mol: DDI1, Num coords: 234
Mol: DDI1, Num coords: 234
Mol: DDI2, Num coords: 230
Mol: DDI2, Num coords: 230
Moldict:  2
Mol_coords:  {name: DDI1: [(2.02622, 31.1022, -71.8765), (4.80582, 33.6756, -71.5302), (3.06615, 37.0103, -70.7137), (4.98488, 40.3283, -70.7645), (3.81192, 42.7664, -68.0354), (4.40828, 46.4977, -68.6152), (4.06075, 47.9897, -65.1113), (3.47131, 51.7464, -64.6394), (3.85987, 52.6869, -60.9415), (2.43177, 55.8475, -59.2263), (5.81416, 57.6232, -59.5867), (5.52194, 56.9796, -63.3855), (8.44683, 54.5292, -63.0394), (8.20062, 51.7268, -65.569), (9.20548, 48.0792, -65.0899), (8.79676, 45.3576, -67.7221), (8.96948, 41.671, -66.7283), (7.71024, 38.3368, -68.1644), (5.67145, 35.6498, -66.3411), (5.31893, 31.9884, -67.4101), (1.69323, 31.0148, -66.5336), (-0.955627, 28.5136, -67.7725), (-4.03942, 29.7067, -69.8028), (-6.29617, 28.3034, -66.9904), (-4.38482, 30.442, -64.3912), (-6.78171, 32.7265, -62.4135), (-6.42892, 36.5437, -62.5651), (-6.20532, 36.6468, -58.7027), (-3.1140

In [66]:
import IMP
import RMF
import IMP.atom
import IMP.rmf
import pandas as pd
import numpy as np
from collections import defaultdict

def load_rmf_frame(rmf_file, frame_index=0):
    """
    Load a specific frame from an RMF file and return the hierarchy.
    """
    mdl = IMP.Model()
    rh = RMF.open_rmf_file_read_only(rmf_file)
    hier = IMP.rmf.create_hierarchies(rh, mdl)[0]
    IMP.rmf.load_frame(rh, RMF.FrameID(frame_index))
    del rh
    return mdl, hier

def get_particle_coordinates(hier, protein_name, residue_index, copy_index, resolution=1):
    """
    Get the coordinates of a specific particle.
    Returns None if particle not found.
    """
    sel = IMP.atom.Selection(hier, 
                            molecule=protein_name,
                            residue_index=residue_index,
                            copy_index=copy_index,
                            resolution=resolution)
    particles = sel.get_selected_particles()
    
    if not particles:
        return None
    
    # Return coordinates as a tuple or list, not Vector3D object
    vec = IMP.core.XYZ(particles[0]).get_coordinates()
    return (vec[0], vec[1], vec[2])  # Return as tuple


def calculate_distance(coord1, coord2):
    """
    Calculate Euclidean distance between two coordinates.
    """
    if coord1 is None or coord2 is None:
        return None
    
    # Calculate distance using numpy or direct calculation
    return np.linalg.norm(np.array(coord1) - np.array(coord2))


def get_valid_copy_combinations(proteins, residues):
    """
    Determine all valid (copy1, copy2, copy3) combinations for trifunctional XL.
    Matches the logic from your modeling script.
    """
    def copy_options(prot1, res1, prot2, res2):
        """Get possible copy pairs for a protein-residue pair"""
        if prot1 != prot2:
            return [(0,0), (0,1), (1,0), (1,1)]  # Different proteins: all combos
        elif res1 == res2:
            return [(0,1), (1,0)]  # Same residue: MUST be different copies
        else:
            return [(0,0), (1,1), (0,1), (1,0)]  # Different residues: ambiguous
    
    # Get possible copy pairs for each edge of the triangle
    combos_12 = copy_options(proteins[0], residues[0], proteins[1], residues[1])
    combos_23 = copy_options(proteins[1], residues[1], proteins[2], residues[2])
    combos_13 = copy_options(proteins[0], residues[0], proteins[2], residues[2])
    
    # Find consistent combinations
    valid = set()
    for (c1_12, c2_12) in combos_12:
        for (c2_23, c3_23) in combos_23:
            for (c1_13, c3_13) in combos_13:
                if c1_12 == c1_13 and c2_12 == c2_23 and c3_23 == c3_13:
                    valid.add((c1_12, c2_12, c3_23))
    
    return sorted(valid)

def analyze_crosslink_distances(rmf_file, xl_csv_file, frame_index=0, resolution=1):
    """
    Main function to analyze distances for all crosslinks in the CSV file.
    
    Parameters:
    - rmf_file: Path to the RMF3 file
    - xl_csv_file: Path to the CSV file with crosslink data
    - frame_index: Which frame to analyze (default 0)
    - resolution: Resolution for particle selection (default 1)
    
    Returns:
    - DataFrame with all distance measurements
    """
    # Load RMF file
    print(f"Loading RMF file: {rmf_file}, frame {frame_index}")
    mdl, hier = load_rmf_frame(rmf_file, frame_index)
    
    # Load crosslink data
    print(f"Loading crosslink data from: {xl_csv_file}")
    xl_data = pd.read_csv(xl_csv_file)
    
    # Results storage
    results = []
    
    # Process each trifunctional crosslink
    for idx, row in xl_data.iterrows():
        # Extract protein names and residue numbers
        proteins = [row['Protein1'], row['Protein2'], row['Protein3']]
        residues = [int(''.join(filter(str.isdigit, str(row[f'Residue{i}'])))) 
                   for i in [1, 2, 3]]
        
        print(f"\nAnalyzing XL {idx}: {proteins[0]}:{residues[0]} - "
              f"{proteins[1]}:{residues[1]} - {proteins[2]}:{residues[2]}")
        
        # Get valid copy combinations
        valid_combos = get_valid_copy_combinations(proteins, residues)
        print(f"  Valid copy combinations: {len(valid_combos)}")
        
        # Calculate distances for each valid combination
        for combo in valid_combos:
            copy1, copy2, copy3 = combo
            
            # Get coordinates for each particle
            coords = []
            missing_particle = False
            
            for i in range(3):
                coord = get_particle_coordinates(hier, proteins[i], residues[i], 
                                                combo[i], resolution)
                if coord is None:
                    missing_particle = True
                    break
                coords.append(coord)
            
            # Skip if any particle not found
            if missing_particle:
                print(f"  Warning: Missing particle in combo {combo}")
                continue
            
            # Calculate all three pairwise distances (triangle)
            dist_12 = calculate_distance(coords[0], coords[1])
            dist_23 = calculate_distance(coords[1], coords[2])
            dist_13 = calculate_distance(coords[0], coords[2])
            
            # Store results
            result_entry = {
                'XL_Index': idx,
                'Protein1': proteins[0],
                'Residue1': residues[0],
                'Copy1': copy1,
                'Protein2': proteins[1],
                'Residue2': residues[1],
                'Copy2': copy2,
                'Protein3': proteins[2],
                'Residue3': residues[2],
                'Copy3': copy3,
                'Distance_1-2': dist_12,
                'Distance_2-3': dist_23,
                'Distance_1-3': dist_13,
                'Max_Distance': max(dist_12, dist_23, dist_13),
                'Min_Distance': min(dist_12, dist_23, dist_13),
                'Mean_Distance': np.mean([dist_12, dist_23, dist_13])
            }
            results.append(result_entry)
            
            print(f"  Combo {proteins[0]}.{copy1}:{residues[0]} - "
                  f"{proteins[1]}.{copy2}:{residues[1]} - "
                  f"{proteins[2]}.{copy3}:{residues[2]}")
            print(f"    Distances: 1-2={dist_12:.2f}Å, 2-3={dist_23:.2f}Å, 1-3={dist_13:.2f}Å")
    
    # Create DataFrame
    df_results = pd.DataFrame(results)
    
    return df_results

def analyze_multiple_frames(rmf_file, xl_csv_file, frame_indices=None, resolution=1):
    """
    Analyze distances across multiple frames.
    
    Parameters:
    - rmf_file: Path to the RMF3 file
    - xl_csv_file: Path to the CSV file with crosslink data
    - frame_indices: List of frame indices to analyze (None = all frames)
    - resolution: Resolution for particle selection
    
    Returns:
    - Dictionary with frame index as key and DataFrame as value
    """
    # Get total number of frames if not specified
    if frame_indices is None:
        rh = RMF.open_rmf_file_read_only(rmf_file)
        n_frames = rh.get_number_of_frames()
        frame_indices = range(n_frames)
        del rh
    
    all_results = {}
    
    for frame_idx in frame_indices:
        print(f"\n{'='*60}")
        print(f"Analyzing frame {frame_idx}")
        print(f"{'='*60}")
        
        df_frame = analyze_crosslink_distances(rmf_file, xl_csv_file, 
                                              frame_idx, resolution)
        all_results[frame_idx] = df_frame
    
    return all_results

def summarize_distances(df_results, threshold=30.0):
    """
    Summarize distance analysis results.
    
    Parameters:
    - df_results: DataFrame from analyze_crosslink_distances
    - threshold: Distance threshold for satisfaction (default 30Å)
    
    Returns:
    - Summary statistics
    """
    print("\n" + "="*60)
    print("DISTANCE ANALYSIS SUMMARY")
    print("="*60)
    
    # Overall statistics
    print(f"\nTotal measurements: {len(df_results)}")
    print(f"Unique crosslinks: {df_results['XL_Index'].nunique()}")
    
    # Distance statistics
    print("\nDistance Statistics (Angstroms):")
    print(f"  Mean of all distances: {df_results['Mean_Distance'].mean():.2f}")
    print(f"  Median of all distances: {df_results['Mean_Distance'].median():.2f}")
    print(f"  Min distance observed: {df_results['Min_Distance'].min():.2f}")
    print(f"  Max distance observed: {df_results['Max_Distance'].max():.2f}")
    
    # Satisfaction analysis
    satisfied = df_results[df_results['Max_Distance'] <= threshold]
    print(f"\nCrosslink Satisfaction (threshold = {threshold}Å):")
    print(f"  Satisfied measurements: {len(satisfied)} / {len(df_results)} "
          f"({100*len(satisfied)/len(df_results):.1f}%)")
    
    # Per-crosslink satisfaction
    print("\nPer-Crosslink Analysis:")
    for xl_idx in df_results['XL_Index'].unique():
        xl_data = df_results[df_results['XL_Index'] == xl_idx]
        xl_satisfied = xl_data[xl_data['Max_Distance'] <= threshold]
        
        print(f"  XL {xl_idx}: {len(xl_satisfied)}/{len(xl_data)} combinations satisfied")
        print(f"    Best (min max-distance): {xl_data['Max_Distance'].min():.2f}Å")
        print(f"    Worst (max max-distance): {xl_data['Max_Distance'].max():.2f}Å")
    
    return {
        'total_measurements': len(df_results),
        'satisfied_count': len(satisfied),
        'satisfaction_rate': len(satisfied) / len(df_results),
        'mean_distance': df_results['Mean_Distance'].mean(),
        'median_distance': df_results['Mean_Distance'].median()
    }

# Example usage
if __name__ == "__main__":
    # Set your file paths
    rmf_file = "test/6.rmf3"  # Your RMF file from modeling
    xl_csv_file = "input_data/reduced_ddi_trifunctional.csv"  # Your crosslink data
    
    # Analyze a single frame
    df_results = analyze_crosslink_distances(rmf_file, xl_csv_file, frame_index=0)
    
    # Save results to CSV
    df_results.to_csv("crosslink_distances_frame0.csv", index=False)
    print(f"\nResults saved to crosslink_distances_frame0.csv")
    
    # Print summary
    #summary = summarize_distances(df_results, threshold=30.0)
    
    # Optional: Analyze multiple frames
    all_frames = analyze_multiple_frames(rmf_file, xl_csv_file, frame_indices=[0, 10, 20, 100, 200, 500])
def create_interactive_dashboard(all_frames, threshold=30.0):
    """
    Create an interactive Plotly dashboard.
    """
    import plotly.graph_objects as go
    from plotly.subplots import make_subplots
    
    # Combine all data
    combined_df = pd.concat([df.assign(Frame=frame_idx) 
                            for frame_idx, df in all_frames.items()], 
                           ignore_index=True)
    
    # Create subplots
    fig = make_subplots(
        rows=2, cols=2,
        subplot_titles=('Max Distance vs Frame', 
                       'Satisfaction Rate per XL',
                       'Distance Distribution', 
                       'Satisfaction Heatmap'),
        specs=[[{'type': 'scatter'}, {'type': 'bar'}],
               [{'type': 'box'}, {'type': 'heatmap'}]]
    )
    
    # Plot 1: Max distance over frames for each XL
    for xl_idx in sorted(combined_df['XL_Index'].unique()):
        xl_data = combined_df[combined_df['XL_Index'] == xl_idx]
        fig.add_trace(
            go.Scatter(x=xl_data['Frame'], y=xl_data['Max_Distance'],
                      mode='lines+markers', name=f'XL {xl_idx}'),
            row=1, col=1
        )
    
    fig.add_hline(y=threshold, line_dash="dash", line_color="red", 
                 row=1, col=1, annotation_text=f"{threshold}Å threshold")
    
    # Plot 2: Satisfaction rate bar chart
    satisfaction_by_xl = combined_df.groupby('XL_Index').apply(
        lambda x: (x['Max_Distance'] <= threshold).sum() / len(x) * 100
    )
    
    fig.add_trace(
        go.Bar(x=satisfaction_by_xl.index, y=satisfaction_by_xl.values,
              name='Satisfaction Rate'),
        row=1, col=2
    )
    
    # Plot 3: Box plot of distances
    for xl_idx in sorted(combined_df['XL_Index'].unique()):
        xl_data = combined_df[combined_df['XL_Index'] == xl_idx]
        fig.add_trace(
            go.Box(y=xl_data['Max_Distance'], name=f'XL {xl_idx}'),
            row=2, col=1
        )
    
    # Plot 4: Heatmap
    frame_indices = sorted(all_frames.keys())
    xl_indices = sorted(combined_df['XL_Index'].unique())
    
    z_data = []
    for xl_idx in xl_indices:
        row = []
        for frame_idx in frame_indices:
            frame_data = combined_df[(combined_df['XL_Index'] == xl_idx) & 
                                    (combined_df['Frame'] == frame_idx)]
            row.append(frame_data['Max_Distance'].min())
        z_data.append(row)
    
    fig.add_trace(
        go.Heatmap(z=z_data, x=frame_indices, 
                  y=[f'XL {i}' for i in xl_indices],
                  colorscale='RdYlGn_r'),
        row=2, col=2
    )
    
    fig.update_layout(height=800, showlegend=True, 
                     title_text="Crosslink Analysis Dashboard")
    
    return fig

# Create and save interactive dashboard
fig = create_interactive_dashboard(all_frames, threshold=30.0)
fig.write_html('xl_dashboard.html')
fig.show()    
    # Optional: Calculate average distances across frames
    # for frame_idx, df in all_frames.items():
    #     print(f"\nFrame {frame_idx} summary:")
    #     summarize_distances(df, threshold=30.0)

Loading RMF file: test/6.rmf3, frame 0
Loading crosslink data from: input_data/reduced_ddi_trifunctional.csv

Analyzing XL 0: DDI1:133 - DDI1:133 - DDI1:213
  Valid copy combinations: 4
  Combo DDI1.0:133 - DDI1.1:133 - DDI1.0:213
    Distances: 1-2=29.42Å, 2-3=45.62Å, 1-3=24.01Å
  Combo DDI1.0:133 - DDI1.1:133 - DDI1.1:213
    Distances: 1-2=29.42Å, 2-3=46.65Å, 1-3=70.74Å
  Combo DDI1.1:133 - DDI1.0:133 - DDI1.0:213
    Distances: 1-2=29.42Å, 2-3=24.01Å, 1-3=45.62Å
  Combo DDI1.1:133 - DDI1.0:133 - DDI1.1:213
    Distances: 1-2=29.42Å, 2-3=70.74Å, 1-3=46.65Å

Analyzing XL 1: DDI1:133 - DDI1:161 - DDI1:213
  Valid copy combinations: 8
  Combo DDI1.0:133 - DDI1.0:161 - DDI1.0:213
    Distances: 1-2=31.50Å, 2-3=28.17Å, 1-3=24.01Å
  Combo DDI1.0:133 - DDI1.0:161 - DDI1.1:213
    Distances: 1-2=31.50Å, 2-3=56.06Å, 1-3=70.74Å
  Combo DDI1.0:133 - DDI1.1:161 - DDI1.0:213
    Distances: 1-2=51.55Å, 2-3=68.20Å, 1-3=24.01Å
  Combo DDI1.0:133 - DDI1.1:161 - DDI1.1:213
    Distances: 1-2=51.55Å, 





## Crosslink Mapping Strategy

### Handling Ambiguous Crosslinks

When mapping crosslinks to structural models, we need to account for protein copy numbers and the resulting ambiguity in crosslink assignments.

#### Scenarios:

**Inter-protein crosslinks (p1 ↔ p2)**
- If protein p1 has copy number = 2 and protein p2 has copy number = 2
- A crosslink between p1 and p2 can have **4 possible mappings**:
    - p1_0 ↔ p2_0
    - p1_0 ↔ p2_1  
    - p1_1 ↔ p2_0
    - p1_1 ↔ p2_1

**Intra-protein crosslinks (p1 ↔ p1)**

*Same residue numbers:*
- Must be between **different copies only**
- Mapping: p1_0 ↔ p1_1

*Different residue numbers:*
- Can be between **same copy or different copies**
- Possible mappings:
    - p1_0 ↔ p1_0 (intra-copy)
    - p1_0 ↔ p1_1 (inter-copy)
    - p1_1 ↔ p1_0 (inter-copy)
    - p1_1 ↔ p1_1 (intra-copy)

In [None]:
# map chain ID to protein names
chain_to_protein = {
    'A': 'DDI1',
    'B': 'DDI1',
    'C': 'DDI2',
    'D': 'DDI2'
}