In [None]:
import os
import random
import numpy as np
from prody import fetchPDB, parsePDB, confProDy
from scipy.spatial.distance import pdist, squareform
from tqdm import tqdm

confProDy(verbosity='none')

DRIVE_PATH = '/content/drive/MyDrive/ProteinData/'
os.makedirs(DRIVE_PATH, exist_ok=True)
TXT_PATH = '/content/drive/MyDrive/ProteinData/immunoglobulin_ids.txt'

BATCH_SIZE = 200
TOTAL_TARGET = 4000
length_stats = []

def import_all_protein_ids(txt_path):
    """Import all protein IDs from a file and shuffle them"""
    protein_ids = []
    with open(txt_path, 'r') as f:
        for line in f:
            protein_id = line.strip()
            if protein_id:
                pdb_id = protein_id.split('_')[0]
                protein_ids.append(pdb_id.lower())
    unique_ids = list(set(protein_ids))
    return unique_ids

def has_chain_a(structure):
    """Check if A chain is included"""
    chains = structure.select('protein').getChids()
    return 'A' in chains

def process_single_pdb(pdb_id):
    """Process and filter the PDB structure, extract only the A chain, and return (data, fail_reason)"""
    try:
        pdb_file = fetchPDB(pdb_id, compressed=False)
        structure = parsePDB(pdb_file)

        if not has_chain_a(structure):
            return None, "Missing A chain"

        # 只处理A链
        calphas = structure.select('protein and chain A and name CA')

        if not calphas:
            return None, "Missing CA atoms"

        num_residues = len(calphas)

        residues = calphas.getResindices()
        if len(set(residues)) != (max(residues) - min(residues) + 1):
            return None, "Residues are not continuous"

        coords = calphas.getCoords()
        seq = calphas.getSequence()
        dist_matrix = squareform(pdist(coords))

        length_stats.append(num_residues)

        return {
            'pdb_id': pdb_id,
            'sequence': ''.join(seq),
            'distance_matrix': dist_matrix,
            'coordinates': coords,
            'length': num_residues,
            'chain': 'A'
        }, None

    except Exception as e:
        return None, f"Exception Error: {e}"

def save_batch(processed_data, batch_index):
    """Save a batch of processing results"""
    output_file = os.path.join(DRIVE_PATH, f'immunoglobulin_proteins_{batch_index + 1}.npz')
    np.savez_compressed(
        output_file,
        pdb_ids=[d['pdb_id'] for d in processed_data],
        sequences=[d['sequence'] for d in processed_data],
        distance_matrices=[d['distance_matrix'] for d in processed_data],
        coordinates=[d['coordinates'] for d in processed_data],
        lengths=[d['length'] for d in processed_data],
        chains=[d['chain'] for d in processed_data]
    )
    print(f"Saved batch {batch_index + 1} with {len(processed_data)} proteins to {output_file}")

def print_length_statistics():
    """Print length statistics"""
    if length_stats:
        min_length = min(length_stats)
        max_length = max(length_stats)
        avg_length = np.mean(length_stats)
        print(f"   Minimum length: {min_length}")
        print(f"   Maximum length: {max_length}")
        print(f"   Average length: {avg_length:.1f}")
        print(f"   Total processed: {len(length_stats)}")

# -------------------- Main --------------------

def process_all_batches():
    # Import and randomly shuffle all protein IDs
    protein_ids = import_all_protein_ids(TXT_PATH)
    processed_data = []
    batch_index = 0
    num_processed = 0
    with tqdm(total=TOTAL_TARGET, desc="Total Processed") as pbar:
        for i, pdb_id in enumerate(protein_ids):
            if pbar.n >= TOTAL_TARGET:
                break

            print(f"Trying PDB ID: {pdb_id} ({i+1}/{len(protein_ids)})")

            data, fail_reason = process_single_pdb(pdb_id)
            if data:
                processed_data.append(data)
                pbar.update(1)
                num_processed += 1
                print(f"Processed: {pdb_id} (#{num_processed}, length:{data['length']})")

                # After processing 100 pieces, save a batch
                if len(processed_data) == BATCH_SIZE:
                    save_batch(processed_data, batch_index)
                    batch_index += 1
                    processed_data = []
            else:
                print(f"× Processing failed: {pdb_id}，reason: {fail_reason}")

        if processed_data:
            save_batch(processed_data, batch_index)

    # 打印长度统计
    print_length_statistics()
    print(f"\n Processing completed! Successfully processed {num_processed}")

# -------------------- Run --------------------
if __name__ == "__main__":
    process_all_batches()