In [None]:
"""
Module: biogas_pipeline
Description: Processes sensor data from biogas installations by performing data quality checks,
preprocessing sensor readings (flow rate and methane concentration), and merging additional metadata.
Handles S3 data reads and writes with retry logic.
"""

import json
import logging
import time
from typing import Any, Dict, Optional

import boto3
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import awswrangler as wr
from botocore.exceptions import ClientError, EndpointConnectionError
from IPython.display import Markdown, display
from tqdm import tqdm

# Setup logging configuration
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


class BiogasPreprocessor:
    """
    Class to preprocess biogas sensor data.

    Attributes:
        methane_col (str): Column name for methane concentration.
        flow_col (str): Column name for flow rate.
        timestamp_col (str): Column name for the timestamp.
        max_methane (float): Maximum allowed methane percentage.
        max_flow (float): Maximum allowed flow rate.
        rolling_window (int): Window size for rolling median smoothing.
    """

    def __init__(
        self,
        methane_col: str = 'methane_percent',
        flow_col: str = 'flow_rate',
        timestamp_col: str = 'timestamp',
        max_methane: float = 100,
        max_flow: float = 500,
        rolling_window: int = 5
    ) -> None:
        self.methane_col = methane_col
        self.flow_col = flow_col
        self.timestamp_col = timestamp_col
        self.max_methane = max_methane
        self.max_flow = max_flow
        self.rolling_window = rolling_window

    def preprocess(self, df: pd.DataFrame) -> pd.DataFrame:
        """
        Preprocess the input dataframe:
          - Convert timestamp column to datetime and sort the values.
          - Compute duration between sensor readings.
          - Clip and fill missing values for methane and flow rate.
          - Apply rolling median smoothing.
          - Calculate energy output in BTU.

        Args:
            df (pd.DataFrame): Input sensor data.

        Returns:
            pd.DataFrame: Processed sensor data.
        """
        df = df.copy()

        # Convert and sort timestamps
        df[self.timestamp_col] = pd.to_datetime(df[self.timestamp_col], errors='coerce', utc=True)
        df = df.sort_values(self.timestamp_col)
        print("#@"*20 + "Timestamp col is:")
        print(df[self.timestamp_col].dtypes)

        # Calculate duration between readings (in minutes)
        df['duration_min'] = df[self.timestamp_col].diff().dt.total_seconds() / 60
        df['duration_min'] = df['duration_min'].fillna(1)

        # Clip sensor readings to valid ranges
        df['methane_percent'] = df[self.methane_col].clip(lower=0, upper=self.max_methane)
        df['flow_rate'] = df[self.flow_col].clip(lower=0, upper=self.max_flow)

        # Fill missing sensor values using forward and backward fill methods
        df['methane_percent'] = df['methane_percent'].fillna(method='ffill').fillna(method='bfill')
        df['flow_rate'] = df['flow_rate'].fillna(method='ffill').fillna(method='bfill')

        # Apply rolling median smoothing for sensor readings
        df['methane_smooth'] = (
            df['methane_percent']
            .rolling(window=self.rolling_window, center=True)
            .median()
            .fillna(method='bfill')
            .fillna(method='ffill')
        )
        df['flow_smooth'] = (
            df['flow_rate']
            .rolling(window=self.rolling_window, center=True)
            .median()
            .fillna(method='bfill')
            .fillna(method='ffill')
        )

        # Compute the energy output in BTU (using a constant conversion factor 1010)
        df['energy_output_btu'] = (
            df['flow_smooth'] *
            df['duration_min'] *
            (df['methane_smooth'] / 100) *
            1010
        )

        return df


def retry_s3_operation(
    func: Any, retries: int = 3, delay: int = 5, *args, **kwargs
) -> Optional[Any]:
    """
    Attempts an S3 operation with retry logic.

    Args:
        func (callable): The function to be retried.
        retries (int): Number of retry attempts.
        delay (int): Delay (in seconds) between retry attempts.
        *args: Additional positional arguments for the function.
        **kwargs: Additional keyword arguments for the function.

    Returns:
        The result of the function if successful, otherwise None.
    """
    for attempt in range(retries):
        try:
            if len(args) == 0:
                return func
            else:
                logger.info("Reading a json file")
        except (ClientError, EndpointConnectionError) as e:
            logger.warning(f"Attempt {attempt + 1} failed: {e}")
            time.sleep(delay)
    logger.error("All retry attempts failed.")
    return None


def run_data_quality_checks(
    df: pd.DataFrame, timestamp_col: str = 'timestamp'
) -> Dict[str, Any]:
    """
    Run basic data quality checks on the sensor data and display the findings.

    Args:
        df (pd.DataFrame): The sensor data.
        timestamp_col (str): The name of the timestamp column.

    Returns:
        Dict[str, Any]: A report dictionary containing shape, missing values, low variance columns,
                        duplicate timestamp counts, timestamp gaps, and outlier counts.
    """
    display(Markdown("### 🔍 Data Quality Report"))
    report: Dict[str, Any] = {}

    # ------------------------------
    # Exploratory Data Analysis (EDA)
    # ------------------------------

    # 1. Basic Information: shape, first rows, and data types.
    # Report dataframe shape
    report['shape'] = df.shape
    display(Markdown(f"**Data shape:** {df.shape}"))

    # report['head'] = df.head().to_markdown(index=False)
    # display(Markdown(f"**Data Head Rows:** {df.head()}"))

    # Provides details on column types and non-null counts
    # report['info'] = df.info()
    # display(Markdown(f"**DataFrame Info:** {df.info()}"))

    # # 2. Descriptive statistics for numeric columns.
    # report['descriptive_stats'] = df.describe()
    # display(Markdown(f"**Descriptive Statistics::** {df.describe()}"))

    # 3. Check for missing values in each column.
    report['missing values by column'] = df.isnull().sum()
    # display(Markdown(f"**Missing Values by Column:::** {df.isnull().sum()}"))

    all_null_report = {col: df[col].isnull().all() for col in df.columns}
    true_keys = [(k, v) for k, v in all_null_report.items() if v is True]
    display(Markdown(f"**Missing all rows by Column:::** {true_keys}"))

    # # 4. Histogram plots for numeric columns.
    numeric_cols = df.select_dtypes(include=['number']).columns
    # for col in numeric_cols:
    #     plt.figure()
    #     df[col].hist()
    #     plt.title(f'Histogram of {col}')
    #     plt.xlabel(col)
    #     plt.ylabel('Frequency')
    #     plt.show()

    # # 5. Value counts for categorical (object type) columns.
    # categorical_cols = df.select_dtypes(include=['object']).columns
    # print("\nValue Counts for Categorical Columns:")
    # for col in categorical_cols:
    #     print(f"\nValue counts for column '{col}':")
    #     print(df[col].value_counts())

    # 6. Optional: Plot a correlation matrix if there are multiple numeric variables.
    # if len(numeric_cols) > 1:
    #     plt.figure()
    #     corr_matrix = df.corr()
    #     # Display the correlation matrix using matshow.
    #     plt.matshow(corr_matrix, fignum=1)
    #     plt.title('Correlation Matrix', pad=20)
    #     plt.colorbar()
    #     plt.xticks(range(len(corr_matrix.columns)), corr_matrix.columns, rotation=90)
    #     plt.yticks(range(len(corr_matrix.columns)), corr_matrix.columns)
    #     plt.show()

    #  7. Calculate missing value percentage for each column
    missing = df.isnull().sum()
    missing_percent = (missing / len(df)) * 100
    missing_report = missing_percent[missing_percent > 0].sort_values(ascending=False)
    display(Markdown("**Missing Values (%):**"))
    display(missing_report)
    report['missing_values'] = missing_report.to_dict()

    # 8. Identify columns with low variance (0 or 1 unique value) - Flat sensors are uninformative and can be dropped.
    low_var_cols = df.loc[:, df.nunique() <= 1].columns.tolist()
    display(Markdown("**Low Variance Columns:**"))
    display(low_var_cols)
    report['low_variance_columns'] = low_var_cols

    # 9. Timestamp related checks
    if timestamp_col in df.columns:
        duplicates = df.duplicated(subset=timestamp_col).sum()
        display(Markdown(f"**Duplicate Timestamps:** {duplicates}"))
        report['duplicate_timestamps'] = int(duplicates)

        df[timestamp_col] = pd.to_datetime(df[timestamp_col])
        time_deltas = df[timestamp_col].diff().dt.total_seconds()
        display(Markdown("**Timestamp Gaps (in seconds):**"))
        display(time_deltas.describe())
        report['timestamp_gaps'] = time_deltas.describe().to_dict()

    # 10. Out-of-range sensor checks (e.g., negative values)
    outliers: Dict[str, int] = {}
    if 'methane_percent' in df.columns:
        neg_ch4 = (df['methane_percent'] < 0).sum()
        outliers['negative_ch4'] = int(neg_ch4)
    if 'flow_rate' in df.columns:
        neg_flow = (df['flow_rate'] < 0).sum()
        outliers['negative_flow'] = int(neg_flow)
    display(Markdown("**Out-of-Range Sensor Checks:**"))
    display(outliers)
    report['out_of_range'] = outliers

    # 11. Visualize correlation between numeric sensor variables using a heatmap. Helps detect redundant 
    # or correlated features
    numeric_cols = df.select_dtypes(include=[np.number])
    if not numeric_cols.empty:
        plt.figure(figsize=(10, 6))
        sns.heatmap(numeric_cols.corr(), annot=False, cmap='coolwarm')
        plt.title("Sensor Correlation Heatmap")
        plt.show()

    return report

def process_bool_columns(df: pd.DataFrame):
    # Step 1: Find actual boolean columns
    #bool_cols = df.select_dtypes(include='bool').columns.tolist()
    result = {}
    for col in df.columns:
        unique_vals = set(df[col].dropna().unique())
        if unique_vals.issubset({True, False, 0.0, 1.0, 0, 1}):
            result[col] = unique_vals
    bool_cols = list(result.keys())

    # Step 2: Find object columns that contain true, false, 0, 1, 0.0, or 1.0 (but don't drop rows!). 
    # And also convert to datatype boolean
    true_set = {"true", "yes", "1", "1.0"}
    false_set = {"false", "no", "0", "0.0"}

    for col in bool_cols:
        if df[col].dtype == 'object':
            df[col] = df[col].astype(str).str.strip().str.lower().map(
                lambda x: True if x in true_set
                else False if x in false_set else pd.NA
            ).astype('boolean')


    # Step 3: Print value counts for existing boolean columns
    # print("🔍 Value counts for boolean columns:")
    # for col in bool_cols:
    #     print(f"\n{col}:")
    #     print(df[col].dtype)
    #     print(df[col].value_counts())

    all_bool_cols = list(set(bool_cols))

    return df, all_bool_cols

def standardize_column_types(df: pd.DataFrame, verbose: bool = True) -> pd.DataFrame:
    """
    Standardizes mixed-type object columns in the DataFrame:
    - Converts to numeric if >50% of values can be converted
    - Otherwise, trims strings and keeps them as type `str`
    - Boolean columns are left untouched

    Parameters:
        df (pd.DataFrame): Input DataFrame
        verbose (bool): If True, prints conversion steps

    Returns:
        pd.DataFrame: A new DataFrame with standardized column types
    """
    df_standardized = df.copy()

    for col in df_standardized.columns:
        col_dtype = df_standardized[col].dtype

        # Skip boolean columns
        if col_dtype == 'bool':
            if verbose:
                print(f"🛑 Column '{col}' is boolean. Skipping standardization.")
            continue

        # Handle object (likely mixed) columns
        if col_dtype == 'object':
            num_converted = pd.to_numeric(df_standardized[col], errors='coerce')
            if num_converted.notnull().sum() > len(df_standardized[col]) / 2:
                df_standardized[col] = num_converted
                if verbose:
                    print(f"📊 Column '{col}' standardized to numeric.")
            else:
                df_standardized[col] = df_standardized[col].astype(str).str.strip()
                if verbose:
                    print(f"🔤 Column '{col}' retained as string (stripped).")
        else:
            if verbose:
                print(f"✅ Column '{col}' already {col_dtype}. No change.")

    return df_standardized

def merge_facility_dfs(
    df_1: pd.DataFrame,
    df_2: pd.DataFrame
) -> pd.DataFrame:
    """
    Merges two facilities data.

    Args:
         df_1[pd.DataFrame]: Facility_2 dataframe.
         df_2[pd.DataFrame]: Facility_2 dataframe.

    Returns:
         pd.DataFrame: Cleaned data types in facility 2 dataframe
    """
    # # Columns (1,10,11,17,18,19,20,23,24,25,29,30,31,36,37,38,40,41,42,45,46,47,48,
    # #          54,55,56,58,59,60,61,62,63,67,68,69,72,82,84,85) have mixed types. 

    # ------------------------------
    # Union of Columns
    # ------------------------------
    union_columns = list(set(df_1.columns.tolist()) | set(df_2.columns.tolist()))
    df1_union = df_1.reindex(columns=union_columns)
    df2_union = df_2.reindex(columns=union_columns)
    combined_union = pd.concat([df1_union, df2_union], ignore_index=True)

    return combined_union


def load_merge_csv_json(
    csv_s3_path: str,
    json_s3_path: str,
    facility_value: str
) -> Optional[pd.DataFrame]:
    """
    Loads sensor CSV data and merges it with facility JSON metadata from S3.

    Args:
        csv_s3_path (str): S3 path to the CSV file containing sensor data.
        json_s3_path (str): S3 path to the JSON file containing facility metadata.
        facility_value (str): Facility identifier to be added to the dataframe.

    Returns:
        pd.DataFrame: Merged dataframe with sensor data and metadata, or None if any error occurs.
    """
    df_1 = pd.DataFrame()
    df_2 = pd.DataFrame()
    # Read CSV data from S3 with retry logic
    df = retry_s3_operation(wr.s3.read_csv(csv_s3_path, low_memory=False))
    if df is None:
        logger.error(f"Failed to load CSV from {csv_s3_path} after retries.")
        return None

    # Load metadata JSON from S3 and extract coordinates
    try:
        s3_client = boto3.client('s3')
        bucket, key = json_s3_path.replace("s3://", "").split("/", 1)
        response = retry_s3_operation(s3_client.get_object(Bucket=bucket, Key=key))
        if response is None:
            logger.error(f"Failed to load JSON from {json_s3_path} after retries.")
            return None
        coords = json.loads(response['Body'].read())
    except Exception as e:
        logger.error(f"Failed to parse JSON from {json_s3_path}: {e}")
        return None

    # Merge facility metadata into the dataframe
    df['longitude'] = coords.get('longitude')
    df['latitude'] = coords.get('latitude')
    df['site_comm_date'] = coords.get('site_comm_date')
    df['facility'] = facility_value

    if (facility_value == "facility_1"):
        df_1 = df.copy()
    elif (facility_value == "facility_2"):
        df_2 = df.copy()

    combined_df = merge_facility_dfs(df_1, df_2)

    return combined_df


def main() -> None:
    """
    Main function to process sensor data from multiple facilities.
      - Loads data from S3.
      - Runs data quality checks.
      - Preprocesses the sensor data.
      - Saves the processed data back to S3.
    """
    facilities = ['facility_1', 'facility_2']
    preprocessor = BiogasPreprocessor()
    base_s3_path = 's3://sagemaker-us-east-2-426179662034'
    dataframes = []
    output_path = f'{base_s3_path}/canvas/processed/facility_merge_processed.csv'

    for facility in tqdm(facilities, desc="Processing Facilities"):
        csv_path = f'{base_s3_path}/canvas/{facility}/{facility}_data.csv'
        json_path = f'{base_s3_path}/canvas/{facility}/{facility}_coordinates.json'
 
        logger.info(f"\nProcessing {facility.upper()}\n{'=' * 50}")

        df = load_merge_csv_json(csv_path, json_path, facility_value=facility)
        dataframes.append(df)

    # Concatenate all DataFrames into a single DataFrame
    df_merged = pd.concat(dataframes, ignore_index=True)

    # Little cleanup to make the headernames readable
    # Remove 'bop_plc_' suffix from column names
    df_merged.columns = df_merged.columns.str.replace('bop_plc_', '')

    # Rename selected columns
    df_merged_renamed = df_merged.rename(columns={
    'abb_gc_outletstream_ch4': 'methane_percent',
    'abb_gc_outletstream_flow': 'flow_rate',
    'bge_skid_running': 'operational'
    })

    if df_merged_renamed is None:
        logger.warning(f"Skipping {facility} due to previous errors.")

    # # Run and display data quality checks . Just commented out not to run out of memory while testing.
    run_data_quality_checks(df_merged_renamed)

    df_processed_cleaned_bool, all_bool_cols = process_bool_columns(df_merged_renamed)
    bool_value_counts = {col: df_processed_cleaned_bool[col].value_counts(dropna=False) for col in all_bool_cols}

    # print all of them
    print(bool_value_counts)

    df_processed_standardized = standardize_column_types(df_processed_cleaned_bool, True)

    # Preprocess sensor data
    df_preprocess_completed = preprocessor.preprocess(df_processed_standardized)
    low_var_cols = df_preprocess_completed.loc[:, df_preprocess_completed.nunique() <= 1].columns.tolist()
    df_preprocess_completed = df_preprocess_completed.drop(low_var_cols, axis=1)

    numeric_cols = df_preprocess_completed.select_dtypes(include=[np.number])
    if not numeric_cols.empty:
        plt.figure(figsize=(10, 6))
        sns.heatmap(numeric_cols.corr(), annot=False, cmap='coolwarm')
        plt.title("Sensor Correlation Heatmap")
        plt.show()

    df_preprocess_completed_sorted = df_preprocess_completed.sort_values(by='timestamp')

    # # Get first and last n rows for each facility
    # first = df_preprocess_completed_sorted.groupby('facility').head(12500)
    # last = df_preprocess_completed_sorted.groupby('facility').tail(12500)

    # # Combine and drop duplicates (in case there's overlap)
    # result = pd.concat([first, last]).drop_duplicates().reset_index(drop=True)

    # print(result.shape)
    # row_counts = result.groupby('facility').size()
    # print(row_counts)

    # Save the processed data back to S3
    try:
        retry_s3_operation(wr.s3.to_csv(df_preprocess_completed_sorted, path=output_path, index=False))
        logger.info(f"Processed data saved to: {output_path}\n")
    except Exception as e:
        logger.error(f"Failed to write processed data to {output_path}: {e}")


if __name__ == '__main__':
    main()
