In [None]:
# import modules
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, when, avg, from_json, lag, round, stddev, unix_timestamp, lit
from pyspark.sql.window import Window
from pyspark.sql.types import StructType, StructField, StringType, IntegerType, DoubleType
import google.cloud.logging
import logging

# variables setup
GCS_BUCKET_SOURCE = "gs://sampple-bkt-13022025/patients_data/*.json"
BQ_TABLE = "avd-group-gcp-1111.gold_dataset.patients_insights"
GCS_BUCKET_INVALID = "gs://sampple-bkt-13022025/invalid/"
GCS_TEMP_BUCKET = "gs://sampple-bkt-13022025/temp/"

# Initialize Spark session
spark = SparkSession.builder \
    .appName("HealthcareDataProcessing") \
    .getOrCreate()

# Initialize Google Cloud Logging
logging_client = google.cloud.logging.Client()
logging_client.setup_logging()
logger = logging.getLogger('healthcare-pipeline')

# Logging helper function
def log_pipeline_step(step, message, level='INFO'):
    if level == 'INFO':
        logger.info(f"Step: {step}, Message: {message}")
    elif level == 'ERROR':
        logger.error(f"Step: {step}, Error: {message}")
    elif level == 'WARNING':
        logger.warning(f"Step: {step}, Warning: {message}")
        
# Define schema for JSON data
schema = StructType([
    StructField("patient_id", StringType(), True),
    StructField("heart_rate", IntegerType(), True),
    StructField("blood_pressure", IntegerType(), True),
    StructField("temperature", DoubleType(), True),
    StructField("timestamp", StringType(), True)
])

# Function to validate incoming data
def validate_data(df):
    log_pipeline_step("Data Validation", "Starting data validation.")
    
    validated_df = df.withColumn("is_valid", when((col("heart_rate") > 40) & (col("heart_rate") < 200) & 
                                                  (col("blood_pressure") > 50) & (col("blood_pressure") < 200) & 
                                                  (col("temperature") > 35.0) & (col("temperature") < 42.0), True).otherwise(False))
    
    valid_records = validated_df.filter(col("is_valid") == True)
    invalid_records = validated_df.filter(col("is_valid") == False)
    
    log_pipeline_step("Data Validation", f"Valid records: {valid_records.count()}, Invalid records: {invalid_records.count()}")
    return valid_records, invalid_records


# Main processing function
def process_data():
    try:
        # Step 1: Read raw data from GCS
        log_pipeline_step("Data Ingestion", "Reading raw data from GCS.")
        df = spark.read.schema(schema).json(GCS_BUCKET_SOURCE)
        
        # Step 2: Validate data
        valid_df, invalid_df = validate_data(df)
        
        # Step 3: Log invalid records (for auditing)
        if invalid_df.count() > 0:
            log_pipeline_step("Invalid Data", "Found invalid records.", level='WARNING')
            invalid_df.write.mode("append").json(GCS_BUCKET_INVALID)
            
        # Step 4: Data Transformation - Aggregate by patient_id
        log_pipeline_step("Data Transformation", "Aggregating data by patient_id.")
        df_agg = valid_df.groupBy("patient_id").agg(
                                                    round(avg("heart_rate"), 2).alias("avg_heart_rate"),
                                                    round(avg("blood_pressure"), 2).alias("avg_blood_pressure"),
                                                    round(avg("temperature"), 2).alias("avg_temperature")
                                                )
        
        # Step 5: Calculate standard deviation for heart rate and blood pressure for each patient
        log_pipeline_step("Data Transformation", "Calculating standard deviation for heart rate and blood pressure.")
        df_stddev = valid_df.groupBy("patient_id").agg(
                                                        stddev("heart_rate").alias("stddev_heart_rate"),
                                                        stddev("blood_pressure").alias("stddev_blood_pressure")
                                                    )
        
        # Step 6: Join the aggregated data with standard deviation metrics
        log_pipeline_step("Data Transformation", "Joining aggregated data with standard deviation metrics.")
        df_joined = df_agg.join(df_stddev, on="patient_id")
        
        # Step 8: Flag patients with high average heart rate or high heart rate variation
        log_pipeline_step("Data Transformation", "Flagging high-risk patients.")
        df_joined = df_joined.withColumn("risk_category", 
                                                        when(col("avg_heart_rate") > 100, "High Risk")
                                                        .when(col("stddev_heart_rate") > 15, "Moderate Risk")
                                                        .otherwise("Low Risk"))
        
        # Step 9: Write the aggregated data with risk categorization to BigQuery
        log_pipeline_step("Data Write", "Writing aggregated and transformed data to BigQuery.")
        df_joined.write \
                    .format("bigquery") \
                    .option("table", BQ_TABLE) \
                    .option("temporaryGcsBucket", GCS_TEMP_BUCKET) \
                    .mode("append") \
                    .save()

    
    except Exception as e:
        log_pipeline_step("Processing Error", str(e), level='ERROR')
        raise e


# Execute the processing function
if __name__ == "__main__":
    log_pipeline_step("Pipeline Start", "Healthcare data processing pipeline initiated.")
    process_data()
    log_pipeline_step("Pipeline End", "Healthcare data processing pipeline completed.")