In [1]:
from google.colab import drive
import os

drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import tqdm
from scipy.spatial import KDTree
import ast
import seaborn as sns
import time
import os
from multiprocessing import Pool
import requests
import ast
from matplotlib.colors import LogNorm
import pickle
import scipy.sparse as sp
from scipy.signal import convolve2d
import gc

In [3]:
os.chdir('/content/drive/MyDrive/BRI Analysis')

In [6]:
restrict_suffix = ""

### Now do some plotting

In [7]:
import os
import pandas as pd
import tqdm
import gc

# @title 1. Generate Mean Invariants with Batch Numbers
inv_dir = "./data/bri_computations"
output_file = "./data/PDB727K_mean_invariants_with_batch.csv"

if os.path.exists(output_file):
    os.remove(output_file)

# Get files and sort them to ensure deterministic order (optional but good practice)
files = sorted([f for f in os.listdir(inv_dir) if f.endswith('.parquet')])

for i, filename in enumerate(tqdm.tqdm(files, desc="Processing Batches")):
    try:
        # Extract batch number from filename "batch_123.parquet"
        # Adjust split logic if your naming convention differs
        try:
            batch_num = int(filename.split('_')[1].split('.')[0])
        except (IndexError, ValueError):
            # Fallback if filename is weird, though user stated "batch_i.parquet"
            print(f"Warning: Could not parse batch number from {filename}. assigning {i}.")
            batch_num = i

        inv_data = pd.read_parquet(os.path.join(inv_dir, filename))

        # Calculate means
        mean_data = calculate_means(inv_data)

        # --- ADD BATCH NUMBER ---
        mean_data['batch_number'] = batch_num

        # Write incrementally
        mode = 'w' if i == 0 else 'a'
        header = (i == 0)
        mean_data.to_csv(output_file, index=False, mode=mode, header=header)

        del inv_data, mean_data
        if i % 10 == 0: gc.collect()

    except Exception as e:
        print(f"Skipping {filename} due to error: {e}")

Processing Batches: 100%|██████████| 146/146 [08:23<00:00,  3.45s/it]


In [14]:
import pandas as pd
import numpy as np
from scipy.spatial import cKDTree
import os
import tqdm
import gc

# @title 2. Mean Comparison (Partitioned by Chain Length) with Distance
restrict_suffix = ""
input_path = "./data/PDB727K_mean_invariants_with_batch.csv"
output_csv = f"./data/PDB727K_mean_pairs_chebyshev_001{restrict_suffix}.csv" # Updated filename for clarity

bri_features = ['x(N)', 'y(N)', 'z(N)', 'x(A)', 'y(A)', 'z(A)', 'x(C)', 'y(C)', 'z(C)']
id_cols = ['pdb_id', 'model_id', 'chain_id', 'start_residue', 'chain_length', 'batch_number']

# --- UPDATED THRESHOLD ---
radius = 0.01
# -------------------------

query_batch_size = 5000

if not os.path.exists(input_path):
    raise FileNotFoundError(f"Run Step 1 first to generate {input_path}")

print(f"Loading data...")
mean_data_complete = pd.read_csv(input_path)

# 1. Prepare Data
valid_data = mean_data_complete.dropna(subset=bri_features)

# 2. Initialize Output
if os.path.exists(output_csv):
    os.remove(output_csv)

# Update Header to include distance
header_cols = [f"{c}_1" for c in id_cols] + [f"{c}_2" for c in id_cols] + ['chebyshev_dist']
pd.DataFrame(columns=header_cols).to_csv(output_csv, index=False)

# 3. Process by Chain Length
grouped = valid_data.groupby('chain_length')

print(f"Processing {len(grouped)} unique chain lengths...")
total_pairs = 0

for length, group_df in tqdm.tqdm(grouped, desc="Processing Length Groups"):

    if len(group_df) < 2:
        continue

    group_df = group_df.reset_index(drop=True)
    points = group_df[bri_features].values

    tree = cKDTree(points)
    group_pairs = []

    for i in range(0, len(points), query_batch_size):
        batch_points = points[i : i + query_batch_size]

        try:
            results = tree.query_ball_point(batch_points, r=radius, p=np.inf, workers=-1)
        except TypeError:
            results = tree.query_ball_point(batch_points, r=radius, p=np.inf)

        for local_idx, neighbors in enumerate(results):
            if len(neighbors) < 2: continue

            global_idx_1 = i + local_idx

            # Filter (j > i)
            valid_neighbors = [n for n in neighbors if n > global_idx_1]

            if valid_neighbors:
                group_pairs.extend([(global_idx_1, n) for n in valid_neighbors])

    if group_pairs:
        total_pairs += len(group_pairs)

        pairs_arr = np.array(group_pairs)

        # 4. Calculate Distance Vectorized
        p1 = points[pairs_arr[:, 0]]
        p2 = points[pairs_arr[:, 1]]

        dists = np.max(np.abs(p1 - p2), axis=1)

        # Map indices to identifiers
        df_1 = group_df.iloc[pairs_arr[:, 0]][id_cols].reset_index(drop=True)
        df_2 = group_df.iloc[pairs_arr[:, 1]][id_cols].reset_index(drop=True)

        df_1.columns = [f"{c}_1" for c in id_cols]
        df_2.columns = [f"{c}_2" for c in id_cols]

        # Create Distance DataFrame
        dist_df = pd.DataFrame({'chebyshev_dist': dists})

        # Concatenate Identifiers + Distance
        output_chunk = pd.concat([df_1, df_2, dist_df], axis=1)

        output_chunk.to_csv(output_csv, mode='a', header=False, index=False)

    del tree, points, group_df, group_pairs
    if total_pairs % 10000 == 0:
        gc.collect()

print(f"Done. Found {total_pairs} pairs across all lengths.")

Loading data...
Processing 1247 unique chain lengths...


Processing Length Groups: 100%|██████████| 1247/1247 [00:17<00:00, 69.59it/s] 

Done. Found 3247407 pairs across all lengths.





In [22]:
import pandas as pd
import numpy as np
import os
import tqdm
import gc

# @title 3. Full Comparison (Memory Optimized: Filter-on-Load)
# ==============================================================================
# Configuration
# ==============================================================================
restrict_suffix = ""
pairs_file = f"./data/PDB727K_mean_pairs_chebyshev_001{restrict_suffix}.csv"
parquet_dir = "./data/bri_computations"
output_full_diff_file = f"./data/PDB727K_full_comparison_results_001_seq{restrict_suffix}.csv"

full_dist_threshold = 0.01

# Columns to load from Parquet
id_cols = ['pdb_id', 'model_id', 'chain_id', 'start_residue', 'chain_length']
bri_cols = ['x(N)', 'y(N)', 'z(N)', 'x(A)', 'y(A)', 'z(A)', 'x(C)', 'y(C)', 'z(C)']
seq_col = 'residue_label'

load_columns = list(set(id_cols + bri_cols + [seq_col]))

# ==============================================================================
# 1. Identify "Relevant Chains"
# ==============================================================================
if not os.path.exists(pairs_file):
    raise FileNotFoundError("Run Step 2 first.")

print("Loading pairs to identify relevant chains...")
pairs_df = pd.read_csv(pairs_file)

if len(pairs_df) == 0:
    print("No pairs found.")
    exit()

# Extract unique keys (Chain 1 and Chain 2) needed for analysis
# We use a set of tuples for O(1) lookup: (pdb_id, model_id, chain_id, start_residue, chain_length)
print("Building set of required chains...")
keys_1 = list(zip(pairs_df['pdb_id_1'], pairs_df['model_id_1'], pairs_df['chain_id_1'], pairs_df['start_residue_1'], pairs_df['chain_length_1']))
keys_2 = list(zip(pairs_df['pdb_id_2'], pairs_df['model_id_2'], pairs_df['chain_id_2'], pairs_df['start_residue_2'], pairs_df['chain_length_2']))

required_keys = set(keys_1) | set(keys_2)

print(f"Total unique chains to load: {len(required_keys)}")

# ==============================================================================
# 2. Load and Filter Data (One Pass over Files)
# ==============================================================================
# Store data as: chain_data[key] = {'mat': np.array, 'seq': str}
chain_data_store = {}

# Get list of batch files
batch_files = sorted([f for f in os.listdir(parquet_dir) if f.endswith('.parquet')])

print(f"Scanning {len(batch_files)} batch files...")

for f in tqdm.tqdm(batch_files, desc="Loading Data"):
    try:
        path = os.path.join(parquet_dir, f)

        # Load batch (only relevant columns)
        df = pd.read_parquet(path, columns=load_columns)

        # Create a tuple key column for filtering
        # Note: Vectorized zip is faster than apply
        # We ensure types match the pairs_df types (usually int/str)
        current_keys = list(zip(df['pdb_id'], df['model_id'], df['chain_id'], df['start_residue'], df['chain_length']))

        # Filter: keep rows where the key is in our required set
        # Using a boolean mask with map/set is fast
        mask = [k in required_keys for k in current_keys]

        if not any(mask):
            continue # Nothing useful in this batch

        filtered_df = df[mask].copy()

        # Group by chain to extract Matrix and Sequence
        # We groupby the full key
        grouped = filtered_df.groupby(id_cols)

        for key, group in grouped:
            # key is the tuple (pdb, model, chain, start, length)

            # Extract Matrix
            mat = group[bri_cols].to_numpy()

            # Extract Sequence
            labels = group[seq_col]
            if len(labels) > 0 and isinstance(labels.iloc[0], str):
                # Standard case: sequence of characters
                seq = "".join(labels)
            else:
                seq = ""

            chain_data_store[key] = {'mat': mat, 'seq': seq}

        del df, filtered_df, mask, current_keys
        # gc.collect() # Optional here, usually not needed per file if filtered aggressively

    except Exception as e:
        print(f"Error reading {f}: {e}")

print(f"Successfully loaded {len(chain_data_store)} chains into memory.")

# ==============================================================================
# 3. Compute Distances
# ==============================================================================
print("Computing pairwise comparisons...")

results_list = []

# Iterate through pairs and lookup data from memory
for idx, row in tqdm.tqdm(pairs_df.iterrows(), total=len(pairs_df), desc="Comparing"):

    key1 = (row['pdb_id_1'], row['model_id_1'], row['chain_id_1'], row['start_residue_1'], row['chain_length_1'])
    key2 = (row['pdb_id_2'], row['model_id_2'], row['chain_id_2'], row['start_residue_2'], row['chain_length_2'])

    # Retrieve data
    if key1 not in chain_data_store or key2 not in chain_data_store:
        # Should not happen if logic is correct, but safe to skip
        continue

    data1 = chain_data_store[key1]
    data2 = chain_data_store[key2]

    mat1 = data1['mat']
    mat2 = data2['mat']

    # Check length compatibility (should match from Step 2)
    min_len = min(len(mat1), len(mat2))

    # Compute Distance
    dist = np.max(np.abs(mat1[:min_len] - mat2[:min_len]))

    # Check Threshold
    if dist <= full_dist_threshold:
        seq1 = data1['seq']
        seq2 = data2['seq']

        res_row = row.to_dict()
        res_row['full_chebyshev_dist'] = dist
        res_row['sequence_1'] = seq1
        res_row['sequence_2'] = seq2
        res_row['sequences_identical'] = 1 if seq1 == seq2 else 0

        results_list.append(res_row)

# ==============================================================================
# 4. Save Results
# ==============================================================================
if results_list:
    final_df = pd.DataFrame(results_list)
    final_df.to_csv(output_full_diff_file, index=False)
    print(f"Saved {len(final_df)} passing pairs to {output_full_diff_file}")
else:
    print("No pairs passed the full distance threshold.")

Loading pairs to identify relevant chains...
Building set of required chains...
Total unique chains to load: 288643
Scanning 146 batch files...


Loading Data: 100%|██████████| 146/146 [03:01<00:00,  1.24s/it]


Successfully loaded 287499 chains into memory.
Computing pairwise comparisons...


Comparing: 100%|██████████| 3247407/3247407 [01:49<00:00, 29759.81it/s]


Saved 832774 passing pairs to ./data/PDB727K_full_comparison_results_001_seq.csv


In [26]:
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import os

# Configuration
restrict_suffix = ""
input_file = f"./data/PDB727K_full_comparison_results_001_seq{restrict_suffix}.csv"
output_dir = './plotting/nearest_neighbours_001A'

# 1. Create Output Directory
if not os.path.exists(output_dir):
    os.makedirs(output_dir)

# Check if file exists
if os.path.exists(input_file):
    print("Loading data for plotting...")
    df = pd.read_csv(input_file)

    # Filter data
    diff_seq_data = df[df['sequences_identical'] == 0]
    same_seq_data = df[df['sequences_identical'] == 1]

    # Common Plot Settings
    x_label = r'$L_{\infty}$ distance on pairs of BRI, Angstroms'
    y_label = 'Pairs of close chains'
    bins_range = (0, 0.01)
    bin_width = 0.001

    # Helper function to generate plots efficiently
    def generate_histogram(data, color, filename_suffix, log_scale=False):
        plt.figure(figsize=(10, 6))
        sns.set_style("whitegrid")
        sns.set(font_scale=1.2)

        # Plot (Note: edgecolor removed to drop black border)
        sns.histplot(
            data=data,
            x='full_chebyshev_dist',
            binwidth=bin_width,
            binrange=bins_range,
            color=color,
            element="bars",
            linewidth=0  # Explicitly ensure no border
        )

        if log_scale:
            plt.yscale('log')
            filename_suffix += "_log"

        plt.xlabel(x_label)
        plt.ylabel(y_label)
        # No Title

        plt.tight_layout()

        # Construct filename
        filename = f'PDB727K_pairwise_BRI_comparisons_{filename_suffix}.png'
        save_path = os.path.join(output_dir, filename)

        plt.savefig(save_path)
        plt.close()

        print(f"Saved: {save_path}")

    # --- Generate the 4 Plots ---

    # 1. Linear Scale
    generate_histogram(diff_seq_data, 'orange', 'different_seq', log_scale=False)
    generate_histogram(same_seq_data, 'cornflowerblue', 'identical_seq', log_scale=False)

    # 2. Log Scale
    generate_histogram(diff_seq_data, 'orange', 'different_seq', log_scale=True)
    generate_histogram(same_seq_data, 'cornflowerblue', 'identical_seq', log_scale=True)

else:
    print(f"Input file not found: {input_file}")
    print("Please ensure you have run the 'Full Comparison' step to generate the results CSV.")

Loading data for plotting...
Saved: ./plotting/nearest_neighbours_001A/PDB727K_pairwise_BRI_comparisons_different_seq.png
Saved: ./plotting/nearest_neighbours_001A/PDB727K_pairwise_BRI_comparisons_identical_seq.png
Saved: ./plotting/nearest_neighbours_001A/PDB727K_pairwise_BRI_comparisons_different_seq_log.png
Saved: ./plotting/nearest_neighbours_001A/PDB727K_pairwise_BRI_comparisons_identical_seq_log.png
