<a href="https://colab.research.google.com/github/yoakiyama/MSA_Pairformer/blob/dev-branch/MSA_Pairformer_with_MMseqs2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# **MSA Pairformer**

In [None]:
# @title Setup mmseqs2 + hhsuite + MSA pairformer
import os
import re
import sys
import gc
import io
import json
import time
import pickle
import shutil
import hashlib
import tarfile
import tempfile
import warnings
import importlib
import subprocess
from pathlib import Path
from sys import version_info
from contextlib import redirect_stdout, redirect_stderr
from typing import List, Dict, Optional, Tuple, Union
from google.colab import files

import requests
import torch
import numpy as np
import matplotlib.pyplot as plt

class ColabFoldPairedMSA:
    """Simple class to get paired MSAs from ColabFold with extended filtering and genomic distance support"""
    def __init__(self, host_url: str = "https://api.colabfold.com",
                 cache_dir: Optional[str] = None):
        self.host_url = host_url
        self.job_id = None
        self.parsed_entries = None  # List of parsed entries with metadata

        # Set up cache directory
        if cache_dir is None:
            self.cache_dir = Path.home() / ".colabfold_cache"
        else:
            self.cache_dir = Path(cache_dir)
        self.cache_dir.mkdir(parents=True, exist_ok=True)

        print(f"Cache directory: {self.cache_dir}")

        # Initialize UniProt converter
        self._init_uniprot_converter()

    def _init_uniprot_converter(self):
        """Initialize UniProt ID to number conversion tables"""
        from string import ascii_uppercase

        # Initialize encoding tables
        self.pa = {a: 0 for a in ascii_uppercase}
        for a in ["O", "P", "Q"]:
            self.pa[a] = 1

        self.ma = [[{} for k in range(6)], [{} for k in range(6)]]

        # Fill encoding tables
        for n, t in enumerate(range(10)):
            for i in [0, 1]:
                for j in [0, 4]:
                    self.ma[i][j][str(t)] = n

        for n, t in enumerate(list(ascii_uppercase) + list(range(10))):
            for i in [0, 1]:
                for j in [1, 2]:
                    self.ma[i][j][str(t)] = n
            self.ma[1][3][str(t)] = n

        for n, t in enumerate(ascii_uppercase):
            self.ma[0][3][str(t)] = n
            for i in [0, 1]:
                self.ma[i][5][str(t)] = n

        self.upi_encoding = {}
        hex_chars = list(range(10)) + ['A', 'B', 'C', 'D', 'E', 'F']
        for n, char in enumerate(hex_chars):
            self.upi_encoding[str(char)] = n

    def _extract_uniprot_id(self, header: str) -> str:
        """Extract UniProt ID from header."""
        pos = header.find("UniRef")
        if pos == -1:
            return ""

        start = header.find('_', pos)
        if start == -1:
            return ""
        start += 1

        end = start
        while end < len(header) and header[end] not in ' _\t':
            end += 1

        uid = header[start:end]

        # Validate - including UPI IDs
        if len(uid) >= 3 and uid[:3] == "UPI":
            return uid

        # Regular UniProt ID validation
        if len(uid) not in [6, 10]:
            return ""
        if not uid[0].isalpha():
            return ""

        return uid

    def _uniprot_to_number(self, uniprot_ids: List[str]) -> List[int]:
        """Convert UniProt IDs to numbers for distance calculation."""
        numbers = []
        for uni in uniprot_ids:
            if not uni or not uni[0].isalpha():
                numbers.append(0)
                continue

            # Handle UPI IDs
            if uni.startswith("UPI") and len(uni) == 13:
                hex_part = uni[3:]  # Remove "UPI" prefix
                num = 0
                tot = 1

                # Process hexadecimal characters in reverse order
                for u in reversed(hex_part):
                    if str(u) in self.upi_encoding:
                        num += self.upi_encoding[str(u)] * tot
                        tot *= 16  # Base 16 for hexadecimal
                    else:
                        # Invalid hex character, assign 0
                        num = 0
                        break
                # Add offset to distinguish UPI IDs from standard ones
                # Use a large offset to avoid collisions
                numbers.append(num + 10**15)
                continue

            p = self.pa.get(uni[0], 0)
            tot, num = 1, 0

            if len(uni) == 10:
                for n, u in enumerate(reversed(uni[-4:])):
                    if str(u) in self.ma[p][n]:
                        num += self.ma[p][n][str(u)] * tot
                        tot *= len(self.ma[p][n].keys())

            for n, u in enumerate(reversed(uni[:6])):
                if n < len(self.ma[p]) and str(u) in self.ma[p][n]:
                    num += self.ma[p][n][str(u)] * tot
                    tot *= len(self.ma[p][n].keys())

            numbers.append(num)

        return numbers

    def _calculate_genomic_distances(self, entry: Dict) -> List[int]:
        """Calculate sequential distances between adjacent chains."""
        distances = []
        nums = entry['uniprot_nums']

        for i in range(1, len(nums)):
            if nums[i-1] and nums[i]:  # Both must be valid numbers
                dist = abs(nums[i] - nums[i-1])
                distances.append(dist)
            else:
                distances.append(-1)  # Invalid distance

        return distances

    def _get_cache_key(self, sequences: List[str], genomic_distance: Optional[int],
                       prefix: Optional[str]) -> str:
        """Generate a unique cache key for the request"""
        # Always use genomic_distance=20 for caching to maximize reuse
        cache_genomic_distance = 20 if genomic_distance is not None else None

        # Create a deterministic string representation
        cache_data = {
            'sequences': sequences,
            'genomic_distance': cache_genomic_distance,
            'prefix': prefix,
            'host_url': self.host_url
        }
        cache_str = json.dumps(cache_data, sort_keys=True)

        # Generate hash
        cache_hash = hashlib.sha256(cache_str.encode()).hexdigest()[:16]

        # Create human-readable prefix
        seq_info = f"{len(sequences)}seq"
        if prefix:
            seq_info += f"_{prefix}"

        return f"{seq_info}_{cache_hash}"

    def _load_from_cache(self, cache_key: str) -> bool:
        """Try to load parsed entries from cache"""
        cache_file = self.cache_dir / f"{cache_key}.pkl"

        if cache_file.exists():
            try:
                with open(cache_file, 'rb') as f:
                    cache_data = pickle.load(f)

                self.parsed_entries = cache_data['parsed_entries']
                self.job_id = cache_data.get('job_id', f"cached_{cache_key}")

                print(f"Loaded from cache: {cache_key}")
                return True
            except Exception as e:
                print(f"Cache load failed: {e}")
                return False

        return False

    def _save_to_cache(self, cache_key: str):
        """Save parsed entries to cache"""
        cache_file = self.cache_dir / f"{cache_key}.pkl"

        try:
            cache_data = {
                'parsed_entries': self.parsed_entries,
                'job_id': self.job_id,
                'timestamp': time.time()
            }

            with open(cache_file, 'wb') as f:
                pickle.dump(cache_data, f)

            print(f"Saved to cache: {cache_key}")

            # Also save a human-readable info file
            info_file = self.cache_dir / f"{cache_key}_info.json"
            info_data = {
                'job_id': self.job_id,
                'num_entries': len(self.parsed_entries),
                'num_chains': len(self.parsed_entries[0]['sequences']) if self.parsed_entries else 0,
                'timestamp': time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(cache_data['timestamp']))
            }
            with open(info_file, 'w') as f:
                json.dump(info_data, f, indent=2)

        except Exception as e:
            print(f"Cache save failed: {e}")

    def clear_cache(self, older_than_days: Optional[int] = None):
        """Clear cache files, optionally only those older than specified days"""
        import glob

        cache_files = glob.glob(str(self.cache_dir / "*.pkl"))
        removed = 0

        for cache_file in cache_files:
            if older_than_days is not None:
                # Check age
                age_days = (time.time() - os.path.getmtime(cache_file)) / (24 * 3600)
                if age_days < older_than_days:
                    continue

            try:
                os.remove(cache_file)
                # Also remove info file if exists
                info_file = cache_file.replace('.pkl', '_info.json')
                if os.path.exists(info_file):
                    os.remove(info_file)
                removed += 1
            except:
                pass

        print(f"Removed {removed} cache files")

    def submit_or_load_from_cache(self,
                                  sequences: List[str],
                                  genomic_distance: Optional[int] = 20,
                                  prefix: Optional[str] = None,
                                  use_cache: bool = True) -> Tuple[str, bool]:
        """Submit sequences or load from cache if available

        Always uses genomic_distance=20 for caching to maximize reuse.

        Returns:
            Tuple of (job_id, from_cache) where from_cache indicates if data was loaded from cache
        """
        # Always use distance=20 for caching
        cache_genomic_distance = 20 if genomic_distance is not None else None

        # Generate cache key
        cache_key = self._get_cache_key(sequences, cache_genomic_distance, prefix)
        self._current_cache_key = cache_key

        # Try to load from cache
        if use_cache and self._load_from_cache(cache_key):
            return self.job_id, True

        # If not in cache, submit normally with distance=20
        self.submit(sequences, cache_genomic_distance, prefix)
        return self.job_id, False

    def submit(self,
               sequences: List[str],
               genomic_distance: Optional[int] = 20,
               prefix: Optional[str] = None) -> str:
        """Submit sequences and return job ID"""
        # Create query
        query = ""
        for i, seq in enumerate(sequences, start=101):
            if prefix:
                query += f">{prefix}_{i}\n{seq}\n"
            else:
                query += f">{i}\n{seq}\n"

        # Determine mode based on number of sequences
        if len(sequences) == 1:
            # Single sequence - use regular MSA mode
            mode = "env"
            endpoint = "ticket/msa"
        else:
            # Multiple sequences - use pairing mode
            if genomic_distance is None:
                mode = "paircomplete"
            else:
                mode = f"paircomplete-pairfilterprox_{genomic_distance}"
            endpoint = "ticket/pair"

        response = requests.post(
            f'{self.host_url}/{endpoint}',
            data={'q': query, 'mode': mode},
            timeout=30
        )

        if response.status_code != 200:
            raise Exception(f"Failed to submit: {response.text}")

        self.job_id = response.json()['id']
        print(f"Job submitted: {self.job_id} with mode: {mode}")
        return self.job_id

    def wait(self, check_interval: int = 5):
        """Wait for job completion"""
        while True:
            response = requests.get(f'{self.host_url}/ticket/{self.job_id}', timeout=30)
            status = response.json().get('status', 'UNKNOWN')
            print(f"Status: {status}")

            if status == "COMPLETE":
                break
            elif status == "ERROR":
                raise Exception("Job failed")
            time.sleep(check_interval)

    def download_and_parse(self, output_dir: str = "results"):
        """Download results and parse the MSA"""
        # Download
        Path(output_dir).mkdir(parents=True, exist_ok=True)
        tar_path = os.path.join(output_dir, f"{self.job_id}.tar.gz")

        response = requests.get(f'{self.host_url}/result/download/{self.job_id}', timeout=60)
        with open(tar_path, 'wb') as f:
            f.write(response.content)

        # Extract
        with tarfile.open(tar_path) as tar:
            tar.extractall(output_dir)

        # Check if this is a paired MSA or single sequence MSA
        pair_a3m = os.path.join(output_dir, 'pair.a3m')
        if os.path.exists(pair_a3m):
            # Paired MSA
            self.parsed_entries = self._parse_paired_a3m(pair_a3m)
        else:
            # Single sequence - combine multiple MSA files
            self.parsed_entries = self._parse_single_msas(output_dir)

        # Save to cache if we have a cache key
        if hasattr(self, '_current_cache_key'):
            self._save_to_cache(self._current_cache_key)

    def _parse_msa_lines(self, lines: List[str]) -> List[Dict]:
        """Parse MSA lines into structured entries with UniProt ID extraction"""
        entries = []
        i = 0
        is_first = True

        while i < len(lines):
            line = lines[i].rstrip()

            if line.startswith('>'):
                header = line
                seq_lines = []
                i += 1

                # Collect sequence lines
                while i < len(lines) and not lines[i].startswith('>'):
                    if lines[i].strip():
                        seq_lines.append(lines[i].rstrip())
                    i += 1

                sequence = ''.join(seq_lines)

                # Parse header
                header_parts = header.split('\t')
                header_clean = header_parts[0].lstrip('>').replace('UniRef100_', '')

                # Extract UniProt ID
                uid = self._extract_uniprot_id(header)
                has_uniref = "UniRef" in header
                uniprot_num = 0

                if uid:
                    # Convert to number
                    uniprot_nums = self._uniprot_to_number([uid])
                    uniprot_num = uniprot_nums[0] if uniprot_nums else 0

                # For query sequence
                if is_first:
                    coverage = 1.0
                    identity = 1.0
                    evalue = 0.0
                    alnscore = float('inf')
                    is_first = False
                else:
                    # Extract metadata from header
                    coverage = None
                    identity = None
                    evalue = None
                    alnscore = None

                    if len(header_parts) >= 10:
                        try:
                            alnscore = float(header_parts[1])
                            identity = float(header_parts[2])
                            evalue = float(header_parts[3])
                            q_start = int(header_parts[4])
                            q_end = int(header_parts[5])
                            q_len = int(header_parts[6])
                            coverage = (q_end - q_start + 1) / q_len
                        except:
                            pass

                    # Fallback values
                    if coverage is None:
                        coverage = 0.0
                    if identity is None:
                        identity = 0.0
                    if evalue is None:
                        evalue = float('inf')
                    if alnscore is None:
                        alnscore = 0.0

                entries.append({
                    'header': header_clean,
                    'sequence': sequence,
                    'coverage': coverage,
                    'identity': identity,
                    'evalue': evalue,
                    'alnscore': alnscore,
                    'uid': uid,
                    'uniprot_num': uniprot_num,
                    'has_uniref': has_uniref
                })
            else:
                i += 1

        return entries

    def _parse_paired_a3m(self, a3m_path: str) -> List[Dict]:
        """Parse paired A3M file into list of entries with UniProt ID tracking"""
        # First, separate MSAs by chain ID
        raw_msas = {}
        update_M = True
        M = None

        with open(a3m_path, 'r') as f:
            for line in f:
                if "\x00" in line:
                    line = line.replace("\x00", "")
                    update_M = True

                if line.startswith(">") and update_M:
                    # Extract chain ID (101, 102, etc.)
                    M = int(line[1:].rstrip().split('_')[-1])
                    update_M = False
                    if M not in raw_msas:
                        raw_msas[M] = []

                if M is not None:
                    raw_msas[M].append(line.rstrip())

        # Parse each chain's MSA
        parsed_msas = {}
        for seq_id, lines in raw_msas.items():
            parsed_msas[seq_id] = self._parse_msa_lines(lines)

        # Get sorted chain IDs
        seq_ids = sorted(parsed_msas.keys())

        # IMPORTANT: All chains should have the same number of sequences when paired
        num_entries_per_chain = [len(parsed_msas[sid]) for sid in seq_ids]
        if len(set(num_entries_per_chain)) > 1:
            print(f"Warning: Chains have different numbers of sequences: {dict(zip(seq_ids, num_entries_per_chain))}")
            print("Taking minimum to preserve pairing")

        min_entries = min(num_entries_per_chain)

        # Stitch entries together - sequences at the same position are paired
        stitched_entries = []
        for i in range(min_entries):
            # Collect info from each chain at position i
            headers = []
            sequences = []
            coverages = []
            identities = []
            evalues = []
            alnscores = []
            uids = []
            uniprot_nums = []
            has_uniref = True  # Will be False if any chain doesn't have UniRef

            for sid in seq_ids:
                entry = parsed_msas[sid][i]
                headers.append(entry['header'])
                sequences.append(entry['sequence'])
                coverages.append(entry['coverage'])
                identities.append(entry['identity'])
                evalues.append(entry['evalue'])
                alnscores.append(entry['alnscore'])
                uids.append(entry['uid'])
                uniprot_nums.append(entry['uniprot_num'])
                has_uniref = has_uniref and entry['has_uniref']

            stitched_entries.append({
                'headers': headers,
                'sequences': sequences,
                'coverages': coverages,
                'identities': identities,
                'evalues': evalues,
                'alnscores': alnscores,
                'uids': uids,
                'uniprot_nums': uniprot_nums,
                'has_uniref': has_uniref,
                'is_query': (i == 0)
            })

        return stitched_entries

    def _parse_single_msas(self, output_dir: str) -> List[Dict]:
        """Parse and combine single sequence MSAs"""
        # MSA files to look for
        msa_files = ['uniref.a3m', 'bfd.mgnify30.metaeuk30.smag30.a3m']

        all_entries = []
        seen_sequences = set()

        for msa_file in msa_files:
            msa_path = os.path.join(output_dir, msa_file)
            if os.path.exists(msa_path):
                with open(msa_path, 'r') as f:
                    lines = f.readlines()

                entries = self._parse_msa_lines(lines)

                # Add non-duplicate entries
                for entry in entries:
                    if entry['sequence'] not in seen_sequences:
                        seen_sequences.add(entry['sequence'])
                        all_entries.append({
                            'headers': [entry['header']],
                            'sequences': [entry['sequence']],
                            'coverages': [entry['coverage']],
                            'identities': [entry['identity']],
                            'evalues': [entry['evalue']],
                            'alnscores': [entry['alnscore']],
                            'uids': [entry['uid']],
                            'uniprot_nums': [entry['uniprot_num']],
                            'has_uniref': entry['has_uniref'],
                            'is_query': len(all_entries) == 0
                        })

        return all_entries

    def save_msa(self,
                 output_file: str,
                 min_coverage: Optional[float] = None,
                 min_identity: Optional[float] = None,
                 max_evalue: Optional[float] = None,
                 min_alnscore: Optional[float] = None,
                 max_genomic_distance: Optional[int] = None) -> Tuple[int, List[Dict]]:
        """Save MSA with optional filtering including genomic distance

        Returns:
            Tuple of (number of sequences written, list of filtered entries)
        """
        if not self.parsed_entries:
            raise ValueError("No MSA loaded. Run download_and_parse first.")

        # Smart parsing: convert percentages to fractions
        if min_coverage is not None and min_coverage > 1:
            min_coverage = min_coverage / 100
        if min_identity is not None and min_identity > 1:
            min_identity = min_identity / 100

        sequences_written = 0
        sequences_filtered = 0
        filtered_entries = []  # Store entries that pass the filter

        num_chains = len(self.parsed_entries[0]['sequences']) if self.parsed_entries else 0

        with open(output_file, 'w') as f:
            for entry in self.parsed_entries:
                # Skip filtering for query sequence
                if not entry['is_query']:
                    # All chains must pass the filter
                    filter_reasons = []

                    if min_coverage and any(c < min_coverage for c in entry['coverages']):
                        filter_reasons.append(f"coverage < {min_coverage}")
                    if min_identity and any(i < min_identity for i in entry['identities']):
                        filter_reasons.append(f"identity < {min_identity}")
                    if max_evalue is not None and any(e > max_evalue for e in entry['evalues'] if e is not None):
                        filter_reasons.append(f"evalue > {max_evalue}")
                    if min_alnscore is not None and any(a < min_alnscore for a in entry['alnscores'] if a is not None):
                        filter_reasons.append(f"alnscore < {min_alnscore}")

                    # Genomic distance filtering
                    if max_genomic_distance is not None and entry['has_uniref']:
                        distances = self._calculate_genomic_distances(entry)

                        if num_chains == 2:
                            # Simple case: check single distance
                            if distances[0] != -1 and distances[0] > max_genomic_distance:
                                filter_reasons.append(f"genomic distance > {max_genomic_distance}")
                        else:
                            # For >2 chains: check if all distances exceed threshold
                            # (relaxed filtering - keep if ANY distance is within threshold)
                            valid_distances = [d for d in distances if d != -1]
                            if valid_distances and all(d > max_genomic_distance for d in valid_distances):
                                filter_reasons.append(f"all genomic distances > {max_genomic_distance}")

                    if filter_reasons:
                        sequences_filtered += 1
                        continue

                # Write entry with modified header
                if entry['is_query']:
                    # Query header format: query_len1_len2_len3
                    header = "query"
                    for seq in entry['sequences']:
                        header += f"_len{len(seq)}"
                else:
                    # Regular header format: UID1_UID2_UID3_dist1-2_dist2-3
                    # First write UIDs (or original headers if no UID)
                    header_parts = []
                    for i, uid in enumerate(entry['uids']):
                        if uid:
                            header_parts.append(uid)
                        else:
                            header_parts.append(entry['headers'][i])

                    header = '_'.join(header_parts)

                    # Add distances if we have valid UIDs
                    if entry['has_uniref'] and all(entry['uids']):
                        distances = self._calculate_genomic_distances(entry)
                        for dist in distances:
                            if dist != -1:
                                header += f"_{dist}"

                sequence = ''.join(entry['sequences']).replace('\x00','')
                f.write(f">{header}\n{sequence}\n")
                sequences_written += 1
                filtered_entries.append(entry)

        print(f"Saved {sequences_written} sequences to {output_file}")
        if sequences_filtered > 0:
            print(f"Filtered out {sequences_filtered} sequences")
        return sequences_written, filtered_entries

    def get_stats(self, entries: Optional[List[Dict]] = None) -> Dict:
        """Get statistics about the MSA

        Args:
            entries: Optional list of entries to calculate stats from.
                    If None, uses all parsed entries.
        """
        if entries is None:
            entries = self.parsed_entries

        if not entries:
            return {}

        num_chains = len(entries[0]['sequences']) if entries else 0

        stats = {
            'num_chains': num_chains,
            'num_entries': len(entries),
        }

        # Per-chain statistics
        for i in range(num_chains):
            coverages = [e['coverages'][i] for e in entries[1:]]  # Skip query
            identities = [e['identities'][i] for e in entries[1:]]
            evalues = [e['evalues'][i] for e in entries[1:] if e['evalues'][i] is not None]
            alnscores = [e['alnscores'][i] for e in entries[1:] if e['alnscores'][i] is not None]

            chain_id = i + 101
            stats[f'chain_{chain_id}'] = {
                'query_length': len(entries[0]['sequences'][i]) if entries else 0,
                'avg_coverage': sum(coverages) / len(coverages) if coverages else 0,
                'avg_identity': sum(identities) / len(identities) if identities else 0,
                'avg_evalue': sum(evalues) / len(evalues) if evalues else 0,
                'avg_alnscore': sum(alnscores) / len(alnscores) if alnscores else 0,
                'min_evalue': min(evalues) if evalues else None,
                'max_alnscore': max(alnscores) if alnscores else None,
            }

        return stats

def get_paired_msa(sequences: Union[str, List[str]],
                   output_file: str,
                   genomic_distance: Optional[int] = 20,
                   min_coverage: Optional[float] = None,
                   min_identity: Optional[float] = None,
                   max_evalue: Optional[float] = None,
                   min_alnscore: Optional[float] = None,
                   prefix: Optional[str] = None,
                   host_url: str = "https://api.colabfold.com",
                   cache_dir: Optional[str] = None,
                   use_cache: bool = True,
                   keep_temp: bool = False) -> str:
    """
    Simple wrapper to get paired MSA from ColabFold with extended filtering

    Args:
        sequences: List of protein sequences or a single string with ':' delimiter
        output_file: Path to save the stitched MSA
        genomic_distance: Genomic distance for pairing (default: 20, None for no filtering)
        min_coverage: Minimum coverage filter (0-1 or 0-100 for percentage)
        min_identity: Minimum identity filter (0-1 or 0-100 for percentage)
        max_evalue: Maximum e-value filter (e.g., 1e-5, 0.001)
        min_alnscore: Minimum alignment score filter
        prefix: Optional prefix for sequence IDs
        host_url: ColabFold API URL
        cache_dir: Directory for caching results (default: ~/.colabfold_cache)
        use_cache: Whether to use caching (default: True)
        keep_temp: If True, keep temporary files for debugging (default: False)

    Returns:
        Path to the output file
    """
    # Handle string input
    if isinstance(sequences, str):
        # Split by ':' and clean up
        sequences = [seq.strip() for seq in sequences.split(':') if seq.strip()]

    # Further cleanup of sequences
    cleaned_sequences = []
    for seq in sequences:
        # Remove whitespace and convert to uppercase
        seq = ''.join(seq.split()).upper()
        # Only add non-empty sequences
        if seq:
            cleaned_sequences.append(seq)

    if not cleaned_sequences:
        raise ValueError("No valid sequences provided")

    sequences = cleaned_sequences

    # Store user's requested genomic distance
    user_genomic_distance = genomic_distance

    # Always fetch with distance=20 for better caching
    fetch_genomic_distance = 20 if genomic_distance is not None else None

    # Create temp directory
    if keep_temp:
        # Create a permanent temp directory
        temp_dir = tempfile.mkdtemp(prefix="colabfold_")
        print(f"Temporary files will be kept in: {temp_dir}")
    else:
        # Use context manager for automatic cleanup
        temp_context = tempfile.TemporaryDirectory()
        temp_dir = temp_context.__enter__()

    try:
        # Create handler with cache support
        msa = ColabFoldPairedMSA(host_url, cache_dir)

        # Store cache key for later use - always use distance=20 for caching
        cache_key = msa._get_cache_key(sequences, fetch_genomic_distance, prefix)
        msa._current_cache_key = cache_key

        # Submit or load from cache
        job_id, from_cache = msa.submit_or_load_from_cache(sequences, fetch_genomic_distance, prefix, use_cache)

        # Only wait and download if it's a new job (not from cache)
        if not from_cache:
            msa.wait()
            # Download and parse
            msa.download_and_parse(temp_dir)

        # Save MSA to temporary file first with user's requested distance filtering
        temp_output = os.path.join(temp_dir, "temp_output.a3m")
        num_sequences, filtered_entries = msa.save_msa(
            temp_output,
            min_coverage,
            min_identity,
            max_evalue,
            min_alnscore,
            max_genomic_distance=user_genomic_distance
        )

        # Move to final location
        shutil.move(temp_output, output_file)

        print(f"\nMSA saved to: {output_file}")

        # Print stats BEFORE filtering
        stats_before = msa.get_stats()
        print(f"\n=== Statistics BEFORE filtering ===")
        print(f"Total entries: {stats_before['num_entries']}")
        for i in range(stats_before['num_chains']):
            chain_stats = stats_before[f'chain_{i + 101}']
            print(f"Chain {i + 101}: query_length={chain_stats['query_length']}, "
                  f"avg_coverage={chain_stats['avg_coverage']:.2f}, "
                  f"avg_identity={chain_stats['avg_identity']:.2f}"
                  )

        # Print stats AFTER filtering (if any filtering was applied)
        if any([min_coverage, min_identity, max_evalue is not None, min_alnscore is not None, user_genomic_distance != fetch_genomic_distance]):
            stats_after = msa.get_stats(filtered_entries)
            print(f"\n=== Statistics AFTER filtering ===")
            print(f"Total entries: {stats_after['num_entries']} (saved)")
            for i in range(stats_after['num_chains']):
                chain_stats = stats_after[f'chain_{i + 101}']
                print(f"Chain {i + 101}: query_length={chain_stats['query_length']}, "
                      f"avg_coverage={chain_stats['avg_coverage']:.2f}, "
                      f"avg_identity={chain_stats['avg_identity']:.2f}"
                      )

        if keep_temp:
            print(f"\nTemporary files kept in: {temp_dir}")
            print("Files:")
            for file in os.listdir(temp_dir):
                print(f"  - {file}")

    finally:
        # Clean up if not keeping temp files
        if not keep_temp:
            temp_context.__exit__(None, None, None)

    return output_file


def get_unique_jobname(base_jobname):
    """Get a unique jobname by incrementing if directory already exists."""
    if not os.path.exists(base_jobname):
        return base_jobname

    counter = 1
    while os.path.exists(f"{base_jobname}_{counter}"):
        counter += 1

    return f"{base_jobname}_{counter}"

def prepare_sequences(sequence, remove_duplicates = True):
    """
    Clean and prepare sequences from input string.

    Args:
        sequence: Raw sequence string, chains separated by ':'
        remove_duplicates: If True, removes duplicate sequences while preserving order

    Returns:
        Tuple of (cleaned sequences, chain break indices)
    """
    # Clean sequence
    sequence = sequence.upper()
    sequence = re.sub("[^A-Z:/()]", "", sequence)
    sequence = re.sub(r"\(", ":(", sequence)
    sequence = re.sub(r"\)", "):", sequence)
    sequence = re.sub(":+", ":", sequence)
    sequence = re.sub("/+", "/", sequence)
    sequence = re.sub("^[:/]+", "", sequence)
    sequence = re.sub("[:/]+$", "", sequence)

    # Split into individual sequences
    sequences = sequence.split(":")
    sequences = [seq for seq in sequences if seq]

    # Remove duplicates while preserving order
    if remove_duplicates and len(sequences) > 1:
        seen = set()
        unique_sequences = []

        for seq in sequences:
            if seq not in seen:
                seen.add(seq)
                unique_sequences.append(seq)

        if len(unique_sequences) < len(sequences):
            print(f"Note: Removed {len(sequences) - len(unique_sequences)} duplicate sequence(s)")

        sequences = unique_sequences

    # Calculate chain break indices for the final sequences
    chain_breaks = []
    position = 0
    for i, seq in enumerate(sequences[:-1]):  # All except last sequence
        position += len(seq)
        chain_breaks.append(position)

    return sequences, chain_breaks

#################################################################################################
def convert_to_numpy(obj):
    """
    Recursively convert PyTorch tensors to numpy arrays, handling BFloat16 and nested structures.
    """
    if isinstance(obj, torch.Tensor):
        # Convert BFloat16 to Float32 first, then to numpy
        if obj.dtype == torch.bfloat16:
            return obj.float().cpu().numpy()
        else:
            return obj.cpu().numpy()
    elif isinstance(obj, dict):
        # Recursively convert dictionary values
        return {key: convert_to_numpy(value) for key, value in obj.items()}
    elif isinstance(obj, list):
        # Recursively convert list elements
        return [convert_to_numpy(item) for item in obj]
    elif isinstance(obj, tuple):
        # Recursively convert tuple elements
        return tuple(convert_to_numpy(item) for item in obj)
    else:
        # Return as-is for non-tensor types
        return obj

def clear_gpu_memory(keep_model=True):
    """
    Clear GPU memory while optionally keeping the model.
    """
    # Get all objects in memory
    for obj in gc.get_objects():
        try:
            if isinstance(obj, torch.Tensor):
                # Skip model parameters if we want to keep the model
                if keep_model and hasattr(obj, '_base') and obj._base is not None:
                    continue
                del obj
        except:
            pass

    # Multiple rounds of garbage collection
    for _ in range(3):
        gc.collect()

    # Clear CUDA cache
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        torch.cuda.synchronize()

#################################################################################################

def run_msa_analysis(
    msa_file,
    sequences,
    breaks=None,
    max_msa_depth=512,
    mode="contacts",  # "contacts", "conservation", or "jacobian"
    mutation_subset=None,  # For jacobian mode: which mutations to test
    device=None,
    show_progress=True,
    use_query_biasing=True,
    fix_weights=True
):
    """
    Unified function for all MSA Pairformer analyses.

    Args:
        msa_file: Path to MSA file
        sequences: List of protein sequences or a single string with ':' delimiter
        breaks: Chain break indices (optional - will be computed from sequences if not provided)
        max_msa_depth: Maximum MSA depth
        mode: Analysis mode:
            - "contacts": Returns contact map from contact head
            - "conservation": Returns (L, 20) matrix of p(aa) at each position
            - "jacobian": Categorical Jacobian with flexible mutation subset
        mutation_subset: For jacobian mode, which mutations to test:
            - None: all 20 amino acids (default)
            - ['F', 'D', 'V']: specific amino acids
            - [13, 3, 19]: amino acid indices (0-19)
        device: Device to run on (defaults to CUDA if available)
        show_progress: Whether to show progress bar (for conservation/jacobian modes)

    Returns:
        - If mode="contacts": contact matrix (L, L) or full results dict
        - If mode="conservation": (L, 20) matrix of amino acid probabilities
        - If mode="jacobian": (L, K, L, 20) where K = len(mutation_subset) or 20 if None
    """
    # Handle sequence input
    if isinstance(sequences, str):
        sequences = [seq.strip() for seq in sequences.split(':') if seq.strip()]

    # Compute total length and breaks if not provided
    total_length = sum(len(seq) for seq in sequences)
    if breaks is None and len(sequences) > 1:
        breaks = []
        position = 0
        for seq in sequences[:-1]:
            position += len(seq)
            breaks.append(position)
    elif breaks is None:
        breaks = []

    try:
        # Set device
        if device is None:
            device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

        # Clear memory before starting
        clear_gpu_memory(keep_model=True)

        # Load and process MSA
        np.random.seed(42)
        msa_obj = MSA(
            msa_file_path=msa_file,
            max_seqs=max_msa_depth,
            max_length=total_length,
            max_tokens=1e12,
            diverse_select_method="hhfilter",
            hhfilter_kwargs={"binary": "hhfilter"}
        )

        # Store MSA depth
        msa_depth = msa_obj.n_diverse_seqs

        # Prepare MSA tensors
        msa_tokenized_t = msa_obj.diverse_tokenized_msa
        msa_onehot_t = torch.nn.functional.one_hot(msa_tokenized_t, num_classes=len(aa2tok_d)).unsqueeze(0).float().to(device)

        # Prepare masks
        mask, msa_mask, full_mask, pairwise_mask = prepare_msa_masks(msa_obj.diverse_tokenized_msa.unsqueeze(0))
        mask = mask.to(device)
        msa_mask = msa_mask.to(device)
        full_mask = full_mask.to(device)
        pairwise_mask = pairwise_mask.to(device)

        # Get sequence length
        seq_length = msa_onehot_t.shape[2]

        # Set model to eval mode
        global_model.eval()

        # Configure query-biasing
        if use_query_biasing:
            global_model.turn_on_query_biasing()
        else:
            global_model.turn_off_query_biasing()

        # Initialize common model kwargs
        model_kwargs = {
            'mask': mask,
            'msa_mask': msa_mask,
            'full_mask': full_mask,
            'pairwise_mask': pairwise_mask,
            'complex_chain_break_indices': [breaks] if breaks else None
        }

        # Initialize results dictionary
        results= {}

        with torch.no_grad():
            # Step 1: Always run a forward pass for logits, contacts, and sequence weights
            with torch.amp.autocast(dtype=torch.bfloat16, device_type="cuda"):
                contacts_res = global_model(
                    msa=msa_onehot_t.to(torch.bfloat16),
                    return_seq_weights=True,
                    **model_kwargs
                )
            # Store all contact results
            contacts_numpy = convert_to_numpy(contacts_res)
            results.update({
                "predicted_cb_contacts": contacts_numpy["predicted_cb_contacts"][0],
                "predicted_confind_contacts": contacts_numpy["predicted_confind_contacts"][0],
                "seq_weights_list_d": contacts_numpy["seq_weights_list_d"],
                "total_length": total_length,
                "max_msa_depth": max_msa_depth,
                "msa_depth": msa_depth,
                "weight_scale": msa_onehot_t.shape[1]
            })

            # Step 2: Run conservation and/or categorical jacobian analyses if specified
            if mode in ["conservation", "jacobian", "all"]:
                # Turn off query biasing if using fixed sequence weights from initial forward pass
                if use_query_biasing and fix_weights:
                    global_model.turn_off_query_biasing()
                    model_kwargs['seq_weights_dict'] = contacts_res['seq_weights_list_d']
                # Set up forward function based on sequence weighting strategy
                def f(msa_input, return_probs=False):
                    with torch.amp.autocast(dtype=torch.bfloat16, device_type="cuda"):
                        res = global_model(
                            msa=msa_input.to(torch.bfloat16),
                            return_seq_weights=False,
                            return_contacts=False,
                            query_only=True,
                            **model_kwargs
                        )
                    logits = res['logits'][0, 0, :seq_length, :20].float()
                    if return_probs:
                        return torch.nn.functional.softmax(logits, dim=-1).cpu().numpy()
                    else:
                      return logits.cpu().numpy()

                # Conservation analysis
                if mode in ["conservation", "all"]:
                    # Initialize conservation matrix
                    conservation = np.zeros((seq_length, 20))
                    # Mask each position and compute predicted profiles
                    for n in range(seq_length):
                        if show_progress:
                            print(f"\rConservation: {n+1}/{seq_length} ({100*(n+1)//seq_length}%)", end="", flush=True)
                        # Mask position
                        msa_h = msa_onehot_t.clone()
                        msa_h[0, 0, n, :] = 0
                        msa_h[0, 0, n, aa2tok_d["<mask>"]] = 1
                        # Compute amino acid probabilities
                        probs = f(msa_h, return_probs=True)
                        conservation[n] = probs[n]
                    if show_progress:
                        print()
                    results['conservation'] = conservation

                if mode in ["jacobian", "all"]:
                    # Get baseline logits
                    fx = f(msa_onehot_t, return_probs=False)
                    wt_sequence = msa_tokenized_t[0].cpu().numpy()

                    # Parse mutation subset
                    if mutation_subset is None:
                        mutation_indices = list(range(20))
                    else:
                        mutation_indices = []
                        for mut in mutation_subset:
                            if isinstance(mut, str):
                                assert mut in aa2tok_d, f"Invalid mutation: {mut}"
                                mutation_indices.append(aa2tok_d[mut.upper()])
                            else:
                                assert int(mut) <= 27, f"Invalid mutation index: {mut}"
                                mutation_indices.append(int(mut))

                    # Initialize jacobian tensor and iterate over positions and mutations
                    fx_h = np.zeros((seq_length, len(mutation_indices), seq_length, 20))
                    for n in range(seq_length):
                        if show_progress:
                            print(f"\rJacobian: {n+1}/{seq_length} ({100*(n+1)//seq_length}%)", end="", flush=True)
                        wt_aa = wt_sequence[n]
                        for idx, mutation_aa in enumerate(mutation_indices):
                            if mutation_aa == wt_aa and mutation_subset is None:
                                fx_h[n, idx] = fx.copy()
                            else:
                                msa_h = msa_onehot_t.clone()
                                msa_h[0, 0, n, :] = 0
                                msa_h[0, 0, n, mutation_aa] = 1
                                fx_h[n, idx] = f(msa_h, return_probs=False)
                    if show_progress:
                        print()
                    # Compute delta
                    results['jacobian'] = fx - fx_h

        # Clean up
        del msa_onehot_t, mask, msa_mask, full_mask, pairwise_mask, msa_tokenized_t, msa_obj
        clear_gpu_memory(keep_model=True)

        return results

    except Exception as e:
        print(f"Error in MSA analysis: {e}")
        clear_gpu_memory(keep_model=True)
        raise

#################################################################################################

# Also include the jac_to_con function from your code
def jac_to_con(jac, center=True, diag="remove", apc=True,
               symm_before=True, symm_after=False):
    """Convert Jacobian to contact map"""
    X = jac.copy()
    Lx, Ax, Ly, Ay = X.shape

    if symm_before:
        X = X + X.transpose(2, 3, 0, 1)

    if center:
        for i in range(4):
            if X.shape[i] > 1:
                X -= X.mean(i, keepdims=True)

    contacts = np.sqrt(np.square(X).sum((1, 3)))

    if symm_after:
        contacts = contacts + contacts.T

    if diag == "remove":
        np.fill_diagonal(contacts, 0)

    if diag == "normalize":
        contacts_diag = np.diag(contacts)
        contacts = contacts / np.sqrt(contacts_diag[:, None] * contacts_diag[None, :])

    if apc:
        ap = contacts.sum(0, keepdims=True) * contacts.sum(1, keepdims=True) / contacts.sum()
        contacts = contacts - ap

    if diag == "remove":
        np.fill_diagonal(contacts, 0)

    return contacts

#################################################################################################
import pandas as pd
import numpy as np
from bokeh.io import output_notebook
from bokeh.plotting import figure, show
from bokeh.transform import linear_cmap
from bokeh.palettes import gray, viridis, RdBu

output_notebook()

class ContactAnalyzer:
    def __init__(self, contacts, sequences, breaks, title):
        self.sequences = sequences
        self.breaks = breaks if breaks else []
        self.contacts = contacts
        self._prepare_data()
        self.title = title

    def _prepare_data(self):
        """Create position mapping with chain info"""
        full_seq = ''.join(self.sequences)
        chain_starts = [0] + self.breaks
        chain_ends = self.breaks + [len(full_seq)]

        # Map each position to chain info
        self.pos_info = {}
        self.chain_info = {}
        for i, (start, end) in enumerate(zip(chain_starts, chain_ends)):
            chain = chr(65 + i)  # A, B, C...
            seq = self.sequences[i]
            for j in range(start, end):
                pos_in_chain = j - start + 1
                abs_pos = j + 1
                self.pos_info[abs_pos] = f"{chain}:{pos_in_chain}{seq[j - start]}"
                self.chain_info[abs_pos] = chain

    def get_table(self, min_score=None):
        """Get contact table for display"""
        data = []
        n = self.contacts.shape[0]

        for i in range(n):
            for j in range(i + 1, n):  # Upper triangle only
                if min_score is None or self.contacts[i, j] >= min_score:
                    chain_i = self.chain_info[i + 1]
                    chain_j = self.chain_info[j + 1]
                    interaction = 'intra' if chain_i == chain_j else 'inter'

                    data.append({
                        'Residue i': self.pos_info[i + 1],
                        'Residue j': self.pos_info[j + 1],
                        'Score': f"{self.contacts[i, j]:.3f}",
                        'Chain i': chain_i,
                        'Chain j': chain_j,
                        'Type': interaction
                    })

        df = pd.DataFrame(data)
        return df.sort_values('Score', ascending=False)

    def get_plot_data(self, threshold=None):
        """Get filtered contact data for plotting"""
        data = []
        n = self.contacts.shape[0]

        for i in range(n):
            for j in range(n):
                if threshold is None or self.contacts[i, j] >= threshold:
                    chain_i = self.chain_info[i + 1]
                    chain_j = self.chain_info[j + 1]

                    data.append({
                        'i': str(i + 1),
                        'j': str(j + 1),
                        'value': self.contacts[i, j],
                        'label_i': self.pos_info[i + 1],
                        'label_j': self.pos_info[j + 1],
                        'type': 'intra' if chain_i == chain_j else 'inter'
                    })

        return pd.DataFrame(data)

    def plot(self, threshold=None, size=800):
        """Create bokeh plot"""
        from bokeh.plotting import figure, show
        from bokeh.transform import linear_cmap
        from bokeh.palettes import gray

        df = self.get_plot_data(threshold)
        n = self.contacts.shape[0]

        p = figure(
            width=size, height=size,
            x_range=[str(i) for i in range(1, n + 1)],
            y_range=[str(i) for i in range(1, n + 1)][::-1],
            tools="hover,save",
            tooltips=[
                ("Residue i", "@label_i"),
                ("Residue j", "@label_j"),
                ("Score", "@value{0.000}"),
                ("Type", "@type")
            ],
            title=self.title
        )
        p.title.text_font_size = '16pt'

        p.rect(x="i", y="j", width=1, height=1, source=df,
               fill_color=linear_cmap('value', gray(256)[::-1],
                                      low=df.value.min(),
                                      high=df.value.max()
                                      ),
               line_color=None)

        # Add chain breaks
        for b in self.breaks:
            p.line([str(b + 1)] * 2, ['1', str(n)], color='red', width=2)
            p.line(['1', str(n)], [str(b + 1)] * 2, color='red', width=2)

        p.xaxis.visible = False
        p.yaxis.visible = False
        p.grid.visible = False

        show(p)
#################################################################################################

def _setup_tools():
  """Download and compile C++ tools."""

  # Install HHsuite
  hhsuite_path = "hhsuite"
  if not os.path.isdir(hhsuite_path):
      print("Installing HHsuite...")
      os.makedirs(hhsuite_path, exist_ok=True)
      url = "https://github.com/soedinglab/hh-suite/releases/download/v3.3.0/hhsuite-3.3.0-SSE2-Linux.tar.gz"
      os.system(f"curl -fsSL {url} | tar xz -C {hhsuite_path}/")

  os.environ['PATH'] += f":{hhsuite_path}/bin:{hhsuite_path}/scripts"


python_version = f"{version_info.major}.{version_info.minor}"

# this part might not be needed
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'

# Suppress progress bars and warnings
os.environ["HF_HUB_DISABLE_PROGRESS_BARS"] = "1"
warnings.filterwarnings("ignore", category=FutureWarning, message=".*torch.distributed.reduce_op.*")
warnings.filterwarnings("ignore", category=UserWarning, message=".*HF_TOKEN.*")

# Install MSA Pairformer (only if not already installed)
if not os.path.isdir("MSA_Pairformer"):
    print("Setting up MSA Pairformer...")

    # Capture output for git clone
    GIT_REPO = 'https://github.com/yoakiyama/MSA_Pairformer'
    TMP_DIR = "tmp"
    os.makedirs(TMP_DIR, exist_ok=True)

    result = subprocess.run(
        f"git clone {GIT_REPO}.git",
        shell=True,
        capture_output=True,
        text=True
    )

    # Capture pip install output
    with io.StringIO() as buf, redirect_stdout(buf), redirect_stderr(buf):
        subprocess.run(
            ["pip", "install", "-e", "MSA_Pairformer/", "--no-deps"],
            capture_output=True,
            text=True
        )
        subprocess.run(
            ["pip", "install", "biopython", "einx", "jaxtyping"],
            capture_output=True,
            text=True
        )

    importlib.invalidate_caches()
    # Add the package to Python path
    package_path = os.path.abspath("MSA_Pairformer")
    if package_path not in sys.path:
        sys.path.insert(0, package_path)

    print("✓ MSA Pairformer installed successfully")

# Import MSA Pairformer modules
from MSA_Pairformer.model import MSAPairformer
from MSA_Pairformer.dataset import MSA, aa2tok_d, prepare_msa_masks

# Initialize device
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
device_name = torch.cuda.get_device_name(device) if device.type == 'cuda' else 'CPU'
print(f"Using device: {device_name}")

# Load model ONCE and store globally
if 'global_model' not in globals():
    print("Loading MSA Pairformer model (this will only happen once)...")

    # Suppress HuggingFace warnings during model loading
    with warnings.catch_warnings():
        warnings.filterwarnings("ignore")
        # Optionally capture stdout/stderr if the model loading is still too verbose
        with io.StringIO() as buf, redirect_stderr(buf):
            global_model = MSAPairformer.from_pretrained(device=device).to(torch.bfloat16)

    print("✓ Model loaded successfully and cached for reuse!")
else:
    print("✓ Using cached model")

_setup_tools()

In [None]:
# @title Settings
#@markdown **inputs**
sequence_a = "PIAQIHILEGRSDEQKETLIREVSEAISRSLDAPLTSVRVIITEMAKGHFGIGGELASK" # @param {"type":"string"}
sequence_b = "" # @param {"type":"string"}
sequence_c = "" # @param {"type":"string"}
jobname = "test"# @param {"type":"string"}

jobname = re.sub(r'\W+', '', jobname)
sequence = f"{sequence_a}:{sequence_b}:{sequence_c}"

sequences, breaks = prepare_sequences(sequence)
print("lengths",[len(x) for x in sequences])

#@markdown ----
#@markdown **MSA options**
msa_method = "mmseqs2" #@param ["mmseqs2", "custom_fas", "custom_a3m", "custom_sto"]

#@markdown **MSA filters** (only applied when msa_method=mmseqs2)
cov = 75 #@param ["0", "25", "50", "75", "90", "99"] {type:"raw"}
qid = 15 #@param ["0", "15", "20", "30", "40"] {type:"raw"}
#@markdown For MSA Pairformer analyses, we typically recommend starting with
#@markdown 75% coverage (cov), and 15% minimum sequence identity with query (qid).

#@markdown ----
#@markdown **Multimer settings** (experimental option)
neighbor_stitching = True #@param {type:"boolean"}
Δgene = 1 #@param ["0", "1", "5", "10", "20"] {type:"raw"}
#@markdown For prokaryotes, it's sometimes helpful to stitch genes based on how far part the genes are on the genome.

jobname = get_unique_jobname(jobname)
os.makedirs(jobname, exist_ok=True)

msa_file = f"{jobname}/msa.a3m"
if msa_method == "mmseqs2":
  get_paired_msa(
      sequences,
      msa_file,
      min_coverage=cov,
      min_identity=qid,
      genomic_distance=Δgene if neighbor_stitching else None,  # This passes the user's selected distance
  )
else:
  msa_format = msa_method.split("_")[1]
  print(f"upload {msa_method}")
  msa_dict = files.upload()
  lines = []
  for k,v in msa_dict.items():
    lines += v.decode().splitlines()
  input_lines = []
  for line in lines:
    line = line.replace("\x00","")
    if len(line) > 0 and not line.startswith('#'):
      input_lines.append(line)
  with open(f"{jobname}/msa.{msa_format}","w") as msa:
    msa.write("\n".join(input_lines))
  if msa_format != "a3m":
    os.system(f"perl hhsuite/scripts/reformat.pl {msa_format} a3m {jobname}/msa.{msa_format} {msa_file}")

print(f"MSA saved to: {msa_file}")

In [None]:
#@title run MSA Pairformer

mode = "sequence_weights" # @param ["cb_contacts","confind_contacts","conservation","jacobian_approx","jacobian","sequence_weights", "all"]
# max_msa_depth = 512 # @param ["64","128","256","512","1024"] {"type":"raw"}
max_msa_depth = 512 # @param {type:"integer"}
use_query_biasing = True # @param {type:"boolean"}
if use_query_biasing:
  global_model.turn_on_query_biasing()
else:
  global_model.turn_off_query_biasing()

# Use only CERN to compute jacobian if approximating
if mode == "jacobian_approx":
    mutation_subset = ["C","E","R","N"]
    jacobian_title = "Approximate categorical Jacobian (C, E, R, N)"
    mode = "jacobian"
else:
    mutation_subset = None
    jacobian_title = "Categorical Jacobian"

# Run MSA analyses
results = run_msa_analysis(
    msa_file=msa_file,
    sequences=sequences,
    breaks=breaks,
    max_msa_depth=max_msa_depth,
    mode=mode,
    mutation_subset=mutation_subset,
    device=device,
    show_progress=True,
    use_query_biasing=use_query_biasing,
    fix_weights=True
)

# Save and plot predicted amino acid profiles
if mode in ["conservation", "all"]:
    # Save results
    conservation = results['conservation']
    np.savetxt(f"{jobname}/conservation.txt",conservation)
    # Plot results
    alphabet = "ARNDCEQGHILKMFPSTWYV"
    sequence_length = conservation.shape[0]
    idx = [str(i) for i in np.arange(1, sequence_length+1)]
    df = pd.DataFrame(conservation, index=idx, columns=list(alphabet))
    df = df.stack().reset_index()
    df.columns = ['Position', 'Amino Acid', 'Probability']
    num_colors = 256
    palette = viridis(num_colors)
    TOOLS="hover,save,pan,box_zoom,reset,wheel_zoom"
    p = figure(title="CONSERVATION",
              x_range=[str(x) for x in range(1, sequence_length+1)],
              y_range=list(alphabet)[::-1],
              width=900, height=400,
              tools=TOOLS, toolbar_location='below',
              tooltips=[('Position', '@Position'), ('Amino Acid', '@{Amino Acid}'), ('Probability', '@Probability')])

    r = p.rect(x="Position", y="Amino Acid", width=1, height=1, source=df,
              fill_color=linear_cmap('Probability', palette, low=0, high=1),
              line_color=None)
    p.xaxis.visible = False  # Hide the x-axis
    show(p)

# Save and plot predicted Cb-Cb contacts
if mode in ["cb_contacts", "all"]:
    # Save predicted confind contacts
    cb_contacts = results["predicted_cb_contacts"]
    np.savetxt(f"{jobname}/predicted_cb_contacts.txt", cb_contacts)
    # Plot predicted contacts
    cb_analyzer = ContactAnalyzer(cb_contacts, sequences, breaks, "Prediced Cb-Cb contacts")
    cb_analyzer.plot(size=512)

# Save and plot predicted ConFind contacts
if mode in ["confind_contacts", "all"]:
    # Save predicted confind contacts
    confind_contacts = results["predicted_confind_contacts"]
    np.savetxt(f"{jobname}/predicted_confind_contacts.txt", confind_contacts)
    # Plot predicted contacts
    confind_analyzer = ContactAnalyzer(confind_contacts, sequences, breaks, "Predicted ConFind contacts")
    confind_analyzer.plot(size=512)

# Save and plot categorical jacobian (or approximated version)
if mode in ["jacobian", "jacobian_approx", "all"]:
    # Save catjac results
    jac_contacts = jac_to_con(results["jacobian"], symm_before=False, symm_after=True)
    np.savetxt(f"{jobname}/jacobian_contacts.txt", jac_contacts)
    # Plot catjac results
    catjac_analyzer = ContactAnalyzer(jac_contacts, sequences, breaks, jacobian_title)
    catjac_analyzer.plot(size=512)

if mode in ["sequence_weights", "all"]:
    # Save sequence weights
    with open(f"{jobname}/sequence_weights.pkl", "wb") as f:
        pickle.dump(results["seq_weights_list_d"], f)
    # Plot median sequence weight distribution
    mean_seq_weights_a = np.mean(np.stack([results['seq_weights_list_d'][f"layer_{layer_idx}"][0] for layer_idx in range(16)]), axis=0)
    mean_seq_weights_a *= results["weight_scale"]
    f, ax = plt.subplots(1, 1, figsize=(8,4))
    _ = ax.hist(mean_seq_weights_a, bins=50)
    ax.axvline(x=1, linestyle='--', color='red')
    ax.set_title("Sequence weight distribution", size=18)
    ax.set_xlabel("Normalized sequence weight", size=16)
    ax.set_ylabel("Count", size=16)
    ax.tick_params(axis='both', which='major', labelsize=12)
    plt.show()

In [None]:
#@title Show table of top predicted contacts
from google.colab import data_table
from IPython.display import display, Markdown
if mode in ["all", "cb_contacts"]:
  cb_df = cb_analyzer.get_table(min_score=None)
  display(Markdown("### Top Predicted Cb-Cb Contacts"))
  display(data_table.DataTable(cb_df, include_index=False, num_rows_per_page=20,  min_width=10))

if mode in ["all", "confind_contacts"]:
  confind_df = confind_analyzer.get_table(min_score=None)
  display(Markdown("### Top Predicted ConFind Contacts"))
  display(data_table.DataTable(confind_df, include_index=False, num_rows_per_page=20,  min_width=10))

if mode in ["all", "jacobian", "jacobian_approx"]:
  catjac_df = catjac_analyzer.get_table(min_score=None)
  display(Markdown("### Top Predicted Categorical Jacobian Contacts"))
  display(data_table.DataTable(catjac_df, include_index=False, num_rows_per_page=20,  min_width=10))

In [None]:
#@title download results (optional)
from google.colab import files
os.system(f"zip -r {jobname}.zip {jobname}/")
files.download(f'{jobname}.zip')