In [1]:
# Import necessary libraries
import os
import sys
import yaml
import logging
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from datetime import datetime, timedelta
from pathlib import Path

# Add the project root to the path
sys.path.append('..')

# Import project modules
from src.data.data_processor import DataProcessor
from src.features.feature_generator import FeatureGenerator
from src.features.feature_selector import FeatureSelector
from src.models.model_trainer import ModelTrainer
from src.visualization.visualizer import Visualizer


## Setup Logging

In [2]:
# Configure logging
def setup_logging(config):
    log_level = getattr(logging, config['logging']['level'])
    log_format = '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
    
    # Create logger
    logger = logging.getLogger('decohere')
    logger.setLevel(log_level)
    
    # Clear existing handlers
    logger.handlers = []
    
    # Create console handler if enabled
    if config['logging'].get('console', True):
        console_handler = logging.StreamHandler()
        console_handler.setLevel(log_level)
        console_handler.setFormatter(logging.Formatter(log_format))
        logger.addHandler(console_handler)
    
    # Create file handler if log file is specified
    if 'file' in config['logging']:
        log_dir = os.path.dirname(config['logging']['file'])
        os.makedirs(log_dir, exist_ok=True)
        file_handler = logging.FileHandler(config['logging']['file'])
        file_handler.setLevel(log_level)
        file_handler.setFormatter(logging.Formatter(log_format))
        logger.addHandler(file_handler)
    
    return logger


## Load Configuration

In [3]:
# Load the main configuration file
def load_main_config():
    config_path = '../config/config.yaml'
    with open(config_path, 'r') as f:
        config = yaml.safe_load(f)
    return config

# Load mode-specific configuration
def load_mode_config(mode):
    main_config = load_main_config()
    mode_config_path = main_config['modes'][mode]['config_file']
    
    # Convert relative path to absolute path
    if not os.path.isabs(mode_config_path):
        mode_config_path = os.path.join('..', mode_config_path)
    
    with open(mode_config_path, 'r') as f:
        mode_config = yaml.safe_load(f)
    
    # Merge mode config with main config
    merged_config = {**main_config, **mode_config}
    return merged_config

# Load the main configuration
main_config = load_main_config()
print(f"Loaded main configuration from {os.path.abspath('../config/config.yaml')}")


Loaded main configuration from /home/siddharth.johri/DECOHERE/config/config.yaml


In [4]:
from src.data.efficient_data_storage import EfficientDataStorage

# Initialize data storage with config
data_storage = EfficientDataStorage(config=main_config)

## Select Pipeline Mode

Choose the mode to run the pipeline in:
- **day**: Process a single day of data
- **week**: Process a week of data
- **year**: Process a year of data

In [5]:
# Select the mode
import ipywidgets as widgets
from IPython.display import display

mode_dropdown = widgets.Dropdown(
    options=['day', 'week', 'year'],
    value='day',
    description='Mode:',
    disabled=False,
)

date_picker = widgets.DatePicker(
    description='Date:',
    disabled=False,
    value=datetime.strptime('2024-09-02', '%Y-%m-%d').date()
    #value= datetime.now().date()
)

display(widgets.VBox([mode_dropdown, date_picker]))

# Function to get the selected mode and date
def get_mode_and_date():
    mode = mode_dropdown.value
    date = date_picker.value
    return mode, date

# Load the configuration for the selected mode
def load_selected_config():
    mode, _ = get_mode_and_date()
    config = load_mode_config(mode)
    return config


VBox(children=(Dropdown(description='Mode:', options=('day', 'week', 'year'), value='day'), DatePicker(value=d…

In [6]:

# tdf = pd.read_parquet('/home/siddharth.johri/DECOHERE/data/raw/financials/financials_2024_09.pq')
# print(date_picker.value)
# print(tdf.columns)
# tdf[(tdf['pit_date']=='2024-09-02') & (tdf['ID']=='TATA IB Equity') ][['ID','PERIOD_END_DATE','pit_date','NET_INCOME']]



## 1. Data Loading

Load raw financial data based on the selected mode and date.

In [7]:
# Load the raw data
def load_data():
    # Get the selected mode and date
    mode, date = get_mode_and_date()
    date_str = date.strftime('%Y-%m-%d')
    
    # Load the configuration for the selected mode
    config = load_mode_config(mode)
    
    # Setup logging
    logger = setup_logging(config)
    logger.info(f"Running pipeline in {mode} mode for date {date_str}")
    
    # Initialize the data processor
    data_processor = DataProcessor(config, logger)
    
    # Load the raw data
    logger.info("Loading raw data...")
    raw_data = data_processor.load_raw_data(date=date_str)
    logger.info(f"Loaded {len(raw_data)} rows of raw data")
    
    return raw_data, data_processor, config, logger

# Execute data loading
raw_data, data_processor, config, logger = load_data()

# Display a sample of the raw data
# display(raw_data.head())
# print(f"Raw data shape: {raw_data.shape}")


2025-03-31 03:37:06,125 - decohere - INFO - Running pipeline in day mode for date 2024-09-02
2025-03-31 03:37:06,127 - decohere - INFO - Loading raw data...
2025-03-31 03:37:06,128 - decohere - INFO - Loading raw data from /home/siddharth.johri/DECOHERE/data/raw/financials/financials_2024_09.pq
2025-03-31 03:37:06,165 - decohere - INFO - Filtering data for date: 2024-09-02
2025-03-31 03:37:06,173 - decohere - INFO - Loaded raw data with shape: (5223, 34)
2025-03-31 03:37:06,174 - decohere - INFO - Loaded 5223 rows of raw data


In [8]:
# raw_data[(raw_data['PIT_DATE']=='2024-09-02') & (raw_data['ID']=='ICICIBC IB Equity') ][['ID','PERIOD_END_DATE','PIT_DATE','NET_INCOME']]

## 2. Data Processing

Clean and transform the raw data.

In [9]:
# Process the raw data
def process_data(raw_data, data_processor, logger):
    """
    Process raw data and return processed data.
    
    Args:
        raw_data: Raw data DataFrame
        data_processor: DataProcessor instance
        logger: Logger instance
        
    Returns:
        Processed data DataFrame
    """
    logger.info("Processing raw data...")
    
    # Get the selected mode and date
    mode, date = get_mode_and_date()
    date_str = date.strftime('%Y-%m-%d')
    
    # Determine date range based on mode
    if mode == 'day':
        start_date = date_str
        end_date = date_str
    elif mode == 'week':
        start_date = date_str
        end_date = (date + timedelta(days=6)).strftime('%Y-%m-%d')
    elif mode == 'year':
        start_date = date_str
        end_date = (date + timedelta(days=364)).strftime('%Y-%m-%d')
    
    logger.info(f"Processing data from {start_date} to {end_date}")
    
    # Process the data with date range
    processed_files = data_processor.process_data(start_date, end_date)
    
    # Load the processed data using the data processor's load method
    processed_data = data_processor.load_processed_data_by_mode(
        mode=mode,
        date=date_str if mode == 'day' else None,
        start_date=start_date if mode in ['week', 'year'] else None,
        end_date=end_date if mode in ['week', 'year'] else None
    )
    
    logger.info(f"Processed data shape: {processed_data.shape}")
    return processed_data

# Execute data processing
processed_data = process_data(raw_data, data_processor, logger)

# Display a sample of the processed data
display(processed_data.head())
print(f"Processed data shape: {processed_data.shape}")


2025-03-31 03:37:06,468 - decohere - INFO - Processing raw data...
2025-03-31 03:37:06,469 - decohere - INFO - Processing data from 2024-09-02 to 2024-09-02
2025-03-31 03:37:06,471 - decohere - INFO - Processing data for date range: 2024-09-02 to 2024-09-02
2025-03-31 03:37:06,472 - decohere - INFO - Processing data for date: 2024-09-02
2025-03-31 03:37:06,472 - decohere - INFO - Loading raw data from /home/siddharth.johri/DECOHERE/data/raw/financials/financials_2024_09.pq
2025-03-31 03:37:06,494 - decohere - INFO - Filtering data for date: 2024-09-02
2025-03-31 03:37:06,502 - decohere - INFO - Loaded raw data with shape: (5223, 34)
2025-03-31 03:37:06,503 - decohere - INFO - Calculating periods for each ticker using COHERE logic
2025-03-31 03:37:06,516 - decohere - INFO - Number of unique fiscal months per ID: fiscal_month
1    457
2     43
Name: count, dtype: int64
2025-03-31 03:37:06,517 - decohere - INFO - Calculating periods by ID using id column: 'ID'
2025-03-31 03:37:08,245 - de

Unnamed: 0,ID,PERIOD_END_DATE,NET_INCOME,NET_INCOME_CSTAT_STD,EBIT,EBIT_CSTAT_STD,EBITDA,EBITDA_CSTAT_STD,SALES,SALES_CSTAT_STD,...,SALES_COEFF_OF_VAR_RATIO_SIGNED_LOG,PE_RATIO_RATIO,PE_RATIO_RATIO_SIGNED_LOG,PREV_PE_RATIO_RATIO,PREV_PE_RATIO_RATIO_SIGNED_LOG,PX_TO_BOOK_RATIO_RATIO,PX_TO_BOOK_RATIO_RATIO_SIGNED_LOG,PREV_PX_TO_BOOK_RATIO_RATIO,PREV_PX_TO_BOOK_RATIO_RATIO_SIGNED_LOG,date
0,360ONE IB Equity,2024-03-31,8042100000.0,1513435000.0,12343300000.0,1863053000.0,12909700000.0,1713381000.0,25070300000.0,5440580000.0,...,0.048624,47.92294,3.890246,,,11.306098,2.510095,,,2024-09-02
1,360ONE IB Equity,2023-03-31,6579300000.0,1513435000.0,11603420000.0,1863053000.0,13705750000.0,1713381000.0,15650000000.0,5440580000.0,...,0.048624,60.009417,4.111028,,,12.396609,2.595002,,,2024-09-02
2,360ONE IB Equity,2022-03-31,5777385000.0,1513435000.0,8938944000.0,1863053000.0,9356376000.0,1713381000.0,18506500000.0,5440580000.0,...,0.048624,294.117528,5.687374,,,12.865364,2.629394,,,2024-09-02
3,360ONE IB Equity,2021-03-31,1561882000.0,1513435000.0,6072078000.0,1863053000.0,6293098000.0,1713381000.0,9706505000.0,5440580000.0,...,0.048624,100.140303,4.616509,,,13.403186,2.667449,,,2024-09-02
4,360ONE IB Equity,2020-03-31,2011639000.0,1513435000.0,11603420000.0,1863053000.0,13705750000.0,1713381000.0,73778600000.0,5440580000.0,...,0.048624,195.783173,5.282102,,,12.283057,2.586489,,,2024-09-02


Processed data shape: (4914, 134)


In [10]:
# processed_data[(processed_data['PIT_DATE']=='2024-09-02') & (processed_data['ID']=='INFO IB Equity') ][['ID','PERIOD','PERIOD_END_DATE','PIT_DATE','NET_INCOME','NET_INCOME_CSTAT_STD']]
#list(processed_data.columns)

In [11]:
# Generate and save feature set (automatically uses PIT_DATE)

processed_data_feat_set = data_processor.processed_data_feat_gen(processed_data, scaling_field='SALES')
#processed_data_feat_set.columns
# # Load feature set (using either YYYY-MM or YYYY-MM-DD format)
# loaded_feature_set = data_processor.load_pre_feature_set('2024-01')  # For monthly data
# # or
# loaded_feature_set = data_processor.load_pre_feature_set('2024-01-01')  # For daily data

2025-03-31 03:37:08,551 - decohere - INFO - Generating feature set from processed data
2025-03-31 03:37:08,555 - decohere - INFO - Generated feature set with shape: (4914, 55)
2025-03-31 03:37:08,556 - decohere - INFO - Number of features: 55
2025-03-31 03:37:08,578 - decohere - INFO - Saved pre-feature set to /home/siddharth.johri/DECOHERE/notebooks/data/processed/pre_feature_data/pre_feature_set_2024-09-02.pq


In [12]:
# df = pd.read_parquet('/home/siddharth.johri/DECOHERE/data/raw/sector/sector.pq')
# df1 = df.groupby(level='ID').first()
# df1[df1.index == 'INFO IN Equity'][['sector_1','sector_2', 'sector_3', 'sector_4']]

# df1.drop(columns=['Partial Errors'], inplace=True)
# df1.to_parquet('/home/siddharth.johri/DECOHERE/data/raw/sector/sector_mappings.pq')


In [13]:
df_input=processed_data_feat_set.copy(deep=True)
sector_file_path = '/home/siddharth.johri/DECOHERE/data/raw/sector/sector_mappings.pq'

In [14]:
import pandas as pd
import numpy as np
from scipy.stats import linregress
from sklearn.preprocessing import OneHotEncoder
import os # For file operations in example
# Optional: Import tqdm for progress bar
try:
    from tqdm.auto import tqdm
except ImportError:
    tqdm = None # Set tqdm to None if not installed

# --- Helper Functions ---

def inv_signed_log(y: float) -> float:
    """Inverse of signed_log transformation: sign(y) * (exp(abs(y)) - 1)."""
    if pd.isna(y):
        return np.nan
    try:
        abs_y = np.abs(np.float64(y))
        exp_val = np.exp(abs_y)
        if np.isinf(exp_val):
            return np.inf * np.sign(y)
        return np.sign(y) * (exp_val - 1)
    except OverflowError:
        return np.inf * np.sign(y)


def robust_slope(series: pd.Series, periods: list) -> tuple[float, float]:
    """
    Calculate slope and R-squared robustly using linear regression.
    Handles NaN, insufficient data (< 2 points), and constant data explicitly.
    Uses the series index (expected to be period numbers) as the independent variable.
    Returns: tuple[float, float]: (slope, r_squared). Returns (0.0, 0.0) for insufficient/constant data or errors.
    """
    if not periods or series.empty:
      return 0.0, 0.0

    valid_indices = series.index.intersection(periods)
    if valid_indices.empty:
        return 0.0, 0.0
    data = series.loc[valid_indices].dropna()

    if len(data) < 2:
        return 0.0, 0.0
    if data.nunique() == 1:
        return 0.0, 0.0 # Slope is 0, R^2 is ill-defined (treat as 0 fit)

    x_values = data.index.astype(float)
    y_values = data.values
    try:
        # Suppress RankWarning which can occur with few points but is handled by checks
        # with warnings.catch_warnings():
        #     warnings.filterwarnings('ignore', category=np.RankWarning)
        slope, _, r_value, p_value, std_err = linregress(x_values, y_values)
        if pd.isna(slope) or pd.isna(r_value):
            return 0.0, 0.0
        r_squared = r_value**2
        return slope, r_squared
    except ValueError as e:
         print(f"Warning: linregress failed unexpectedly for index {data.index.tolist()}: {e}. Returning (0.0, 0.0).")
         return 0.0, 0.0


def calculate_stdev(series: pd.Series, periods: list) -> float:
    """
    Calculate sample standard deviation robustly for specified periods.
    Ignores NaNs. Requires at least 2 data points. Uses ddof=1. Returns 0.0 for constant data.
    Returns: float: Standard deviation, or np.nan if fewer than 2 data points.
    """
    if not periods or series.empty:
        return np.nan

    valid_indices = series.index.intersection(periods)
    if valid_indices.empty:
        return np.nan
    data = series.loc[valid_indices].dropna()

    if len(data) < 2:
        return np.nan
    if data.nunique() == 1:
        return 0.0 # Standard deviation of constant data is 0
    return np.std(data.values, ddof=1)


# --- Group Processing Function ---
def process_group(group: pd.DataFrame, period_range: list, # period_range now includes 0
                  raw_scaled_sales_signed_log_cols: list,
                  ratio_signed_log_cols: list,
                  cstat_std_cols: list) -> dict:
    """
    Processes a single ID/PIT_DATE group to calculate time series features.
    Historical periods (`hist_periods`) include PERIOD <= 0.
    Forward periods (`fwd_periods`) include PERIOD > 0.
    Handles duplicate PERIODs within a group by keeping the first occurrence and warning.
    Returns: dict: A dictionary of calculated features for the group.
    """
    # --- Initial Checks and Setup ---
    if not all(col in group.columns for col in ['ID', 'PIT_DATE', 'PERIOD']):
         return {}
    if group.empty:
        return {}

    try:
        group_id = group['ID'].iloc[0]
        group_pit_date = group['PIT_DATE'].iloc[0]
    except IndexError:
        return {}
    feats = {'ID': group_id, 'PIT_DATE': group_pit_date}

    # --- Handle duplicate PERIODs ---
    if group['PERIOD'].duplicated().any():
        n_dups = group['PERIOD'].duplicated().sum()
        print(f"Warning: Found {n_dups} duplicate PERIOD(s) in group ID {group_id}, PIT {group_pit_date}. Keeping first.")
        group = group.drop_duplicates(subset=['PERIOD'], keep='first').copy()

    # --- Set Index ---
    try:
        group = group.set_index('PERIOD')
        if not group.index.is_unique:
             raise ValueError("Index not unique after dropping duplicates - unexpected.")
        if not group.index.is_monotonic_increasing:
            group = group.sort_index()
    except Exception as e:
        print(f"Error setting index for group ID {group_id}, PIT {group_pit_date}: {e}. Skipping.")
        return {}

    # Define index covering full range and period subsets
    full_range_index = pd.Index(period_range, name='PERIOD') # Includes 0
    # *** MODIFICATION: Historical periods include 0 ***
    hist_periods = [p for p in period_range if p <= 0]
    fwd_periods = [p for p in period_range if p > 0] # Forward remains > 0
    combined_periods = period_range # Use the full range

    numerical_feature_keys = []

    # --- Process Metric Groups ---
    for metric_group, metric_cols in [('scaled', raw_scaled_sales_signed_log_cols),
                                      ('ratio', ratio_signed_log_cols)]:
        for metric in metric_cols:
            if metric not in group.columns:
                continue

            series = group[metric].reindex(full_range_index)

            # Find specific period values (logic for latest negative is important)
            negative_hist_periods = [p for p in hist_periods if p < 0] # Explicitly get negative periods
            valid_negative_hist_indices = series.loc[series.index.intersection(negative_hist_periods)].dropna().index
            valid_fwd_indices = series.loc[series.index.intersection(fwd_periods)].dropna().index

            latest_negative_hist_period = valid_negative_hist_indices.max() if not valid_negative_hist_indices.empty else np.nan
            first_fwd_period = valid_fwd_indices.min() if not valid_fwd_indices.empty else np.nan
            value_period_0 = series.get(0, np.nan) # Get value at period 0

            # Feature: Levels (Rename latest hist for clarity)
            level_latest_neg_hist_key = f'level_latest_neg_hist_{metric}' # Renamed
            level_period_0_key = f'level_period_0_{metric}'
            level_first_fwd_key = f'level_first_fwd_{metric}'
            feats[level_latest_neg_hist_key] = series.get(latest_negative_hist_period, np.nan) # Use specific variable
            feats[level_period_0_key] = value_period_0
            feats[level_first_fwd_key] = series.get(first_fwd_period, np.nan)
            numerical_feature_keys.extend([level_latest_neg_hist_key, level_period_0_key, level_first_fwd_key])

            # Features: Slopes and R-squared (hist slope now includes 0)
            hist_slope, hist_r2 = robust_slope(series, hist_periods) # uses periods <= 0
            fwd_slope, fwd_r2 = robust_slope(series, fwd_periods) # uses periods > 0
            combined_slope, combined_r2 = robust_slope(series, combined_periods) # uses full range

            hist_slope_key = f'{metric_group}_hist_slope_{metric}'
            fwd_slope_key = f'{metric_group}_fwd_slope_{metric}'
            combined_slope_key = f'{metric_group}_combined_slope_{metric}'
            hist_r2_key = f'{metric_group}_hist_r2_{metric}'
            fwd_r2_key = f'{metric_group}_fwd_r2_{metric}'
            combined_r2_key = f'{metric_group}_combined_r2_{metric}'
            feats.update({
                hist_slope_key: hist_slope, fwd_slope_key: fwd_slope, combined_slope_key: combined_slope,
                hist_r2_key: hist_r2, fwd_r2_key: fwd_r2, combined_r2_key: combined_r2
            })
            numerical_feature_keys.extend([
                hist_slope_key, fwd_slope_key, combined_slope_key,
                hist_r2_key, fwd_r2_key, combined_r2_key
            ])

            # Features: Volatility (hist vol now includes 0)
            hist_vol = calculate_stdev(series, hist_periods) # uses periods <= 0
            fwd_vol = calculate_stdev(series, fwd_periods) # uses periods > 0
            combined_vol = calculate_stdev(series, combined_periods) # uses full range

            hist_vol_key = f'{metric_group}_hist_vol_{metric}'
            fwd_vol_key = f'{metric_group}_fwd_vol_{metric}'
            combined_vol_key = f'{metric_group}_combined_vol_{metric}'
            feats[hist_vol_key] = hist_vol
            feats[fwd_vol_key] = fwd_vol
            feats[combined_vol_key] = combined_vol
            numerical_feature_keys.extend([hist_vol_key, fwd_vol_key, combined_vol_key])

            # Features: Normalized Slopes (based on respective slope/vol calculations)
            norm_hist_slope_key = f'{metric_group}_norm_hist_slope_{metric}'
            norm_fwd_slope_key = f'{metric_group}_norm_fwd_slope_{metric}'
            feats[norm_hist_slope_key] = hist_slope / hist_vol if pd.notna(hist_vol) and hist_vol != 0 else np.nan
            feats[norm_fwd_slope_key] = fwd_slope / fwd_vol if pd.notna(fwd_vol) and fwd_vol != 0 else np.nan
            numerical_feature_keys.extend([norm_hist_slope_key, norm_fwd_slope_key])

            # Feature: Slope Divergence (hist slope now includes 0)
            slope_divergence_key = f'{metric_group}_slope_divergence_{metric}'
            feats[slope_divergence_key] = fwd_slope - hist_slope if pd.notna(fwd_slope) and pd.notna(hist_slope) else np.nan
            numerical_feature_keys.append(slope_divergence_key)

            # Feature: Acceleration (Slope of Differences)
            diff_series = series.diff()
            # Hist accel: slope of diffs for periods <= 0 (needs diff at index 0, -1, ...)
            # Valid indices are those in hist_periods except the minimum period value
            min_hist_period = min(hist_periods) if hist_periods else None
            valid_hist_diff_periods = [
                p for p in hist_periods
                if p in diff_series.index and not pd.isna(diff_series.get(p)) and (min_hist_period is None or p > min_hist_period)
            ]
            # Fwd accel: slope of diffs for periods > 0 (needs diff at index 1, 2, ...)
            valid_fwd_diff_periods = [p for p in fwd_periods if p in diff_series.index and not pd.isna(diff_series.get(p))]

            hist_accel, _ = robust_slope(diff_series, valid_hist_diff_periods)
            fwd_accel, _ = robust_slope(diff_series, valid_fwd_diff_periods)
            hist_accel_key = f'{metric_group}_hist_accel_{metric}'
            fwd_accel_key = f'{metric_group}_fwd_accel_{metric}'
            feats[hist_accel_key] = hist_accel
            feats[fwd_accel_key] = fwd_accel
            numerical_feature_keys.extend([hist_accel_key, fwd_accel_key])

    # --- Process Relative Dispersion (Still only for Forward Periods > 0) ---
    for std_col in cstat_std_cols:
        estimate_col = std_col.replace('_CSTAT_STD', '')
        if not all(col in group.columns for col in [estimate_col, std_col]):
            continue

        std_series = group[std_col]
        estimate_series = group[estimate_col]

        # Loop only over positive forward periods for dispersion
        for fwd_period in fwd_periods:
            rel_disp_key = f'rel_disp_{std_col}_period_{fwd_period}'
            numerical_feature_keys.append(rel_disp_key)
            try:
                slog_estimate = estimate_series.get(fwd_period)
                slog_std = std_series.get(fwd_period)

                if pd.isna(slog_estimate) or pd.isna(slog_std):
                    feats[rel_disp_key] = np.nan
                    continue

                actual_estimate = inv_signed_log(slog_estimate)
                actual_stdev = inv_signed_log(slog_std)

                if pd.isna(actual_stdev) or actual_stdev < 0 or not np.isfinite(actual_stdev):
                     feats[rel_disp_key] = np.nan
                     continue
                if pd.isna(actual_estimate) or not np.isfinite(actual_estimate):
                     feats[rel_disp_key] = np.nan
                     continue

                denominator = max(abs(actual_estimate), 1e-9)
                if denominator == 0:
                    feats[rel_disp_key] = np.nan
                    continue
                relative_dispersion_log1p = np.log1p(actual_stdev / denominator)
                feats[rel_disp_key] = relative_dispersion_log1p
            except Exception as e:
                feats[rel_disp_key] = np.nan
                print(f'Error: Relative dispersion calc failed unexpectedly for ID {group_id}, key {rel_disp_key}: {e}')

    # --- Process As-is Ratio Values (Includes PERIOD=0) ---
    for period in period_range: # Loop includes 0
        for ratio_metric in ratio_signed_log_cols:
             as_is_key = f'as_is_{ratio_metric}_period_{period}'
             metric_series = group.get(ratio_metric)
             if metric_series is not None:
                 feats[as_is_key] = metric_series.get(period, np.nan)
             else:
                 feats[as_is_key] = np.nan
             numerical_feature_keys.append(as_is_key)

    feats['_numerical_feature_keys'] = list(set(numerical_feature_keys))
    return feats


# --- Main Orchestration Function ---
def generate_enhanced_features(
    df: pd.DataFrame,
    hist_window: int = 6,
    fwd_window: int = 6,
    target_metric: str = 'PE_RATIO_RATIO_SIGNED_LOG',
    sector_mapping_path: str | None = None,
    sector_levels_to_include: list = ['sector_1'],
    include_sector_features: bool = True
    ) -> pd.DataFrame:
    """
    Generates time series and optional sector features, including PERIOD=0 in hist/combined calculations.
    """
    print("Starting enhanced feature generation...")

    # --- Input Validation and Setup ---
    required_cols = ['ID', 'PIT_DATE', 'PERIOD']
    if not all(col in df.columns for col in required_cols):
        missing = [col for col in required_cols if col not in df.columns]
        raise ValueError(f"Input DataFrame is missing required columns: {missing}")

    if target_metric not in df.columns:
         print(f"Warning: Target metric '{target_metric}' not found in input DataFrame columns.")

    # *** MODIFICATION: Define period range INCLUDING 0 ***
    if hist_window < 0 or fwd_window < 0:
        raise ValueError("hist_window and fwd_window must be non-negative.")
    period_range = list(range(-hist_window, fwd_window + 1)) # Includes 0
    print(f"Using period range: {min(period_range)} to {max(period_range)} (inclusive)")


    # Identify metric columns dynamically
    raw_scaled_sales_signed_log_cols = [col for col in df.columns if '_RAW_SCALED_SALES_SIGNED_LOG' in col]
    ratio_signed_log_cols = [col for col in df.columns if '_RATIO_SIGNED_LOG' in col and col != target_metric]
    cstat_std_cols = [col for col in df.columns if '_CSTAT_STD' in col]
    print(f"Identified metric columns: "
          f"{len(raw_scaled_sales_signed_log_cols)} scaled sales, "
          f"{len(ratio_signed_log_cols)} ratios (excl. target), "
          f"{len(cstat_std_cols)} stdevs.")


    # --- Core Feature Generation (Group Apply) ---
    global_duplicates = df.duplicated(subset=['ID', 'PIT_DATE', 'PERIOD']).sum()
    if global_duplicates > 0:
        print(f"CRITICAL WARNING: Found {global_duplicates} duplicate rows in input based on (ID, PIT_DATE, PERIOD). Using first occurrence per group.")

    grouped = df.groupby(['ID', 'PIT_DATE'], observed=True, sort=False)
    n_groups = grouped.ngroups
    print(f"Processing {n_groups} ID/PIT_DATE groups...")

    features_list = []
    group_iterator = tqdm(grouped, total=n_groups, desc="Processing groups") if tqdm else grouped

    for group_key, group_data in group_iterator:
        group_result = process_group(group_data.copy(), period_range, # Pass the new period_range
                                     raw_scaled_sales_signed_log_cols,
                                     ratio_signed_log_cols,
                                     cstat_std_cols)
        if group_result:
            features_list.append(group_result)

    if not tqdm: print(f"Finished processing {n_groups} groups.")

    if not features_list:
        print("Warning: No features generated. Returning empty DataFrame.")
        expected_cols = ['ID', 'PIT_DATE'] + ([target_metric] if target_metric in df.columns else [])
        return pd.DataFrame(columns=expected_cols)

    features_df = pd.DataFrame(features_list)

    all_numerical_keys = set()
    if '_numerical_feature_keys' in features_df.columns:
        for keys_list in features_df['_numerical_feature_keys'].dropna():
            if isinstance(keys_list, list):
                 all_numerical_keys.update(keys_list)
        features_df = features_df.drop(columns=['_numerical_feature_keys'])
    numerical_feature_cols = [key for key in all_numerical_keys if key in features_df.columns]
    print(f"Identified {len(numerical_feature_cols)} potential numerical feature columns generated.")


    # --- Optional Sector Feature Integration (No changes needed here) ---
    ohe_feature_names = []
    if include_sector_features and sector_mapping_path:
        try:
            print(f"Loading sector mappings from: {sector_mapping_path}")
            if not os.path.exists(sector_mapping_path):
                raise FileNotFoundError(f"Sector mapping file not found at {sector_mapping_path}")

            sector_df = pd.read_parquet(sector_mapping_path)
            valid_sector_levels = [col for col in sector_levels_to_include if col in sector_df.columns]
            if not valid_sector_levels:
                 print(f"Warning: None of specified sector levels found in {sector_mapping_path}. Skipping.")
            else:
                cols_to_merge = ['ID'] + valid_sector_levels
                sector_df = sector_df[cols_to_merge].drop_duplicates(subset=['ID'], keep='first')
                print(f"Merging sector features for levels: {valid_sector_levels}")
                original_feature_rows = len(features_df)
                try:
                    id_dtype_feat = features_df['ID'].dtype
                    id_dtype_sect = sector_df['ID'].dtype
                    if id_dtype_feat != id_dtype_sect:
                         features_df['ID'] = features_df['ID'].astype(str)
                         sector_df['ID'] = sector_df['ID'].astype(str)
                    features_df = features_df.merge(sector_df, on='ID', how='left', validate='m:1')
                    if len(features_df) != original_feature_rows:
                         print(f"Warning: Row count changed during sector merge.")
                except Exception as merge_err:
                     print(f"Error merging sectors: {merge_err}. Skipping.")
                     valid_sector_levels = []

                sector_cols_in_features = [col for col in valid_sector_levels if col in features_df.columns]
                if sector_cols_in_features:
                    fill_value = 'Missing_Sector'
                    features_df[sector_cols_in_features] = features_df[sector_cols_in_features].fillna(fill_value)
                    print(f"Applying OneHotEncoding to: {sector_cols_in_features}")
                    ohe = OneHotEncoder(sparse_output=False, handle_unknown='ignore', dtype=np.uint8)
                    encoded_sectors = ohe.fit_transform(features_df[sector_cols_in_features])
                    ohe_feature_names = ohe.get_feature_names_out(sector_cols_in_features).tolist()
                    encoded_sectors_df = pd.DataFrame(encoded_sectors, columns=ohe_feature_names, index=features_df.index)
                    features_df = features_df.drop(columns=sector_cols_in_features)
                    features_df = pd.concat([features_df, encoded_sectors_df], axis=1)
                    print(f"Added {len(ohe_feature_names)} OHE sector features (unranked).")
        except FileNotFoundError as e:
            print(f"Warning: {e}. Skipping sector features.")
        except Exception as e:
            print(f"Warning: Error processing sector features: {e}. Skipping.")
            ohe_feature_names = []
    elif include_sector_features and not sector_mapping_path:
         print("Info: Sector features requested but no path provided. Skipping.")


    # --- Ranking (Only Numerical Features - No changes needed here) ---
    processed_features_df = features_df.copy()
    numerical_cols_to_rank_candidates = [col for col in numerical_feature_cols if col in processed_features_df.columns]
    if not numerical_cols_to_rank_candidates:
         print("Warning: No numerical feature columns found to rank.")
    else:
        non_numeric_cols = processed_features_df[numerical_cols_to_rank_candidates].select_dtypes(exclude=[np.number]).columns
        if non_numeric_cols.any():
             print(f"Warning: Non-numeric columns found among numerical candidates for ranking: {non_numeric_cols.tolist()}. Excluding.")
             numerical_cols_to_rank = [col for col in numerical_cols_to_rank_candidates if col not in non_numeric_cols]
        else:
             numerical_cols_to_rank = numerical_cols_to_rank_candidates

        if not numerical_cols_to_rank:
             print("Skipping ranking: No valid numeric columns remain.")
        else:
            print(f"Ranking {len(numerical_cols_to_rank)} numerical features cross-sectionally (by PIT_DATE)...")
            try:
                ranked_data = processed_features_df.groupby('PIT_DATE')[numerical_cols_to_rank].transform(lambda x: x.rank(pct=True))
                rename_dict = {col: f'rank_{col}' for col in numerical_cols_to_rank}
                ranked_data = ranked_data.rename(columns=rename_dict)
                processed_features_df = processed_features_df.drop(columns=numerical_cols_to_rank)
                processed_features_df = pd.concat([processed_features_df, ranked_data], axis=1)
                print("Numerical feature ranking complete.")
            except Exception as e:
                print(f"Error during ranking: {e}. Proceeding with unranked features.")


    # --- Merge Target Variable (No changes needed here) ---
    print(f"Merging target variable: {target_metric}")
    if target_metric not in df.columns:
         print(f"Warning: Target metric '{target_metric}' not in original DataFrame. Cannot merge target.")
         if target_metric not in processed_features_df.columns:
              processed_features_df[target_metric] = np.nan
         final_df = processed_features_df
    else:
        target_df = df[df['PERIOD'] == 1][['ID', 'PIT_DATE', target_metric]].drop_duplicates(subset=['ID', 'PIT_DATE'], keep='first')
        if target_df.empty:
            print(f"Warning: No data for PERIOD=1 to extract target '{target_metric}'. Target column will be all NaNs.")
            target_df_placeholder = processed_features_df[['ID', 'PIT_DATE']].drop_duplicates()
            target_df_placeholder[target_metric] = np.nan
            try:
                final_df = processed_features_df.merge(target_df_placeholder, on=['ID', 'PIT_DATE'], how='left', validate='m:1')
            except Exception as merge_err:
                print(f"Error merging placeholder target: {merge_err}.")
                if target_metric not in processed_features_df.columns: processed_features_df[target_metric] = np.nan
                final_df = processed_features_df
        else:
            try:
                id_dtype_feat = processed_features_df['ID'].dtype
                id_dtype_target = target_df['ID'].dtype
                if id_dtype_feat != id_dtype_target:
                    processed_features_df['ID'] = processed_features_df['ID'].astype(str)
                    target_df['ID'] = target_df['ID'].astype(str)
                pit_dtype_feat = processed_features_df['PIT_DATE'].dtype
                pit_dtype_target = target_df['PIT_DATE'].dtype
                is_dt_feat = pd.api.types.is_datetime64_any_dtype(pit_dtype_feat)
                is_dt_target = pd.api.types.is_datetime64_any_dtype(pit_dtype_target)
                if is_dt_feat != is_dt_target or (is_dt_feat and pit_dtype_feat != pit_dtype_target):
                    try:
                        processed_features_df['PIT_DATE'] = pd.to_datetime(processed_features_df['PIT_DATE'])
                        target_df['PIT_DATE'] = pd.to_datetime(target_df['PIT_DATE'])
                    except Exception as date_err:
                        print(f"Error converting PIT_DATE types: {date_err}.")
                final_df = processed_features_df.merge(target_df, on=['ID', 'PIT_DATE'], how='left', validate='m:1')
            except Exception as merge_err:
                print(f"Error merging target '{target_metric}': {merge_err}.")
                if target_metric not in processed_features_df.columns: processed_features_df[target_metric] = np.nan
                final_df = processed_features_df

    print("Feature generation pipeline complete.")

    # --- Final Checks and Cleanup (No changes needed here) ---
    if final_df.empty and not features_df.empty:
        print("Critical Warning: Final DataFrame is empty after merging target.")
        return final_df

    if target_metric in final_df.columns:
        missing_target_fraction = final_df[target_metric].isnull().mean()
        if missing_target_fraction == 1.0: print(f"Warning: Target '{target_metric}' is ALL missing.")
        elif missing_target_fraction > 0: print(f"Warning: Target '{target_metric}' has {missing_target_fraction:.1%} missing values.")
    else:
        print(f"Warning: Target metric '{target_metric}' column not present in final DataFrame.")

    id_cols = ['ID', 'PIT_DATE']
    feature_cols = [col for col in final_df.columns if col not in id_cols and col != target_metric]
    final_cols_order = id_cols + sorted(feature_cols)
    if target_metric in final_df.columns: final_cols_order.append(target_metric)
    final_cols_order_exist = [col for col in final_cols_order if col in final_df.columns]
    try:
        final_df = final_df[final_cols_order_exist]
    except KeyError as e:
        print(f"Warning: Could not reorder columns: {e}")

    print(f"Final DataFrame shape: {final_df.shape}")
    return final_df


# --- Example Usage ---
if __name__ == '__main__':

    # --- Create Sample Data (Now includes PERIOD 0) ---
    # print("\n--- Setting up sample data (including duplicates and PERIOD 0) ---")
    # ids = ['A01'] * 8 + ['B02'] * 7 + ['C03'] * 6 # A01 has duplicate period 1
    # pit_dates_list = pd.to_datetime(['2023-01-31'] * 8 + ['2023-01-31'] * 7 + ['2023-01-31'] * 6).tolist()
    # # Sequence now includes 0
    # periods_list = [-2, -1, 0, 1, 1, 2, 3, 4] + [-2, -1, 0, 1, 2, 3, 4] + [-2, -1, 0, 1, 2, 3] # C03 missing period 4

    # df_input = pd.DataFrame({
    #     'ID': ids,
    #     'PIT_DATE': pd.to_datetime(pit_dates_list),
    #     'PERIOD': periods_list,
    #     'SALES_RAW_SCALED_SALES_SIGNED_LOG': np.random.randn(len(ids)) * 2,
    #     'OTHER_RATIO_SIGNED_LOG': np.random.rand(len(ids)) - 0.5,
    #     'PE_RATIO_RATIO_SIGNED_LOG': np.random.randn(len(ids)), # Target column name
    #     'PE_RATIO_CSTAT_STD': np.abs(np.random.randn(len(ids)) * 0.1) + 0.05
    # })

    # # Assign specific target values at PERIOD=1 (affects first occurrence if duplicates)
    # df_input.loc[(df_input['ID'] == 'A01') & (df_input['PERIOD'] == 1), 'PE_RATIO_RATIO_SIGNED_LOG'] = 10.5
    # df_input.loc[(df_input['ID'] == 'B02') & (df_input['PERIOD'] == 1), 'PE_RATIO_RATIO_SIGNED_LOG'] = 15.0
    # df_input.loc[(df_input['ID'] == 'C03') & (df_input['PERIOD'] == 1), 'PE_RATIO_RATIO_SIGNED_LOG'] = 8.2

    # # --- Create dummy sector mapping file ---
    # sector_file_path = './dummy_sector_mappings.pq'
    # sector_data = pd.DataFrame({
    #     'ID': ['A01', 'B02', 'C03', 'D04'],
    #     'sector_1': ['Tech', 'Finance', 'Tech', 'Retail'],
    #     'sector_2': ['Hardware', 'Banking', 'Software', 'Apparel']
    # })
    # try:
    #     sector_data.to_parquet(sector_file_path)
    #     print(f"Created dummy sector file: {sector_file_path}")
    # except Exception as e:
    #     print(f"Error creating dummy sector file: {e}")

    # --- Check for global duplicates in sample data ---
    print("\n--- Checking sample data for duplicates ---")
    key_cols_check = ['ID', 'PIT_DATE', 'PERIOD']
    global_dup_check = df_input.duplicated(subset=key_cols_check).sum()
    print(f"Checking sample data for global duplicates on {key_cols_check}: Found {global_dup_check}") # Should find 1

    # --- Define target metric name for the example run ---
    example_target_metric = 'PE_RATIO_RATIO_SIGNED_LOG' # Define variable here

    # --- Run Feature Generation (With Sectors) ---
    print("\n--- Running feature generation WITH sector features ---")
    try:
        final_features_with_sectors = generate_enhanced_features(
            df_input.copy(),
            hist_window=2,   # Max negative period is -2
            fwd_window=4,    # Max positive period is 4
            target_metric=example_target_metric,
            sector_mapping_path=sector_file_path,
            sector_levels_to_include=['sector_1', 'sector_2'], # Use existing levels
            include_sector_features=True
        )
    except Exception as e:
        print(f"\n--- ERROR during feature generation: {e} ---")
        final_features_with_sectors = pd.DataFrame()


    # --- Analyze Results (With Sectors) ---
    print("\n--- Final DataFrame Schema (with sectors): ---")
    if not final_features_with_sectors.empty:
        final_features_with_sectors.info(verbose=False, memory_usage='deep')
        print("\n--- Final DataFrame Head (with sectors): ---")
        print(final_features_with_sectors.head())

        # --- Assertions for Verification ---
        print("\n--- Verifying column categorization (with sectors): ---")
        all_cols = final_features_with_sectors.columns.tolist()
        id_target_cols = ['ID', 'PIT_DATE', example_target_metric]
        id_target_cols_present = [c for c in id_target_cols if c in all_cols]
        ohe_cols = [col for col in all_cols if col.startswith('sector_')]
        ranked_cols = [col for col in all_cols if col.startswith('rank_')]
        # Find cols that don't fit known categories
        other_cols = [c for c in all_cols if c not in id_target_cols_present and c not in ohe_cols and c not in ranked_cols]

        print(f"Total columns: {len(all_cols)}")
        print(f"ID/Target columns found: {len(id_target_cols_present)} ({id_target_cols_present})")
        print(f"OHE columns found: {len(ohe_cols)}")
        print(f"Ranked columns found: {len(ranked_cols)}")
        print(f"Other columns (expected unranked numerical): {len(other_cols)}")

        # Basic check: Ensure 'Other' columns don't start with prefixes of known categories
        categories_ok = True
        for col in other_cols:
            if col.startswith('rank_') or col.startswith('sector_'):
                print(f"Assertion Error: Column '{col}' is not ranked/OHE but starts with prefix.")
                categories_ok = False
            # Check if it looks like a generated feature that *should* have been ranked
            # This check depends heavily on naming conventions used in process_group
            if ('slope' in col or 'vol' in col or 'level' in col or 'r2' in col or 'accel' in col or 'rel_disp' in col or 'as_is' in col):
                 # It's likely a numerical feature, it should ideally be ranked unless ranking failed
                 pass # Allow unranked numerical if ranking failed, but ideally check warnings
            # else: # If strict checking desired:
            #     print(f"Assertion Warning: Uncategorized column '{col}' found.")
            #     categories_ok = False

        if categories_ok:
             # Verify no overlap
            all_categorized = set(id_target_cols_present) | set(ohe_cols) | set(ranked_cols) | set(other_cols)
            if len(all_categorized) != len(all_cols):
                 print("Assertion Error: Column overlap detected or uncategorized columns missed.")
                 categories_ok = False
            else:
                print("Assertions passed: Column categories seem consistent.")
        else:
            print("Assertions failed: Check column naming and categorization.")

    else:
        print("Skipping checks as final DataFrame is empty.")

    # --- Clean up dummy file ---
    # try:
    #     if sector_file_path and os.path.exists(sector_file_path):
    #         #  os.remove(sector_file_path)
    #         #  print(f"\nRemoved dummy sector file: {sector_file_path}")
    # except OSError as e:
    #     # print(f"Error removing dummy file '{sector_file_path}': {e}")


--- Checking sample data for duplicates ---
Checking sample data for global duplicates on ['ID', 'PIT_DATE', 'PERIOD']: Found 0

--- Running feature generation WITH sector features ---
Starting enhanced feature generation...
Using period range: -2 to 4 (inclusive)
Identified metric columns: 12 scaled sales, 18 ratios (excl. target), 9 stdevs.
Processing 500 ID/PIT_DATE groups...


Processing groups:   0%|          | 0/500 [00:00<?, ?it/s]

Identified 664 potential numerical feature columns generated.
Loading sector mappings from: /home/siddharth.johri/DECOHERE/data/raw/sector/sector_mappings.pq
Merging sector features for levels: ['sector_1', 'sector_2']
Applying OneHotEncoding to: ['sector_1', 'sector_2']
Added 33 OHE sector features (unranked).
Ranking 664 numerical features cross-sectionally (by PIT_DATE)...
Numerical feature ranking complete.
Merging target variable: PE_RATIO_RATIO_SIGNED_LOG
Feature generation pipeline complete.
Final DataFrame shape: (500, 700)

--- Final DataFrame Schema (with sectors): ---
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 500 entries, 0 to 499
Columns: 700 entries, ID to PE_RATIO_RATIO_SIGNED_LOG
dtypes: datetime64[ns](1), float64(665), object(1), uint8(33)
memory usage: 2.6 MB

--- Final DataFrame Head (with sectors): ---
                 ID   PIT_DATE  \
0  360ONE IB Equity 2024-09-02   
1      3M IB Equity 2024-09-02   
2    AACL IB Equity 2024-09-02   
3   AAVAS IB Equity 202

In [21]:
final_features_with_sectors.head()

Unnamed: 0,ID,PIT_DATE,rank_as_is_ASSET_TURNOVER_RATIO_SIGNED_LOG_period_-1,rank_as_is_ASSET_TURNOVER_RATIO_SIGNED_LOG_period_-2,rank_as_is_ASSET_TURNOVER_RATIO_SIGNED_LOG_period_0,rank_as_is_ASSET_TURNOVER_RATIO_SIGNED_LOG_period_1,rank_as_is_ASSET_TURNOVER_RATIO_SIGNED_LOG_period_2,rank_as_is_ASSET_TURNOVER_RATIO_SIGNED_LOG_period_3,rank_as_is_ASSET_TURNOVER_RATIO_SIGNED_LOG_period_4,rank_as_is_CURRENT_RATIO_RATIO_SIGNED_LOG_period_-1,...,sector_2_Oil & Gas,sector_2_Real Estate,sector_2_Renewable Energy,sector_2_Retail & Wholesale - Staples,sector_2_Retail & Whsle - Discretionary,sector_2_Software & Tech Services,sector_2_Tech Hardware & Semiconductors,sector_2_Telecommunications,sector_2_Utilities,PE_RATIO_RATIO_SIGNED_LOG
0,360ONE IB Equity,2024-09-02,0.501002,0.501006,0.501,0.513304,0.493333,0.483491,0.504335,0.551102,...,0,0,0,0,0,0,0,0,0,3.714934
1,3M IB Equity,2024-09-02,0.501002,0.501006,0.501,0.513304,0.493333,,,0.8998,...,0,0,0,0,0,0,0,0,0,4.113432
2,AACL IB Equity,2024-09-02,0.501002,0.501006,0.501,0.513304,0.493333,0.483491,,0.717435,...,0,0,0,0,0,0,0,0,0,3.841716
3,AAVAS IB Equity,2024-09-02,0.501002,0.501006,0.501,0.022173,0.493333,0.483491,0.504335,0.551102,...,0,0,0,0,0,0,0,0,0,3.165248
4,ABB IB Equity,2024-09-02,0.501002,0.501006,0.501,0.513304,0.493333,0.483491,0.504335,0.633267,...,0,0,0,0,0,0,0,0,0,4.512759


## 3. Feature Generation

Generate features for machine learning.

In [15]:
#final_features_with_sectors.to_parquet('/home/siddharth.johri/DECOHERE/data/features/fundamental/test.pq')#['PE_RATIO_RATIO_SIGNED_LOG'].hist(bins=100)

## 4. Feature Selection

Select important features using SHAP.

In [16]:
# Select features
# def select_features(features, config, logger):
#     logger.info("Selecting features...")
#     feature_selector = FeatureSelector(config, logger)
    
#     # Get target variables from config
#     target_vars = config['features']['targets']
    
#     # Prepare X and y
#     X = features.drop(columns=target_vars)
#     y = features[target_vars[0]]  # Use the first target variable
    
#     # Select features
#     selected_features = feature_selector.select_features(X, y)
#     logger.info(f"Selected {len(selected_features)} features")
    
#     # Create dataset with selected features and target
#     selected_data = features[selected_features + target_vars]
    
#     return selected_data, selected_features, feature_selector

# # Execute feature selection
# selected_data, selected_features, feature_selector = select_features(features, config, logger)

# # Display the selected features
# print("Selected features:")
# print(selected_features)
# print(f"\nSelected data shape: {selected_data.shape}")
# display(selected_data.head())


## 5. Model Training and Evaluation

Train and evaluate regression models.

In [17]:
# Train and evaluate models
# def train_and_evaluate_model(selected_data, config, logger):
#     logger.info("Training and evaluating model...")
#     model_trainer = ModelTrainer(config, logger)
    
#     # Get target variables from config
#     target_vars = config['features']['targets']
    
#     # Prepare X and y
#     X = selected_data.drop(columns=target_vars)
#     y = selected_data[target_vars[0]]  # Use the first target variable
    
#     # Train and evaluate model
#     model, metrics = model_trainer.train_and_evaluate(X, y)
    
#     return model, metrics, model_trainer

# # Execute model training and evaluation
# model, metrics, model_trainer = train_and_evaluate_model(selected_data, config, logger)

# # Display model metrics
# print("Model Metrics:")
# for metric_name, metric_value in metrics.items():
#     print(f"{metric_name}: {metric_value:.4f}")


## 6. Visualization

Visualize the results.

In [18]:
# Visualize results
# def visualize_results(model, selected_data, selected_features, feature_selector, config, logger):
#     logger.info("Visualizing results...")
#     visualizer = Visualizer(config, logger)
    
#     # Get target variables from config
#     target_vars = config['features']['targets']
    
#     # Prepare X and y
#     X = selected_data.drop(columns=target_vars)
#     y = selected_data[target_vars[0]]  # Use the first target variable
    
#     # Plot feature importance
#     plt.figure(figsize=(12, 8))
#     visualizer.plot_feature_importance(model, X.columns)
#     plt.tight_layout()
#     plt.show()
    
#     # Plot SHAP values
#     plt.figure(figsize=(12, 8))
#     visualizer.plot_shap_values(model, X)
#     plt.tight_layout()
#     plt.show()
    
#     # Plot actual vs predicted values
#     plt.figure(figsize=(10, 6))
#     y_pred = model.predict(X)
#     visualizer.plot_actual_vs_predicted(y, y_pred)
#     plt.tight_layout()
#     plt.show()
    
#     return visualizer

# # Execute visualization
# visualizer = visualize_results(model, selected_data, selected_features, feature_selector, config, logger)


## 7. Save Results

Save the processed data, features, model, and results.

In [19]:
# Save results
# def save_results(processed_data, features, selected_data, model, metrics, config, logger):
#     logger.info("Saving results...")
    
#     # Get the selected mode and date
#     mode, date = get_mode_and_date()
#     date_str = date.strftime('%Y-%m-%d')
    
#     # Create output directory
#     output_dir = os.path.join('..', 'data', 'results', mode, date_str)
#     os.makedirs(output_dir, exist_ok=True)
    
#     # Save processed data
#     processed_data_path = os.path.join(output_dir, 'processed_data.csv')
#     processed_data.to_csv(processed_data_path, index=False)
#     logger.info(f"Saved processed data to {processed_data_path}")
    
#     # Save features
#     features_path = os.path.join(output_dir, 'features.csv')
#     features.to_csv(features_path, index=False)
#     logger.info(f"Saved features to {features_path}")
    
#     # Save selected data
#     selected_data_path = os.path.join(output_dir, 'selected_data.csv')
#     selected_data.to_csv(selected_data_path, index=False)
#     logger.info(f"Saved selected data to {selected_data_path}")
    
#     # Save model
#     import joblib
#     model_path = os.path.join(output_dir, 'model.joblib')
#     joblib.dump(model, model_path)
#     logger.info(f"Saved model to {model_path}")
    
#     # Save metrics
#     import json
#     metrics_path = os.path.join(output_dir, 'metrics.json')
#     with open(metrics_path, 'w') as f:
#         json.dump(metrics, f, indent=4)
#     logger.info(f"Saved metrics to {metrics_path}")
    
#     return output_dir

# # Execute saving results
# output_dir = save_results(processed_data, features, selected_data, model, metrics, config, logger)
# print(f"Results saved to {output_dir}")


## 8. Summary

Display a summary of the pipeline run.

In [20]:
# # Display summary
# def display_summary(raw_data, processed_data, features, selected_data, selected_features, metrics, config):
#     # Get the selected mode and date
#     mode, date = get_mode_and_date()
#     date_str = date.strftime('%Y-%m-%d')
    
#     print("\n" + "=" * 80)
#     print(f"DECOHERE Pipeline Summary - {mode.upper()} MODE - {date_str}")
#     print("=" * 80)
    
#     print("\nData Processing:")
#     print(f"  Raw data shape: {raw_data.shape}")
#     print(f"  Processed data shape: {processed_data.shape}")
    
#     print("\nFeature Engineering:")
#     print(f"  Total features generated: {features.shape[1] - len(config['features']['targets'])}")
#     print(f"  Selected features: {len(selected_features)}")
    
#     print("\nModel Performance:")
#     for metric_name, metric_value in metrics.items():
#         print(f"  {metric_name}: {metric_value:.4f}")
    
#     print("\nTop 10 Important Features:")
#     for i, feature in enumerate(selected_features[:10]):
#         print(f"  {i+1}. {feature}")
    
#     print("\nResults saved to:")
#     print(f"  {os.path.abspath(output_dir)}")
    
#     print("\n" + "=" * 80)

# # Execute summary display
# display_summary(raw_data, processed_data, features, selected_data, selected_features, metrics, config)
