In [None]:
# Make sure paths are correct for the imports

import os
import sys

notebook_dir = os.path.abspath("")
parent_dir = os.path.dirname(notebook_dir)
grandparent_dir = os.path.dirname(parent_dir)


sys.path.append(grandparent_dir)

In [None]:
import os
import sys
import csv
import psycopg2
from psycopg2 import sql
from dotenv import load_dotenv
from datetime import datetime
import re
from pathlib import Path

load_dotenv()

In [None]:
DB_CONFIG = {
    'host': os.getenv('PG_HOST', 'localhost'),
    'port': os.getenv('DB_PORT', '5432'),
    'database': os.getenv('PG_DATABASE', 'your_database'),
    'user': os.getenv('PG_USERNAME', 'postgres'),
    'password': os.getenv('PG_PASSWORD', '')
}

In [None]:
# Data directory
DATA_DIR = '../data'

In [None]:
class CSVSchemaInferrer:
    """Infers PostgreSQL schema from CSV file."""
    
    @staticmethod
    def sanitize_column_name(column_name):
        """Convert column name to PostgreSQL-friendly format."""
        # Convert to lowercase, replace spaces and special chars with underscore
        sanitized = re.sub(r'[^a-zA-Z0-9_]', '_', column_name.lower())
        # Remove consecutive underscores
        sanitized = re.sub(r'_+', '_', sanitized)
        # Remove leading/trailing underscores
        sanitized = sanitized.strip('_')
        # Ensure it doesn't start with a number
        if sanitized and sanitized[0].isdigit():
            sanitized = f'col_{sanitized}'
        return sanitized or 'unnamed_column'
    
    @staticmethod
    def infer_data_type(values, column_name):
        """Infer PostgreSQL data type from sample values."""
        # Remove None/empty values for analysis
        non_empty_values = [v for v in values if v and str(v).strip()]
        
        if not non_empty_values:
            return 'TEXT'
        
        # Check for integer
        if all(CSVSchemaInferrer._is_integer(v) for v in non_empty_values):
            max_val = max(int(v) for v in non_empty_values)
            min_val = min(int(v) for v in non_empty_values)
            if min_val >= -32768 and max_val <= 32767:
                return 'SMALLINT'
            elif min_val >= -2147483648 and max_val <= 2147483647:
                return 'INTEGER'
            else:
                return 'BIGINT'
        
        # Check for numeric/decimal
        if all(CSVSchemaInferrer._is_numeric(v) for v in non_empty_values):
            return 'NUMERIC'
        
        # Check for date/timestamp
        if all(CSVSchemaInferrer._is_date(v) for v in non_empty_values):
            # Check if it includes time component
            if any('T' in str(v) or ':' in str(v) for v in non_empty_values):
                return 'TIMESTAMP'
            return 'DATE'
        
        # Check for boolean
        if all(str(v).strip().upper() in ['TRUE', 'FALSE', 'T', 'F', '1', '0', 'YES', 'NO'] 
               for v in non_empty_values):
            return 'BOOLEAN'
        
        # Default to VARCHAR with appropriate length
        max_length = max(len(str(v)) for v in non_empty_values)
        if max_length <= 50:
            return 'VARCHAR(50)'
        elif max_length <= 255:
            return 'VARCHAR(255)'
        else:
            return 'TEXT'
    
    @staticmethod
    def _is_integer(value):
        """Check if value is an integer."""
        try:
            int(value)
            return '.' not in str(value)
        except (ValueError, TypeError):
            return False
    
    @staticmethod
    def _is_numeric(value):
        """Check if value is numeric."""
        try:
            float(value)
            return True
        except (ValueError, TypeError):
            return False
    
    @staticmethod
    def _is_date(value):
        """Check if value is a date."""
        date_formats = [
            '%Y-%m-%d', '%m/%d/%Y', '%d/%m/%Y', 
            '%Y-%m-%d %H:%M:%S', '%m/%d/%Y %H:%M:%S',
            '%Y-%m-%dT%H:%M:%S', '%Y-%m-%d %H:%M:%S.%f'
        ]
        for fmt in date_formats:
            try:
                datetime.strptime(str(value).strip(), fmt)
                return True
            except ValueError:
                continue
        return False
    
    @staticmethod
    def analyze_csv(csv_file_path, sample_size=100):
        """Analyze CSV and return schema information."""
        with open(csv_file_path, 'r', encoding='utf-8') as f:
            reader = csv.DictReader(f)
            original_columns = reader.fieldnames
            
            # Collect sample values
            sample_data = {col: [] for col in original_columns}
            for i, row in enumerate(reader):
                if i >= sample_size:
                    break
                for col in original_columns:
                    sample_data[col].append(row.get(col, ''))
        
        # Build schema
        schema = []
        column_mapping = {}  # Maps original column names to sanitized names
        
        for orig_col in original_columns:
            sanitized_col = CSVSchemaInferrer.sanitize_column_name(orig_col)
            data_type = CSVSchemaInferrer.infer_data_type(sample_data[orig_col], orig_col)
            schema.append((sanitized_col, data_type))
            column_mapping[orig_col] = sanitized_col
        
        return schema, column_mapping



In [None]:

class PostgreSQLTableManager:
    """Manages PostgreSQL table creation and data loading."""
    
    def __init__(self, db_config):
        self.db_config = db_config
        self.conn = None
        self.cursor = None
    
    def connect(self):
        """Establish database connection."""
        self.conn = psycopg2.connect(**self.db_config)
        self.cursor = self.conn.cursor()
    
    def disconnect(self):
        """Close database connection."""
        if self.cursor:
            self.cursor.close()
        if self.conn:
            self.conn.close()
    
    def create_table(self, table_name, schema, drop_if_exists=False):
        """Create table with inferred schema."""
        if drop_if_exists:
            drop_sql = sql.SQL("DROP TABLE IF EXISTS {} CASCADE;").format(
                sql.Identifier(table_name)
            )
            self.cursor.execute(drop_sql)
            print(f"Dropped existing table: {table_name}")
        
        # Build CREATE TABLE statement
        column_defs = []
        for col_name, col_type in schema:
            column_defs.append(f"{col_name} {col_type}")
        
        create_sql = f"""
        CREATE TABLE IF NOT EXISTS {table_name} (
            id SERIAL PRIMARY KEY,
            {', '.join(column_defs)},
            loaded_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
        );
        """
        
        self.cursor.execute(create_sql)
        self.conn.commit()
        print(f"Table '{table_name}' created successfully.")
    
    def create_indexes(self, table_name, index_columns):
        """Create indexes on specified columns."""
        for col in index_columns:
            index_name = f"idx_{table_name}_{col}"
            index_sql = sql.SQL(
                "CREATE INDEX IF NOT EXISTS {} ON {} ({});"
            ).format(
                sql.Identifier(index_name),
                sql.Identifier(table_name),
                sql.Identifier(col)
            )
            try:
                self.cursor.execute(index_sql)
            except Exception as e:
                print(f"Warning: Could not create index on {col}: {e}")
        
        self.conn.commit()
        print(f"Indexes created for table '{table_name}'.")
    
    def load_csv_data(self, table_name, csv_file_path, column_mapping):
        """Load CSV data into table."""
        with open(csv_file_path, 'r', encoding='utf-8') as f:
            reader = csv.DictReader(f)
            original_columns = reader.fieldnames
            sanitized_columns = [column_mapping[col] for col in original_columns]
            
            # Prepare INSERT statement
            insert_sql = sql.SQL(
                "INSERT INTO {} ({}) VALUES ({});"
            ).format(
                sql.Identifier(table_name),
                sql.SQL(', ').join(map(sql.Identifier, sanitized_columns)),
                sql.SQL(', ').join(sql.Placeholder() * len(sanitized_columns))
            )
            
            rows_inserted = 0
            batch_size = 1000
            batch = []
            
            for row in reader:
                values = []
                for orig_col in original_columns:
                    value = row.get(orig_col, '').strip()
                    values.append(None if value == '' else value)
                
                batch.append(values)
                
                if len(batch) >= batch_size:
                    self.cursor.executemany(insert_sql, batch)
                    rows_inserted += len(batch)
                    batch = []
            
            # Insert remaining rows
            if batch:
                self.cursor.executemany(insert_sql, batch)
                rows_inserted += len(batch)
            
            self.conn.commit()
            print(f"Loaded {rows_inserted} rows into '{table_name}'.")
            return rows_inserted


def get_table_name_from_filename(filename):
    """Convert filename to table name."""
    # Remove extension
    name = Path(filename).stem
    # Sanitize
    name = re.sub(r'[^a-zA-Z0-9_]', '_', name.lower())
    name = re.sub(r'_+', '_', name)
    return name.strip('_')


def load_all_csvs(data_directory, db_config, drop_existing=False):
    """Load all CSV files from directory into PostgreSQL."""
    data_dir = Path(data_directory)
    csv_files = list(data_dir.glob('*.csv'))
    
    if not csv_files:
        print(f"No CSV files found in {data_directory}")
        return
    
    print(f"Found {len(csv_files)} CSV files to process.\n")
    
    manager = PostgreSQLTableManager(db_config)
    manager.connect()
    
    results = []
    
    try:
        for csv_file in csv_files:
            print(f"\n{'='*60}")
            print(f"Processing: {csv_file.name}")
            print(f"{'='*60}")
            
            # Generate table name
            table_name = get_table_name_from_filename(csv_file.name)
            print(f"Table name: {table_name}")
            
            # Analyze CSV and infer schema
            schema, column_mapping = CSVSchemaInferrer.analyze_csv(str(csv_file))
            print(f"Detected {len(schema)} columns:")
            for col_name, col_type in schema:
                print(f"  - {col_name}: {col_type}")
            
            # Create table
            manager.create_table(table_name, schema, drop_if_exists=drop_existing)
            
            # Load data
            rows_loaded = manager.load_csv_data(table_name, str(csv_file), column_mapping)
            
            # Create indexes on common column patterns
            index_candidates = [col for col, _ in schema 
                              if any(keyword in col for keyword in 
                                   ['id', 'date', 'status', 'category', 'priority', 'region'])]
            if index_candidates:
                print(f"Creating indexes on: {', '.join(index_candidates)}")
                manager.create_indexes(table_name, index_candidates)
            
            results.append({
                'file': csv_file.name,
                'table': table_name,
                'rows': rows_loaded,
                'columns': len(schema)
            })
            
    except Exception as e:
        print(f"\nError during processing: {e}")
        manager.conn.rollback()
        raise
    finally:
        manager.disconnect()
    
    # Print summary
    print(f"\n{'='*60}")
    print("SUMMARY")
    print(f"{'='*60}")
    for result in results:
        print(f"✓ {result['file']} → {result['table']}: "
              f"{result['rows']} rows, {result['columns']} columns")
    print(f"\nTotal: {len(results)} tables created successfully.")



In [None]:
# Set drop_existing=True to recreate tables
load_all_csvs(DATA_DIR, DB_CONFIG, drop_existing=True)