<a href="https://colab.research.google.com/github/tngjody/bt4221-airfare-booking/blob/main/BT4221_Group_20.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# BT4221 Group 20: Optimising Airfare Booking with Predictive Analytics

Group Members:

Gong Yongjia A0286144X

Jody Tng Jin Zi A0238195W

Wu Shuhan A0266501L

Zhou Jingchu Jeslyn A0275993H

### Importing Libraries

In [21]:


# PySpark Imports
from pyspark.sql import SparkSession
from pyspark.sql import functions as F
from pyspark.sql.functions import col, sum, isnan, datediff, lit
from pyspark.sql.types import NumericType, DoubleType, IntegerType, FloatType, DateType
from pyspark.ml.stat import Correlation
from pyspark.ml.feature import VectorAssembler, StringIndexer

# Visualization Imports
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

# Other Imports
from datetime import datetime
from datetime import timedelta


# 1. Data Loading & Preprocessing

### Initialize Spark Session & Load Dataset

In [12]:
# Initialize Spark Sesison
spark = SparkSession.builder.appName("FlightPrices").getOrCreate()

# Load Dataset
flight_df = spark.read.csv("flight.csv", header=True, inferSchema=True)

# Show a sample (5 rows) of the data
flight_df.show(5, truncate=False)

# Display schema
flight_df.printSchema()


                                                                                

+--------------------------------+----------+----------+---------------+------------------+-------------+--------------+-----------+--------------+------------+---------+--------+---------+--------------+-------------------+---------------------------------+------------------------------------------------------------+-------------------------------+------------------------------------------------------------+--------------------------+----------------------------+-------------------+-------------------+------------------------------+-------------------------+----------------+-----------------+
|legId                           |searchDate|flightDate|startingAirport|destinationAirport|fareBasisCode|travelDuration|elapsedDays|isBasicEconomy|isRefundable|isNonStop|baseFare|totalFare|seatsRemaining|totalTravelDistance|segmentsDepartureTimeEpochSeconds|segmentsDepartureTimeRaw                                    |segmentsArrivalTimeEpochSeconds|segmentsArrivalTimeRaw                          

                                                                                

+--------------------------------+----------+----------+---------------+------------------+-------------+--------------+-----------+--------------+------------+---------+--------+---------+--------------+-------------------+---------------------------------+-----------------------------+-------------------------------+-----------------------------+--------------------------+----------------------------+-------------------+-------------------+----------------------------+-------------------------+----------------+-----------------+
|legId                           |searchDate|flightDate|startingAirport|destinationAirport|fareBasisCode|travelDuration|elapsedDays|isBasicEconomy|isRefundable|isNonStop|baseFare|totalFare|seatsRemaining|totalTravelDistance|segmentsDepartureTimeEpochSeconds|segmentsDepartureTimeRaw     |segmentsArrivalTimeEpochSeconds|segmentsArrivalTimeRaw       |segmentsArrivalAirportCode|segmentsDepartureAirportCode|segmentsAirlineName|segmentsAirlineCode|segmentsEquipment

                                                                                

# 2. Feature Engineering

In [23]:
def convert_duration_to_hours(duration_col):
    # Extract hours and minutes using regexp_extract
    hours = F.regexp_extract(duration_col, r'PT(\d+)H', 1).cast('double')
    minutes = F.regexp_extract(duration_col, r'H(\d+)M', 1).cast('double')
    
    # Calculate total hours and round to nearest 0.25
    total_hours = hours + (minutes / 60)
    rounded_hours = F.round(total_hours * 2) / 2
    
    return rounded_hours

def get_total_distance(total_col, segments_col):
    # If totalTravelDistance is not null, use it
    # Otherwise, split the segments by || and sum them
    return F.when(
        F.col(total_col).isNotNull(), 
        F.col(total_col)
    ).otherwise(
        # Split by ||, convert to array of doubles, and sum
        F.expr("cast(aggregate(transform(split(segmentsDistance, '\\\\|\\\\|'), x -> cast(x as double)), 0D, (acc, x) -> acc + x) as double)")
    )

def calculate_layover_duration():
    return F.when(
        F.size(F.split('segmentsDepartureTimeEpochSeconds', '\\|\\|')) <= 1,
        0.0  # No layover for single segment
    ).otherwise(
        # For multiple segments:
        # 1. Split the arrival and departure times into arrays
        # 2. Calculate differences between next departure and previous arrival
        # 3. Sum up the differences and convert to hours
        # 4. Round to nearest 0.5 hours
        F.round(
            F.expr("""
                (
                    aggregate(
                        zip_with(
                            slice(split(segmentsDepartureTimeEpochSeconds, '\\\\|\\\\|'), 2, size(split(segmentsDepartureTimeEpochSeconds, '\\\\|\\\\|'))),
                            slice(split(segmentsArrivalTimeEpochSeconds, '\\\\|\\\\|'), 1, size(split(segmentsArrivalTimeEpochSeconds, '\\\\|\\\\|')) - 1),
                            (x, y) -> cast(x as double) - cast(y as double)
                        ),
                        0D,
                        (acc, x) -> acc + x
                    ) / 3600
                ) * 2
            """)
        ) / 2
    )

holiday_data = [
    ("Good Friday", "2022-04-15"),
    ("Labor Day (International)", "2022-05-01"),
    ("Memorial Day", "2022-05-30"),
    ("Juneteenth (observed)", "2022-06-20"),
    ("Independence Day", "2022-07-04"),
]
holiday_df = spark.createDataFrame(holiday_data, ["holiday_name", "date"])
holiday_df = holiday_df.withColumn("date", F.to_date(F.col("date")))

# Create an array of dates (holiday-1, holiday, holiday+1) for each holiday
holiday_dates = set()
for row in holiday_data:
    holiday_date = F.to_date(F.lit(row[1])).cast(DateType())
    date_val = holiday_df.select(holiday_date).collect()[0][0]
    holiday_dates.add(date_val.strftime('%Y-%m-%d'))
    holiday_dates.add((date_val + timedelta(days=1)).strftime('%Y-%m-%d'))
    holiday_dates.add((date_val - timedelta(days=1)).strftime('%Y-%m-%d'))
holiday_dates_list = list(holiday_dates)

cleaned_flight_df = flight_df.select(
    'legid', # Keep original flight ID
    F.col('startingAirport').cast('string'), # ensure the string type
    F.col('destinationAirport').cast('string'), # ensure the string type
    
    # Convert duration format (PT2H30M) to hours (2.5)
    convert_duration_to_hours(F.col('travelDuration')).alias('travelDuration'),
    
    # ensure the integer type
    F.col('elapsedDays').cast('int').alias('elapsedDays'),
    
    # Convert boolean 'true'/'false' strings to 1/0 integers
    F.when(F.col('isBasicEconomy') == 'true', 1).otherwise(0).cast('int').alias('isBasicEconomy'),
    F.when(F.col('isRefundable') == 'true', 1).otherwise(0).cast('int').alias('isRefundable'),
    F.col('seatsRemaining').cast('int').alias('seatsRemaining'),
    
     # Check if flight date is within ±1 day of holidays
    F.when(
        F.date_format('flightDate', 'yyyy-MM-dd').isin(holiday_dates_list),
        1
    ).otherwise(0).cast('int').alias('isFestival'),
    
    # Calculate total distance (use segments if total not available)
    get_total_distance('totalTravelDistance', 'segmentsDistance').cast('int').alias('totalTravelDistance'),
    
    # Calculate days between search and flight date
    F.datediff(F.col('flightDate'), F.col('searchDate')).cast('int').alias('daysUntilDeparture'),
    
    # Extract day of week (Mon, Tue, etc.)
    F.date_format(F.col('flightDate'), 'E').alias('dayOfTheWeek'),
    
    # Count number of flight segments (1 if no '||', else count + 1)
    F.when(F.col('segmentsDistance').contains('||'),
           F.size(F.split('segmentsDistance', '\\|\\|')) + 1
    ).otherwise(1).cast('int').alias('numOfSegments'),
    
     # Calculate total layover time between segments
    calculate_layover_duration().alias('layoverDuration'),
    
    # Create binary indicators (1/0) for each airline's presence in the journey
    F.when(F.col('segmentsAirlineCode').contains('UA'), 1).otherwise(0).cast('int').alias('UA'),
    F.when(F.col('segmentsAirlineCode').contains('NK'), 1).otherwise(0).cast('int').alias('NK'),
    F.when(F.col('segmentsAirlineCode').contains('AA'), 1).otherwise(0).cast('int').alias('AA'),
    F.when(F.col('segmentsAirlineCode').contains('4B'), 1).otherwise(0).cast('int').alias('4B'),
    F.when(F.col('segmentsAirlineCode').contains('LF'), 1).otherwise(0).cast('int').alias('LF'),
    F.when(F.col('segmentsAirlineCode').contains('B6'), 1).otherwise(0).cast('int').alias('B6'),
    F.when(F.col('segmentsAirlineCode').contains('DL'), 1).otherwise(0).cast('int').alias('DL'),
    F.when(F.col('segmentsAirlineCode').contains('9K'), 1).otherwise(0).cast('int').alias('9K'),
    F.when(F.col('segmentsAirlineCode').contains('F9'), 1).otherwise(0).cast('int').alias('F9'),
    F.when(F.col('segmentsAirlineCode').contains('HA'), 1).otherwise(0).cast('int').alias('HA'),
    F.when(F.col('segmentsAirlineCode').contains('9X'), 1).otherwise(0).cast('int').alias('9X'),
    F.when(F.col('segmentsAirlineCode').contains('AS'), 1).otherwise(0).cast('int').alias('AS'),
    F.when(F.col('segmentsAirlineCode').contains('KG'), 1).otherwise(0).cast('int').alias('KG'),
    F.when(F.col('segmentsAirlineCode').contains('SY'), 1).otherwise(0).cast('int').alias('SY'),
    F.array_contains(F.split(F.col('segmentsCabinCode'), '\\|'), 'business').cast('int').alias('hasBusinessClass'),
    F.array_contains(F.split(F.col('segmentsCabinCode'), '\\|'), 'first').cast('int').alias('hasFirstClass'),
)
cleaned_flight_df.show()

+--------------------+---------------+------------------+--------------+-----------+--------------+------------+--------------+----------+-------------------+------------------+------------+-------------+---------------+---+---+---+---+---+---+---+---+---+---+---+---+---+---+----------------+-------------+
|               legid|startingAirport|destinationAirport|travelDuration|elapsedDays|isBasicEconomy|isRefundable|seatsRemaining|isFestival|totalTravelDistance|daysUntilDeparture|dayOfTheWeek|numOfSegments|layoverDuration| UA| NK| AA| 4B| LF| B6| DL| 9K| F9| HA| 9X| AS| KG| SY|hasBusinessClass|hasFirstClass|
+--------------------+---------------+------------------+--------------+-----------+--------------+------------+--------------+----------+-------------------+------------------+------------+-------------+---------------+---+---+---+---+---+---+---+---+---+---+---+---+---+---+----------------+-------------+
|620cd51404373b9bd...|            LAX|               ATL|           4.5|    

# 3. Data Cleaning 

### Drop Duplicates

In [6]:
# Remove rows with null values

# Drop duplicates
flight_df = flight_df.dropDuplicates()

# Final Count
flight_df.count()

print(f"Number of rows after dropping duplicates: {flight_df.count()}")
print(f"Number of columns: {len(flight_df.columns)}")

Number of rows after dropping duplicates: 2000000
Number of columns: 30


In [None]:
# Build expressions to count missing values (null and NaN for numeric columns; null only otherwise)
missing_exprs = []
for field in flight_df.schema.fields:
    if isinstance(field.dataType, NumericType):
        # For numeric columns, count both nulls and NaNs
        expr = sum(
            (col(field.name).isNull() | isnan(col(field.name))).cast("int")
        ).alias(field.name)
    else:
        # For non-numeric columns, count only nulls
        expr = sum(col(field.name).isNull().cast("int")).alias(field.name)
    missing_exprs.append(expr)

# Compute missing values count per column
missing_counts = flight_df.select(*missing_exprs)
missing_counts.show()

+-----+----------+----------+---------------+------------------+-------------+--------------+-----------+--------------+------------+---------+--------+---------+--------------+-------------------+---------------------------------+------------------------+-------------------------------+----------------------+--------------------------+----------------------------+-------------------+-------------------+----------------------------+-------------------------+----------------+-----------------+-----------+-----------+--------------+
|legId|flightDate|searchDate|startingAirport|destinationAirport|fareBasisCode|travelDuration|elapsedDays|isBasicEconomy|isRefundable|isNonStop|baseFare|totalFare|seatsRemaining|totalTravelDistance|segmentsDepartureTimeEpochSeconds|segmentsDepartureTimeRaw|segmentsArrivalTimeEpochSeconds|segmentsArrivalTimeRaw|segmentsArrivalAirportCode|segmentsDepartureAirportCode|segmentsAirlineName|segmentsAirlineCode|segmentsEquipmentDescription|segmentsDurationInSeconds|s

### need to populate the missing totalTravelDistance

for segmentsEquipmentDescription, ignore cos not important

In [None]:
# populate missing values 

# 4. Data Exploration

### Distribution of Categorical Columns

In [None]:
categorical_cols = ['startingAirport', 'destinationAirport', 'isBasicEconomy', 'isRefundable', 'isNonStop']

def explore_categorical(df, column):
    category_counts = df.groupBy(column).agg(F.count("*").alias("count")).orderBy(F.col("count").desc())
    category_counts.show(5, truncate=False) # Show top 5 categories

# Explore frequency counts for categorical columns
for col_name in categorical_cols:
    explore_categorical(flight_df, col_name) 

# Fare Distribution by Airline
if 'segmentsAirlineCode' in flight_df.columns:
    # Extract the first airline code for each flight
    flight_df = flight_df.withColumn('primaryAirline', F.split(F.col('segmentsAirlineCode'), r'\|\|').getItem(0)) 
    
    # Calculate fare statistics by airline
    airline_stats = flight_df.groupBy("primaryAirline").agg(
        F.count("*").alias("count"),
        F.mean("baseFare").alias("mean_fare"),
        F.stddev("baseFare").alias("stddev_fare"),
        F.min("baseFare").alias("min_fare"),
        F.max("baseFare").alias("max_fare"),
        F.percentile_approx("baseFare", [0.25, 0.5, 0.75], 10000).alias("percentiles")
    ).orderBy(F.col("count").desc())
    

In [None]:
# Filter out rows with inconsistent 'numSegments' for the same 'airportPair'
numerical_cols = [col for col, dtype in flight_df.dtypes if dtype in ("int", "double")]

for col_name in numerical_cols:
    stats = flight_df.select(
        F.mean(F.col(col_name)).alias("mean"),
        F.stddev(F.col(col_name)).alias("stddev"),
        F.min(F.col(col_name)).alias("min"),
        F.max(F.col(col_name)).alias("max"),
        F.percentile_approx(F.col(col_name), [0.25, 0.5, 0.75], 10000).alias(
            "percentiles"
        ),
    ).collect()[0]

    mean_val = stats["mean"]
    stddev_val = stats["stddev"]
    min_val = stats["min"]
    max_val = stats["max"]
    percentiles = stats["percentiles"]

    bin_width = (max_val - min_val) / 10
    bins = [min_val + i * bin_width for i in range(11)]

    for i in range(10):
        lower = bins[i]
        upper = bins[i + 1]
        count = flight_df.filter(
            (F.col(col_name) >= lower) & (F.col(col_name) < upper)
        ).count()


# Function to collect data for visualization
def collect_data_for_viz(df, columns, limit=10000):
    return df.select(columns).limit(limit).collect()


# Create histogram visualization for totalFare using PySpark
# Get min, max, and use PySpark to count frequencies in bins
min_max = flight_df.select(F.min("totalFare"), F.max("totalFare")).collect()[0]
min_fare, max_fare = min_max[0], min_max[1]
num_bins = 30
bin_width = (max_fare - min_fare) / num_bins

# Create bins and count records in each bin
bins = []
counts = []
for i in range(num_bins):
    lower = min_fare + i * bin_width
    upper = min_fare + (i + 1) * bin_width
    count = flight_df.filter(
        (F.col("totalFare") >= lower) & (F.col("totalFare") < upper)
    ).count()
    bins.append((lower + upper) / 2)  # Use bin center for x-axis
    counts.append(count)

# Create histogram with matplotlib
fig, ax = plt.subplots(figsize=(10, 6))
ax.bar(bins, counts, width=bin_width * 0.8)
ax.set_xlabel("Total Fare (USD)")
ax.set_ylabel("Frequency")
ax.set_title("Histogram of Total Fare")
ax.grid(alpha=0.3)

# Display histogram
plt.tight_layout()
plt.show()

In [None]:
# Create fare comparison for direct vs. connecting flights
if 'isNonStop' in flight_df.columns and 'baseFare' in flight_df.columns: 
    
    # Group data by flight type
    flight_types = flight_df.groupBy("isNonStop") \
                         .agg(
                             F.count("*").alias("count"),
                             F.avg("baseFare").alias("avgFare"),
                             F.stddev("baseFare").alias("stddevFare")
                         ) \
                         .collect()
    
    # Extract data
    labels = [f"{'Direct' if row['isNonStop'] else 'Connecting'} Flights" for row in flight_types]
    avg_fares = [row["avgFare"] for row in flight_types]
    std_fares = [row["stddevFare"] if row["stddevFare"] is not None else 0 for row in flight_types]
    
    # Create bar chart with error bars
    fig, ax = plt.subplots(figsize=(10, 6))
    ax.bar(labels, avg_fares, yerr=std_fares, capsize=10, alpha=0.7)
    ax.set_title('Average Fare Comparison: Direct vs. Connecting Flights')
    ax.set_ylabel('Average Base Fare (USD)')
    ax.grid(axis='y', alpha=0.3)
    
    # Add value labels on top of bars
    for i, v in enumerate(avg_fares):
        ax.text(i, v + 5, f"${v:.2f}", ha='center')
    
    plt.tight_layout()
    plt.show() 
    
    # Extract data for boxplot
    seat_labels = [row["seatGroup"] for row in seat_stats] # need to fix this part
    boxplot_data = []
    for row in seat_stats:
        q1, median, q3 = row["percentiles"]
        min_val, max_val = row["minFare"], row["maxFare"]
        boxplot_data.append([min_val, q1, median, q3, max_val])
    
    # Create boxplot
    fig, ax = plt.subplots(figsize=(10, 6))
    ax.boxplot(boxplot_data, labels=seat_labels, vert=True)
    ax.set_title('Base Fare Distribution by Seats Remaining')
    ax.set_xlabel('Seats Remaining')
    ax.set_ylabel('Base Fare (USD)')
    ax.grid(axis='y', alpha=0.3)
    
    plt.tight_layout()
    plt.show()

# 10. Create a heatmap of day-to-flight price correlation
if 'elapsedDays' in flight_df.columns and 'baseFare' in flight_df.columns: 
    
    # Make sure elapsedDays is numeric
    numeric_df = flight_df.filter(F.col("elapsedDays").isNotNull())
    
    # Group by days before flight
    elapsed_stats = numeric_df.groupBy("elapsedDays") \
                           .agg(F.avg("baseFare").alias("avgFare")) \
                           .orderBy("elapsedDays") \
                           .collect()
    
    # Extract data
    elapsed_days = [row["elapsedDays"] for row in elapsed_stats]
    avg_fares = [row["avgFare"] for row in elapsed_stats]
    
    # Create line chart
    fig, ax = plt.subplots(figsize=(12, 6))
    ax.plot(elapsed_days, avg_fares, marker='o', linestyle='-', linewidth=2)
    ax.set_title('Average Fare vs. Days Before Flight')
    ax.set_xlabel('Days Before Flight')
    ax.set_ylabel('Average Base Fare (USD)')
    ax.grid(True, alpha=0.3)
    
    # Add trend line
    if len(elapsed_days) > 1:
        z = np.polyfit(elapsed_days, avg_fares, 1)
        p = np.poly1d(z)
        ax.plot(elapsed_days, p(elapsed_days), "r--", linewidth=2)
        
        # Add trend direction annotation
        trend = "increases" if z[0] > 0 else "decreases"
        ax.text(0.05, 0.95, f"Trend: Price {trend} by ${abs(z[0]):.2f} per day", 
                transform=ax.transAxes, fontsize=12, verticalalignment='top',
                bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))
    
    plt.tight_layout()
    plt.show()

In [None]:
# Select only numeric columns
numeric_cols = [col for col, dtype in flight_df.dtypes if dtype in ('int', 'double')]

# Plot histograms directly from PySpark DataFrame
plt.figure(figsize=(12, 8))

for i, col_name in enumerate(numeric_cols, 1):
    plt.subplot(2, 3, i)  # Creates a grid of subplots

    # Collect data from PySpark DataFrame
    data = flight_df.select(col_name).rdd.flatMap(lambda x: x).filter(lambda x: x is not None).collect()

    # Plot histogram
    plt.hist(data, bins=30, alpha=0.7, edgecolor='black')
    plt.title(f"Distribution of {col_name}")

plt.tight_layout()
plt.show()

In [None]:
# Select only numeric columns and convert to double
numeric_cols = [col for col, dtype in flight_df.dtypes if dtype in ('int', 'double')]

# Convert integer columns to double (required for Correlation.corr)
for col_name in numeric_cols:
    flight_df = flight_df.withColumn(col_name, col(col_name).cast("double"))

# Assemble numeric columns into a feature vector
vector_col = "features"
assembler = VectorAssembler(inputCols=numeric_cols, outputCol=vector_col)
df_vector = assembler.transform(flight_df).select(vector_col)

# Compute correlation matrix
correlation_matrix = Correlation.corr(df_vector, vector_col).head()[0].toArray()

# Plot heatmap directly
plt.figure(figsize=(10, 8))
sns.heatmap(correlation_matrix, annot=True, cmap="coolwarm", fmt=".2f",
            xticklabels=numeric_cols, yticklabels=numeric_cols)

plt.title("Correlation Matrix")
plt.show()

### Feature selection after correlation matrix is done

### Use PCA? 

### Machine Learning Models below

### Linear Regression


reminder to include a summary of the results 

### Polynomial Regression

### Random Forest

### Gradient Boosting (XGBoost)

### Neural Network

### Validation of models using K-Fold, time-based split