In [3]:
# @title Mount Drive and Install BRI Package (Fast Mode)
from google.colab import drive
import os

drive.mount('/content/drive')

package_path = '/content/drive/MyDrive/BRI Analysis/backbone_rigid_invariant-1.2.2.tar.gz'

if os.path.exists(package_path):
  print(f"Found package at: {package_path}")
  print("Installing without recompiling pandas...")

  # 2. Install your package, explicitly telling pip NOT to check versions
  # This bypasses the errors about "incompatible versions"
  !pip install "{package_path}"

print("Installation complete. Ignoring version mismatch warnings.")


Mounted at /content/drive
Found package at: /content/drive/MyDrive/BRI Analysis/backbone_rigid_invariant-1.2.2.tar.gz
Installing without recompiling pandas...
Processing ./drive/MyDrive/BRI Analysis/backbone_rigid_invariant-1.2.2.tar.gz
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Collecting biotite==1.2.0 (from backbone-rigid-invariant==1.2.2)
  Downloading biotite-1.2.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (7.3 kB)
Collecting matplotlib==3.7.2 (from backbone-rigid-invariant==1.2.2)
  Downloading matplotlib-3.7.2.tar.gz (38.1 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m38.1/38.1 MB[0m [31m50.5 MB/s[0m eta [36m0:00:00[0m
[?25h  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Collecti



In [1]:
import bri
import os
import os
import time
import warnings
import pandas as pd
import numpy as np
from tqdm.notebook import tqdm
from multiprocessing import Pool, cpu_count
from datetime import date

# Import the specific cleaning tools
from bri.filter import minientry_integrated_cleaning

os.chdir('/content/drive/MyDrive/BRI Analysis')

In [None]:
# @title Bulk Cleaning (Fast Multiprocessing Version)

# --- 1. Define Helper Function for Parallelization ---
# This must be defined at the top level (outside the loop) to work with multiprocessing
def process_single_entry(pdb_id):
    """
    Wrapper function that processes one PDB ID.
    Returns: (pdb_id, clean_df, dirty_df) or None
    """
    try:
        # The cleaning function from the package
        result = minientry_integrated_cleaning(pdb_id)

        if result is None:
            return None

        clean_df, dirty_df = result
        return (pdb_id, clean_df, dirty_df)

    except Exception as e:
        # Return None if it crashes, so the main loop can skip it
        return None

# --- 2. Setup Input/Output ---
input_file = './data/entry_ids_30Jan25.txt'
today = date.today()
# dd/mm/YY
d1 = today.strftime("%d/%m/%Y")
clean_output_file = f'./data/cleaned_connective_chains_{d1}.csv'
dirty_output_file = f'./data/dropout_chains_{d1}.csv'

os.makedirs('./data', exist_ok=True)

# --- 3. Read PDB IDs ---
try:
    with open(input_file, 'r') as f:
        content = f.read()
    pdb_ids = [pid.strip().replace("'", "").replace('"', "")
               for pid in content.replace('\n', ',').split(',')
               if pid.strip()]

    # OPTIONAL: Uncomment to test on a small subset first
    #pdb_ids = pdb_ids[0:100]

    print(f"Loaded {len(pdb_ids)} unique PDB IDs to process.")
except FileNotFoundError:
    print(f"ERROR: Could not find {input_file}")
    pdb_ids = []

# --- 4. Run Parallel Processing ---
if pdb_ids:
    print(f"Starting processing with {cpu_count()} CPU cores...")
    # We increase processes slightly beyond core count to help with network latency (downloading)
    # If standard Colab has 2 cores, 4 workers is usually a sweet spot for mixed I/O tasks.
    num_workers = max(cpu_count(), 4)

    start_time = time.time()

    clean_chains_list = []
    dirty_chains_list = []

    # Create a Pool of workers
    with Pool(processes=num_workers) as pool:
        # imap_unordered is faster as it yields results as soon as they finish
        # We wrap it in tqdm for the progress bar
        results = list(tqdm(pool.imap(process_single_entry, pdb_ids),
                           total=len(pdb_ids),
                           desc="Processing Entries"))

    # --- 5. Aggregate Results ---
    print("Aggregating results...")
    for res in results:
        if res is None:
            continue

        _, clean_df, dirty_df = res

        if clean_df is not None and not clean_df.empty:
            clean_chains_list.append(clean_df)

        if dirty_df is not None and not dirty_df.empty:
            dirty_chains_list.append(dirty_df)

    end_time = time.time()
    total_duration = end_time - start_time

    # --- 6. Save Files ---
    print("\nSaving to disk...")

    if clean_chains_list:
        final_clean_df = pd.concat(clean_chains_list, ignore_index=True)
        # Filter columns
        cols = ['pdb_id', 'entity_id', 'model_id', 'chain_id', 'start_residue', 'chain_length', 'seq']
        final_cols = [c for c in cols if c in final_clean_df.columns]
        final_clean_df = final_clean_df[final_cols]

        final_clean_df.to_csv(clean_output_file, index=False)
        print(f"✅ CLEANED chains saved to: {clean_output_file}")
        print(f"   Count: {len(final_clean_df)} chains")
    else:
        print("⚠️ No clean chains found.")

    if dirty_chains_list:
        final_dirty_df = pd.concat(dirty_chains_list, ignore_index=True)
        final_dirty_df.to_csv(dirty_output_file, index=False)
        print(f"❌ REJECTED chains saved to: {dirty_output_file}")
        print(f"   Count: {len(final_dirty_df)} chains")

    # Report Stats
    print("\n" + "="*30)
    print(f"Total Time:      {total_duration:.2f}s ({total_duration/60:.2f} min)")
    print(f"Avg Time/Entry:  {total_duration/len(pdb_ids):.4f}s")
    print("="*30)

Loaded 230744 unique PDB IDs to process.
Starting processing with 44 CPU cores...


Processing Entries:   0%|          | 0/230744 [00:00<?, ?it/s]

Aggregating results...

Saving to disk...
✅ CLEANED chains saved to: ./data/cleaned_connective_chains.csv
   Count: 726838 chains
❌ REJECTED chains saved to: ./data/dropout_chains.csv
   Count: 3817208 chains

Total Time:      4288.74s (71.48 min)
Avg Time/Entry:  0.0186s


In [2]:
# @title Read back in data

clean_output_file = f'./data/cleaned_connective_chains.csv'
dirty_output_file = f'./data/dropout_chains.csv'

final_clean_df = pd.read_csv(clean_output_file)
final_dirty_df = pd.read_csv(dirty_output_file)

#replace nan values of final_clean_df['chain_id'] with string 'NA'
final_clean_df['chain_id'] = final_clean_df['chain_id'].fillna('NA')
final_clean_df['chain_id'] = final_clean_df['chain_id'].astype(str)

#replace nan values of final_dirty_df['chain_id'] with string 'NA'
final_dirty_df['chain_id'] = final_dirty_df['chain_id'].fillna('NA')
final_dirty_df['chain_id'] = final_dirty_df['chain_id'].astype(str)

In [3]:
final_clean_df.head(5)

Unnamed: 0,pdb_id,model_id,chain_id,start_residue,chain_length,seq
0,101M,1,A,1,154,MVLSEGEWQLVLHVWAKVEADVAGHGQDILIRLFKSHPETLEKFDR...
1,102L,1,A,1,163,MNIFEMLRIDEGLRLKIYKDTEGYYTIGIGHLLTKSPSLNAAAKSE...
2,102M,1,A,1,154,MVLSEGEWQLVLHVWAKVEADVAGHGQDILIRLFKSHPETLEKFDR...
3,103M,1,A,1,154,MVLSEGEWQLVLHVWAKVEADVAGHGQDILIRLFKSHPETLEKFDR...
4,104L,1,A,1,164,MNIFEMLRIDEGLRLKIYKDTEGYYTIGIGHLLTKSPSLNAAKSAA...


In [None]:
# @title Robust Batched & Multithreaded BRI
import os
import time
import pandas as pd
import numpy as np
import gc
from tqdm.notebook import tqdm
from multiprocessing import Pool, cpu_count
from bri.pdbx2df import MiniChain

# --- 1. The Worker Function ---
def process_single_chain(row):
    """
    Worker function.
    Note: 'row' is now a dictionary, not a Pandas Series (faster).
    """
    try:
        mc = MiniChain(
            pdb_id=row['pdb_id'],
            model_id=row['model_id'],
            chain_id=row['chain_id'],
            start_residue=row['start_residue'],
            chain_length=row['chain_length']
        )

        bri = mc.get_chain_invariant(angles=True)

        if bri is not None and not bri.empty:
            # Re-attach meta-data
            bri = bri.assign(
                pdb_id=str(row['pdb_id']),
                chain_id=str(row['chain_id']),
                model_id=int(row['model_id']),
                start_residue=int(row['start_residue']),
                chain_length=int(row['chain_length'])
            )

            # Type Casting (Optimized for batch processing)
            geom_cols = [
                'x(AN)', 'x(AC)', 'y(AC)', 'x(N)', 'y(N)', 'z(N)',
                'x(A)', 'y(A)', 'z(A)', 'x(C)', 'y(C)', 'z(C)',
                'length(N)', 'length(A)', 'length(C)',
                'angle(N)', 'angle(A)', 'angle(C)',
                'tau(NA)', 'tau(AC)', 'tau(CN)'
            ]
            valid_floats = [c for c in geom_cols if c in bri.columns]
            bri[valid_floats] = bri[valid_floats].apply(pd.to_numeric, errors='coerce').astype('float64')

            int_cols = ['residue_id']
            bri[int_cols] = bri[int_cols].apply(pd.to_numeric, errors='coerce').fillna(-1).astype(int)

            if 'residue_label' in bri.columns:
                bri['residue_label'] = bri['residue_label'].astype(str)

            return bri

    except Exception:
        # In production, you might want to log this error to a file
        pass
    return None

# --- 2. Main Execution Block ---

if __name__ == '__main__':

    OUTPUT_DIR = './data/bri_computations'
    os.makedirs(OUTPUT_DIR, exist_ok=True)

    BATCH_SIZE = 5000  # <--- Configured to your request

    print("Preparing data for processing...")

    # 1. Convert DataFrame to a list of Dictionaries.
    # This is significantly faster for multiprocessing than passing Pandas Series.
    all_records = final_clean_df.to_dict('records')
    total_records = len(all_records)

    # 2. Split into simple batches
    # creates a list of lists: [[row1, ..., row5000], [row5001, ...]]
    batches = [all_records[i:i + BATCH_SIZE] for i in range(0, total_records, BATCH_SIZE)]

    print(f"Total Chains: {total_records}")
    print(f"Total Batches: {len(batches)} (Size: {BATCH_SIZE})")

    time_start = time.time()
    total_chains_calculated = 0

    # 3. Initialize Pool
    # maxtasksperchild=25 is aggressive: it kills the worker after 25 items.
    # This prevents memory swelling in long-running jobs.
    n_cores = max(cpu_count() - 1, 1) # Leave 1 core for OS/Main process

    with Pool(processes=n_cores, maxtasksperchild=25) as pool:

        # Iterate over batches
        for i, batch in enumerate(tqdm(batches, desc="Processing Batches")):

          if i>103:

            # Use imap_unordered for speed. It yields results as soon as they are ready.
            # chunksize=10 reduces IPC overhead.
            results_iter = pool.imap_unordered(process_single_chain, batch, chunksize=10)

            batch_results = []

            # We iterate the results generator to collect processed data
            for res in results_iter:
                if res is not None:
                    batch_results.append(res)

            # Save Batch
            if batch_results:
                batch_df = pd.concat(batch_results, ignore_index=True)

                # Naming file by batch index (batch_0.parquet, batch_1.parquet...)
                save_path = f'{OUTPUT_DIR}/batch_{i}.parquet'
                batch_df.to_parquet(save_path, index=False, compression='snappy')

                total_chains_calculated += len(batch_results)

            # --- MEMORY CLEANUP ---
            # Explicitly delete large objects and force collection
            del batch_results
            if 'batch_df' in locals(): del batch_df
            gc.collect()

    end_time = time.time()
    total_duration = end_time - time_start

    print("\n✅ Processing complete.")
    if total_chains_calculated > 0:
        print(f"{total_chains_calculated} valid chains calculated in {total_duration:.2f} seconds.")
        print(f"Speed: {total_duration/total_chains_calculated:.4f} seconds/chain")
    else:
        print("No valid chains were calculated.")

Preparing data for processing...
Total Chains: 726838
Total Batches: 146 (Size: 5000)


Processing Batches:   0%|          | 0/146 [00:00<?, ?it/s]


✅ Processing complete.
206628 valid chains calculated in 14198.71 seconds.
Speed: 0.0687 seconds/chain


In [2]:
# @title Code to create AlphaFold predictions

In [None]:
# quickly get 700k MMSeqs a3m files for the chains
# get structures
# compute BRI

# How long will this take? Might involve way too much compute

In [None]:
# @title Auto-Shutdown Runtime
from google.colab import runtime

print("Job complete. Shutting down runtime in 30 seconds to save compute units...")
time.sleep(30) # Small buffer to ensure logs/files are saved
runtime.unassign()

Job complete. Shutting down runtime in 30 seconds to save compute units...
