In [None]:
import sys
import pandas as pd
from pathlib import Path
import re

In [None]:


class ProteinDataProcessor:
    def __init__(self, data_dir="./data"):
        self.data_dir = Path(data_dir)
        self.pdb_structures = {}
        self.binding_pockets = {}
        self.pocket_descriptors = {}
        self.output_df = pd.DataFrame()
        
    def prepare(self, system_type="all", output_dir="./processed_data"):
        """
        Process and clean protein-protein interaction and ligand binding data.
        
        Args:
            system_type: 'HD' (heterodimer), 'PL' (protein-ligand), or 'all'
            output_dir: Directory to save processed data
        """
        self._load_csv_annotations()
        self._process_pdb_structures()
        self._extract_pocket_features()
        self._create_unified_dataset(system_type)
        self._save_processed_data(output_dir)
        
    def _load_csv_annotations(self):
        """Load and process CSV annotation files"""
        csv_patterns = {
            'HD_orthosteric': 'HD_part8_20230317_matrix_orthosteric.csv',
            'HD_complete': 'HD_part8_20230317PDBe_orthosteric__complete.csv',
            'PL_allosteric': 'PL_part8_20230317_matrix_liganded_allosteric.csv',
            'PL_allosteric_complete': 'PL_part8_20230317PDBe_allosteric__complete.csv',
            'PL_orthosteric_comp': 'PL_part8_20230317_matrix_liganded_orthosteric_competitive.csv',
            'PL_orthosteric_comp_complete': 'PL_part8_20230317PDBe_orthosteric_competitive__complete.csv',
            'PL_orthosteric_noncomp': 'PL_part8_20230317_matrix_liganded_orthosteric_noncompetitive.csv',
            'PL_orthosteric_noncomp_complete': 'PL_part8_20230317PDBe_orthosteric_noncompetitive__complete.csv'
        }
        
        self.csv_data = {}
        for key, filename in csv_patterns.items():
            filepath = self.data_dir / filename
            if filepath.exists():
                print(f"Loading {filename}")
                self.csv_data[key] = pd.read_csv(filepath)
            else:
                print(f"Warning: {filename} not found")
    
    def _process_pdb_structures(self):
        """Process PDB structure files and extract metadata"""
        pdb_files = list(self.data_dir.glob("**/*.pdb"))
        mol2_files = list(self.data_dir.glob("**/*.mol2"))
        
        for pdb_file in pdb_files:
            pdb_info = self._parse_pdb_filename(pdb_file.name)
            if pdb_info:
                self.pdb_structures[pdb_file.stem] = {
                    'file_path': pdb_file,
                    'metadata': pdb_info
                }
        
        for mol2_file in mol2_files:
            pocket_info = self._parse_pocket_filename(mol2_file.name)
            if pocket_info:
                self.binding_pockets[mol2_file.stem] = {
                    'file_path': mol2_file,
                    'metadata': pocket_info
                }
    
    def _parse_pdb_filename(self, filename):
        """Parse PDB filename to extract structure information"""
        patterns = [
            r'(\w+)--(\w+)--(\w+)--(\w+)\.pdb',  # heterodimer complex
            r'(\w+)--(\w)--(\w+)__Repair-H\.pdb',  # repaired chain
            r'(\w+)--(\w)--(\w+)\.pdb',  # single chain
            r'pdb(\w+)\.ent'  # raw PDB
        ]
        
        for pattern in patterns:
            match = re.match(pattern, filename)
            if match:
                groups = match.groups()
                if len(groups) == 4:
                    return {
                        'pdb_code': groups[0],
                        'chain1': groups[1],
                        'uniprot1': groups[2],
                        'chain2_or_ligand': groups[3],
                        'type': 'heterodimer_complex'
                    }
                elif len(groups) == 3:
                    return {
                        'pdb_code': groups[0],
                        'chain': groups[1],
                        'uniprot': groups[2],
                        'type': 'single_chain'
                    }
                elif len(groups) == 1:
                    return {
                        'pdb_code': groups[0],
                        'type': 'raw_pdb'
                    }
        return None
    
    def _parse_pocket_filename(self, filename):
        """Parse pocket filename to extract cavity information"""
        pattern = r'(\w+)-(\w+)-(\w+)-?(\w+)?-?(\w+)?_CAVITY_N(\d+)_ALL_(.+)\.mol2'
        match = re.match(pattern, filename)
        
        if match:
            groups = match.groups()
            return {
                'pdb_code': groups[0],
                'chain': groups[1],
                'uniprot': groups[2],
                'ligand': groups[3] if groups[3] else None,
                'ligand_num': groups[4] if groups[4] else None,
                'cavity_num': groups[5],
                'pocket_type': groups[6]
            }
        return None
    
    def _extract_pocket_features(self):
        """Extract and standardize pocket features from CSV data"""
        all_features = []
        
        for dataset_name, df in self.csv_data.items():
            if df is not None and not df.empty:
                df_copy = df.copy()
                df_copy['dataset_source'] = dataset_name
                df_copy['system_type'] = 'HD' if 'HD' in dataset_name else 'PL'
                
                # Parse cavity identifier
                if 'Cavity' in df_copy.columns:
                    cavity_info = df_copy['Cavity'].apply(self._parse_cavity_identifier)
                    for key in ['pdb_code', 'chain', 'uniprot', 'ligand_id', 'ligand_num', 'cavity_num', 'binding_type']:
                        df_copy[key] = [info.get(key) for info in cavity_info]
                
                all_features.append(df_copy)
        
        if all_features:
            self.pocket_descriptors = pd.concat(all_features, ignore_index=True, sort=False)
        else:
            self.pocket_descriptors = pd.DataFrame()
    
    def _parse_cavity_identifier(self, cavity_str):
        """Parse cavity identifier string"""
        pattern = r'(\w+)-(\w+)-(\w+)-?(\w+)?-?(\w+)?_CAVITY_N(\d+)_(.+)'
        match = re.match(pattern, str(cavity_str))
        
        if match:
            groups = match.groups()
            return {
                'pdb_code': groups[0],
                'chain': groups[1],
                'uniprot': groups[2],
                'ligand_id': groups[3] if groups[3] else None,
                'ligand_num': groups[4] if groups[4] else None,
                'cavity_num': groups[5],
                'binding_type': groups[6]
            }
        return {}
    
    def _create_unified_dataset(self, system_type):
        """Create unified dataset with standardized features"""
        if self.pocket_descriptors.empty:
            print("No pocket descriptor data available")
            return
        
        # Filter by system type if specified
        if system_type != "all":
            df = self.pocket_descriptors[self.pocket_descriptors['system_type'] == system_type].copy()
        else:
            df = self.pocket_descriptors.copy()
        
        # Standardize geometric features
        geometric_features = [
            'Volume', 'PMI1', 'PMI2', 'PMI3', 'NPR1', 'NPR2', 
            'Rgyr', 'Asphericity', 'SpherocityIndex', 'Eccentricity', 'InertialShapeFactor'
        ]
        
        # Standardize polarity features
        polarity_features = ['CZ', 'CA', 'O', 'OD1', 'OG', 'N', 'NZ', 'DU']
        
        # Create feature groups
        for feature in geometric_features:
            if feature in df.columns:
                df[f'{feature}_normalized'] = (df[feature] - df[feature].mean()) / df[feature].std()
        
        # Add derived features
        if all(col in df.columns for col in ['PMI1', 'PMI2', 'PMI3']):
            df['shape_anisotropy'] = (df['PMI1'] - df['PMI2']) / (df['PMI1'] + df['PMI2'] + df['PMI3'])
        
        # Categorize binding types
        df['binding_category'] = df['binding_type'].apply(self._categorize_binding_type)
        
        # Add sequence-based features if available
        if 'pfam_accession' in df.columns:
            df['has_pfam'] = ~df['pfam_accession'].isna()
        
        if 'cath' in df.columns:
            df['has_cath'] = ~df['cath'].isna()
        
        self.output_df = df
    
    def _categorize_binding_type(self, binding_type):
        """Categorize binding types into main categories"""
        if pd.isna(binding_type):
            return 'unknown'
        
        binding_type = str(binding_type).lower()
        
        if 'orthosteric' in binding_type:
            if 'competitive' in binding_type:
                return 'orthosteric_competitive'
            elif 'noncompetitive' in binding_type:
                return 'orthosteric_noncompetitive'
            else:
                return 'orthosteric'
        elif 'allosteric' in binding_type:
            return 'allosteric'
        elif 'nonorthosteric' in binding_type:
            return 'nonorthosteric'
        else:
            return 'other'
    
    def _save_processed_data(self, output_dir):
        """Save processed data to files"""
        output_path = Path(output_dir)
        output_path.mkdir(parents=True, exist_ok=True)
        
        if not self.output_df.empty:
            # Save main dataset
            main_file = output_path / "protein_pockets_processed.csv"
            print(f"Saving processed data to: {main_file}")
            self.output_df.to_csv(main_file, index=False)
            
            # Save system-specific datasets
            for system_type in ['HD', 'PL']:
                system_df = self.output_df[self.output_df['system_type'] == system_type]
                if not system_df.empty:
                    system_file = output_path / f"protein_pockets_{system_type.lower()}.csv"
                    print(f"Saving {system_type} data to: {system_file}")
                    system_df.to_csv(system_file, index=False)
            
            # Save binding type specific datasets
            for binding_cat in self.output_df['binding_category'].unique():
                if pd.notna(binding_cat):
                    binding_df = self.output_df[self.output_df['binding_category'] == binding_cat]
                    binding_file = output_path / f"protein_pockets_{binding_cat}.csv"
                    print(f"Saving {binding_cat} data to: {binding_file}")
                    binding_df.to_csv(binding_file, index=False)
            
            # Generate summary statistics
            self._generate_summary_report(output_path)
        else:
            print("No data to save")
    
    def _generate_summary_report(self, output_path):
        """Generate summary statistics report"""
        summary_file = output_path / "processing_summary.txt"
        
        with open(summary_file, 'w') as f:
            f.write("Protein-Protein Interaction and Ligand Binding Dataset Processing Summary\n")
            f.write("=" * 70 + "\n\n")
            
            f.write(f"Total processed entries: {len(self.output_df)}\n")
            f.write(f"System type distribution:\n")
            for system_type, count in self.output_df['system_type'].value_counts().items():
                f.write(f"  {system_type}: {count}\n")
            
            f.write(f"\nBinding category distribution:\n")
            for binding_cat, count in self.output_df['binding_category'].value_counts().items():
                f.write(f"  {binding_cat}: {count}\n")
            
            f.write(f"\nUnique PDB codes: {self.output_df['pdb_code'].nunique()}\n")
            f.write(f"Unique UniProt IDs: {self.output_df['uniprot'].nunique()}\n")
            
            if 'Volume' in self.output_df.columns:
                f.write(f"\nPocket volume statistics:\n")
                f.write(f"  Mean: {self.output_df['Volume'].mean():.2f}\n")
                f.write(f"  Std: {self.output_df['Volume'].std():.2f}\n")
                f.write(f"  Min: {self.output_df['Volume'].min():.2f}\n")
                f.write(f"  Max: {self.output_df['Volume'].max():.2f}\n")

if __name__ == '__main__':
    data_dir = sys.argv[1] if len(sys.argv) > 1 else "./data"
    system_type = sys.argv[2] if len(sys.argv) > 2 else "all"
    
    processor = ProteinDataProcessor(data_dir)
    processor.prepare(system_type)



In [None]:


# GPU acceleration imports
try:
    import torch
    import torch.nn.functional as F
    HAS_TORCH = True
    print(f"PyTorch available: {torch.__version__}")
    
    # Check for Metal Performance Shaders (macOS)
    if torch.backends.mps.is_available():
        DEVICE = torch.device("mps")
        print("Using Apple Metal Performance Shaders")
    # Check for CUDA (NVIDIA)
    elif torch.cuda.is_available():
        DEVICE = torch.device("cuda")
        print(f"Using NVIDIA CUDA: {torch.cuda.get_device_name()}")
    else:
        DEVICE = torch.device("cpu")
        print("Using CPU - consider installing CUDA/MPS support")
except ImportError:
    HAS_TORCH = False
    DEVICE = None
    print("PyTorch not available - using CPU-only processing")

# Spark imports for large-scale processing
try:
    from pyspark.sql import SparkSession
    from pyspark.sql.functions import col, count, avg, stddev, min as spark_min, max as spark_max
    from pyspark.sql.types import StructType, StructField, StringType, DoubleType, IntegerType
    HAS_SPARK = True
    print("PySpark available")
except ImportError:
    HAS_SPARK = False
    print("PySpark not available - using pandas for processing")

class AcceleratedProteinDataProcessor:
    def __init__(self, data_dir="/Users/priyanshudey/Code/Qunatum/ippidb-pdb-analyses-042023-zenodo"):
        self.data_dir = Path(data_dir)
        self.device = DEVICE
        self.use_gpu = HAS_TORCH and DEVICE != torch.device("cpu")
        self.use_spark = HAS_SPARK
        self.spark = None
        
        # Initialize Spark if available
        if self.use_spark:
            self.spark = SparkSession.builder \
                .appName("ProteinDataScoping") \
                .config("spark.sql.adaptive.enabled", "true") \
                .config("spark.sql.adaptive.coalescePartitions.enabled", "true") \
                .getOrCreate()
            print(f"Spark initialized with {self.spark.sparkContext.defaultParallelism} cores")
        
        self.csv_files = {}
        self.pdb_structures = {}
        self.binding_pockets = {}
        self.summary_stats = {}
        
    def initial_data_scoping(self, sample_size=None, parallel_jobs=None):
        """
        Perform initial data scoping with GPU/Metal acceleration
        """
        print("=== Initial Data Scoping with Acceleration ===")
        start_time = time.time()
        
        if parallel_jobs is None:
            parallel_jobs = min(mp.cpu_count(), 8)
            
        # Step 1: Discover all data files
        print("1. Discovering data files...")
        self._discover_files_parallel(parallel_jobs)
        
        # Step 2: Load and analyze CSV files
        print("2. Loading and analyzing CSV files...")
        self._load_csv_files_accelerated(sample_size)
        
        # Step 3: Analyze PDB structures
        print("3. Analyzing PDB structures...")
        self._analyze_structures_parallel(parallel_jobs)
        
        # Step 4: Generate comprehensive statistics
        print("4. Generating statistics...")
        self._generate_accelerated_stats()
        
        # Step 5: Create data summary report
        print("5. Creating summary report...")
        self._create_scoping_report()
        
        total_time = time.time() - start_time
        print(f"\\nData scoping completed in {total_time:.2f} seconds")
        
        return self.summary_stats
    
    def _discover_files_parallel(self, n_jobs):
        """Parallel file discovery"""
        print(f"Scanning directory with {n_jobs} parallel workers...")
        
        def scan_directory(root_dir):
            results = {
                'csv_files': [],
                'pdb_files': [],
                'mol2_files': [],
                'other_files': []
            }
            
            for file_path in root_dir.rglob("*"):
                if file_path.is_file():
                    suffix = file_path.suffix.lower()
                    if suffix == '.csv':
                        results['csv_files'].append(file_path)
                    elif suffix == '.pdb':
                        results['pdb_files'].append(file_path)
                    elif suffix == '.mol2':
                        results['mol2_files'].append(file_path)
                    else:
                        results['other_files'].append(file_path)
            
            return results
        
        # Split directory scanning across workers
        subdirs = [d for d in self.data_dir.iterdir() if d.is_dir()]
        if len(subdirs) > n_jobs:
            chunk_size = len(subdirs) // n_jobs
            dir_chunks = [subdirs[i:i+chunk_size] for i in range(0, len(subdirs), chunk_size)]
        else:
            dir_chunks = [[d] for d in subdirs]
        
        with ProcessPoolExecutor(max_workers=n_jobs) as executor:
            futures = []
            for chunk in dir_chunks:
                for subdir in chunk:
                    futures.append(executor.submit(scan_directory, subdir))
            
            # Combine results
            all_results = {'csv_files': [], 'pdb_files': [], 'mol2_files': [], 'other_files': []}
            for future in futures:
                result = future.result()
                for key in all_results:
                    all_results[key].extend(result[key])
        
        self.file_inventory = all_results
        print(f"Found {len(all_results['csv_files'])} CSV files")
        print(f"Found {len(all_results['pdb_files'])} PDB files") 
        print(f"Found {len(all_results['mol2_files'])} MOL2 files")
        
    def _load_csv_files_accelerated(self, sample_size):
        """Load CSV files with acceleration and sampling"""
        csv_patterns = {
            'HD_orthosteric': 'HD_part8_20230317_matrix_orthosteric.csv',
            'HD_complete': 'HD_part8_20230317_matrix_orthosteric__complete.csv',
            'PL_allosteric': 'PL_part8_20230317_matrix_liganded_allosteric.csv',
            'PL_allosteric_complete': 'PL_part8_20230317_matrix_liganded_allosteric__complete.csv',
            'PL_orthosteric_comp': 'PL_part8_20230317_matrix_liganded_orthosteric_competitive.csv',
            'PL_orthosteric_comp_complete': 'PL_part8_20230317_matrix_liganded_orthosteric_competitive__complete.csv',
            'PL_orthosteric_noncomp': 'PL_part8_20230317_matrix_liganded_orthosteric_noncompetitive.csv',
            'PL_orthosteric_noncomp_complete': 'PL_part8_20230317_matrix_liganded_orthosteric_noncompetitive__complete.csv'
        }
        
        def load_and_sample_csv(file_path, sample_size=None):
            try:
                if sample_size:
                    # Read just the header first to get column count
                    header = pd.read_csv(file_path, nrows=0)
                    n_rows = sum(1 for _ in open(file_path)) - 1  # subtract header
                    
                    if n_rows > sample_size:
                        # Random sampling
                        skip_rows = np.random.choice(range(1, n_rows + 1), 
                                                   n_rows - sample_size, replace=False)
                        df = pd.read_csv(file_path, skiprows=skip_rows)
                    else:
                        df = pd.read_csv(file_path)
                else:
                    df = pd.read_csv(file_path)
                
                return df, len(df), df.columns.tolist()
            except Exception as e:
                print(f"Error loading {file_path}: {e}")
                return None, 0, []
        
        # Load known CSV files
        loaded_data = {}
        for key, filename in csv_patterns.items():
            file_path = self.data_dir / filename
            if file_path.exists():
                print(f"Loading {filename}...")
                df, n_rows, columns = load_and_sample_csv(file_path, sample_size)
                if df is not None:
                    loaded_data[key] = {
                        'data': df,
                        'n_rows': n_rows,
                        'n_cols': len(columns),
                        'columns': columns,
                        'file_path': file_path
                    }
        
        self.csv_files = loaded_data
        
        # If using Spark, also load into Spark DataFrames for large-scale processing
        if self.use_spark and loaded_data:
            print("Loading data into Spark...")
            self.spark_dfs = {}
            for key, data_info in loaded_data.items():
                spark_df = self.spark.createDataFrame(data_info['data'])
                self.spark_dfs[key] = spark_df
                print(f"  {key}: {spark_df.count()} rows, {len(spark_df.columns)} columns")
    
    def _analyze_structures_parallel(self, n_jobs):
        """Parallel analysis of protein structures"""
        pdb_files = self.file_inventory.get('pdb_files', [])
        mol2_files = self.file_inventory.get('mol2_files', [])
        
        def analyze_pdb_file(pdb_path):
            """Extract basic info from PDB file"""
            try:
                info = self._parse_pdb_filename(pdb_path.name)
                if info:
                    # Quick file size and basic stats
                    file_size = pdb_path.stat().st_size
                    
                    # Count lines and atoms (quick scan)
                    with open(pdb_path, 'r') as f:
                        lines = f.readlines()
                    
                    atom_lines = [l for l in lines if l.startswith('ATOM')]
                    hetatm_lines = [l for l in lines if l.startswith('HETATM')]
                    
                    info.update({
                        'file_size': file_size,
                        'total_lines': len(lines),
                        'atom_count': len(atom_lines),
                        'hetatm_count': len(hetatm_lines),
                        'file_path': str(pdb_path)
                    })
                
                return info
            except Exception as e:
                return {'error': str(e), 'file_path': str(pdb_path)}
        
        def analyze_mol2_file(mol2_path):
            """Extract basic info from MOL2 file"""
            try:
                info = self._parse_pocket_filename(mol2_path.name)
                if info:
                    file_size = mol2_path.stat().st_size
                    info.update({
                        'file_size': file_size,
                        'file_path': str(mol2_path)
                    })
                return info
            except Exception as e:
                return {'error': str(e), 'file_path': str(mol2_path)}
        
        print(f"Analyzing {len(pdb_files)} PDB files with {n_jobs} workers...")
        with ProcessPoolExecutor(max_workers=n_jobs) as executor:
            pdb_results = list(executor.map(analyze_pdb_file, pdb_files))
        
        print(f"Analyzing {len(mol2_files)} MOL2 files with {n_jobs} workers...")
        with ProcessPoolExecutor(max_workers=n_jobs) as executor:
            mol2_results = list(executor.map(analyze_mol2_file, mol2_files))
        
        self.pdb_structures = {f"pdb_{i}": result for i, result in enumerate(pdb_results) if result}
        self.binding_pockets = {f"pocket_{i}": result for i, result in enumerate(mol2_results) if result}
        
        print(f"Analyzed {len(self.pdb_structures)} PDB structures")
        print(f"Analyzed {len(self.binding_pockets)} binding pockets")
    
    def _generate_accelerated_stats(self):
        """Generate statistics using GPU acceleration where possible"""
        stats = {}
        
        # CSV file statistics
        if self.csv_files:
            csv_stats = {}
            for key, data_info in self.csv_files.items():
                df = data_info['data']
                
                # Basic stats
                basic_stats = {
                    'n_rows': len(df),
                    'n_cols': len(df.columns),
                    'memory_usage': df.memory_usage().sum(),
                    'columns': df.columns.tolist()
                }
                
                # Numeric column analysis
                numeric_cols = df.select_dtypes(include=[np.number]).columns
                if len(numeric_cols) > 0 and self.use_gpu:
                    # GPU-accelerated statistics
                    numeric_data = df[numeric_cols].values
                    if numeric_data.size > 0:
                        tensor_data = torch.tensor(numeric_data, dtype=torch.float32).to(self.device)
                        
                        gpu_stats = {
                            'mean': tensor_data.mean(dim=0).cpu().numpy().tolist(),
                            'std': tensor_data.std(dim=0).cpu().numpy().tolist(),
                            'min': tensor_data.min(dim=0)[0].cpu().numpy().tolist(),
                            'max': tensor_data.max(dim=0)[0].cpu().numpy().tolist()
                        }
                        
                        basic_stats['numeric_stats'] = dict(zip(numeric_cols, 
                                                              zip(gpu_stats['mean'], gpu_stats['std'],
                                                                  gpu_stats['min'], gpu_stats['max'])))
                elif len(numeric_cols) > 0:
                    # Fallback to pandas
                    basic_stats['numeric_stats'] = df[numeric_cols].describe().to_dict()
                
                # Categorical analysis
                cat_cols = df.select_dtypes(include=['object']).columns
                if len(cat_cols) > 0:
                    cat_stats = {}
                    for col in cat_cols[:5]:  # Limit to first 5 categorical columns
                        cat_stats[col] = {
                            'unique_count': df[col].nunique(),
                            'top_values': df[col].value_counts().head().to_dict()
                        }
                    basic_stats['categorical_stats'] = cat_stats
                
                csv_stats[key] = basic_stats
            
            stats['csv_files'] = csv_stats
        
        # Structure file statistics
        if self.pdb_structures:
            pdb_stats = {
                'total_files': len(self.pdb_structures),
                'total_atoms': sum(s.get('atom_count', 0) for s in self.pdb_structures.values() if 'atom_count' in s),
                'total_size': sum(s.get('file_size', 0) for s in self.pdb_structures.values() if 'file_size' in s)
            }
            
            # PDB type distribution
            pdb_types = [s.get('type') for s in self.pdb_structures.values() if 'type' in s]
            if pdb_types:
                type_counts = pd.Series(pdb_types).value_counts().to_dict()
                pdb_stats['type_distribution'] = type_counts
            
            stats['pdb_structures'] = pdb_stats
        
        if self.binding_pockets:
            pocket_stats = {
                'total_files': len(self.binding_pockets),
                'total_size': sum(s.get('file_size', 0) for s in self.binding_pockets.values() if 'file_size' in s)
            }
            stats['binding_pockets'] = pocket_stats
        
        self.summary_stats = stats
        return stats
    
    def _parse_pdb_filename(self, filename):
        """Parse PDB filename patterns"""
        patterns = [
            r'(\\w+)--(\\w+)--(\\w+)--(\\w+)\\.pdb',  # heterodimer complex
            r'(\\w+)--(\\w)--(\\w+)__Repair-H\\.pdb',  # repaired chain
            r'(\\w+)--(\\w)--(\\w+)\\.pdb',  # single chain
            r'pdb(\\w+)\\.ent'  # raw PDB
        ]
        
        for pattern in patterns:
            match = re.match(pattern, filename)
            if match:
                groups = match.groups()
                if len(groups) == 4:
                    return {
                        'pdb_code': groups[0],
                        'chain1': groups[1], 
                        'uniprot1': groups[2],
                        'chain2_or_ligand': groups[3],
                        'type': 'heterodimer_complex'
                    }
                elif len(groups) == 3:
                    return {
                        'pdb_code': groups[0],
                        'chain': groups[1],
                        'uniprot': groups[2],
                        'type': 'single_chain'
                    }
                elif len(groups) == 1:
                    return {
                        'pdb_code': groups[0],
                        'type': 'raw_pdb'
                    }
        return None
    
    def _parse_pocket_filename(self, filename):
        """Parse pocket filename patterns"""
        pattern = r'(\\w+)-(\\w+)-(\\w+)-?(\\w+)?-?(\\w+)?_CAVITY_N(\\d+)_ALL_(.+)\\.mol2'
        match = re.match(pattern, filename)
        
        if match:
            groups = match.groups()
            return {
                'pdb_code': groups[0],
                'chain': groups[1],
                'uniprot': groups[2],
                'ligand': groups[3] if groups[3] else None,
                'ligand_num': groups[4] if groups[4] else None,
                'cavity_num': groups[5],
                'pocket_type': groups[6]
            }
        return None
    
    def _create_scoping_report(self):
        """Create comprehensive scoping report"""
        report_path = self.data_dir / "data_scoping_report.txt"
        
        with open(report_path, 'w') as f:
            f.write("IPPIDB Protein Data Scoping Report\\n")
            f.write("=" * 50 + "\\n\\n")
            
            # Hardware info
            f.write("Hardware Configuration:\\n")
            f.write(f"  GPU Acceleration: {'Yes' if self.use_gpu else 'No'}\\n")
            if self.use_gpu:
                f.write(f"  Device: {self.device}\\n")
            f.write(f"  Spark Processing: {'Yes' if self.use_spark else 'No'}\\n")
            f.write(f"  CPU Cores: {mp.cpu_count()}\\n\\n")
            
            # File inventory
            if hasattr(self, 'file_inventory'):
                f.write("File Inventory:\\n")
                for file_type, files in self.file_inventory.items():
                    f.write(f"  {file_type}: {len(files)}\\n")
                f.write("\\n")
            
            # CSV data summary
            if 'csv_files' in self.summary_stats:
                f.write("CSV Data Summary:\\n")
                for key, stats in self.summary_stats['csv_files'].items():
                    f.write(f"  {key}:\\n")
                    f.write(f"    Rows: {stats['n_rows']:,}\\n")
                    f.write(f"    Columns: {stats['n_cols']}\\n")
                    f.write(f"    Memory: {stats['memory_usage']:,} bytes\\n")
                    
                    if 'numeric_stats' in stats:
                        f.write(f"    Numeric columns: {len(stats['numeric_stats'])}\\n")
                    
                    if 'categorical_stats' in stats:
                        f.write(f"    Categorical columns: {len(stats['categorical_stats'])}\\n")
                    f.write("\\n")
            
            # Structure data summary
            if 'pdb_structures' in self.summary_stats:
                pdb_stats = self.summary_stats['pdb_structures']
                f.write("PDB Structures Summary:\\n")
                f.write(f"  Total files: {pdb_stats['total_files']:,}\\n")
                f.write(f"  Total atoms: {pdb_stats['total_atoms']:,}\\n")
                f.write(f"  Total size: {pdb_stats['total_size']:,} bytes\\n")
                
                if 'type_distribution' in pdb_stats:
                    f.write("  Type distribution:\\n")
                    for pdb_type, count in pdb_stats['type_distribution'].items():
                        f.write(f"    {pdb_type}: {count}\\n")
                f.write("\\n")
            
            if 'binding_pockets' in self.summary_stats:
                pocket_stats = self.summary_stats['binding_pockets']
                f.write("Binding Pockets Summary:\\n")
                f.write(f"  Total files: {pocket_stats['total_files']:,}\\n")
                f.write(f"  Total size: {pocket_stats['total_size']:,} bytes\\n\\n")
        
        print(f"Scoping report saved to: {report_path}")
        return report_path
    
    def get_sample_data(self, dataset_key, n_samples=100):
        """Get sample data for exploration"""
        if dataset_key in self.csv_files:
            df = self.csv_files[dataset_key]['data']
            return df.sample(min(n_samples, len(df)))
        return None
    
    def cleanup(self):
        """Clean up resources"""
        if self.spark:
            self.spark.stop()
            print("Spark session stopped")

# Initialize and run data scoping
print("Initializing Accelerated Protein Data Processor...")
processor = AcceleratedProteinDataProcessor()

# Run initial data scoping
stats = processor.initial_data_scoping(sample_size=10000, parallel_jobs=6)

print("\\n=== Data Scoping Results ===")
for key, value in stats.items():
    print(f"{key}: {len(value) if isinstance(value, dict) else value}")

In [None]:
# Advanced analysis functions with Spark integration
class AdvancedProteinAnalysis:
    def __init__(self, processor):
        self.processor = processor
        self.spark = processor.spark if processor.use_spark else None
        
    def large_scale_feature_analysis(self):
        """Perform large-scale feature analysis using Spark"""
        if not self.spark:
            print("Spark not available - using pandas fallback")
            return self._pandas_feature_analysis()
        
        print("=== Large-Scale Feature Analysis with Spark ===")
        results = {}
        
        for dataset_name, spark_df in self.processor.spark_dfs.items():
            print(f"Analyzing {dataset_name}...")
            
            # Get numeric columns
            numeric_cols = []
            for field in spark_df.schema.fields:
                if field.dataType.typeName() in ['double', 'float', 'integer', 'long']:
                    numeric_cols.append(field.name)
            
            if numeric_cols:
                # Spark aggregations for all numeric columns
                agg_exprs = []
                for col_name in numeric_cols:
                    agg_exprs.extend([
                        avg(col(col_name)).alias(f"{col_name}_mean"),
                        stddev(col(col_name)).alias(f"{col_name}_std"),
                        spark_min(col(col_name)).alias(f"{col_name}_min"),
                        spark_max(col(col_name)).alias(f"{col_name}_max")
                    ])
                
                stats_df = spark_df.agg(*agg_exprs)
                stats_row = stats_df.collect()[0]
                
                # Organize results
                feature_stats = {}
                for col_name in numeric_cols:
                    feature_stats[col_name] = {
                        'mean': stats_row[f"{col_name}_mean"],
                        'std': stats_row[f"{col_name}_std"],
                        'min': stats_row[f"{col_name}_min"],
                        'max': stats_row[f"{col_name}_max"]
                    }
                
                results[dataset_name] = {
                    'total_rows': spark_df.count(),
                    'numeric_features': feature_stats,
                    'feature_count': len(numeric_cols)
                }
        
        return results
    
    def _pandas_feature_analysis(self):
        """Fallback pandas-based feature analysis"""
        print("=== Feature Analysis with Pandas ===")
        results = {}
        
        for dataset_name, data_info in self.processor.csv_files.items():
            df = data_info['data']
            numeric_cols = df.select_dtypes(include=[np.number]).columns
            
            if len(numeric_cols) > 0:
                feature_stats = {}
                for col in numeric_cols:
                    feature_stats[col] = {
                        'mean': df[col].mean(),
                        'std': df[col].std(),
                        'min': df[col].min(),
                        'max': df[col].max()
                    }
                
                results[dataset_name] = {
                    'total_rows': len(df),
                    'numeric_features': feature_stats,
                    'feature_count': len(numeric_cols)
                }
        
        return results
    
    def gpu_accelerated_clustering_analysis(self, dataset_key, n_clusters=5, max_features=20):
        """GPU-accelerated clustering analysis of protein pocket features"""
        if dataset_key not in self.processor.csv_files:
            print(f"Dataset {dataset_key} not found")
            return None
        
        df = self.processor.csv_files[dataset_key]['data']
        numeric_cols = df.select_dtypes(include=[np.number]).columns
        
        if len(numeric_cols) == 0:
            print("No numeric features found for clustering")
            return None
        
        # Limit features for demonstration
        selected_cols = numeric_cols[:max_features]
        data = df[selected_cols].fillna(df[selected_cols].mean())
        
        print(f"Performing clustering analysis on {len(selected_cols)} features...")
        
        if self.processor.use_gpu and HAS_TORCH:
            return self._gpu_clustering(data, n_clusters, selected_cols)
        else:
            return self._cpu_clustering(data, n_clusters, selected_cols)
    
    def _gpu_clustering(self, data, n_clusters, feature_names):
        """GPU-accelerated K-means clustering"""
        print("Using GPU-accelerated clustering...")
        
        # Convert to tensor and normalize
        X = torch.tensor(data.values, dtype=torch.float32).to(self.processor.device)
        X_normalized = F.normalize(X, p=2, dim=1)
        
        # Simple K-means implementation on GPU
        n_samples, n_features = X_normalized.shape
        
        # Initialize centroids randomly
        centroids = torch.randn(n_clusters, n_features).to(self.processor.device)
        centroids = F.normalize(centroids, p=2, dim=1)
        
        for iteration in range(100):  # max iterations
            # Compute distances to centroids
            distances = torch.cdist(X_normalized, centroids)
            cluster_assignments = torch.argmin(distances, dim=1)
            
            # Update centroids
            new_centroids = torch.zeros_like(centroids)
            for k in range(n_clusters):
                mask = cluster_assignments == k
                if mask.sum() > 0:
                    new_centroids[k] = X_normalized[mask].mean(dim=0)
                else:
                    new_centroids[k] = centroids[k]  # Keep old centroid
            
            # Check for convergence
            if torch.allclose(centroids, new_centroids, rtol=1e-4):
                print(f"Converged after {iteration+1} iterations")
                break
                
            centroids = new_centroids
        
        # Compute final assignments and statistics
        distances = torch.cdist(X_normalized, centroids)
        final_assignments = torch.argmin(distances, dim=1).cpu().numpy()
        
        # Cluster statistics
        cluster_stats = {}
        for k in range(n_clusters):
            mask = final_assignments == k
            cluster_size = mask.sum()
            if cluster_size > 0:
                cluster_data = data.iloc[mask]
                cluster_stats[f'cluster_{k}'] = {
                    'size': int(cluster_size),
                    'percentage': float(cluster_size / len(data) * 100),
                    'feature_means': cluster_data.mean().to_dict()
                }
        
        return {
            'method': 'gpu_kmeans',
            'n_clusters': n_clusters,
            'n_features': len(feature_names),
            'feature_names': list(feature_names),
            'cluster_assignments': final_assignments.tolist(),
            'cluster_statistics': cluster_stats
        }
    
    def _cpu_clustering(self, data, n_clusters, feature_names):
        """CPU-based clustering fallback"""
        print("Using CPU-based clustering...")
        
        try:
            from sklearn.cluster import KMeans
            from sklearn.preprocessing import StandardScaler
            
            # Normalize data
            scaler = StandardScaler()
            X_scaled = scaler.fit_transform(data)
            
            # K-means clustering
            kmeans = KMeans(n_clusters=n_clusters, random_state=42, n_init=10)
            cluster_assignments = kmeans.fit_predict(X_scaled)
            
            # Cluster statistics
            cluster_stats = {}
            for k in range(n_clusters):
                mask = cluster_assignments == k
                cluster_size = mask.sum()
                if cluster_size > 0:
                    cluster_data = data.iloc[mask]
                    cluster_stats[f'cluster_{k}'] = {
                        'size': int(cluster_size),
                        'percentage': float(cluster_size / len(data) * 100),
                        'feature_means': cluster_data.mean().to_dict()
                    }
            
            return {
                'method': 'cpu_kmeans',
                'n_clusters': n_clusters,
                'n_features': len(feature_names),
                'feature_names': list(feature_names),
                'cluster_assignments': cluster_assignments.tolist(),
                'cluster_statistics': cluster_stats
            }
            
        except ImportError:
            print("scikit-learn not available for CPU clustering")
            return None
    
    def generate_data_quality_report(self):
        """Generate comprehensive data quality report"""
        print("=== Generating Data Quality Report ===")
        
        quality_report = {}
        
        for dataset_name, data_info in self.processor.csv_files.items():
            df = data_info['data']
            
            # Basic quality metrics
            total_cells = df.size
            missing_cells = df.isnull().sum().sum()
            missing_percentage = (missing_cells / total_cells) * 100
            
            # Column-wise missing data
            missing_by_column = df.isnull().sum()
            columns_with_missing = missing_by_column[missing_by_column > 0]
            
            # Duplicate rows
            duplicate_rows = df.duplicated().sum()
            
            # Data type distribution
            dtype_counts = df.dtypes.value_counts().to_dict()
            
            # Numeric column quality
            numeric_cols = df.select_dtypes(include=[np.number]).columns
            numeric_quality = {}
            
            for col in numeric_cols:
                col_data = df[col].dropna()
                if len(col_data) > 0:
                    numeric_quality[col] = {
                        'missing_count': df[col].isnull().sum(),
                        'missing_percentage': (df[col].isnull().sum() / len(df)) * 100,
                        'infinite_values': np.isinf(col_data).sum(),
                        'zero_values': (col_data == 0).sum(),
                        'negative_values': (col_data < 0).sum(),
                        'outliers_iqr': self._count_outliers_iqr(col_data)
                    }
            
            quality_report[dataset_name] = {
                'total_rows': len(df),
                'total_columns': len(df.columns),
                'total_cells': total_cells,
                'missing_cells': missing_cells,
                'missing_percentage': missing_percentage,
                'duplicate_rows': duplicate_rows,
                'columns_with_missing': columns_with_missing.to_dict(),
                'data_type_distribution': {str(k): v for k, v in dtype_counts.items()},
                'numeric_column_quality': numeric_quality
            }
        
        return quality_report
    
    def _count_outliers_iqr(self, series):
        """Count outliers using IQR method"""
        Q1 = series.quantile(0.25)
        Q3 = series.quantile(0.75)
        IQR = Q3 - Q1
        lower_bound = Q1 - 1.5 * IQR
        upper_bound = Q3 + 1.5 * IQR
        return ((series < lower_bound) | (series > upper_bound)).sum()

# Initialize advanced analysis
if 'processor' in locals():
    print("\\n=== Initializing Advanced Analysis ===")
    advanced_analysis = AdvancedProteinAnalysis(processor)
    
    # Run large-scale feature analysis
    feature_results = advanced_analysis.large_scale_feature_analysis()
    print(f"Feature analysis completed for {len(feature_results)} datasets")
    
    # Run data quality analysis
    quality_report = advanced_analysis.generate_data_quality_report()
    print(f"Data quality analysis completed for {len(quality_report)} datasets")
    
    # Example clustering analysis (if data is available)
    available_datasets = list(processor.csv_files.keys())
    if available_datasets:
        example_dataset = available_datasets[0]
        print(f"\\nRunning clustering analysis on {example_dataset}...")
        clustering_result = advanced_analysis.gpu_accelerated_clustering_analysis(
            example_dataset, n_clusters=3, max_features=10
        )
        if clustering_result:
            print(f"Clustering completed: {clustering_result['n_clusters']} clusters found")
            for cluster_id, stats in clustering_result['cluster_statistics'].items():
                print(f"  {cluster_id}: {stats['size']} samples ({stats['percentage']:.1f}%)")
    
else:
    print("Processor not initialized. Please run the previous cell first.")