In [1]:
import os
os.environ['JAVA_TOOL_OPTIONS'] = '-Djava.security.manager=allow'

In [2]:
# =================== 1. Setup Spark and Import Libraries ===================
from pyspark.sql import SparkSession
from pyspark.ml.regression import RandomForestRegressor, DecisionTreeRegressor, LinearRegression
from pyspark.ml.tuning import ParamGridBuilder, CrossValidator
from pyspark.ml.evaluation import RegressionEvaluator
from pyspark.ml.regression import LinearRegression, DecisionTreeRegressor, RandomForestRegressor
from pyspark.ml.regression import RandomForestRegressor
from pyspark.ml.tuning import ParamGridBuilder, CrossValidator
from pyspark.ml.evaluation import RegressionEvaluator
import argparse
from pyspark.ml import Pipeline
from pyspark.ml import PipelineModel
from pyspark.sql import SparkSession
from pyspark.sql.functions import *
from pyspark.ml import Pipeline
from pyspark.ml.feature import *
from pyspark.ml.classification import *
from pyspark.ml.evaluation import *
from pyspark.ml.tuning import *
import json
from pyspark.sql.functions import col, isnan, when, count
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from pyspark.ml.feature import VectorAssembler
from pyspark.ml.stat import Correlation
from pyspark.sql.types import IntegerType, DoubleType, FloatType, StructField, Row
from pyspark.sql import functions as F

# Initialize Spark Session
spark = SparkSession.builder.appName("MachineLearningProject").getOrCreate()

Picked up JAVA_TOOL_OPTIONS: -Djava.security.manager=allow
Picked up JAVA_TOOL_OPTIONS: -Djava.security.manager=allow
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
25/01/10 19:14:53 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


In [3]:
csv_path = './data/2008.csv'
plane_data_path = './data/plane-data.csv'

In [4]:
def load_csv(spark, df_path, planes_data_path) -> DataFrame:
    # Read csv
    df = spark.read.csv(
        df_path,
        header=True,
        inferSchema=True
    )
    forbidden_cols = [
        "ArrTime",
        "ActualElapsedTime",
        "AirTime",
        "TaxiIn",
        "Diverted",
        "CarrierDelay",
        "WeatherDelay",
        "NASDelay",
        "SecurityDelay",
        "LateAircraftDelay"
    ]
    df = df.drop(*forbidden_cols)

    df_planes = spark.read.csv(
        planes_data_path,
        header=True,
        inferSchema=True
    )
    df_planes = df_planes.withColumnRenamed("tailnum", "TailNum")
    df_planes = df_planes.withColumnRenamed("year", "PlaneIssueYear")
    df_planes = df_planes.withColumnRenamed("engine_type", "EngineType")
    df_planes = df_planes.withColumnRenamed("aircraft_type", "AircraftType")
    df_planes = df_planes.withColumnRenamed("model", "Model")
    df_planes = df_planes.withColumnRenamed("manufacturer", "Manufacturer")

    data = df.join(df_planes, on="TailNum", how="inner")
    return data

In [5]:
def organize_data(df):
    quant_time_features = [
        'DepTime',
        'CRSDepTime',
        'CRSArrTime'
    ]

    quantitative_features = [
            'CRSElapsedTime',
            'DepDelay',
            'Distance',
            'TaxiOut',
            'PlaneIssueYear'
        ]

    target_column = "ArrDelay"

    for column in quantitative_features + [target_column]:
        df = df.withColumn(column, col(column).cast(IntegerType()))
    df = df.dropna(subset=[target_column])
    null_count = df.filter(col(target_column).isNull()).count()

    for column in quant_time_features:  # They are strings hhmm
        df = df.withColumn(
            column + "_minutes",
            (F.col(column).substr(1, 2).cast("int") * 60 + F.col(column).substr(3, 2).cast("int"))
        )
        quantitative_features.append(column + "_minutes")
    df = df.drop(*quant_time_features)
    return df

In [6]:
df = load_csv(spark, csv_path, plane_data_path)
df = organize_data(df)
df.printSchema()

                                                                                

root
 |-- TailNum: string (nullable = true)
 |-- Year: integer (nullable = true)
 |-- Month: integer (nullable = true)
 |-- DayofMonth: integer (nullable = true)
 |-- DayOfWeek: integer (nullable = true)
 |-- UniqueCarrier: string (nullable = true)
 |-- FlightNum: integer (nullable = true)
 |-- CRSElapsedTime: integer (nullable = true)
 |-- ArrDelay: integer (nullable = true)
 |-- DepDelay: integer (nullable = true)
 |-- Origin: string (nullable = true)
 |-- Dest: string (nullable = true)
 |-- Distance: integer (nullable = true)
 |-- TaxiOut: integer (nullable = true)
 |-- Cancelled: integer (nullable = true)
 |-- CancellationCode: string (nullable = true)
 |-- type: string (nullable = true)
 |-- Manufacturer: string (nullable = true)
 |-- issue_date: string (nullable = true)
 |-- Model: string (nullable = true)
 |-- status: string (nullable = true)
 |-- AircraftType: string (nullable = true)
 |-- EngineType: string (nullable = true)
 |-- PlaneIssueYear: integer (nullable = true)
 |-- 

In [7]:
numeric_features = [
    'Month',
    'DayofMonth',
    'DayOfWeek',
    'Year',
    'PlaneIssueYear',
    'DepTime_minutes',
    'CRSDepTime_minutes',
    'CRSArrTime_minutes',
    'CRSElapsedTime',
    'DepDelay',
    'Distance',
    'TaxiOut'
]

In [8]:
categorical_features = [
    'UniqueCarrier',
    'FlightNum',
    'TailNum',
    'Origin',
    'Dest',
    'Cancelled',
    'CancellationCode',
    'EngineType',
    'AircraftType',
    'Manufacturer',
    'Model',
    "issue_date", "status",
    "type",
    "ArrDelay"
]

In [9]:
print(f"Number of rows: {df.count()}")
print(f"Number of columns: {len(df.columns)}")

Number of rows: 2235032
Number of columns: 27


In [10]:
def null_values(data, features_list):
    # Calculate null values for each column in features_list
    null_data = data.select([
        count(when(col(c).isNull() | isnan(col(c)), c)).alias(c) for c in features_list
    ])
    
    # Show the results
    null_data.show()
    
    # Return the DataFrame with null counts
    return null_data

In [11]:
def plot_null_percentages(df, null_counts, numeric):# Convert to Pandas and compute percentage
    if numeric == True:
        type = 'Numerical'
        color = 'skyblue'
    else:
        type = 'Categorical'
        color = 'lightcoral'
    total_rows = df.count()
    null_counts_pandas = null_counts.toPandas().T  # Transpose for easier handling
    null_counts_pandas.columns = ["NullCount"]
    null_counts_pandas["Percentage"] = (null_counts_pandas["NullCount"] / total_rows) * 100
    null_counts_pandas = null_counts_pandas.sort_values("Percentage", ascending=False)

    # Plot the bar chart
    null_counts_pandas["Percentage"].plot(kind="barh", color=color)  # Use `barh` for horizontal bars
    plt.xlabel("Percentage of Null Values (%)")
    plt.ylabel(f"{type} Features")
    plt.title(f"Percentage of Null Values by {type} Features")
    plt.savefig(f"output/{type.lower()}/img/null_values_percentage_{type.lower()}.png", dpi=300, bbox_inches="tight")  # Save as PNG
    plt.close()


In [12]:
null_counts_numeric = null_values(df, numeric_features)

[Stage 13:=====>                                                   (1 + 9) / 10]

+-----+----------+---------+----+--------------+---------------+------------------+------------------+--------------+--------+--------+-------+
|Month|DayofMonth|DayOfWeek|Year|PlaneIssueYear|DepTime_minutes|CRSDepTime_minutes|CRSArrTime_minutes|CRSElapsedTime|DepDelay|Distance|TaxiOut|
+-----+----------+---------+----+--------------+---------------+------------------+------------------+--------------+--------+--------+-------+
|    0|         0|        0|   0|        176935|           6950|              3077|             19365|             0|       0|       0|      0|
+-----+----------+---------+----+--------------+---------------+------------------+------------------+--------------+--------+--------+-------+



                                                                                

In [13]:
plot_null_percentages(df, null_counts_numeric, True)

                                                                                

In [14]:
null_counts_categorical = null_values(df, categorical_features)



+-------------+---------+-------+------+----+---------+----------------+----------+------------+------------+------+----------+------+------+--------+
|UniqueCarrier|FlightNum|TailNum|Origin|Dest|Cancelled|CancellationCode|EngineType|AircraftType|Manufacturer| Model|issue_date|status|  type|ArrDelay|
+-------------+---------+-------+------+----+---------+----------------+----------+------------+------------+------+----------+------+------+--------+
|            0|        0|      0|     0|   0|        0|         2235032|    107854|      107854|      107854|107854|    107854|107854|107854|       0|
+-------------+---------+-------+------+----+---------+----------------+----------+------------+------------+------+----------+------+------+--------+



                                                                                

In [15]:
plot_null_percentages(df, null_counts_categorical, False)

25/01/10 19:15:11 WARN GarbageCollectionMetrics: To enable non-built-in garbage collector(s) List(G1 Concurrent GC), users should configure it(them) to spark.eventLog.gcMetrics.youngGenerationGarbageCollectors or spark.eventLog.gcMetrics.oldGenerationGarbageCollectors
                                                                                

## **Statistics Summary**

In [16]:
def statistics_summary(data):
    summary_df = data.select(numeric_features).summary().toPandas()
    summary_df.set_index("summary", inplace=True)
    summary_numeric = summary_df.apply(pd.to_numeric, errors='coerce')
    summary_numeric = summary_numeric.T
    return summary_numeric

In [17]:
summary = statistics_summary(df)
summary

25/01/10 19:15:13 WARN SparkStringUtils: Truncated the string representation of a plan since it was too large. This behavior can be adjusted by setting 'spark.sql.debug.maxToStringFields'.
                                                                                

summary,count,mean,stddev,min,25%,50%,75%,max
Month,2235032.0,2.51106,1.123274,1.0,1.0,3.0,4.0,4.0
DayofMonth,2235032.0,15.695071,8.74591,1.0,8.0,16.0,23.0,31.0
DayOfWeek,2235032.0,3.915981,1.982679,1.0,2.0,4.0,6.0,7.0
Year,2235032.0,2008.0,1.908099e-13,2008.0,2008.0,2008.0,2008.0,2008.0
PlaneIssueYear,2058097.0,1995.048672,68.06124,0.0,1992.0,2000.0,2003.0,2008.0
DepTime_minutes,2228082.0,1947.94586,1662.226,600.0,839.0,1081.0,3604.0,5709.0
CRSDepTime_minutes,2231955.0,1966.182544,1685.97,600.0,830.0,1070.0,3660.0,5709.0
CRSArrTime_minutes,2215667.0,1614.865577,1467.061,600.0,835.0,1070.0,1310.0,5709.0
CRSElapsedTime,2235032.0,130.778137,70.55441,-21.0,80.0,113.0,162.0,660.0
DepDelay,2235032.0,11.36959,36.27695,-92.0,-4.0,0.0,11.0,2467.0


## **Features Distribution**

In [18]:
def features_distributions(data, features_list, is_numeric_features=True):
    if is_numeric_features == True:
        type = 'Numerical'
        numerical_df = data.select(features_list).toPandas()

        num_features = len(features_list)
        cols = 2  # Number of columns in the grid
        rows = (num_features // cols) + (num_features % cols > 0)  # Calculate rows needed

        fig, axes = plt.subplots(rows, cols, figsize=(15, rows * 5))  # Adjust the figure size
        axes = axes.flatten()  # Flatten the axes array for easy iteration

        # Plot each feature's distribution
        for i, col in enumerate(features_list):
            sns.histplot(numerical_df[col], bins=30, ax=axes[i])  # Use the subplot's axis
            axes[i].set_title(f"Distribution of {col}")
            axes[i].set_xlabel(col)
            axes[i].set_ylabel("Frequency")

            if col == "DepDelay":  # Modify based on feature name
                axes[i].set_xlim(0, 500)  # Set x-axis range (e.g., 0 to 500)
            elif col == "TaxiOut":
                axes[i].set_xlim(0, 150)

        # Remove any unused subplots
        for i in range(len(features_list), len(axes)):
            fig.delaxes(axes[i])

        plt.tight_layout()  # Adjust layout to avoid overlap
        plt.savefig(f"output/{type.lower()}/img/features_distribution_{type.lower()}.png", dpi=300, bbox_inches="tight")  # Save as PNG
        plt.close(fig)
    else:
        type = 'Categorical'
        num_features = len(features_list)
        cols = 2  # Number of columns in the grid
        rows = (num_features // cols) + (num_features % cols > 0)  # Calculate rows needed

        fig, axes = plt.subplots(rows, cols, figsize=(15, rows * 5))  # Adjust figure size
        axes = axes.flatten()  # Flatten the axes array for easy iteration

        # Plot each feature's distribution
        for i, col in enumerate(features_list):
            # Group by column and count occurrences
            # Limit to top 15 categories
            if col == "ArrDelay":  # Special case for "ArrDelay"
                # Group by column and count occurrences for all data
                category_counts = data.groupBy(col).count().orderBy("count", ascending=False)
            else:
                # Limit to top 15 categories for other features
                top_n = 15
                category_counts = data.groupBy(col).count().orderBy("count", ascending=False).limit(top_n)

            category_df = category_counts.toPandas()

            # Plot using the subplot axis
            sns.barplot(data=category_df, x=col, y="count", ax=axes[i])
            axes[i].set_title(f"Distribution of {col}")
            axes[i].set_xlabel(col)
            axes[i].set_ylabel("Count")
            axes[i].tick_params(axis="x", rotation=90)  # Rotate x-axis labels

            if col == "ArrDelay":
                axes[i].set_xlim(50, 120)
                axes[i].tick_params(axis="x", labelsize=8)

        # Remove any unused subplots
        for i in range(len(features_list), len(axes)):
            fig.delaxes(axes[i])

        plt.tight_layout()  # Adjust layout to avoid overlap
        plt.savefig(f"output/{type.lower()}/img/features_distribution_{type.lower()}.png", dpi=300, bbox_inches="tight")  # Save as PNG
        plt.close(fig)

Numerical features:

In [19]:
features_distributions(df, ['DayofMonth','DayOfWeek','CRSDepTime_minutes','CRSArrTime_minutes','CRSElapsedTime','DepDelay','Distance','TaxiOut'], is_numeric_features=True)

                                                                                

Categorical features:

In [20]:
features_distributions(df, ['UniqueCarrier','Origin','Dest','EngineType','AircraftType','Manufacturer','Model','ArrDelay'], is_numeric_features=False)

                                                                                

## **Features Proportions**

In [21]:
def proportions(data, features_list, is_numeric_features = True):
    if is_numeric_features == True:
        type = 'numerical'
    else:
        type = 'categorical'

    total_count = data.count()
    for feature in features_list:
        feature_counts = df.groupBy(feature).count()
        # Calculate proportions
        feature_proportions = feature_counts.withColumn(
            "Proportion", round((col("count") / total_count)*100,2)
        )
        if not os.path.exists(f'output/{type}'):
            os.makedirs(f'output/{type}')
        feature_proportions.write.csv(f"output/{type}/{feature}_proportions.csv", header=True, mode="overwrite")

In [22]:
proportions(df, categorical_features, is_numeric_features=False)

                                                                                

In [23]:
proportions(df, numeric_features, is_numeric_features=True)

                                                                                

## **Average ArrDelay by categorical features**

In [24]:
def avg_ArrDelay(data, features_list):
    # Number of features
    num_features = len(features_list)
    cols = 2  # Number of columns in the grid
    rows = (num_features // cols) + (num_features % cols > 0)  # Calculate rows needed

    # Create subplots
    fig, axes = plt.subplots(rows, cols, figsize=(15, rows * 6))  # Adjust figure size
    axes = axes.flatten()  # Flatten axes for easy iteration

    # Iterate over each categorical feature
    for i, col_name in enumerate(features_list):  # Use col_name for clarity
        if col_name in ["Origin", "Dest", "Model"]:  # Special case for "Origin" and "Dest"
            type = 'categorical'
            top_n = 20
            category_counts = data.groupBy(col_name).count().orderBy("count", ascending=False).limit(top_n)
            # Filter data for the top 20 categories
            top_categories = [row[col_name] for row in category_counts.collect()]
            data_filtered = data.filter(col(col_name).isin(top_categories))  # Correct usage of col()
        else:
            type = 'numerical'
            data_filtered = data

        # Group data and calculate average arrival delay
        grouped_df = data_filtered.groupBy(col_name).agg({"ArrDelay": "mean"})

        # Convert to Pandas
        grouped_pandas = grouped_df.toPandas()

        # Plot bar chart in the subplot
        sns.barplot(data=grouped_pandas, x=col_name, y="avg(ArrDelay)", ax=axes[i])
        axes[i].set_title(f"Average Arrival Delay by {col_name}", fontsize=12)
        axes[i].set_xlabel(col_name)
        axes[i].set_ylabel("Avg. Arrival Delay")
        axes[i].tick_params(axis="x", rotation=90)

        if col_name in ["DepTime_minutes", "DepDelay"]:
            xticks = axes[i].get_xticks()
            axes[i].set_xticks([xticks[0], xticks[-1]])
            axes[i].set_xticklabels([grouped_pandas[col_name].iloc[0], grouped_pandas[col_name].iloc[-1]])  # Set corresponding labels


    # Remove any unused subplots
    for i in range(len(features_list), len(axes)):
        fig.delaxes(axes[i])

    # Adjust layout
    plt.tight_layout()
    plt.savefig(f"output/{type.lower()}/img/avg_ArrDelay_{type.lower()}.png", dpi=300, bbox_inches="tight")  # Save as PNG
    plt.close(fig)


In [25]:
avg_ArrDelay(df, ['Origin','Dest','EngineType','AircraftType','Manufacturer','Model'])

                                                                                

In [26]:
avg_ArrDelay(df, ['Month','DayofMonth','DayOfWeek','PlaneIssueYear','DepTime_minutes','DepDelay'])

                                                                                

## **Correlation Matrix**

In [27]:
def corr_matrix(data, features_list):
    data = data.fillna(0, subset=features_list)
    vector_col = "features_corr"

    vector_assembler = VectorAssembler(inputCols=features_list, outputCol=vector_col)
    df_vector = vector_assembler.transform(data)

    # Compute Correlation Matrix
    correlation_matrix = Correlation.corr(df_vector, vector_col).head()[0]  # Get the DenseMatrix
    correlation_array = correlation_matrix.toArray()
    correlation_df = pd.DataFrame(correlation_array, index=features_list, columns=features_list)
    sns.heatmap(
        correlation_df,
        annot=True,              # Show the correlation values
        fmt=".1f",               # Format to two decimal places
        cmap="coolwarm",         # Color map
        annot_kws={"size": 8}    # Reduce annotation font size
    )
    plt.title("Correlation Matrix Heatmap")
    plt.savefig(f"output/numerical/img/correlation_matrix.png", dpi=300, bbox_inches="tight")  # Save as PNG
    plt.close()

In [28]:
corr_matrix(df, numeric_features)

25/01/10 19:16:24 WARN InstanceBuilder: Failed to load implementation from:dev.ludovic.netlib.blas.JNIBLAS
25/01/10 19:16:24 WARN InstanceBuilder: Failed to load implementation from:dev.ludovic.netlib.blas.VectorBLAS
25/01/10 19:16:25 WARN PearsonCorrelation: Pearson correlation matrix contains NaN values.


In [None]:
# =================== 3. Data Processing ===================
def process_data(data, mode):
    """
    Process the dataset: handle missing values and perform feature engineering.

    Args:
        data (DataFrame): Spark DataFrame to process.
        mode (str): Mode of operation ("train" or "predict").

    Returns:
        DataFrame: Processed Spark DataFrame with new features added.
    """
    # Validate the target variable for training mode
    if mode == "train" and "ArrDelay" not in data.columns:
        raise ValueError("The target variable 'ArrDelay' is missing.")

    # Handle missing values
    if mode == "train":
        # Drop rows where the target variable or features are null
        data = data.dropna(subset=["ArrDelay"])

    # Example: Fill null values in specific columns with a default value
    # Replace 'column_name' with actual column names as needed
    # Uncomment this if specific columns require filling
    # data = data.fillna({"column_name": 0})

    # Transform special variables
    # Feature engineering: Create time-based features
    if "DepTime" in data.columns:
        data = data.withColumn("DepHour", (col("DepTime") / 100).cast("int"))  # Extract hour from departure time

    if "FlightDate" in data.columns:
        data = data.withColumn("DayOfWeek", date_format(col("FlightDate"), "u").cast("int"))  # Convert to day of the week

    # Feature engineering: Create flight distance categories
    if "Distance" in data.columns:
        data = data.withColumn(
            "DistanceCategory",
            when(col("Distance") < 500, "Short")  # Short flights
            .when((col("Distance") >= 500) & (col("Distance") < 1500), "Medium")  # Medium flights
            .otherwise("Long")  # Long flights
        )

    return data


In [None]:
# =================== 4. Feature Engineering ===================
def feature_engineering(data, additional_dataset_path=None):
    """
    Perform feature engineering, including creating new features and optionally integrating additional datasets.

    Args:
        data (DataFrame): Spark DataFrame for feature engineering.
        additional_dataset_path (str): Optional path to an additional dataset for integration.
    
    Returns:
        DataFrame: Enhanced Spark DataFrame with new features.
    """
    # Create new features based on existing columns
    if "existing_column" in data.columns:
        data = data.withColumn("new_feature", col("existing_column") * 2)  # Example transformation

    # Optional: Integrate additional datasets
    if additional_dataset_path:
        try:
            additional_dataset = spark.read.csv(additional_dataset_path, header=True, inferSchema=True)
            
            # Example: Join the datasets on a common column
            if "common_column" in data.columns and "common_column" in additional_dataset.columns:
                data = data.join(additional_dataset, on="common_column", how="left")
        except Exception as e:
            print(f"Error integrating additional dataset: {e}")

    return data


In [None]:
def build_and_train_model(data, pipeline, model_save_path=None):
    """
    Build, train, evaluate, and optionally save the model using cross-validation with three models.

    Args:
        data (DataFrame): Spark DataFrame with features and labels.
        pipeline (Pipeline): Preprocessing pipeline to use before modeling.
        model_save_path (str): Path to save the trained model (optional).
    
    Returns:
        dict: Evaluation metrics for the trained model.
    """
    # Split data into training and testing sets
    train_data, test_data = data.randomSplit([0.8, 0.2], seed=42)

    # Define the models: RandomForestRegressor, DecisionTreeRegressor, LinearRegression
    rf = RandomForestRegressor(featuresCol="features_vector", labelCol="ArrDelay")
    dt = DecisionTreeRegressor(featuresCol="features_vector", labelCol="ArrDelay")
    lr = LinearRegression(featuresCol="features_vector", labelCol="ArrDelay")

    # Initialize metrics dictionary
    all_metrics = {}

    # Evaluate each model separately
    models = [rf, dt, lr]
    model_names = ['Random Forest', 'Decision Tree', 'Linear Regression']
    
    for model, name in zip(models, model_names):
        print(f"Training {name} model...")
        
        # Add the current model to the pipeline
        pipeline.setStages(pipeline.getStages() + [model])

        # Hyperparameter tuning with cross-validation for the current model
        param_grid = ParamGridBuilder() \
            .addGrid(model.numTrees, [10, 50, 100]) if isinstance(model, RandomForestRegressor) else \
            (addGrid(model.maxDepth, [5, 10, 20]) if isinstance(model, DecisionTreeRegressor) else \
            addGrid(model.regParam, [0.1, 0.3, 0.5])) \
            .build()

        evaluator = RegressionEvaluator(labelCol="ArrDelay", predictionCol="prediction", metricName="rmse")

        # Set up cross-validation
        cv = CrossValidator(
            estimator=pipeline,
            estimatorParamMaps=param_grid,
            evaluator=evaluator,
            numFolds=5
        )

        # Train the model with cross-validation
        cv_model = cv.fit(train_data)

        # Generate predictions on the test dataset
        predictions = cv_model.transform(test_data)

        # Evaluate the model using multiple metrics
        metrics = {}
        # Root Mean Square Error (RMSE)
        rmse_evaluator = RegressionEvaluator(labelCol="ArrDelay", predictionCol="prediction", metricName="rmse")
        metrics['rmse'] = rmse_evaluator.evaluate(predictions)
        print(f"{name} - Root Mean Square Error (RMSE) on test data: {metrics['rmse']}")

        # Mean Absolute Error (MAE)
        mae_evaluator = RegressionEvaluator(labelCol="ArrDelay", predictionCol="prediction", metricName="mae")
        metrics['mae'] = mae_evaluator.evaluate(predictions)
        print(f"{name} - Mean Absolute Error (MAE) on test data: {metrics['mae']}")

        # R-Squared (R²)
        r2_evaluator = RegressionEvaluator(labelCol="ArrDelay", predictionCol="prediction", metricName="r2")
        metrics['r2'] = r2_evaluator.evaluate(predictions)
        print(f"{name} - R-Squared (R²) on test data: {metrics['r2']}")

        # Store model-specific metrics
        all_metrics[name] = metrics

        # Save the best model if a save path is provided
        if model_save_path:
            cv_model.bestModel.write().overwrite().save(f"{model_save_path}_{name}")
            print(f"Best {name} model saved to: {model_save_path}_{name}")

    return all_metrics


In [None]:
def predict(data, model_path, output_path):
    """
    Load the trained model and generate predictions for the given data.

    Args:
        data (DataFrame): Spark DataFrame for predictions.
        model_path (str): Path to load the trained model.
        output_path (str): Path to save predictions (CSV).
    
    Returns:
        None
    """
    # Load the trained model
    model = PipelineModel.load(model_path)

    # Make predictions on the input data
    predictions = model.transform(data)

    # Save predictions to the specified output path
    predictions.select("features_vector", "prediction").write.csv(output_path, header=True)
    print(f"Predictions saved to: {output_path}")


In [None]:
def main():
    """
    Main function to execute the pipeline workflow.
    Accepts command-line arguments for dynamic input/output handling.
    """
    parser = argparse.ArgumentParser(description="Flight Delay Prediction Application")
    parser.add_argument("--mode", type=str, required=True, choices=["train", "predict"], help="Mode: train or predict")
    parser.add_argument("--input", type=str, required=True, help="Path to input CSV file")
    parser.add_argument("--model", type=str, required=True, help="Path to save/load the model")
    parser.add_argument("--output", type=str, help="Path to save predictions (required for predict mode)")

    args = parser.parse_args()

    # Start Spark Session
    spark = SparkSession.builder.appName("FlightDelayPipeline").getOrCreate()

    try:
        # Workflow
        data = load_data(spark, args.input, args.mode)  # Load the dataset
        data = eda(data)
        data = process_data(data, args.mode)        # Preprocess the dataset
        pipeline, _ = feature_engineering(data)        # Perform feature engineering

        if args.mode == "train":
            # Train the model, evaluate it, and optionally save it
            metrics = build_and_train_model(data, pipeline, args.model)
            print(f"Training completed. Evaluation metrics: {metrics}")
        elif args.mode == "predict":
            if not args.output:
                raise ValueError("Output path is required for prediction mode.")
            # Use the trained model to generate predictions
            predict(data, args.model, args.output)

    finally:
        # Stop Spark Session
        spark.stop()

# Add this block to execute the script when running it as a standalone script
if __name__ == "__main__":
    main()


In [None]:
# spark-submit notebook.py --mode train --input path/to/train.csv --model path/to/save_model
# spark-submit notebook.py --mode predict --input path/to/test.csv --model path/to/save_model --output path/to/predictions