## MSA-Search NIM Output Reformatting JSON into A3M Format

Reformatting of JSON output from https://build.nvidia.com/colabfold/msa-search

07Aug2025

## 1.1 Set Up Dependencies

In [22]:
import json
import re
from pathlib import Path

## 1.2 Functions 

In [23]:
def load_json_safely(filename, default_value=None):
    """
    Load JSON file with robust error handling for common issues
    
    Args:
        filename (str): Path to JSON file
        default_value: Default value to return if loading fails
    
    Returns:
        dict: Loaded JSON data or default value
    """
    try:
        with open(filename, 'r', encoding='utf-8') as f:
            data = json.load(f)
        print(f"Successfully loaded {filename}")
        return data
        
    except json.JSONDecodeError as e:
        print(f"JSONDecodeError in {filename}: {e}")
        print(f"Error at line {e.lineno}, column {e.colno}")
        
        # Try to show context around the error
        try:
            with open(filename, 'r', encoding='utf-8') as f:
                content = f.read()
            lines = content.split('\n')
            if e.lineno <= len(lines):
                start_line = max(0, e.lineno - 2)
                end_line = min(len(lines), e.lineno + 1)
                print(f"Context around error:")
                for i in range(start_line, end_line):
                    marker = ">>> " if i == e.lineno - 1 else "    "
                    print(f"{marker}Line {i+1}: {lines[i][:80]}...")
        except Exception:
            pass
            
        if default_value is not None:
            print(f"Using default value")
            return default_value
        else:
            raise
            
    except FileNotFoundError as e:
        print(f"File not found: {filename}")
        if default_value is not None:
            print(f"Using default value")
            return default_value
        else:
            raise
            
    except Exception as e:
        print(f"Unexpected error loading {filename}: {e}")
        if default_value is not None:
            print(f"Using default value")
            return default_value
        else:
            raise


def load_msa_data(msa_file_path):
    """
    Load MSA (Multiple Sequence Alignment) data with robust error handling
    
    Args:
        msa_file_path (str): Path to MSA JSON file
    
    Returns:
        dict: MSA data or error information
    """
    msa_file_path = Path(msa_file_path)
    
    if not msa_file_path.exists():
        print(f"MSA file not found: {msa_file_path}")
        return {"alignments": {}, "error": "file_not_found"}
    
    print(f"Loading MSA data from: {msa_file_path}")
    
    # Check file size
    file_size = msa_file_path.stat().st_size
    print(f"File size: {file_size:,} bytes")
    
    if file_size == 0:
        print(f"MSA file is empty")
        return {"alignments": {}, "error": "empty_file"}
    
    # Try to load the JSON
    try:
        msa_data = load_json_safely(str(msa_file_path), default_value={
            "alignments": {},
            "error": "json_decode_error",
            "templates": {},
            "metrics": {"search_type": "unknown"}
        })
        
        # Validate the structure
        if "alignments" in msa_data:
            alignment_count = len(msa_data.get("alignments", {}))
            print(f"Found {alignment_count} alignment(s)")
        else:
            print(f"No 'alignments' key found in data")
            
        return msa_data
        
    except Exception as e:
        print(f"Failed to load MSA data: {e}")
        return {"alignments": {}, "error": str(e)}


def parse_sequences(input_string, n, sequence):
    """
    Parse the output of alignments from the MSA-Search NIM to be used downstream
    
    Args:
        input_string (str): The output file of alignments in a string format
        n (int): The amount of alignments to return from the output when parsing
        sequence (str): The query sequence for alignment
    
    Returns:
        list: A list of alignment identifiers and sequences, starting with the query,
              where the amount of sequences is given by n
    """
    # Output is parsed to have a line for the sequence id and sequence itself so `n` returns correlates to n*2 lines
    n = n * 2

    # First, handle the `Query` block separately
    lines = input_string.strip().split('\n')

    # Now process the rest of the lines
    remaining_string = "\n".join(lines[:])

    # Regex to find blocks starting with `>` and then followed by a sequence.
    pattern = re.compile(r'\n>(.*?)\n(.*?)(?=\n>|\Z)', re.DOTALL)

    matches = pattern.finditer(remaining_string)

    output_list = []
    output_list_to_order = []

    for num_match, match in enumerate(matches):
        # The name is the first capturing group, split by tab and take the first part
        name_full = match.group(1).split('\t')[0]
        SW_score = match.group(1).split('\t')[1]

        # The sequence is the second capturing group
        sequence_raw = match.group(2).strip()
        sequence = ''.join(char for char in sequence_raw if char.isupper() or not char.isalpha())

        # Store the aligned sequence in the list of outputs by name, sequence, Smith-Waterman score
        output_list_to_order.append((f'>{name_full}', sequence, int(SW_score)))

    output_lines = output_list_to_order[:n]

    return output_lines


def write_alignments_to_a3m(alignments_data, output_file_path, description="MSA alignments"):
    """
    Write alignment data to a3M format file.
    
    Args:
        alignments_data: Either a list of alternating headers/sequences or a string containing alignments
        output_file_path (str): Path for the output a3M file
        description (str): Description for the file
    
    Returns:
        str: Path to the created a3M file
    """
    output_path = Path(output_file_path)
    
    # Handle both list and string input formats
    if isinstance(alignments_data, list):
        alignments_string = '\n'.join(alignments_data)
    elif isinstance(alignments_data, str):
        alignments_string = alignments_data
    else:
        raise ValueError("alignments_data must be either a list or string")
    
    # Count sequences for reporting
    sequence_count = alignments_string.count('>')
    
    print(f"Writing {sequence_count} sequences to a3M format: {output_path}")
    
    try:
        with open(output_path, 'w', encoding='utf-8') as f:
            # Write the alignments
            f.write(alignments_string)
            
            # Ensure file ends with newline
            if not alignments_string.endswith('\n'):
                f.write('\n')
        
        # Verify the file was created successfully
        if output_path.exists():
            file_size = output_path.stat().st_size
            print(f"Successfully created a3M file:")
            print(f"File: {output_path}")
            print(f"Size: {file_size:,} bytes")
            print(f"Sequences: {sequence_count}")
            
            return str(output_path)
        else:
            raise IOError(f"Failed to create file {output_path}")
            
    except Exception as e:
        print(f"Error writing a3M file: {e}")
        raise


def write_filtered_a3m(alignments_data, output_file_path, max_sequences=None, min_length=None, description="Filtered MSA alignments"):
    """
    Write alignment data to a3M format with optional filtering.
    
    Args:
        alignments_data: String containing alignments in FASTA-like format
        output_file_path (str): Path for the output a3M file
        max_sequences (int, optional): Maximum number of sequences to include
        min_length (int, optional): Minimum sequence length (excluding gaps)
        description (str): Description for the file
    
    Returns:
        str: Path to the created a3M file
    """
    output_path = Path(output_file_path)
    
    # Parse sequences from the input data
    if isinstance(alignments_data, str):
        lines = alignments_data.strip().split('\n')
    else:
        lines = '\n'.join(alignments_data).strip().split('\n')
    
    sequences = []
    current_header = None
    current_sequence = ""
    
    for line in lines:
        line = line.strip()
        if line.startswith('>'):
            # Save previous sequence if it exists
            if current_header is not None:
                sequences.append((current_header, current_sequence))
            current_header = line
            current_sequence = ""
        else:
            current_sequence += line
    
    # Don't forget the last sequence
    if current_header is not None:
        sequences.append((current_header, current_sequence))
    
    print(f"Parsed {len(sequences)} sequences from input data")
    
    # Apply filters
    filtered_sequences = []
    
    for header, sequence in sequences:
        # Apply minimum length filter (count non-gap characters)
        if min_length is not None:
            non_gap_length = len(sequence.replace('-', '').replace('.', ''))
            if non_gap_length < min_length:
                continue
        
        filtered_sequences.append((header, sequence))
        
        # Apply maximum sequences limit
        if max_sequences is not None and len(filtered_sequences) >= max_sequences:
            break
    
    print(f"After filtering: {len(filtered_sequences)} sequences")
    if max_sequences:
        print(f"Max sequences limit: {max_sequences}")
    if min_length:
        print(f"Min length filter: {min_length}")
    
    # Write to a3M format
    try:
        with open(output_path, 'w', encoding='utf-8') as f:
            # Write sequences
            for header, sequence in filtered_sequences:
                f.write(f"{header}\n{sequence}\n")
        
        # Report success
        file_size = output_path.stat().st_size
        print(f"Successfully created filtered a3M file:")
        print(f"File: {output_path}")
        print(f"Size: {file_size:,} bytes")
        print(f"Sequences: {len(filtered_sequences)}")
        
        return str(output_path)
        
    except Exception as e:
        print(f"Error writing filtered a3M file: {e}")
        raise


def analyze_a3m_file(file_path):
    """
    Analyze an a3M file and provide statistics.
    
    Args:
        file_path (str): Path to the a3M file
    """
    file_path = Path(file_path)
    
    if not file_path.exists():
        print(f"File not found: {file_path}")
        return
    
    print(f"Analyzing a3M file: {file_path.name}")
    
    try:
        with open(file_path, 'r') as f:
            lines = f.readlines()
        
        # Count statistics
        total_lines = len(lines)
        comment_lines = sum(1 for line in lines if line.startswith('#'))
        sequence_headers = sum(1 for line in lines if line.startswith('>'))
        sequence_lines = total_lines - comment_lines - sequence_headers
        
        # Calculate sequence lengths
        sequence_lengths = []
        current_sequence = ""
        
        for line in lines:
            line = line.strip()
            if line.startswith('#'):
                continue
            elif line.startswith('>'):
                if current_sequence:
                    sequence_lengths.append(len(current_sequence))
                current_sequence = ""
            else:
                current_sequence += line
        
        # Don't forget the last sequence
        if current_sequence:
            sequence_lengths.append(len(current_sequence))
        
        # File statistics
        file_size = file_path.stat().st_size
        
        print(f"File Statistics:")
        print(f"File size: {file_size:,} bytes")
        print(f"Total lines: {total_lines}")
        print(f"Comment lines: {comment_lines}")
        print(f"Sequence headers: {sequence_headers}")
        print(f"Sequence lines: {sequence_lines}")
        
        if sequence_lengths:
            avg_length = sum(sequence_lengths) / len(sequence_lengths)
            min_length = min(sequence_lengths)
            max_length = max(sequence_lengths)
            
            print(f"Sequence Statistics:")
            print(f"Number of sequences: {len(sequence_lengths)}")
            print(f"Average length: {avg_length:.1f}")
            print(f"Length range: {min_length} - {max_length}")
            
            # Show first sequence as example
            with open(file_path, 'r') as f:
                content = f.read()
                
            # Find first sequence
            lines = content.split('\n')
            for i, line in enumerate(lines):
                if line.startswith('>') and not line.startswith('#'):
                    header = line
                    sequence = ""
                    j = i + 1
                    while j < len(lines) and not lines[j].startswith('>'):
                        if not lines[j].startswith('#'):
                            sequence += lines[j].strip()
                        j += 1
                    
                    print(f"First sequence example:")
                    print(f"Header: {header}")
                    print(f"Length: {len(sequence)}")
                    print(f"Preview: {sequence[:80]}{'...' if len(sequence) > 80 else ''}")
                    break
        
    except Exception as e:
        print(f"Error analyzing file: {e}")


def process_msa_alignments(msa_response_dict, databases, sequence, max_sequences_per_db=10000, output_file="merged_alignments_protein.a3m"):
    """
    Process MSA alignments from multiple databases and merge them into A3M format.
    
    Args:
        msa_response_dict (dict): MSA response data containing alignments
        databases (list): List of database names to process
        sequence (str): Query sequence for alignment
        max_sequences_per_db (int): Maximum number of sequences to parse per database
        output_file (str): Output A3M file path
    
    Returns:
        tuple: (merged_alignments_protein, a3m_file_path)
            - merged_alignments_protein: List of merged alignments
            - a3m_file_path: Path to the created A3M file
    """
    all_parsed_dataset_output = []
    
    for num_done, database in enumerate(databases):
        print(f"Parsing results from database: {database}")

        # Pull string of alignments stored in json output for specific dataset
        a3m_dict_msa_search = msa_response_dict['alignments'][database]['a3m']['alignment']

        a3m_dict_msa_search_parsed = parse_sequences(a3m_dict_msa_search, max_sequences_per_db, sequence)

        num_sequences_aligned = (len(a3m_dict_msa_search_parsed))
        print(f"Number of sequences aligned: {num_sequences_aligned}")

        all_parsed_dataset_output.extend(a3m_dict_msa_search_parsed)

    # Sort all the alignments based off of the alignment score
    all_parsed_dataset_output.sort(key=lambda x: x[2], reverse=True)

    # Now that the alignments across all datasets are sorted, reformat each entry to name and sequence
    sorted_parsed_output_formatted = []
    for align_tuple in all_parsed_dataset_output:
        sorted_parsed_output_formatted.append(align_tuple[0])
        sorted_parsed_output_formatted.append(align_tuple[1])

    merged_alignments_protein = [f">query_sequence\n{sequence}"]
    merged_alignments_protein.extend(sorted_parsed_output_formatted)

    print(f"Total merged alignments: {len(merged_alignments_protein)}")

    # Write merged_alignments_protein to a3M format
    a3m_file_path = write_alignments_to_a3m(
        merged_alignments_protein, 
        output_file, 
        description=f"Merged protein alignments from MSA-Search NIM ({', '.join(databases)})"
    )
    
    return merged_alignments_protein, a3m_file_path


## 1.3 Apply Functions

#### Include Databases Queried in MSA-Search NIM and the Amino Acid Query Sequence 

**NOTE:** No spaces nor carriage returns permitted in AA sequence

In [None]:
# Example with KOR (P41145 | OPRK_HUMAN)
sequence = 'AIPVIITAVYSVVFVVGLVGNSLVMFVIIRYTKMKTATNIYIFNLALADALVTTTMPFQSTVYLMNSWPFGDVLCKIVISIDYYNMFTSIFTLTMMSVDRYIAVCHPVKALDFRTPLKAKIINICIWLLSSSVGISAIVLGGTKVREDVDVIECSLQFPDDDYSWWDLFMKICVFIFAFVIPVLIIIVCYTLMILRLKSVRLLSGSREKDRNLRRITRLVLVVVAVFVVCWTPIHIFILVEALGSTSHSTAALSSYYFCIALGYTNSSLNPILYAFLDENFKRCF'

databases = ['Uniref30_2302', 'colabfold_envdb_202108', 'PDB70_220313']


### Load the MSA `JSON` file

In [None]:
# Load MSA data from the JSON file
# Example with KOR (P41145 | OPRK_HUMAN)
msa_file_path = "kor_msa.json"

msa_response_dict = load_msa_data(msa_file_path)

# Display results
if "error" in msa_response_dict and msa_response_dict["error"] != "json_decode_error":
    print(f"MSA data loading had issues: {msa_response_dict['error']}")
else:
    print(f"MSA data loaded successfully!")
    print(f"Available keys: {list(msa_response_dict.keys())}")
    if "alignments" in msa_response_dict:
        alignments = msa_response_dict["alignments"]
        print(f"Alignment databases: {list(alignments.keys())}")
        for db_name, db_data in alignments.items():
            if isinstance(db_data, dict):
                print(f"- {db_name}: {list(db_data.keys())}")

Loading MSA data from: kor_msa.json
File size: 429,709 bytes
Successfully loaded kor_msa.json
Found 3 alignment(s)
MSA data loaded successfully!
Available keys: ['alignments', 'templates', 'metrics']
Alignment databases: ['Uniref30_2302', 'PDB70_220313', 'colabfold_envdb_202108']
- Uniref30_2302: ['fasta', 'a3m']
- PDB70_220313: ['fasta', 'a3m']
- colabfold_envdb_202108: ['fasta', 'a3m']


### Parse the MSA alignment results to merge results from all datasets used for MSA

In [26]:
merged_alignments_protein, a3m_file_path = process_msa_alignments(
    msa_response_dict,
    databases,
    sequence,
    max_sequences_per_db=10000,
    output_file="merged_alignments_protein.a3m"
    )

Parsing results from database: Uniref30_2302
Number of sequences aligned: 199
Parsing results from database: colabfold_envdb_202108
Number of sequences aligned: 199
Parsing results from database: PDB70_220313
Number of sequences aligned: 197
Total merged alignments: 1191
Writing 596 sequences to a3M format: merged_alignments_protein.a3m
Successfully created a3M file:
File: merged_alignments_protein.a3m
Size: 180,883 bytes
Sequences: 596


### Create Filtered Versions of the Full MSA (i.e., Top-100, Subset, etc.)

In [27]:
# Create a filtered version with top 100 sequences
filtered_a3m_path = write_filtered_a3m(
    merged_alignments_protein,
    "merged_alignments_protein_top100.a3m", 
    max_sequences=100,
    min_length=50,
    description="Top 100 protein alignments from MSA-Search NIM (min length 50 aa)"
)

# Create a smaller sample for quick testing
sample_a3m_path = write_filtered_a3m(
    merged_alignments_protein,
    "merged_alignments_protein_sample.a3m", 
    max_sequences=10,
    description="Sample of 10 protein alignments for testing"
)


Parsed 596 sequences from input data
After filtering: 100 sequences
Max sequences limit: 100
Min length filter: 50
Successfully created filtered a3M file:
File: merged_alignments_protein_top100.a3m
Size: 30,652 bytes
Sequences: 100
Parsed 596 sequences from input data
After filtering: 10 sequences
Max sequences limit: 10
Successfully created filtered a3M file:
File: merged_alignments_protein_sample.a3m
Size: 3,015 bytes
Sequences: 10


### Analyze All Created `A3M` Format Files

In [28]:
# Analyze all created a3M files
print("=" * 60)
print("A3M FILE ANALYSIS")
print("=" * 60)

files_to_analyze = [
    "merged_alignments_protein.a3m",
    "merged_alignments_protein_top100.a3m", 
    "merged_alignments_protein_sample.a3m"
]

for file_name in files_to_analyze:
    if Path(file_name).exists():
        analyze_a3m_file(file_name)
        print("-" * 40)
    else:
        print(f"File not found: {file_name}")
        print("-" * 40)



A3M FILE ANALYSIS
Analyzing a3M file: merged_alignments_protein.a3m
File Statistics:
File size: 180,883 bytes
Total lines: 1192
Comment lines: 0
Sequence headers: 596
Sequence lines: 596
Sequence Statistics:
Number of sequences: 596
Average length: 285.0
Length range: 285 - 285
First sequence example:
Header: >query_sequence
Length: 285
Preview: AIPVIITAVYSVVFVVGLVGNSLVMFVIIRYTKMKTATNIYIFNLALADALVTTTMPFQSTVYLMNSWPFGDVLCKIVIS...
----------------------------------------
Analyzing a3M file: merged_alignments_protein_top100.a3m
File Statistics:
File size: 30,652 bytes
Total lines: 200
Comment lines: 0
Sequence headers: 100
Sequence lines: 100
Sequence Statistics:
Number of sequences: 100
Average length: 285.0
Length range: 285 - 285
First sequence example:
Header: >query_sequence
Length: 285
Preview: AIPVIITAVYSVVFVVGLVGNSLVMFVIIRYTKMKTATNIYIFNLALADALVTTTMPFQSTVYLMNSWPFGDVLCKIVIS...
----------------------------------------
Analyzing a3M file: merged_alignments_protein_sample.a3m
File Stati