In [1]:
import pandas as pd
import numpy as np
import json
import os
import shutil
from typing import Tuple, List, Dict, Any, Optional
import pickle

In [13]:
class DatasetProcessor:
    """Complete dataset processor for CoDi with automatic fixes and validation"""
    
    def __init__(self, categorical_threshold: int = 20, numeric_categorical_threshold: float = 0.05):
        self.categorical_threshold = categorical_threshold
        self.numeric_categorical_threshold = numeric_categorical_threshold
    
    def auto_detect_column_types(self, df: pd.DataFrame) -> Tuple[List[str], List[str]]:
        """Automatically detect continuous and categorical columns with improved logic"""
        continuous_cols = []
        categorical_cols = []
        
        print("Analyzing column types...")
        for col in df.columns:
            col_data = df[col].dropna()  # Remove NaN for analysis
            unique_count = col_data.nunique()
            total_count = len(col_data)
            unique_ratio = unique_count / total_count if total_count > 0 else 0
            
            # Check if column contains only integers (potential categorical)
            is_integer_like = False
            if pd.api.types.is_numeric_dtype(col_data):
                is_integer_like = col_data.apply(lambda x: float(x).is_integer()).all()
            
            # Enhanced decision logic
            if pd.api.types.is_numeric_dtype(col_data):
                # Special case: floating point values that are actually discrete
                if is_integer_like and (unique_count <= self.categorical_threshold or unique_ratio < self.numeric_categorical_threshold):
                    categorical_cols.append(col)
                    print(f"'{col}': Numeric categorical ({unique_count} unique, ratio: {unique_ratio:.3f})")
                # Special case: many decimal values suggest continuous
                elif not is_integer_like and unique_count > self.categorical_threshold:
                    continuous_cols.append(col)
                    print(f"'{col}': Continuous decimal ({unique_count} unique)")
                # Default numeric logic
                elif unique_count <= self.categorical_threshold or unique_ratio < self.numeric_categorical_threshold:
                    categorical_cols.append(col)
                    print(f"'{col}': Numeric categorical ({unique_count} unique, ratio: {unique_ratio:.3f})")
                else:
                    continuous_cols.append(col)
                    print(f"'{col}': Continuous ({unique_count} unique)")
            else:
                # Non-numeric -> categorical
                categorical_cols.append(col)
                print(f"'{col}': Text categorical ({unique_count} unique)")
        
        return continuous_cols, categorical_cols
    
    def preprocess_data(self, df: pd.DataFrame) -> Tuple[pd.DataFrame, Dict]:
        """Enhanced preprocessing with better missing value handling"""
        df_processed = df.copy()
        categorical_mappings = {}
        
        print("\nPreprocessing data...")
        
        # Handle missing values
        for col in df_processed.columns:
            if df_processed[col].isnull().any():
                null_count = df_processed[col].isnull().sum()
                if pd.api.types.is_numeric_dtype(df_processed[col]):
                    df_processed[col].fillna(df_processed[col].median(), inplace=True)
                    print(f"Filled {null_count} missing values in '{col}' with median")
                else:
                    mode_val = df_processed[col].mode()
                    if len(mode_val) > 0:
                        df_processed[col].fillna(mode_val[0], inplace=True)
                    else:
                        df_processed[col].fillna('unknown', inplace=True)
                    print(f"Filled {null_count} missing values in '{col}' with mode/unknown")
        
        # Encode categorical variables with proper indexing
        for col in df_processed.columns:
            if not pd.api.types.is_numeric_dtype(df_processed[col]):
                unique_vals = sorted(df_processed[col].unique())
                mapping = {val: idx for idx, val in enumerate(unique_vals)}
                categorical_mappings[col] = {
                    'mapping': mapping,
                    'reverse_mapping': {idx: val for val, idx in mapping.items()}
                }
                df_processed[col] = df_processed[col].map(mapping)
                print(f"Encoded '{col}': {len(unique_vals)} categories -> [0, {len(unique_vals)-1}]")
        
        return df_processed, categorical_mappings
    
    def validate_and_fix_categorical_data(self, data: np.ndarray, columns: List[Dict]) -> Tuple[np.ndarray, List[Dict]]:
        """Validate and fix categorical columns to ensure proper 0-based indexing"""
        print("\nValidating and fixing categorical data...")
        fixed_data = data.copy()
        fixed_columns = [col.copy() for col in columns]
        
        for i, col in enumerate(fixed_columns):
            if col['type'] == 'categorical':
                col_data = fixed_data[:, i].astype(int)
                unique_vals = sorted(np.unique(col_data))
                
                # Check if values are properly 0-based
                expected_range = list(range(len(unique_vals)))
                if unique_vals != expected_range:
                    print(f"Fixing '{col['name']}': {unique_vals} -> {expected_range}")
                    
                    # Create mapping to fix indexing
                    mapping = {old_val: new_val for new_val, old_val in enumerate(unique_vals)}
                    
                    # Apply mapping
                    for old_val, new_val in mapping.items():
                        fixed_data[fixed_data[:, i] == old_val, i] = new_val
                    
                    # Update column metadata
                    col['size'] = len(unique_vals)
                    col['i2s'] = [str(val) for val in unique_vals]
                else:
                    print(f"'{col['name']}': Already properly indexed [0, {len(unique_vals)-1}]")
        
        return fixed_data, fixed_columns
    
    def create_codi_format(self, dataset_name: str, train_data: np.ndarray, test_data: np.ndarray, 
                          column_names: List[str], con_idx: List[int], dis_idx: List[int], 
                          categorical_mappings: Dict) -> Dict:
        """Create CoDi-compatible format with validation"""
        
        # Validate and fix categorical data
        all_data = np.vstack([train_data, test_data])
        
        # Create initial columns structure
        columns = []
        for i, col_name in enumerate(column_names):
            if i in con_idx:
                col_data = all_data[:, i]
                columns.append({
                    "name": col_name,
                    "type": "continuous",
                    "min": float(np.min(col_data)),
                    "max": float(np.max(col_data))
                })
            else:
                col_data = all_data[:, i].astype(int)
                unique_vals = sorted(np.unique(col_data))
                
                # Create i2s mapping
                if col_name in categorical_mappings:
                    reverse_mapping = categorical_mappings[col_name]['reverse_mapping']
                    i2s = [str(reverse_mapping.get(idx, str(idx))) for idx in unique_vals]
                else:
                    i2s = [str(val) for val in unique_vals]
                
                columns.append({
                    "name": col_name,
                    "type": "categorical",
                    "size": len(unique_vals),
                    "i2s": i2s
                })
        
        # Fix categorical data indexing
        fixed_train, fixed_columns = self.validate_and_fix_categorical_data(train_data, columns)
        fixed_test, _ = self.validate_and_fix_categorical_data(test_data, columns)
        
        # Determine problem type
        last_col_idx = len(column_names) - 1
        if last_col_idx in dis_idx:
            last_col_name = column_names[last_col_idx]
            if last_col_name in categorical_mappings:
                num_classes = len(categorical_mappings[last_col_name]['mapping'])
            else:
                num_classes = len(np.unique(all_data[:, last_col_idx]))
            
            problem_type = "binary_classification" if num_classes == 2 else "multiclass_classification"
        else:
            problem_type = "regression"
        
        # Save fixed data
        os.makedirs('tabular_datasets', exist_ok=True)
        np.savez(f'tabular_datasets/{dataset_name}.npz', train=fixed_train, test=fixed_test)
        
        # Create metadata
        codi_meta = {
            "columns": fixed_columns,
            "problem_type": problem_type
        }
        
        with open(f'tabular_datasets/{dataset_name}.json', 'w') as f:
            json.dump(codi_meta, f, indent=2)
        
        return codi_meta
    
    def process_dataset(self, csv_path: str, dataset_name: str, 
                       force_continuous: Optional[List[str]] = None,
                       force_categorical: Optional[List[str]] = None,
                       test_split: float = 0.2,
                       create_backup: bool = True) -> Dict[str, Any]:
        """Complete dataset processing pipeline"""
        
        force_continuous = force_continuous or []
        force_categorical = force_categorical or []
        
        print(f"Processing dataset: {csv_path} -> {dataset_name}")
        print("="*60)
        
        # Create backup if requested
        if create_backup and os.path.exists(f'tabular_datasets/{dataset_name}.npz'):
            print("Creating backup of existing dataset...")
            shutil.copy(f'tabular_datasets/{dataset_name}.npz', f'tabular_datasets/{dataset_name}_backup.npz')
            if os.path.exists(f'tabular_datasets/{dataset_name}.json'):
                shutil.copy(f'tabular_datasets/{dataset_name}.json', f'tabular_datasets/{dataset_name}_backup.json')
        
        # Load and analyze data
        df = pd.read_csv(csv_path)
        print(f"Original shape: {df.shape}")
        print(f"Columns: {list(df.columns)}")
        
        # Auto-detect column types
        continuous_cols, categorical_cols = self.auto_detect_column_types(df)
        
        # Apply manual overrides
        if force_continuous or force_categorical:
            print(f"\nApplying manual overrides...")
            for col in force_continuous:
                if col in categorical_cols:
                    categorical_cols.remove(col)
                if col not in continuous_cols:
                    continuous_cols.append(col)
                print(f"Forced '{col}' to continuous")
            
            for col in force_categorical:
                if col in continuous_cols:
                    continuous_cols.remove(col)
                if col not in categorical_cols:
                    categorical_cols.append(col)
                print(f"Forced '{col}' to categorical")
        
        print(f"\nFinal column assignment:")
        print(f"Continuous ({len(continuous_cols)}): {continuous_cols}")
        print(f"Categorical ({len(categorical_cols)}): {categorical_cols}")
        
        # Preprocess data
        df_processed, categorical_mappings = self.preprocess_data(df)
        
        # Get indices
        con_idx = [df_processed.columns.get_loc(col) for col in continuous_cols]
        dis_idx = [df_processed.columns.get_loc(col) for col in categorical_cols]
        
        # Split data
        data = df_processed.values.astype(np.float32)
        n_samples, n_features = data.shape
        n_test = int(n_samples * test_split)
        
        # Shuffle and split
        np.random.seed(42)  # For reproducibility
        indices = np.random.permutation(n_samples)
        test_data = data[indices[:n_test]]
        train_data = data[indices[n_test:]]
        
        print(f"\nData split:")
        print(f"Training samples: {len(train_data)}")
        print(f"Test samples: {len(test_data)}")
        
        # Create CoDi format with validation and fixes
        codi_meta = self.create_codi_format(
            dataset_name, train_data, test_data, 
            df_processed.columns.tolist(), con_idx, dis_idx, categorical_mappings
        )
        
        print(f"\nDataset '{dataset_name}' processed successfully!")
        print(f"Problem type: {codi_meta['problem_type']}")
        print(f"Files created:")
        print(f"  - tabular_datasets/{dataset_name}.npz")
        print(f"  - tabular_datasets/{dataset_name}.json")
        
        return {
            'dataset_name': dataset_name,
            'shape': (n_samples, n_features),
            'problem_type': codi_meta['problem_type'],
            'continuous_columns': continuous_cols,
            'categorical_columns': categorical_cols,
            'train_samples': len(train_data),
            'test_samples': len(test_data)
        }
    
    def validate_dataset(self, dataset_name: str) -> bool:
        """Validate that a dataset is properly formatted for CoDi"""
        try:
            data = np.load(f'tabular_datasets/{dataset_name}.npz')
            with open(f'tabular_datasets/{dataset_name}.json', 'r') as f:
                meta = json.load(f)
            
            train_data = data['train']
            print(f"Validating dataset: {dataset_name}")
            print("="*40)
            
            issues_found = False
            
            for i, col in enumerate(meta['columns']):
                if col['type'] == 'categorical':
                    col_data = train_data[:, i].astype(int)
                    unique_vals = np.unique(col_data)
                    min_val, max_val = np.min(col_data), np.max(col_data)
                    
                    if min_val < 0 or max_val >= col['size']:
                        print(f"{col['name']}: values [{min_val}, {max_val}] outside expected [0, {col['size']-1}]")
                        issues_found = True
                    else:
                        print(f"{col['name']}: properly indexed [0, {col['size']-1}]")
                else:
                    print(f"{col['name']}: continuous column OK")
            
            if not issues_found:
                print(f"\nDataset '{dataset_name}' is valid for CoDi training!")
                return True
            else:
                print(f"\nDataset '{dataset_name}' has validation issues.")
                return False
                
        except Exception as e:
            print(f"❌ Validation failed: {e}")
            return False

# Usage functions
def quick_process(csv_path: str, dataset_name: str, **kwargs) -> Dict[str, Any]:
    """Quick processing with default settings"""
    processor = DatasetProcessor()
    return processor.process_dataset(csv_path, dataset_name, **kwargs)

def process_with_overrides(csv_path: str, dataset_name: str, 
                          continuous_cols: List[str] = None,
                          categorical_cols: List[str] = None, **kwargs) -> Dict[str, Any]:
    """Process with manual column type specification"""
    processor = DatasetProcessor()
    return processor.process_dataset(
        csv_path, dataset_name, 
        force_continuous=continuous_cols,
        force_categorical=categorical_cols,
        **kwargs
    )    

In [14]:
iris_original = pd.read_csv("raw_data/iris.csv")
iris_original.isnull().sum()

Id               0
SepalLengthCm    0
SepalWidthCm     0
PetalLengthCm    0
PetalWidthCm     0
Species          0
dtype: int64

In [15]:
# Example usage
processor = DatasetProcessor(categorical_threshold=15, numeric_categorical_threshold=0.05)

# Process a dataset
result = processor.process_dataset(
    csv_path='raw_data/iris.csv',
    dataset_name='iris',
    test_split=0.2,
    create_backup=True
)

# Validate the result
processor.validate_dataset('iris')

Processing dataset: raw_data/iris.csv -> iris
Creating backup of existing dataset...
Original shape: (150, 6)
Columns: ['Id', 'SepalLengthCm', 'SepalWidthCm', 'PetalLengthCm', 'PetalWidthCm', 'Species']
Analyzing column types...
'Id': Continuous (150 unique)
'SepalLengthCm': Continuous decimal (35 unique)
'SepalWidthCm': Continuous decimal (23 unique)
'PetalLengthCm': Continuous decimal (43 unique)
'PetalWidthCm': Continuous decimal (22 unique)
'Species': Text categorical (3 unique)

Final column assignment:
Continuous (5): ['Id', 'SepalLengthCm', 'SepalWidthCm', 'PetalLengthCm', 'PetalWidthCm']
Categorical (1): ['Species']

Preprocessing data...
Encoded 'Species': 3 categories -> [0, 2]

Data split:
Training samples: 120
Test samples: 30

Validating and fixing categorical data...
'Species': Already properly indexed [0, 2]

Validating and fixing categorical data...
'Species': Already properly indexed [0, 2]

Dataset 'iris' processed successfully!
Problem type: multiclass_classification

True

python main.py --data iris --logdir CoDi_exp

In [None]:
import subprocess
import sys

# Run using subprocess directly
cmd = [
    sys.executable, 'main.py',
    '--data', 'iris',
    '--total_epochs_both', '20',
    '--training_batch_size', '1024',
    '--num_samples', '500',
    '--logdir', './CoDi_exp',
    '--train'
]

result = subprocess.run(cmd, capture_output=True, text=True)
print("STDOUT:", result.stdout)
print("STDERR:", result.stderr)

STDOUT: 
STDERR: I0902 20:54:20.012278 140093344847680 main.py:93] Co-evolving Conditional Diffusion models
I0902 20:54:23.486226 140093344847680 co_evolving_condition.py:69] Continuous model params: 450829
I0902 20:54:23.486808 140093344847680 co_evolving_condition.py:70] Discrete model params: 675285
I0902 20:54:23.487200 140093344847680 co_evolving_condition.py:76] Total steps: 20
I0902 20:54:23.487432 140093344847680 co_evolving_condition.py:77] Sample steps: 2000
I0902 20:54:23.487679 140093344847680 co_evolving_condition.py:78] Continuous: 120, 5
I0902 20:54:23.487900 140093344847680 co_evolving_condition.py:79] Discrete: 120, 3
I0902 20:54:24.515496 140093344847680 co_evolving_condition.py:126] Epoch :0, diffusion continuous loss: 0.924, discrete loss: 0.325
I0902 20:54:24.516203 140093344847680 co_evolving_condition.py:127] Epoch :0, CL continuous loss: 0.993, discrete loss: 1.000
I0902 20:54:24.516596 140093344847680 co_evolving_condition.py:128] Epoch :0, Total continuous los

In [None]:
# Enhanced version without dataset index
def combine_all_synthetic_datasets():
    """Combine all synthetic datasets into one clean dataset (no dataset_idx column)"""
    # Load data
    with open('./CoDi_exp/synthetic_data.pkl', 'rb') as f:
        synthetic_datasets = pickle.load(f)
    
    with open('tabular_datasets/iris.json', 'r') as f:
        metadata = json.load(f)
    
    # Get column names
    column_names = [col['name'] for col in metadata['columns']]
    
    # Combine all raw data first (more efficient)
    combined_raw_data = np.vstack(synthetic_datasets)
    
    # Create single DataFrame
    combined_df = pd.DataFrame(combined_raw_data, columns=column_names)
    
    # Map categorical values
    for col_info in metadata['columns']:
        if col_info['type'] == 'categorical' and 'i2s' in col_info:
            col_name = col_info['name']
            i2s = col_info['i2s']
            combined_df[col_name] = combined_df[col_name].round().astype(int).apply(
                lambda x: i2s[x] if 0 <= x < len(i2s) else f"unknown_{x}"
            )
    
    # print(f"Combined {len(synthetic_datasets)} datasets")
    # print(f"Total shape: {combined_df.shape}")
    # print(f"Individual dataset shapes: {[data.shape for data in synthetic_datasets]}")
    
    return combined_df


# Get clean combined dataset
combined_clean = combine_all_synthetic_datasets()
print("\nClean combined dataset:")
display(combined_clean.head())
display(combined_clean.describe())

Combined 1 datasets
Total shape: (500, 6)
Individual dataset shapes: [(500, 6)]

Clean combined dataset:


Unnamed: 0,Id,SepalLengthCm,SepalWidthCm,PetalLengthCm,PetalWidthCm,Species
0,59.968529,7.7,3.014778,6.516171,0.615078,Iris-virginica
1,40.079647,7.462254,3.787957,3.519754,1.371104,Iris-versicolor
2,125.315575,7.409658,4.088545,2.945665,0.434851,Iris-setosa
3,79.446953,5.459857,3.185479,6.396809,2.073986,Iris-virginica
4,24.423725,6.78831,2.891448,4.659242,1.856624,Iris-setosa


Unnamed: 0,Id,SepalLengthCm,SepalWidthCm,PetalLengthCm,PetalWidthCm
count,500.0,500.0,500.0,500.0,500.0
mean,77.346307,6.272344,3.235651,4.446397,1.35667
std,48.727254,1.130016,0.775684,1.873476,0.789619
min,1.0,4.300436,2.0,1.002051,0.1
25%,32.615707,5.242195,2.52309,2.744202,0.613988
50%,76.127281,6.425347,3.228642,4.85456,1.348977
75%,120.902212,7.407302,4.001225,6.245686,2.097555
max,150.0,7.7,4.4,6.7,2.5


In [11]:
iris_original.describe()

Unnamed: 0,Id,SepalLengthCm,SepalWidthCm,PetalLengthCm,PetalWidthCm
count,150.0,150.0,150.0,150.0,150.0
mean,75.5,5.843333,3.054,3.758667,1.198667
std,43.445368,0.828066,0.433594,1.76442,0.763161
min,1.0,4.3,2.0,1.0,0.1
25%,38.25,5.1,2.8,1.6,0.3
50%,75.5,5.8,3.0,4.35,1.3
75%,112.75,6.4,3.3,5.1,1.8
max,150.0,7.9,4.4,6.9,2.5


In [12]:
iris_original.shape

(150, 6)