In [1]:
import logging
import math
import multiprocessing as mp
import numpy as np
import polars as pl
import psutil
import random
import time

from concurrent.futures import ProcessPoolExecutor, as_completed
from dataclasses import dataclass
from pathlib import Path
from scipy import stats
from shuffle_worker import process_batch_worker
from typing import List, Set, Dict, Tuple, Optional, NamedTuple


In [2]:
# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

In [None]:
@dataclass
class UniformityTestResult:
    """Results from uniformity testing"""
    chi2_statistic: float
    p_value: float
    degrees_of_freedom: int
    n_bins: int
    sample_size: int
    is_uniform: bool
    uniformity_score: float
    observed_frequencies: List[int]
    expected_frequencies: List[float]
    bin_edges: List[float]
    config_idx_range: Tuple[int, int]

@dataclass
class ChunkTestResult:
    """Results from testing a single chunk"""
    chunk_file: str
    chunk_index: int
    total_rows: int
    sample_size: int
    result: UniformityTestResult
    is_uniform: bool
    uniformity_score: float

class ConfigIdxUniformityTester:
    def __init__(self, 
                 config_idx_min: int,
                 config_idx_max: int,
                 alpha: float = 0.05
                 ):
        """
        Initialize config_idx uniformity tester
        
        Args:
            config_idx_min: Minimum config_idx value across all files
            config_idx_max: Maximum config_idx value across all files
            alpha: Significance level for hypothesis test
        """
        self.config_idx_min = config_idx_min
        self.config_idx_max = config_idx_max
        self.alpha = alpha
        self.n_categories = config_idx_max - config_idx_min + 1
                
        logger.info(f"Initialized config_idx uniformity tester")
        logger.info(f"Range: [{config_idx_min}, {config_idx_max}], Categories: {self.n_categories}, Alpha: {alpha}")
        
    def _test_discrete_uniformity(self, values: np.ndarray) -> UniformityTestResult:
        """
        Test uniformity for config_idx
        
        Args:
            values: Array of config_idx values
            
        Returns:
            UniformityTestResult
        """
        # Count occurrences of each value
        unique_vals, counts = np.unique(values, return_counts=True)
        
        # Create full frequency array (including zero counts)
        observed_freq = np.zeros(self.n_categories, dtype=int)
        for val, count in zip(unique_vals, counts):
            if self.config_idx_min <= val <= self.config_idx_max:
                observed_freq[val - self.config_idx_min] = count
        
        # Expected frequencies (uniform distribution)
        total_samples = len(values)
        expected_freq = np.full(self.n_categories, total_samples / self.n_categories)
        
        # Chi-squared test
        chi2_stat, p_value = stats.chisquare(observed_freq, expected_freq)
        
        # Results
        is_uniform = p_value > self.alpha
        
        # Uniformity score (normalized, 1.0 = perfectly uniform)
        max_possible_chi2 = total_samples * (self.n_categories - 1)
        uniformity_score = max(0.0, 1.0 - (chi2_stat / max_possible_chi2))
        
        # Create bin edges for discrete values
        bin_edges = list(range(self.config_idx_min, self.config_idx_max + 2))
        
        logger.debug(f"config_idx: χ²={chi2_stat:.4f}, p={p_value:.6f}, "
                    f"uniform={is_uniform}, score={uniformity_score:.4f}")
        
        return UniformityTestResult(
            chi2_statistic=chi2_stat,
            p_value=p_value,
            degrees_of_freedom=self.n_categories - 1,
            n_bins=self.n_categories,
            sample_size=total_samples,
            is_uniform=is_uniform,
            uniformity_score=uniformity_score,
            observed_frequencies=observed_freq.tolist(),
            expected_frequencies=expected_freq.tolist(),
            bin_edges=bin_edges,
            config_idx_range=(self.config_idx_min, self.config_idx_max)
        )
    
    def _load_and_sample_single_chunk(self, chunk_file: str, 
                                    sample_size: Optional[int] = None, 
                                    seed: int = 42) -> pl.DataFrame:
        """Load a single chunk file and optionally sample from it"""
        chunk_path = Path(chunk_file)
        if not chunk_path.exists():
            raise FileNotFoundError(f"Chunk file does not exist: {chunk_file}")
        
        # Load the chunk
        chunk_df = pl.read_parquet(chunk_path)
        
        # Validate that config_idx column exists
        if 'config_idx' not in chunk_df.columns:
            raise ValueError(f"Missing config_idx column in {chunk_file}")
        
        total_rows = len(chunk_df)
        
        # Sample if requested and if chunk is larger than sample size
        if sample_size is not None and total_rows > sample_size:
            logger.debug(f"Sampling {sample_size:,} rows from {total_rows:,} total rows in {chunk_path.name}")
            try:
                sampled_df = chunk_df.sample(n=sample_size, seed=seed)
            except:
                # Fallback: manual sampling with indices
                np.random.seed(seed)
                indices = np.random.choice(total_rows, size=sample_size, replace=False)
                sampled_df = chunk_df[sorted(indices)]
        else:
            logger.debug(f"Using all {total_rows:,} rows from {chunk_path.name}")
            sampled_df = chunk_df
        
        return sampled_df
    
    def _test_single_chunk(self, chunk_file: str, chunk_index: int, 
                          sample_size: Optional[int] = None, seed: int = 42) -> ChunkTestResult:
        """Test uniformity for config_idx in a single chunk"""
        chunk_path = Path(chunk_file)
        logger.info(f"Testing chunk {chunk_index + 1}: {chunk_path.name}")
        
        try:
            # Load chunk data
            chunk_df = self._load_and_sample_single_chunk(chunk_file, sample_size, seed + chunk_index)
            total_rows = len(chunk_df)
            actual_sample_size = len(chunk_df)
                                    
            config_idx_values = chunk_df.select('config_idx').to_numpy().flatten().astype(int)
            
            # Test uniformity
            result = self._test_discrete_uniformity(config_idx_values)
            
            logger.info(f"  Chunk {chunk_index + 1} results: uniform={result.is_uniform}, score={result.uniformity_score:.4f}")
            
            return ChunkTestResult(
                chunk_file=str(chunk_path),
                chunk_index=chunk_index,
                total_rows=total_rows,
                sample_size=actual_sample_size,
                result=result,
                is_uniform=result.is_uniform,
                uniformity_score=result.uniformity_score
            )
            
        except Exception as e:
            logger.error(f"Failed to test chunk {chunk_index + 1} ({chunk_path.name}): {e}")
            # Return dummy result on error
            return ChunkTestResult(
                chunk_file=str(chunk_path),
                chunk_index=chunk_index,
                total_rows=0,
                sample_size=0,
                result=None,
                is_uniform=False,
                uniformity_score=0.0
            )
    
    def test_chunks_uniformity(self, chunk_files: List[str], sample_size_per_chunk: Optional[int] = None, 
                             seed: int = 42) -> List[ChunkTestResult]:
        """Test uniformity of config_idx for each chunk"""
        logger.info(f"\n{'='*80}")
        logger.info(f"CONFIG_IDX UNIFORMITY TESTING")
        logger.info(f"{'='*80}")
        logger.info(f"Testing {len(chunk_files)} chunks")
        logger.info(f"config_idx range: [{self.config_idx_min}, {self.config_idx_max}]")
        if sample_size_per_chunk:
            logger.info(f"Sampling {sample_size_per_chunk:,} rows per chunk")
        else:
            logger.info(f"Using all rows in each chunk")
        
        results = []
        
        for chunk_index, chunk_file in enumerate(chunk_files):
            chunk_result = self._test_single_chunk(
                chunk_file, chunk_index, sample_size_per_chunk, seed)
            results.append(chunk_result)
        
        return results
    
    def print_detailed_report(self, results: List[ChunkTestResult]):
        """Print detailed report of config_idx uniformity tests"""
        
        print(f"\n{'='*100}")
        print(f"CONFIG_IDX UNIFORMITY TEST REPORT")
        print(f"{'='*100}")
        
        if not results:
            print("No results to display!")
            return
                
        # Overall statistics
        total_chunks = len(results)
        successful_chunks = len([r for r in results if r.result is not None])
        uniform_chunks = len([r for r in results if r.is_uniform])
        
        print(f"\nOverall Statistics:")
        print(f"  Total chunks: {total_chunks}")
        print(f"  Successfully tested: {successful_chunks}/{total_chunks} ({successful_chunks/total_chunks*100:.1f}%)")
        print(f"  Uniform chunks: {uniform_chunks}/{successful_chunks} ({uniform_chunks/successful_chunks*100:.1f}%)")
        
        if successful_chunks > 0:
            avg_score = np.mean([r.uniformity_score for r in results if r.result is not None])
            print(f"  Average uniformity score: {avg_score:.4f}")
        
        # Per-chunk summary (first 20 chunks)
        print(f"\nPer-Chunk Summary (first 20 chunks):")
        print(f"{'Chunk':<8} {'File':<30} {'Uniform':<8} {'Score':<10} {'p-value':<10}")
        print(f"{'-'*70}")
        
        for result in results[:20]:
            chunk_name = Path(result.chunk_file).name[:27] + "..." if len(Path(result.chunk_file).name) > 30 else Path(result.chunk_file).name
            uniform_str = "✓" if result.is_uniform else "✗"
            p_val = f"{result.result.p_value:.6f}" if result.result else "N/A"
            
            print(f"{result.chunk_index+1:<8} {chunk_name:<30} {uniform_str:<8} "
                  f"{result.uniformity_score:<10.4f} {p_val:<10}")
        
        if len(results) > 20:
            print(f"... and {len(results) - 20} more chunks")

In [4]:
@dataclass
class ShuffleConfig:
    """Configuration for shuffle operations"""
    max_memory_usage_ratio: float = 0.7  # Use 70% of available memory
    safety_margin: float = 0.9  # 90% of calculated max chunks
    min_chunks_per_batch: int = 2
    max_chunks_per_batch: int = 1000
    chunk_size_variation: Tuple[float, float] = (0.8, 1.2)  # Min, max variation
    max_workers: int = 1

In [None]:
class MemoryAwareChunkShuffler:
    def __init__(self, input_dir: str, output_dir: str, temp_dir: str = "shuffle_temp", 
                 config = None):
        self.input_dir = Path(input_dir)
        self.output_dir = Path(output_dir)
        self.temp_dir = Path(temp_dir)
        self.config = config or ShuffleConfig()
        
        # Create directories
        self.output_dir.mkdir(exist_ok=True, parents=True)
        self.temp_dir.mkdir(exist_ok=True, parents=True)
        
        # Get all input chunk files
        self.input_chunks = list(self.input_dir.glob("*.parquet"))
        if not self.input_chunks:
            raise ValueError(f"No parquet files found in {input_dir}")
        
        logger.info(f"Found {len(self.input_chunks)} input chunks")
        
        # Initialize tracking
        self.processed_chunks = set()
        self.memory_stats = self._get_memory_info()
        self.chunk_memory_usage = self._estimate_chunk_memory()
        self.optimal_batch_size = self._calculate_optimal_batch_size()
        self.required_passes = self._calculate_required_passes()
        
        # Scan for global config_idx range
        self.config_idx_min, self.config_idx_max = self._scan_config_idx_range()
        
        logger.info(f"Memory available: {self.memory_stats['available_gb']:.1f} GB")
        logger.info(f"Estimated chunk memory: {self.chunk_memory_usage / 1024**2:.1f} MB")
        logger.info(f"Optimal batch size: {self.optimal_batch_size} chunks")
        logger.info(f"Required passes: {self.required_passes}")
        logger.info(f"Global config_idx range: [{self.config_idx_min}, {self.config_idx_max}]")
    
    def _get_memory_info(self) -> Dict[str, float]:
        """Get current memory information"""
        memory = psutil.virtual_memory()
        return {
            'total_gb': memory.total / 1024**3,
            'available_gb': memory.available / 1024**3,
            'used_gb': memory.used / 1024**3,
            'percent_used': memory.percent
        }
    
    def _estimate_chunk_memory(self, sample_size: int = 3) -> float:
        """Estimate memory usage per chunk by sampling"""
        sample_files = random.sample(self.input_chunks, min(sample_size, len(self.input_chunks)))
        memory_usages = []
        
        for file_path in sample_files:
            try:
                # Load chunk and measure memory
                chunk = pl.read_parquet(file_path)
                memory_usage = chunk.estimated_size()
                memory_usages.append(memory_usage)
                logger.debug(f"Chunk {file_path.name}: {memory_usage / 1024**2:.1f} MB")
            except Exception as e:
                logger.warning(f"Could not estimate memory for {file_path}: {e}")
        
        if not memory_usages:
            # Fallback estimate
            return 100 * 1024**2  # 100MB default
        
        avg_memory = sum(memory_usages) / len(memory_usages)
        # Add buffer for processing overhead
        return avg_memory * 1.5
    
    def _calculate_optimal_batch_size(self) -> int:
        """Calculate optimal number of chunks to process in one batch"""
        available_memory = self.memory_stats['available_gb'] * 1024**3 / self.config.max_workers
        usable_memory = available_memory * self.config.max_memory_usage_ratio
        
        # Calculate theoretical max chunks
        theoretical_max = int(usable_memory / self.chunk_memory_usage)
        
        # Apply safety margin and constraints
        safe_max = int(theoretical_max * self.config.safety_margin)
        optimal_size = max(
            self.config.min_chunks_per_batch,
            min(safe_max, self.config.max_chunks_per_batch)
        )
        
        logger.debug(f"Theoretical max chunks: {theoretical_max}")
        logger.debug(f"Safe max chunks: {safe_max}")
        
        return optimal_size
    
    def _calculate_required_passes(self) -> int:
        """Calculate number of shuffle passes needed for uniform distribution"""
        n_chunks = len(self.input_chunks)
        k_batch_size = self.optimal_batch_size
        
        if k_batch_size >= n_chunks:
            # Can process all chunks at once
            return 1
        
        # Theoretical passes based on k-way merging
        theoretical_passes = math.ceil(math.log(n_chunks) / math.log(k_batch_size))
        
        # Additional mixing passes for uniformity
        mixing_passes = max(1, math.ceil(math.log2(n_chunks) * 0.3))
        
        total_passes = theoretical_passes + mixing_passes
        
        logger.debug(f"Theoretical passes: {theoretical_passes}")
        logger.debug(f"Mixing passes: {mixing_passes}")
        
        return min(total_passes, 10)  # Cap at 10 passes
    
    def _scan_config_idx_range(self) -> Tuple[int, int]:
        """Scan all input chunks to find global config_idx min and max"""
        logger.info("Scanning all chunks for config_idx range...")
        
        global_min = float('inf')
        global_max = float('-inf')
        
        for chunk_file in self.input_chunks:
            try:
                # Read just the config_idx column
                df = pl.read_parquet(chunk_file, columns=['config_idx'])
                chunk_min = df['config_idx'].min()
                chunk_max = df['config_idx'].max()
                
                global_min = min(global_min, chunk_min)
                global_max = max(global_max, chunk_max)
                
                logger.debug(f"{chunk_file.name}: config_idx range [{chunk_min}, {chunk_max}]")
            except Exception as e:
                logger.warning(f"Could not read config_idx from {chunk_file}: {e}")
        
        if global_min == float('inf') or global_max == float('-inf'):
            raise ValueError("Could not determine config_idx range from any chunks")
        
        logger.info(f"Global config_idx range: [{int(global_min)}, {int(global_max)}]")
        return int(global_min), int(global_max)
    
    def _select_random_chunks_for_batch(self, available_chunks: List[Path], 
                                      batch_size: int) -> List[Path]:
        """Select random chunks for processing, avoiding already processed ones"""
        # Filter out already processed chunks
        unprocessed_chunks = [chunk for chunk in available_chunks 
                            if chunk not in self.processed_chunks]
        
        if not unprocessed_chunks:
            # Reset for new pass
            self.processed_chunks.clear()
            unprocessed_chunks = available_chunks.copy()
            logger.info("Reset processed chunks tracker for new pass")
        
        # Select random batch
        batch_size = min(batch_size, len(unprocessed_chunks))
        selected_chunks = random.sample(unprocessed_chunks, batch_size)
        
        # Mark as processed
        self.processed_chunks.update(selected_chunks)
        
        return selected_chunks
        
    def _cleanup_files(self, files: List[Path]):
        """Clean up temporary files"""
        for file_path in files:
            try:
                if file_path.exists():
                    file_path.unlink()
                    logger.debug(f"Deleted {file_path.name}")
            except Exception as e:
                logger.warning(f"Could not delete {file_path}: {e}")
    
    def _perform_shuffle_pass_parallel(self, input_chunks: List[Path], pass_number: int, 
                                      seed: int, max_workers: int = None) -> List[Path]:
        """Perform one complete shuffle pass with parallel batch processing"""
        logger.info(f"\n{'='*50}")
        logger.info(f"SHUFFLE PASS {pass_number} (PARALLEL)")
        logger.info(f"{'='*50}")
        
        # Reset processed tracker for this pass
        self.processed_chunks.clear()
        
        # Estimate target chunk size for output
        total_input_chunks = len(input_chunks)
        
        # Sample a chunk to estimate average rows per chunk
        if input_chunks:
            sample_chunk = pl.read_parquet(input_chunks[0])
            avg_rows_per_chunk = len(sample_chunk)
        else:
            avg_rows_per_chunk = 10000
        
        # Prepare all batches upfront
        batches = []
        chunk_offset = 0
        batch_number = 0
        
        remaining_chunks = input_chunks.copy()
        random.shuffle(remaining_chunks)  # Randomize order
        
        while remaining_chunks:
            batch_number += 1
            batch_size = min(self.optimal_batch_size, len(remaining_chunks))
            batch_chunks = remaining_chunks[:batch_size]
            remaining_chunks = remaining_chunks[batch_size:]
            
            # Estimate output chunks for this batch
            estimated_rows = avg_rows_per_chunk * len(batch_chunks)
            estimated_output_chunks = math.ceil(estimated_rows / avg_rows_per_chunk)
            
            batches.append({
                'chunks': batch_chunks,
                'batch_number': batch_number,
                'chunk_offset': chunk_offset
            })
            
            # Reserve space for output chunks from this batch
            chunk_offset += estimated_output_chunks * 2  # 2x buffer for variation
        
        logger.info(f"Prepared {len(batches)} batches for parallel processing")
        
        # Process batches in parallel
        all_output_chunks = []
        
        if max_workers is None:
            max_workers = min(len(batches), self.config.max_workers if hasattr(self.config, 'max_workers') else 4)
        
        with ProcessPoolExecutor(max_workers=max_workers, mp_context=mp.get_context('spawn')) as executor:
            # Submit all batch jobs
            future_to_batch = {}
            for batch_info in batches:
                future = executor.submit(
                    process_batch_worker,
                    batch_info['chunks'],
                    pass_number,
                    batch_info['batch_number'],
                    seed + batch_info['batch_number'],
                    avg_rows_per_chunk,
                    batch_info['chunk_offset'],
                    self.temp_dir,
                    self.config.chunk_size_variation
                )
                future_to_batch[future] = batch_info['batch_number']
            
            # Collect results as they complete
            for future in as_completed(future_to_batch):
                batch_num = future_to_batch[future]
                try:
                    output_files, num_chunks, status = future.result()
                    logger.info(status)
                    all_output_chunks.extend(output_files)
                except Exception as e:
                    logger.error(f"Batch {batch_num} failed with exception: {e}")
        
        # Clean up input chunks if they're temporary
        if pass_number > 1:
            self._cleanup_files(input_chunks)
        
        logger.info(f"Pass {pass_number} complete: {len(all_output_chunks)} total chunks created")
        return all_output_chunks
    
    def shuffle_dataset(self, seed: int = 42, use_parallel: bool = True, 
                       max_workers: int = None) -> List[Path]:
        """
        Perform complete multi-pass shuffle of the dataset.
        
        Args:
            seed: Random seed for reproducibility
            use_parallel: Whether to use parallel processing
            max_workers: Maximum number of parallel workers (None = auto)
        """
        logger.info(f"\n{'='*60}")
        logger.info(f"STARTING MULTI-PASS SHUFFLE {'(PARALLEL)' if use_parallel else ''}")
        logger.info(f"{'='*60}")
        logger.info(f"Input chunks: {len(self.input_chunks)}")
        logger.info(f"Required passes: {self.required_passes}")
        logger.info(f"Batch size: {self.optimal_batch_size}")
        if use_parallel:
            logger.info(f"Max workers: {max_workers or 'auto'}")
        
        current_chunks = self.input_chunks.copy()
        
        # Perform multiple shuffle passes
        for pass_num in range(1, self.required_passes + 1):
            start_time = time.time()
            
            if use_parallel:
                current_chunks = self._perform_shuffle_pass_parallel(
                    current_chunks, pass_num, seed * pass_num, max_workers)
            else:
                # Fallback to sequential processing (your original method)
                current_chunks = self._perform_shuffle_pass(
                    current_chunks, pass_num, seed * pass_num)
            
            pass_time = time.time() - start_time
            logger.info(f"Pass {pass_num} completed in {pass_time:.1f} seconds")
            
            # Update memory stats
            self.memory_stats = self._get_memory_info()
            logger.info(f"Memory usage: {self.memory_stats['percent_used']:.1f}%")

            # Test config_idx uniformity after each pass
            tester = ConfigIdxUniformityTester(
                config_idx_min=self.config_idx_min,
                config_idx_max=self.config_idx_max,
                alpha=0.05,
            )
            results = tester.test_chunks_uniformity(
                [str(c) for c in current_chunks], 10000, 42)
            tester.print_detailed_report(results)
        
        # Move final chunks to output directory
        final_chunks = []
        for i, chunk_file in enumerate(current_chunks):
            final_path = self.output_dir / f"chunk_{i:04d}.parquet"
            try:
                chunk_file.rename(final_path)
                final_chunks.append(final_path)
            except Exception as e:
                logger.error(f"Failed to move {chunk_file} to {final_path}: {e}")

        # Final uniformity test
        tester = ConfigIdxUniformityTester(
            config_idx_min=self.config_idx_min,
            config_idx_max=self.config_idx_max,
            alpha=0.05,
        )
        results = tester.test_chunks_uniformity(
            [str(c) for c in final_chunks], 10000, 42)
        tester.print_detailed_report(results)
        
        # Clean up temp directory
        try:
            for temp_file in self.temp_dir.glob("*.parquet"):
                temp_file.unlink()
        except Exception as e:
            logger.warning(f"Error cleaning temp directory: {e}")
        
        logger.info(f"\n{'='*60}")
        logger.info(f"SHUFFLE COMPLETE!")
        logger.info(f"{'='*60}")
        logger.info(f"Final output: {len(final_chunks)} chunks in {self.output_dir}")
        
        return final_chunks

In [6]:
# Configure shuffle parameters
config = ShuffleConfig(
    max_memory_usage_ratio=0.4,  # Use 80% of available memory
    safety_margin=0.9,           # 90% safety margin
    min_chunks_per_batch=2,
    max_chunks_per_batch=100,
    max_workers=4
)

In [None]:
level = 2
cube_size = 2**level

image_source = './data/shapenet_config_ortho_vis_1_128'

angle_id = 18


In [None]:
data_dir = Path(f'./data/chunked/level_{level}')

input_data_dir = data_dir / 'chunked' / f'level_{level}'
output_data_dir = data_dir / 'shuffled' / f'level_{level}'

In [8]:
for joined_path in input_data_dir.glob('*'):
    # Create shuffler
    shuffler = MemoryAwareChunkShuffler(
        input_dir=joined_path,
        output_dir=output_data_dir / joined_path.name, 
        temp_dir="shuffle_temp",
        config=config
    )
    
    # Perform shuffle
    final_chunks = shuffler.shuffle_dataset(seed=42)        

    print(f"{joined_path.name} Shuffle complete! Created {len(final_chunks)} shuffled chunks.")

2025-10-05 23:10:50,296 - INFO - Found 51 input chunks


AttributeError: 'MemoryAwareChunkShuffler' object has no attribute 'max_workers'