In [None]:
import os
import numpy as np
from sklearn.model_selection import train_test_split
import pickle

def preprocess_protein_data():
    """
    Data preprocessing main function:
    1. Load raw data and filter
    2. Calculate statistics
    3. Standardize data
    4. Split training and test sets
    5. Save processed data
    """
    data_path = '/content/drive/MyDrive/ProteinData/immunoglobulin/'
    output_path = '/content/drive/MyDrive/ProteinData/immunoglobulin_processed/'
    os.makedirs(output_path, exist_ok=True)

    # ------------------------- Step 1: Streaming and filtering data by length -------------------------
    print("Step 1: Streaming and filtering data by length...")
    filtered_data = []
    total_count = 0

    for i in range(1, 20):
        file_path = os.path.join(data_path, f'immunoglobulin_proteins_{i}.npz')
        print(f"processig {i}/19: {file_path}")

        data = np.load(file_path, allow_pickle=True)
        sequences = data['sequences']
        matrices = data['distance_matrices']
        for seq, mat in zip(sequences, matrices):
            total_count += 1
            if 100 <= len(seq) <= 400:
                filtered_data.append((seq, mat))

        del data, sequences, matrices

    print(f"Filtering completed: Total data {total_count} -> after filtering {len(filtered_data)}")

    # ------------------------- Step 2: Streaming global statistics -------------------------
    print("Step 2: Streaming global statistics...")

    def calculate_global_stats_streaming(data_list):
        sum_values = 0.0
        sum_squares = 0.0
        count = 0

        for i, (seq, matrix) in enumerate(data_list):
            # Only take the upper triangular matrix
            upper_tri_indices = np.triu_indices_from(matrix, k=1)
            values = matrix[upper_tri_indices].astype(np.float64)

            sum_values += np.sum(values)
            sum_squares += np.sum(values**2)
            count += len(values)

        global_mean = sum_values / count
        global_std = np.sqrt(sum_squares / count - global_mean**2)

        return global_mean, global_std

    global_mean, global_std = calculate_global_stats_streaming(filtered_data)
    print(f"Global Statistics - Mean: {global_mean:.6f}, Std: {global_std:.6f}")

    # ------------------------- Step 3: Data Verification -------------------------
    print("Step 3: Data Verification...")

    def validate_data_batch(data_list, sample_size=100):
        sample_indices = np.random.choice(len(data_list), min(sample_size, len(data_list)), replace=False)

        for i in sample_indices:
            seq, mat = data_list[i]
            assert len(seq) == mat.shape[0] == mat.shape[1], f"Data {i}: sequence length does not match matrix dimensions"
            assert np.allclose(mat, mat.T, rtol=1e-10), f"Data {i}: Matrix is ​​not symmetric"

        print(f"Data verification passed! (Sampling verification of {len(sample_indices)} samples)")

    validate_data_batch(filtered_data)

    # ------------------------- Step 4: Split the dataset -------------------------
    print("Step 4: Split the training and testing sets...")

    all_sequences = [item[0] for item in filtered_data]
    all_matrices = [item[1] for item in filtered_data]

    train_seq, test_seq, train_mat, test_mat = train_test_split(
        all_sequences, all_matrices, test_size=0.2, random_state=42
    )

    print(f"Training set: {len(train_seq)} samples")
    print(f"Testing set: {len(test_seq)} samples")
    del filtered_data, all_sequences, all_matrices

    # ------------------------- Step 5: Recalculate training set statistics -------------------------
    print("Step 5: Recalculate training set statistics...")

    def calculate_train_stats(matrices):
        sum_values = 0.0
        sum_squares = 0.0
        count = 0

        for i, matrix in enumerate(matrices):
            upper_tri_indices = np.triu_indices_from(matrix, k=1)
            values = matrix[upper_tri_indices].astype(np.float64)

            sum_values += np.sum(values)
            sum_squares += np.sum(values**2)
            count += len(values)

        train_mean = sum_values / count
        train_std = np.sqrt(sum_squares / count - train_mean**2)

        return train_mean, train_std

    train_global_mean, train_global_std = calculate_train_stats(train_mat)
    print(f"Training set statistics - Mean: {train_global_mean:.6f}, Std: {train_global_std:.6f}")

    # ------------------------- Step 6: Standardize and save data -------------------------
    print("Step 6: Standardize and save data...")

    def standardize_distance_matrix(matrix, mean, std):
        return (matrix - mean) / std

    print("Standardized training set...")
    train_standardized_matrices = []
    for i, mat in enumerate(train_mat):
        standardized_mat = standardize_distance_matrix(mat, train_global_mean, train_global_std)
        train_standardized_matrices.append(standardized_mat)

    print("Standardized test sets...")
    test_standardized_matrices = []
    for i, mat in enumerate(test_mat):
        standardized_mat = standardize_distance_matrix(mat, train_global_mean, train_global_std)
        test_standardized_matrices.append(standardized_mat)

    # ------------------------- Step 7: Save all data-------------------------
    print("Step 7: Save all data...")

    np.savez_compressed(
        os.path.join(output_path, 'train_data.npz'),
        sequences=train_seq,
        matrices=train_standardized_matrices
    )

    np.savez_compressed(
        os.path.join(output_path, 'test_data.npz'),
        sequences=test_seq,
        matrices=test_standardized_matrices
    )

    global_stats = {
        'train_mean': train_global_mean,
        'train_std': train_global_std,
        'global_mean': global_mean,
        'global_std': global_std,
        'train_samples': len(train_seq),
        'test_samples': len(test_seq),
        'total_samples': len(train_seq) + len(test_seq)
    }

    with open(os.path.join(output_path, 'global_stats.pkl'), 'wb') as f:
        pickle.dump(global_stats, f)

    np.savez(
        os.path.join(output_path, 'global_stats.npz'),
        **global_stats
    )

    return global_stats

# ------------------------- Run all-------------------------
if __name__ == "__main__":
    stats = preprocess_protein_data()