# PySpark Tutorial: Data Processing and Analysis

## Overview
This notebook covers fundamental PySpark concepts and operations for distributed data processing. PySpark is the Python API for Apache Spark, a powerful distributed computing framework designed for processing large-scale datasets.

### Key Topics Covered:
1. **SparkSession Initialization** - Entry point for Spark functionality
2. **Data Loading** - Reading Parquet files
3. **DataFrame Inspection** - Schema, columns, and statistics
4. **Data Selection** - Choosing specific columns
5. **Data Sorting** - Ordering data by criteria
6. **Data Filtering** - Subsetting data based on conditions
7. **Data Cleaning** - Handling missing values
8. **Feature Engineering** - Creating new columns from existing data
9. **Column Renaming** - Restructuring DataFrame columns

---
[Download dataset](https://www.nyc.gov/site/tlc/about/tlc-trip-record-data.page)

In [69]:
pip install pyspark



In [70]:
# Initialize SparkSession - the entry point for Spark functionality
from pyspark.sql import SparkSession

# Create or get an existing SparkSession with the name 'SparkApp'
spark = SparkSession.builder.appName("SparkApp").getOrCreate()
spark

In [71]:
# Load data from a Parquet file into a Spark DataFrame
# Parquet is a columnar storage format optimized for analytical queries
spark_df = spark.read.parquet("/content/yellow_tripdata_2025-01.parquet")

In [72]:
# Count the total number of rows in the DataFrame
# Action: This triggers Spark to compute and return the result
spark_df.count()

3475226

In [73]:
# Display the first 20 rows of the DataFrame
# This is useful for a quick visual inspection of the data
spark_df.show()

+--------+--------------------+---------------------+---------------+-------------+----------+------------------+------------+------------+------------+-----------+-----+-------+----------+------------+---------------------+------------+--------------------+-----------+------------------+
|VendorID|tpep_pickup_datetime|tpep_dropoff_datetime|passenger_count|trip_distance|RatecodeID|store_and_fwd_flag|PULocationID|DOLocationID|payment_type|fare_amount|extra|mta_tax|tip_amount|tolls_amount|improvement_surcharge|total_amount|congestion_surcharge|Airport_fee|cbd_congestion_fee|
+--------+--------------------+---------------------+---------------+-------------+----------+------------------+------------+------------+------------+-----------+-----+-------+----------+------------+---------------------+------------+--------------------+-----------+------------------+
|       1| 2025-01-01 00:18:38|  2025-01-01 00:26:59|              1|          1.6|         1|                 N|         229|    

## 1. SparkSession Initialization

**What is SparkSession?**
- The entry point for all Spark functionality in PySpark
- Represents the connection to a Spark cluster
- Enables you to read data, run SQL queries, and perform distributed computations

**Key Properties:**
- `appName`: Name of your Spark application (useful for tracking in cluster UIs)
- Lazily creates a session on first use
- Can be reused throughout your script

**Why it's important:**
All data processing operations in Spark depend on an active SparkSession instance.

## 2. Reading Data with Parquet Files

**What is Parquet?**
- A columnar storage format developed by Apache
- Optimized for analytical queries and data warehousing
- Stores data in a compressed format, reducing storage space
- Faster than row-based formats like CSV for analytical operations

**Why use Parquet?**
- Excellent compression (reduces data size by 70-90%)
- Efficient column-oriented queries (only reads needed columns)
- Preserves data types and schema information
- Ideal for big data analytics

**The `spark.read.parquet()` method:**
- Reads Parquet files into a Spark DataFrame
- Automatically infers the schema from the file
- Returns a distributed DataFrame object

## 3. DataFrame Row Counting

**The `count()` Method:**
- Returns the total number of rows in the DataFrame
- This is an **Action** in Spark - it triggers actual computation
- Useful for understanding dataset size and volume

**Actions vs Transformations:**
- **Actions**: Execute computations and return results (count, show, collect, write)
- **Transformations**: Create new DataFrames from existing ones (select, filter, sort) - lazy evaluation

**Performance Note:**
- For very large datasets, counting can take significant time
- Spark must read and process all partitions to get an accurate count

## 4. Viewing DataFrame Data

**The `show()` Method:**
- Displays the first 20 rows of the DataFrame (by default)
- Returns data in a tabular format for easy reading
- Useful for quick data inspection and verification
- Only evaluates what's needed (truncates long strings)

**Parameters:**
- `n`: Number of rows to display (default: 20)
- `truncate`: Max column width before truncating (default: True)
- `vertical`: Display rows vertically instead of horizontally

## 5. Inspecting DataFrame Columns

**The `columns` Property:**
- Returns a Python list of all column names in the DataFrame
- Useful for understanding what fields are available
- Helps with programmatic column access and validation
- Returns column names in the order they appear in the schema

## 6. Understanding DataFrame Schema

**The `printSchema()` Method:**
- Displays the schema in a tree-like format
- Shows column names, data types, and nullable properties
- Essential for understanding data structure and types
- Helps identify potential data type mismatches

**What's in the Schema:**
- Column name
- Data type (String, Integer, Double, Boolean, etc.)
- Nullable flag (can column contain NULL values?)
- Nested structures for complex data types

## 7. Accessing Schema as an Object

**The `schema` Property:**
- Returns a `StructType` object representing the DataFrame schema
- Allows programmatic access to schema details
- Useful when you need to manipulate or inspect schema in code
- Can be serialized/deserialized for schema management

**Schema Inspection Benefits:**
- Extract field information programmatically
- Validate data types before processing
- Generate schema-aware code dynamically

## 8. Statistical Summary of Data

**The `describe()` Method:**
- Computes descriptive statistics for numerical columns
- Returns count, mean, standard deviation, min, and max values
- Called with `.show()` to display results

**Statistics Provided:**
- **count**: Non-null values in each column
- **mean**: Average value
- **stddev**: Standard deviation (data spread)
- **min**: Minimum value
- **max**: Maximum value

**Use Cases:**
- Quick data quality checks
- Identify outliers and data ranges
- Understand distribution of numerical data
- Detect potential data entry errors

In [74]:
# Get all column names as a list
spark_df.columns

['VendorID',
 'tpep_pickup_datetime',
 'tpep_dropoff_datetime',
 'passenger_count',
 'trip_distance',
 'RatecodeID',
 'store_and_fwd_flag',
 'PULocationID',
 'DOLocationID',
 'payment_type',
 'fare_amount',
 'extra',
 'mta_tax',
 'tip_amount',
 'tolls_amount',
 'improvement_surcharge',
 'total_amount',
 'congestion_surcharge',
 'Airport_fee',
 'cbd_congestion_fee']

In [75]:
# Print the schema of the DataFrame in a tree format
# Shows all column names, data types, and nullable properties
spark_df.printSchema()

root
 |-- VendorID: integer (nullable = true)
 |-- tpep_pickup_datetime: timestamp_ntz (nullable = true)
 |-- tpep_dropoff_datetime: timestamp_ntz (nullable = true)
 |-- passenger_count: long (nullable = true)
 |-- trip_distance: double (nullable = true)
 |-- RatecodeID: long (nullable = true)
 |-- store_and_fwd_flag: string (nullable = true)
 |-- PULocationID: integer (nullable = true)
 |-- DOLocationID: integer (nullable = true)
 |-- payment_type: long (nullable = true)
 |-- fare_amount: double (nullable = true)
 |-- extra: double (nullable = true)
 |-- mta_tax: double (nullable = true)
 |-- tip_amount: double (nullable = true)
 |-- tolls_amount: double (nullable = true)
 |-- improvement_surcharge: double (nullable = true)
 |-- total_amount: double (nullable = true)
 |-- congestion_surcharge: double (nullable = true)
 |-- Airport_fee: double (nullable = true)
 |-- cbd_congestion_fee: double (nullable = true)



In [76]:
# Get the schema as a StructType object for programmatic access
spark_df.schema

StructType([StructField('VendorID', IntegerType(), True), StructField('tpep_pickup_datetime', TimestampNTZType(), True), StructField('tpep_dropoff_datetime', TimestampNTZType(), True), StructField('passenger_count', LongType(), True), StructField('trip_distance', DoubleType(), True), StructField('RatecodeID', LongType(), True), StructField('store_and_fwd_flag', StringType(), True), StructField('PULocationID', IntegerType(), True), StructField('DOLocationID', IntegerType(), True), StructField('payment_type', LongType(), True), StructField('fare_amount', DoubleType(), True), StructField('extra', DoubleType(), True), StructField('mta_tax', DoubleType(), True), StructField('tip_amount', DoubleType(), True), StructField('tolls_amount', DoubleType(), True), StructField('improvement_surcharge', DoubleType(), True), StructField('total_amount', DoubleType(), True), StructField('congestion_surcharge', DoubleType(), True), StructField('Airport_fee', DoubleType(), True), StructField('cbd_congestio

In [77]:
# Display descriptive statistics (count, mean, stddev, min, max) for numerical columns
spark_df.describe().show()

+-------+-------------------+------------------+------------------+------------------+------------------+-----------------+------------------+------------------+------------------+------------------+-------------------+------------------+------------------+---------------------+------------------+--------------------+-------------------+-------------------+
|summary|           VendorID|   passenger_count|     trip_distance|        RatecodeID|store_and_fwd_flag|     PULocationID|      DOLocationID|      payment_type|       fare_amount|             extra|            mta_tax|        tip_amount|      tolls_amount|improvement_surcharge|      total_amount|congestion_surcharge|        Airport_fee| cbd_congestion_fee|
+-------+-------------------+------------------+------------------+------------------+------------------+-----------------+------------------+------------------+------------------+------------------+-------------------+------------------+------------------+---------------------+-

---

## Data Transformation Operations

### 9. Column Selection (Projection)

**The `select()` Method:**
- Selects specific columns from a DataFrame
- Returns a new DataFrame with only chosen columns
- Reduces memory footprint by removing unnecessary data
- Can be used with column names or Column objects

**Syntax:**
- `df.select(['col1', 'col2', ...])` - using list of strings
- `df.select('col1', 'col2')` - using variable arguments
- `df.select(col('col1'), col('col2'))` - using Column objects

**Benefits:**
- Improves query performance (only processes needed columns)
- Simplifies data by removing irrelevant fields
- Foundation for feature engineering

In [78]:
# Select only specific columns: fare_amount and passenger_count
# This creates a new DataFrame with just these two columns
spark_df.select(['fare_amount','passenger_count']).show()

+-----------+---------------+
|fare_amount|passenger_count|
+-----------+---------------+
|       10.0|              1|
|        5.1|              1|
|        5.1|              1|
|        7.2|              3|
|        5.8|              3|
|       19.1|              2|
|        4.4|              0|
|       12.1|              0|
|       19.1|              0|
|       11.4|              1|
|       11.4|              1|
|        5.8|              1|
|       14.2|              3|
|        7.9|              1|
|       26.1|              1|
|       17.7|              3|
|       16.3|              1|
|       -7.2|              1|
|        7.2|              1|
|       15.6|              2|
+-----------+---------------+
only showing top 20 rows


In [79]:
# Sort the DataFrame by a single column in ascending order (default)
# The sort() method is a transformation that creates a new sorted DataFrame
spark_df.sort('fare_amount').show()

+--------+--------------------+---------------------+---------------+-------------+----------+------------------+------------+------------+------------+-----------+-----+-------+----------+------------+---------------------+------------+--------------------+-----------+------------------+
|VendorID|tpep_pickup_datetime|tpep_dropoff_datetime|passenger_count|trip_distance|RatecodeID|store_and_fwd_flag|PULocationID|DOLocationID|payment_type|fare_amount|extra|mta_tax|tip_amount|tolls_amount|improvement_surcharge|total_amount|congestion_surcharge|Airport_fee|cbd_congestion_fee|
+--------+--------------------+---------------------+---------------+-------------+----------+------------------+------------+------------+------------+-----------+-----+-------+----------+------------+---------------------+------------+--------------------+-----------+------------------+
|       2| 2025-01-07 19:12:25|  2025-01-07 19:14:04|              1|          0.1|         5|                 N|         226|    

---

### 10. Sorting Data

**The `sort()` Method:**
- Arranges rows based on one or more column values
- Returns a new sorted DataFrame
- Default order is ascending (smallest to largest)
- Can sort by multiple columns with different orders

**Parameters:**
- Column name or list of column names
- `ascending`: Boolean or list of booleans for sort order
- Multiple sort criteria applied in order

**Use Cases:**
- Ranking data by values
- Finding top N or bottom N records
- Organizing data for presentation
- Preparing data for time-series analysis

In [80]:
# Sort by multiple columns in descending order
# Sorts first by fare_amount (descending), then by passenger_count (descending)
spark_df.sort(['fare_amount',"passenger_count"], ascending = [False,False]).show()

+--------+--------------------+---------------------+---------------+-------------+----------+------------------+------------+------------+------------+-----------+-----+-------+----------+------------+---------------------+------------+--------------------+-----------+------------------+
|VendorID|tpep_pickup_datetime|tpep_dropoff_datetime|passenger_count|trip_distance|RatecodeID|store_and_fwd_flag|PULocationID|DOLocationID|payment_type|fare_amount|extra|mta_tax|tip_amount|tolls_amount|improvement_surcharge|total_amount|congestion_surcharge|Airport_fee|cbd_congestion_fee|
+--------+--------------------+---------------------+---------------+-------------+----------+------------------+------------+------------+------------+-----------+-----+-------+----------+------------+---------------------+------------+--------------------+-----------+------------------+
|       1| 2025-01-20 12:07:18|  2025-01-20 12:12:42|              1|          1.6|         1|                 N|         138|    

---

### 11. Filtering Data (Selection)

**The `filter()` Method:**
- Selects rows that satisfy a condition
- Returns a new DataFrame with only matching rows
- Critical for data cleaning and subsetting
- Conditions can use SQL syntax or PySpark expressions

**Filter Syntax Options:**
1. **SQL String**: `df.filter('column > 100')`
2. **PySpark Expression**: `df.filter(df['column'] > 100)`
3. **Column Object**: `df.filter(col('column') > 100)`

**Combining Conditions:**
- AND operator: `&` (not `and`)
- OR operator: `|` (not `or`)
- NOT operator: `~` (not `not`)
- Must wrap conditions in parentheses when combining

**Performance Considerations:**
- Filter early to reduce data size
- More selective filters improve performance
- Spark optimizes filter operations automatically

In [81]:
# Filter for rows where Airport_fee is greater than 0
# Uses SQL string syntax for the filter condition
spark_df.filter('Airport_fee >0').show()

+--------+--------------------+---------------------+---------------+-------------+----------+------------------+------------+------------+------------+-----------+-----+-------+----------+------------+---------------------+------------+--------------------+-----------+------------------+
|VendorID|tpep_pickup_datetime|tpep_dropoff_datetime|passenger_count|trip_distance|RatecodeID|store_and_fwd_flag|PULocationID|DOLocationID|payment_type|fare_amount|extra|mta_tax|tip_amount|tolls_amount|improvement_surcharge|total_amount|congestion_surcharge|Airport_fee|cbd_congestion_fee|
+--------+--------------------+---------------------+---------------+-------------+----------+------------------+------------+------------+------------+-----------+-----+-------+----------+------------+---------------------+------------+--------------------+-----------+------------------+
|       1| 2025-01-01 00:51:41|  2025-01-01 01:06:26|              1|          7.2|         1|                 N|         132|    

In [82]:
# Filter for rides where pickup time is after a specific datetime
# Uses PySpark expression syntax with column reference
spark_df.filter(spark_df['tpep_pickup_datetime']>'2025-01-01 00:11:59').show()

+--------+--------------------+---------------------+---------------+-------------+----------+------------------+------------+------------+------------+-----------+-----+-------+----------+------------+---------------------+------------+--------------------+-----------+------------------+
|VendorID|tpep_pickup_datetime|tpep_dropoff_datetime|passenger_count|trip_distance|RatecodeID|store_and_fwd_flag|PULocationID|DOLocationID|payment_type|fare_amount|extra|mta_tax|tip_amount|tolls_amount|improvement_surcharge|total_amount|congestion_surcharge|Airport_fee|cbd_congestion_fee|
+--------+--------------------+---------------------+---------------+-------------+----------+------------------+------------+------------+------------+-----------+-----+-------+----------+------------+---------------------+------------+--------------------+-----------+------------------+
|       1| 2025-01-01 00:18:38|  2025-01-01 00:26:59|              1|          1.6|         1|                 N|         229|    

In [83]:
# Filter for rows that meet BOTH conditions using AND operator (&)
# Airport_fee must be greater than 0 AND pickup time must be after specified time
# Note: Must use & (bitwise AND) not 'and', and wrap conditions in parentheses
spark_df.filter((spark_df['Airport_fee']>0) & (spark_df['tpep_pickup_datetime']> '2025-01-01 00:11:59')).show()

+--------+--------------------+---------------------+---------------+-------------+----------+------------------+------------+------------+------------+-----------+-----+-------+----------+------------+---------------------+------------+--------------------+-----------+------------------+
|VendorID|tpep_pickup_datetime|tpep_dropoff_datetime|passenger_count|trip_distance|RatecodeID|store_and_fwd_flag|PULocationID|DOLocationID|payment_type|fare_amount|extra|mta_tax|tip_amount|tolls_amount|improvement_surcharge|total_amount|congestion_surcharge|Airport_fee|cbd_congestion_fee|
+--------+--------------------+---------------------+---------------+-------------+----------+------------------+------------+------------+------------+-----------+-----+-------+----------+------------+---------------------+------------+--------------------+-----------+------------------+
|       1| 2025-01-01 00:51:41|  2025-01-01 01:06:26|              1|          7.2|         1|                 N|         132|    

In [84]:
# Combine multiple operations: filter, select, and display
# Step 1: Filter for rides with more than 2 passengers
# Step 2: Select only VendorId, passenger_count, and total_amount columns
# Step 3: Display the results
spark_df.filter(spark_df['passenger_count']>2)\
.select(['VendorId','passenger_count','total_amount'])\
.show()

+--------+---------------+------------+
|VendorId|passenger_count|total_amount|
+--------+---------------+------------+
|       2|              3|         9.7|
|       2|              3|         8.3|
|       2|              3|        19.2|
|       1|              3|        22.7|
|       1|              3|       15.45|
|       2|              3|       30.13|
|       2|              4|       50.76|
|       2|              4|       29.76|
|       2|              4|       16.13|
|       1|              4|       37.95|
|       2|              9|      111.32|
|       2|              4|        23.4|
|       1|              4|        12.2|
|       2|              4|       36.48|
|       1|              3|       12.29|
|       2|              4|       28.08|
|       1|              4|        16.4|
|       2|              3|       56.38|
|       2|              3|        34.8|
|       1|              4|       17.15|
+--------+---------------+------------+
only showing top 20 rows


In [85]:
# Challenge: Write a query combining sort, select, and filter
# Find non-airport rides with exactly 1 passenger, sort and show relevant columns
# Step 1: Filter for Airport_fee == 0 (non-airport) AND passenger_count == 1
# Step 2: Sort by Airport_fee
# Step 3: Select only passenger_count and Airport_fee columns
spark_df.filter((spark_df['Airport_fee']==0) & (spark_df['passenger_count']==1))\
.sort('Airport_fee')\
.select('passenger_count','Airport_fee')\
.show()

+---------------+-----------+
|passenger_count|Airport_fee|
+---------------+-----------+
|              1|        0.0|
|              1|        0.0|
|              1|        0.0|
|              1|        0.0|
|              1|        0.0|
|              1|        0.0|
|              1|        0.0|
|              1|        0.0|
|              1|        0.0|
|              1|        0.0|
|              1|        0.0|
|              1|        0.0|
|              1|        0.0|
|              1|        0.0|
|              1|        0.0|
|              1|        0.0|
|              1|        0.0|
|              1|        0.0|
|              1|        0.0|
|              1|        0.0|
+---------------+-----------+
only showing top 20 rows


In [86]:
# Challenge: Select only trip distance and total_amount columns
spark_df.select(['trip_distance', 'total_amount']).show()

+-------------+------------+
|trip_distance|total_amount|
+-------------+------------+
|          1.6|        18.0|
|          0.5|       12.12|
|          0.6|        12.1|
|         0.52|         9.7|
|         0.66|         8.3|
|         2.63|        24.1|
|          0.4|       11.75|
|          1.6|        19.1|
|          2.8|        27.1|
|         1.71|        16.4|
|         2.29|        16.4|
|         0.56|       12.96|
|         1.99|        19.2|
|          1.1|        12.9|
|          3.2|        38.9|
|          2.5|        22.7|
|          1.9|       25.55|
|         0.71|       -8.54|
|         0.71|        12.2|
|          1.2|        20.6|
+-------------+------------+
only showing top 20 rows


In [87]:
# Challenge: Sort the resulting dataframe by trip distance in descending order
# Shows longest trips first
spark_df.select(['trip_distance', 'total_amount'])\
.sort("trip_distance", ascending=False).show()

+-------------+------------+
|trip_distance|total_amount|
+-------------+------------+
|    276423.57|         5.0|
|    276099.95|       13.88|
|    222167.49|       35.94|
|    206137.99|       29.64|
|    202771.63|       14.85|
|    189687.43|       17.45|
|    181139.99|       11.08|
|    168079.57|        4.49|
|    167452.94|        5.45|
|    164959.95|       18.05|
|    158925.09|       14.82|
|    156037.94|       33.49|
|    143712.27|       12.51|
|    135116.83|       36.52|
|    134033.15|       22.89|
|    124083.23|        15.0|
|    121799.97|       31.88|
|    121555.16|        8.46|
|    118927.12|        3.68|
|    118435.89|       19.59|
+-------------+------------+
only showing top 20 rows


In [88]:
# Combined challenge problem: Filter, Select, Sort
# Find all non-airport solo rides and show them sorted by distance (longest first)
spark_df.filter((spark_df['Airport_fee']==0) & (spark_df['passenger_count']==1))\
.select('trip_distance', 'total_amount')\
.sort('trip_distance', ascending=False)\
.show()

+-------------+------------+
|trip_distance|total_amount|
+-------------+------------+
|      44730.3|        45.0|
|      44684.1|       54.94|
|      33588.9|        30.0|
|      11187.2|       61.94|
|      2001.95|       19.35|
|      1847.61|       43.37|
|      1472.37|       20.58|
|        265.9|      139.14|
|       255.33|     2506.71|
|       206.45|      243.32|
|        199.3|        19.0|
|       188.88|      1311.7|
|        181.9|       39.94|
|       150.11|      501.75|
|        148.3|      549.91|
|       122.77|        64.7|
|       119.66|      794.82|
|       114.25|       396.0|
|       105.24|       361.2|
|       104.21|      249.53|
+-------------+------------+
only showing top 20 rows


In [89]:
# Import necessary functions for missing value detection
from pyspark.sql.functions import col, isnull

---

## Data Cleaning and Missing Values

### 12. Detecting and Handling Missing Values

**Missing Data in DataFrames:**
- Missing values are represented as `NULL` in Spark
- Can skew analysis and cause errors in computations
- Must be identified and handled appropriately

**The `isnull()` Function:**
- Returns True where values are NULL
- Often combined with `filter()` to count missing values
- Can be used to create masks for data cleaning

**Common Missing Value Strategies:**
1. **Removal**: Delete rows with NULL values (loses data)
2. **Imputation**: Fill with default, mean, median, or forward-filled values
3. **Domain-specific**: Use business rules to fill values
4. **Keep separate**: Mark and analyze NULL values separately

**The `fillna()` Method:**
- Replaces NULL values with specified values
- Can use a dictionary to fill different columns differently
- Returns a new DataFrame with filled values
- Useful for imputation strategies

In [90]:
# Count the number of NULL/missing values in the 'fare_amount' column
# isnull() returns True for NULL values, count() counts them
spark_df.filter(isnull(col('fare_amount'))).count()

0

In [91]:
# Count the number of NULL/missing values in the 'passenger_count' column
spark_df.filter(isnull(col('passenger_count'))).count()

540149

In [92]:
# Fill missing values in 'passenger_count' column with default value of 1
# This assumes if passenger count is missing, there was 1 passenger
# Returns a new DataFrame with filled values
df = spark_df.fillna({'passenger_count':1})

In [93]:
# Verify that missing values in 'passenger_count' have been filled
# After fillna(), the count should be 0
df.filter(isnull(col('passenger_count'))).count()

0

In [94]:
# Import functions needed for feature engineering
# unix_timestamp(): Convert datetime to seconds since epoch
# round(): Round numerical values to specified decimal places
from pyspark.sql.functions import unix_timestamp, round

---

### 13. Feature Engineering - Creating New Columns

**What is Feature Engineering?**
- Process of creating new features (columns) from existing data
- Transforms raw data into meaningful features for analysis
- Critical step in data preprocessing and machine learning

**The `withColumn()` Method:**
- Adds a new column to a DataFrame or modifies existing ones
- Returns a new DataFrame with the added/modified column
- Accepts column name and expression
- Can chain multiple `withColumn()` calls

**Common PySpark Functions:**
- `unix_timestamp()`: Converts datetime to Unix timestamp (seconds since 1970)
- `round()`: Rounds numerical values to specified decimal places
- String functions: concat, substring, length, upper, lower
- Math functions: abs, sqrt, pow, etc.
- Date functions: year, month, day, hour, minute

**Example: Trip Duration Calculation**
- Extract pickup and dropoff times
- Convert to Unix timestamps
- Calculate difference in seconds
- Convert to minutes and round

In [95]:
# Feature Engineering: Create a new column 'trip_duration' in minutes
# Step 1: Convert dropoff time to Unix timestamp (seconds)
# Step 2: Convert pickup time to Unix timestamp (seconds)
# Step 3: Calculate difference in seconds
# Step 4: Divide by 60 to convert to minutes
# Step 5: Round to 1 decimal place
# withColumn() creates a new DataFrame with the added column
df1 = df.withColumn('trip_duration', \
      round((unix_timestamp('tpep_dropoff_datetime') - unix_timestamp('tpep_pickup_datetime')) / 60, 1))
df1.select('trip_duration').show()

+-------------+
|trip_duration|
+-------------+
|          8.4|
|          2.6|
|          2.0|
|          5.6|
|          3.5|
|         20.0|
|          1.5|
|         12.4|
|         19.7|
|          9.6|
|          7.6|
|          3.4|
|         13.0|
|          6.7|
|         34.1|
|         18.3|
|         16.9|
|          5.6|
|          5.6|
|         17.0|
+-------------+
only showing top 20 rows


In [96]:
# Data Quality Check: Count negative trip durations
# Negative values indicate data issues (dropoff before pickup)
# This helps identify potential data quality problems
df1.filter(df1['trip_duration'] < 0).count()

117

### Data Quality Analysis - Negative Trip Duration Check

**What does this check tell us?**

If the count of negative `trip_duration` values is **zero**:
- ✓ All trip durations are positive or zero
- ✓ Indicates the calculation is accurate in terms of time order
- ✓ No data quality issues detected
- ✓ Pickup times are correctly recorded before dropoff times

If there are **negative values**:
- ✗ Implies data quality issues
- ✗ Drop-off time recorded before pick-up time (impossible for valid trips)
- ✗ Might indicate:
  - Incorrect timestamp recording
  - System clock issues during data capture
  - Data entry errors
  - Records that need investigation and possible removal

**How to Handle Negative Durations:**
1. Filter them out for analysis (treat as invalid records)
2. Investigate the source data for systematic issues
3. Apply business logic (e.g., if duration < some threshold, mark as suspicious)
4. Consider removing outliers and anomalies

In [97]:
# Rename columns for clarity and consistency
# Step 1: Select specific columns
# Step 2: Rename them using withColumnsRenamed()
#   - 'tpep_pickup_datetime' → 'pu_datetime' (shorter, clearer)
#   - 'tpep_dropoff_datetime' → 'do_datetime' (shorter, clearer)
#   - 'fare_amount' → 'ride-amount' (alternative naming)
# This creates a new DataFrame with renamed columns
df2 = df1.select('tpep_pickup_datetime','tpep_dropoff_datetime','fare_amount')\
.withColumnsRenamed({'tpep_pickup_datetime':'pu_datetime','tpep_dropoff_datetime':'do_datetime', 'fare_amount':'ride-amount'})

df2.show()

+-------------------+-------------------+-----------+
|        pu_datetime|        do_datetime|ride-amount|
+-------------------+-------------------+-----------+
|2025-01-01 00:18:38|2025-01-01 00:26:59|       10.0|
|2025-01-01 00:32:40|2025-01-01 00:35:13|        5.1|
|2025-01-01 00:44:04|2025-01-01 00:46:01|        5.1|
|2025-01-01 00:14:27|2025-01-01 00:20:01|        7.2|
|2025-01-01 00:21:34|2025-01-01 00:25:06|        5.8|
|2025-01-01 00:48:24|2025-01-01 01:08:26|       19.1|
|2025-01-01 00:14:47|2025-01-01 00:16:15|        4.4|
|2025-01-01 00:39:27|2025-01-01 00:51:51|       12.1|
|2025-01-01 00:53:43|2025-01-01 01:13:23|       19.1|
|2025-01-01 00:00:02|2025-01-01 00:09:36|       11.4|
|2025-01-01 00:20:28|2025-01-01 00:28:04|       11.4|
|2025-01-01 00:33:58|2025-01-01 00:37:23|        5.8|
|2025-01-01 00:42:40|2025-01-01 00:55:38|       14.2|
|2025-01-01 00:30:07|2025-01-01 00:36:48|        7.9|
|2025-01-01 00:39:55|2025-01-01 01:13:59|       26.1|
|2025-01-01 00:16:54|2025-01

---

### 14. Renaming Columns

**The `withColumnsRenamed()` Method:**
- Renames one or more columns in a DataFrame
- Takes a dictionary mapping old names to new names
- Useful for:
  - Standardizing column naming conventions
  - Simplifying long or unclear column names
  - Preparing data for downstream analysis
  - Making data more readable and accessible
- Returns a new DataFrame with renamed columns

**Use Cases:**
- Shortening verbose column names
- Converting naming conventions (snake_case to camelCase, etc.)
- Making column names more domain-friendly
- Standardizing across multiple data sources

In [98]:
# Drop unnecessary columns to clean up the DataFrame
# Remove VendorID and RateCodeID columns as they're not needed for analysis
df1.drop('VendorID', 'RateCodeID').show()

+--------------------+---------------------+---------------+-------------+------------------+------------+------------+------------+-----------+-----+-------+----------+------------+---------------------+------------+--------------------+-----------+------------------+-------------+
|tpep_pickup_datetime|tpep_dropoff_datetime|passenger_count|trip_distance|store_and_fwd_flag|PULocationID|DOLocationID|payment_type|fare_amount|extra|mta_tax|tip_amount|tolls_amount|improvement_surcharge|total_amount|congestion_surcharge|Airport_fee|cbd_congestion_fee|trip_duration|
+--------------------+---------------------+---------------+-------------+------------------+------------+------------+------------+-----------+-----+-------+----------+------------+---------------------+------------+--------------------+-----------+------------------+-------------+
| 2025-01-01 00:18:38|  2025-01-01 00:26:59|              1|          1.6|                 N|         229|         237|           1|       10.0|  3.

---

## 15. Dropping Unnecessary Columns

**The `drop()` Method:**
- Removes one or more columns from a DataFrame
- Useful for cleaning up unwanted columns after joins or transformations
- Takes column name(s) as arguments
- Returns a new DataFrame without the specified columns

**Use Cases:**
- Remove duplicate columns after joins
- Clean up intermediate columns not needed for analysis
- Reduce data size by removing non-essential fields
- Prepare data for output or downstream processing

---

## 16. Combining DataFrames: Unions & Joins

### Union - Combining Rows

**The `union()` Method:**
- Combines two DataFrames with the same schema by stacking rows
- Creates a single DataFrame from multiple sources
- **Does NOT remove duplicates** (use `unionByName()` or distinct() if needed)
- Useful for combining data from different time periods or sources

**Syntax:**
```python
combined_df = df1.union(df2)
```

**Key Points:**
- Both DataFrames must have the same columns in the same order
- Row count increases: `len(combined_df) = len(df1) + len(df2)`
- Duplicates are preserved; filter or use `distinct()` if cleanup needed

### Join - Combining Columns

**The `join()` Method:**
- Combines two DataFrames based on a condition (key)
- Similar to SQL joins (INNER, LEFT, RIGHT, OUTER, CROSS)
- Creates a new DataFrame with columns from both tables

**Syntax:**
```python
joined_df = df1.join(df2, df1.key_col == df2.key_col, 'left')
```

**Join Types:**
- **inner**: Only matching rows (default)
- **left**: All from left table + matching from right
- **right**: All from right table + matching from left
- **outer**: All rows from both tables
- **cross**: Cartesian product (all combinations)

**Use Cases:**
- Enriching data with lookup tables
- Combining related datasets
- Merging data from different sources
- Creating complex analytical datasets

In [99]:
# Load February data to combine with January data
# We'll use union to combine datasets from different months
df_feb = spark.read.parquet('/content/yellow_tripdata_2025-02.parquet')

In [100]:
# Combine January and February data using union()
# This stacks rows from both DataFrames vertically
df_2025_combined = df.union(df_feb)

In [101]:
# Count the combined records
# Note: union() preserves duplicates, so row count = jan_count + feb_count
df_2025_combined.count()

7052769

In [102]:
# Load taxi zone lookup data for enrichment
# This will be joined with trip data to add location information
taxi_zone_lookup = spark.read.option('header', 'true').csv('/content/taxi_zone_lookup.csv')

In [103]:
# Preview the taxi zone lookup table structure
# Contains LocationID, Borough, Zone, and service_zone information
taxi_zone_lookup.show()

+----------+-------------+--------------------+------------+
|LocationID|      Borough|                Zone|service_zone|
+----------+-------------+--------------------+------------+
|         1|          EWR|      Newark Airport|         EWR|
|         2|       Queens|         Jamaica Bay|   Boro Zone|
|         3|        Bronx|Allerton/Pelham G...|   Boro Zone|
|         4|    Manhattan|       Alphabet City| Yellow Zone|
|         5|Staten Island|       Arden Heights|   Boro Zone|
|         6|Staten Island|Arrochar/Fort Wad...|   Boro Zone|
|         7|       Queens|             Astoria|   Boro Zone|
|         8|       Queens|        Astoria Park|   Boro Zone|
|         9|       Queens|          Auburndale|   Boro Zone|
|        10|       Queens|        Baisley Park|   Boro Zone|
|        11|     Brooklyn|          Bath Beach|   Boro Zone|
|        12|    Manhattan|        Battery Park| Yellow Zone|
|        13|    Manhattan|   Battery Park City| Yellow Zone|
|        14|     Brookly

In [104]:
# Join combined trip data with zone lookup using LEFT JOIN
# This adds location information to each trip based on pickup location
df_joined = df_2025_combined.join(taxi_zone_lookup,
                                   df_2025_combined.PULocationID == taxi_zone_lookup.LocationID,
                                   'left')
df_joined.show()

+--------+--------------------+---------------------+---------------+-------------+----------+------------------+------------+------------+------------+-----------+-----+-------+----------+------------+---------------------+------------+--------------------+-----------+------------------+----------+---------+--------------------+------------+
|VendorID|tpep_pickup_datetime|tpep_dropoff_datetime|passenger_count|trip_distance|RatecodeID|store_and_fwd_flag|PULocationID|DOLocationID|payment_type|fare_amount|extra|mta_tax|tip_amount|tolls_amount|improvement_surcharge|total_amount|congestion_surcharge|Airport_fee|cbd_congestion_fee|LocationID|  Borough|                Zone|service_zone|
+--------+--------------------+---------------------+---------------+-------------+----------+------------------+------------+------------+------------+-----------+-----+-------+----------+------------+---------------------+------------+--------------------+-----------+------------------+----------+---------+

In [105]:
# Group by payment type and count records in each group
# Then sort by payment_type for consistent output
payment_type_counts = df_2025_combined.groupby('payment_type').count().sort('payment_type')
payment_type_counts.show()

+------------+-------+
|payment_type|  count|
+------------+-------+
|           0|1347086|
|           1|4780568|
|           2| 729910|
|           3|  45622|
|           4| 149582|
|           5|      1|
+------------+-------+



---

## 17. Grouping and Aggregation

### The `groupBy()` and `agg()` Methods

**Grouping Data:**
- `groupBy()` divides data into groups based on column values
- Must be followed by an aggregation function
- Returns aggregated results (not raw grouped data)

**Common Aggregation Functions:**
- `count()` - Number of records per group
- `sum()` - Total value per group
- `avg()` - Average value per group
- `min()` / `max()` - Minimum/Maximum values
- `stddev()` - Standard deviation
- `agg()` - Apply multiple aggregations at once

**Use Cases:**
- Summary statistics by category
- Identifying patterns in grouped data
- Business metrics by dimension (e.g., sales by region)
- Data validation (e.g., record counts per day)

In [106]:
# Calculate average fare amount by payment type
payment_type_avg_fare = df_2025_combined.groupby('payment_type').avg('total_amount')
payment_type_avg_fare.show()

+------------+------------------+
|payment_type| avg(total_amount)|
+------------+------------------+
|           5|               0.0|
|           1| 28.01570129324061|
|           3| 9.396843628074157|
|           2|  21.5920279349519|
|           4| 6.355892152799148|
|           0|20.401269013263903|
+------------+------------------+



In [107]:
# Import aggregation function for custom naming
from pyspark.sql.functions import avg

# Calculate average total_amount and alias the result column
# Using agg() with alias allows for more flexible aggregations
payment_type_avg_fare_named = df_2025_combined.groupby('payment_type').agg(avg('total_amount').alias('avg_fare_amount'))
payment_type_avg_fare_named.show()

+------------+------------------+
|payment_type|   avg_fare_amount|
+------------+------------------+
|           5|               0.0|
|           1| 28.01570129324061|
|           3| 9.396843628074157|
|           2|  21.5920279349519|
|           4| 6.355892152799148|
|           0|20.401269013263903|
+------------+------------------+



In [108]:
# Calculate and persist average fare by payment type
# Aggregate by payment_type, calculate average, and save results
avg_fare_by_payment = df_2025_combined.groupBy('payment_type').agg(avg('total_amount')).sort('payment_type')

# Write aggregated results to CSV for reporting or further analysis
avg_fare_by_payment.write.csv('/content/avg_fare_by_payment', header=True, mode='overwrite')

---

## 18. Challenge Exercise: DataFrame Operations

**Objective:** Combine multiple PySpark concepts to create a comprehensive data pipeline

**Instructions:**

1. **Load Data**: Create two DataFrames (df_jan and df_feb) from January and February taxi data
2. **Union**: Combine them into df_2025_combined using union()
3. **Transform Columns**: Select relevant columns and rename for clarity:
   - `tpep_pickup_datetime` → `pu_datetime`
   - `tpep_dropoff_datetime` → `do_datetime`
   - `PULocationID` → `pu_location_id`
   - `DOLocationID` → `do_location_id`
   - `Airport_fee` → `airport_fee`
   

4. **Join**: Load taxi zone lookup and LEFT JOIN to add borough information
5. **Clean**: Drop duplicate/unnecessary columns (LocationID, zone, service_zone)
6. **Rename**: Rename 'Borough' to 'pu_boro'
7. **Output**: Display the result and save to CSV

**Skills Practiced:**
- Reading multiple data sources
- Union operations
- Column selection and renaming
- Join operations
- Data cleanup
- CSV output

In [109]:
# Challenge Solution: Complete data pipeline
# Step 1: Load January and February data
df_jan = spark.read.parquet('/content/yellow_tripdata_2025-01.parquet')
df_feb = spark.read.parquet('/content/yellow_tripdata_2025-02.parquet')

# Step 2: Combine using union
df_2025_combined = df_jan.union(df_feb)

# Step 3: Select and rename columns for clarity and consistency
df_2025_combined = df_2025_combined\
                      .select('tpep_pickup_datetime','tpep_dropoff_datetime', \
                              'PULocationID', 'DOLocationID', 'passenger_count',\
                              'fare_amount', 'Airport_fee', 'total_amount', 'payment_type','trip_distance','VendorID')\
                      .withColumnsRenamed({'tpep_pickup_datetime':'pu_datetime',\
                              'tpep_dropoff_datetime':'do_datetime',\
                              'PULocationID':'pu_location_id',\
                              'DOLocationID':'do_location_id',\
                              'Airport_fee': 'airport_fee'})

# Step 4: Load zone lookup and join for location enrichment
taxi_zones = spark.read.option('header','true').csv('/content/taxi_zone_lookup.csv')

# Left join to add borough information
df_2025_combined = df_2025_combined.join(taxi_zones,
                                          taxi_zones.LocationID == df_2025_combined.do_location_id,
                                          'left')

# Step 5 & 6: Drop superfluous columns and rename Borough
df_2025_combined = df_2025_combined.drop('LocationID','zone','service_zone')\
                                     .withColumnsRenamed({'Borough':'pu_boro'})

# Step 7: Write results to CSV and display
df_2025_combined.write.csv("/content/2025_combined", header=True, mode='overwrite')
df_2025_combined.show()

+-------------------+-------------------+--------------+--------------+---------------+-----------+-----------+------------+------------+-------------+--------+---------+
|        pu_datetime|        do_datetime|pu_location_id|do_location_id|passenger_count|fare_amount|airport_fee|total_amount|payment_type|trip_distance|VendorID|  pu_boro|
+-------------------+-------------------+--------------+--------------+---------------+-----------+-----------+------------+------------+-------------+--------+---------+
|2025-01-01 00:18:38|2025-01-01 00:26:59|           229|           237|              1|       10.0|        0.0|        18.0|           1|          1.6|       1|Manhattan|
|2025-01-01 00:32:40|2025-01-01 00:35:13|           236|           237|              1|        5.1|        0.0|       12.12|           1|          0.5|       1|Manhattan|
|2025-01-01 00:44:04|2025-01-01 00:46:01|           141|           141|              1|        5.1|        0.0|        12.1|           1|        

---

## 19. PySpark SQL - SQL Interface for DataFrames

**What is Spark SQL?**
- Enables writing SQL queries directly on Spark DataFrames
- Combines the power of SQL with distributed computing
- Familiar syntax for SQL users
- Often optimized better than DataFrame API by Spark's Catalyst optimizer

**How to Use:**
1. Create a temporary view from a DataFrame: `df.createOrReplaceTempView('table_name')`
2. Query using SQL: `spark.sql('SELECT ...')`
3. Can combine with DataFrame operations

**Benefits:**
- Easy for SQL-familiar users
- Complex queries can be more readable
- Leverage SQL optimization engine
- Join DataFrames using SQL syntax

In [110]:
# Load taxi data for SQL operations
taxi_df = spark.read.parquet('/content/yellow_tripdata_2025-01.parquet')

In [111]:
# Register the DataFrame as a temporary SQL view
# This allows SQL queries to be run on the data
taxi_df.createOrReplaceTempView('taxi')

In [112]:
# Simple SQL query: Select all trips with high fare amounts
# This demonstrates basic WHERE clause filtering in SQL
spark.sql('SELECT * FROM taxi WHERE total_amount > 50').show()

+--------+--------------------+---------------------+---------------+-------------+----------+------------------+------------+------------+------------+-----------+-----+-------+----------+------------+---------------------+------------+--------------------+-----------+------------------+
|VendorID|tpep_pickup_datetime|tpep_dropoff_datetime|passenger_count|trip_distance|RatecodeID|store_and_fwd_flag|PULocationID|DOLocationID|payment_type|fare_amount|extra|mta_tax|tip_amount|tolls_amount|improvement_surcharge|total_amount|congestion_surcharge|Airport_fee|cbd_congestion_fee|
+--------+--------------------+---------------------+---------------+-------------+----------+------------------+------------+------------+------------+-----------+-----+-------+----------+------------+---------------------+------------+--------------------+-----------+------------------+
|       2| 2025-01-01 00:15:41|  2025-01-01 01:03:03|              4|         3.05|         1|                 N|         114|    

In [113]:
# Combining SQL queries with DataFrame operations
# SQL query provides initial filtering, then DataFrame method refines results
high_fare_trips = spark.sql('SELECT * FROM taxi WHERE total_amount > 50')\
.filter('passenger_count > 2')\
.select('payment_type','passenger_count','total_amount')\
.show()

+------------+---------------+------------+
|payment_type|passenger_count|total_amount|
+------------+---------------+------------+
|           1|              4|       50.76|
|           2|              9|      111.32|
|           1|              3|       56.38|
|           1|              4|        58.3|
|           1|              3|       51.55|
|           1|              4|       91.19|
|           2|              3|        59.1|
|           1|              4|       61.27|
|           3|              4|      123.44|
|           1|              3|       59.09|
|           2|              3|       192.4|
|           1|              4|       62.52|
|           1|              3|      151.35|
|           1|              3|       88.92|
|           1|              4|       100.0|
|           1|              4|      115.05|
|           2|              4|       80.94|
|           1|              3|        58.1|
|           1|              3|       62.49|
|           1|              3|  

In [114]:
# More complex SQL query with multiple conditions
# Using multi-line SQL for better readability
complex_query = """
SELECT payment_type, passenger_count, total_amount
FROM taxi
WHERE
  total_amount > 50
  AND
  passenger_count > 2
"""
spark.sql(complex_query).show()

+------------+---------------+------------+
|payment_type|passenger_count|total_amount|
+------------+---------------+------------+
|           1|              4|       50.76|
|           2|              9|      111.32|
|           1|              3|       56.38|
|           1|              4|        58.3|
|           1|              3|       51.55|
|           1|              4|       91.19|
|           2|              3|        59.1|
|           1|              4|       61.27|
|           3|              4|      123.44|
|           1|              3|       59.09|
|           2|              3|       192.4|
|           1|              4|       62.52|
|           1|              3|      151.35|
|           1|              3|       88.92|
|           1|              4|       100.0|
|           1|              4|      115.05|
|           2|              4|       80.94|
|           1|              3|        58.1|
|           1|              3|       62.49|
|           1|              3|  

---

## 20. Challenge Exercise: SQL and Aggregation

**Objective:** Combine SQL queries with aggregation and joins to answer business questions

**Instructions:**

1. **Load Data**: Load January taxi data and register as a temp view called 'taxi'
2. **Load Lookup**: Load taxi zone lookup CSV and register as 'taxi_lookup'
3. **SQL Join**: Use SQL to LEFT JOIN taxi data with zone lookup on DOLocationID
   - Select: DOLocationID, Borough, total_amount
   - Assign result to `joined_df`
4. **Aggregation**: Group by Borough and calculate average total_amount
   - Alias the average column as `avg_amount`
5. **Output**: Display results and save to CSV

**Business Question:** What is the average fare amount by borough for dropoff locations?

**Skills Practiced:**
- Creating temporary SQL views
- SQL JOIN operations
- DataFrame grouping and aggregation
- Aliasing columns for clarity
- CSV export

In [115]:
# Challenge Solution: SQL Join + Aggregation
from pyspark.sql.functions import avg

# Step 1: Load and register taxi data
taxi_df = spark.read.parquet('/content/yellow_tripdata_2025-01.parquet')
taxi_df.createOrReplaceTempView('taxi')

# Step 2: Load and register zone lookup
taxi_zone_lookup = spark.read.csv('/content/taxi_zone_lookup.csv', header=True, mode='overwrite')
taxi_zone_lookup.createOrReplaceTempView('taxi_lookup')

# Step 3: SQL JOIN query - combine trip data with zone information
join_query = """
 SELECT DOLocationID, Borough, total_amount
 FROM taxi
 LEFT JOIN taxi_lookup
 ON taxi.DOLocationID = taxi_lookup.LocationID
 """
joined_df = spark.sql(join_query)

# Step 4: Group by Borough and calculate average fare amount
avg_fare_by_borough = joined_df.groupBy('Borough').agg(avg('total_amount').alias('avg_amount'))

In [116]:
# Step 5: Display results
# Shows average fare amount by borough
avg_fare_by_borough.show()

+-------------+------------------+
|      Borough|        avg_amount|
+-------------+------------------+
|       Queens| 51.61163036538719|
|          EWR|123.16705659828357|
|      Unknown|25.926336005344027|
|     Brooklyn| 42.40119152829651|
|Staten Island| 88.12902995720401|
|          N/A|107.62149760052951|
|    Manhattan|22.818652849443495|
|        Bronx| 42.89454769447338|
+-------------+------------------+



In [117]:
# Step 6: Export results to CSV for reporting and analysis
avg_fare_by_borough.write.option('header', True).mode('overwrite').csv('/content/avg_fare_by_borough')

---

## 21. Production Environment Requirements

When deploying PySpark applications in production, consider:

### Core Requirements

**Scalability**
- Horizontal scaling with cluster management
- Load balancing across nodes
- Efficient resource allocation

**Reliability**
- Fault tolerance and recovery mechanisms
- Data consistency and integrity
- Error handling and logging

**Security**
- Access control and authentication
- Data encryption (in transit and at rest)
- Audit trails and monitoring

### Infrastructure Components

- **Data Sources**: HDFS, S3, databases, data lakes
- **Distributed Storage**: For persistent data, checkpoints, and logs
- **Cluster Management**: YARN, Kubernetes, Mesos for orchestration
- **Job Scheduling**: Airflow, Oozie, or cloud-native schedulers
- **Consistent Environments**: Docker, containerization for reproducibility
- **Monitoring & Logging**: Tools like Prometheus, ELK Stack, CloudWatch
- **Security & Access Control**: Role-based access, encryption, VPNs

---

## 22. Advanced Topic: Writing Data with Partitioning

**What is Partitioning?**
- Divides data across multiple files/directories based on column values
- Improves query performance by filtering partitions
- Creates physical structure on disk
- Enables faster reads when filtering on partition columns

**Partitioning Strategies:**
1. **Single Column**: Partition by date, region, type
2. **Multiple Columns**: Partition by year, month, day for hierarchical structure
3. **Range Partitioning**: By numeric ranges
4. **Hash Partitioning**: Distribute by hash of column

**Benefits:**
- Faster queries on partitioned columns
- Easier data management and retention
- Better parallelization
- Reduced data scanning

**Trade-offs:**
- Increased number of files
- Partition pruning requires filter conditions
- Uneven partitions can cause skew

In [118]:
# Example: Write data partitioned by payment type
# This creates separate directories for each payment type, improving query performance

df_2025_combined.write \
    .partitionBy('payment_type') \
    .mode('overwrite') \
    .parquet('/content/taxi_data_partitioned_by_payment')

print("Data written to: /content/taxi_data_partitioned_by_payment")
print("Directory structure:")
print("  /payment_type=1/")
print("  /payment_type=2/")
print("  /payment_type=3/")
print("  /payment_type=4/")

Data written to: /content/taxi_data_partitioned_by_payment
Directory structure:
  /payment_type=1/
  /payment_type=2/
  /payment_type=3/
  /payment_type=4/


In [119]:
# Example: Multi-level Partitioning - by payment type and vendor
# Creates hierarchical structure for better organization

from pyspark.sql.functions import month

df_with_month = df_2025_combined.withColumn('trip_month', month('pu_datetime'))

df_with_month.write \
    .partitionBy('payment_type', 'trip_month') \
    .mode('overwrite') \
    .parquet('/content/taxi_data_partitioned_multi')

print("Multi-level partitioned data written!")
print("Directory structure (example):")
print("  /payment_type=1/trip_month=1/")
print("  /payment_type=1/trip_month=2/")
print("  /payment_type=2/trip_month=1/")
print("  /payment_type=2/trip_month=2/")

Multi-level partitioned data written!
Directory structure (example):
  /payment_type=1/trip_month=1/
  /payment_type=1/trip_month=2/
  /payment_type=2/trip_month=1/
  /payment_type=2/trip_month=2/


In [120]:
# Example: Read back only specific partitions (fast due to partition pruning)
# When you filter on partition columns, Spark skips reading other partitions

# Read all data for payment_type = 1 only
df_payment_1 = spark.read.parquet('/content/taxi_data_partitioned_by_payment') \
    .filter('payment_type = 1')

print(f"Trips with payment_type = 1: {df_payment_1.count()}")

# This is much faster than reading all payment types and filtering
# because Spark only reads the payment_type=1 directory

Trips with payment_type = 1: 4780568


---

## 23. Advanced Topic: Data Quality Checks & Outlier Detection

**What are Outliers?**
- Data points that significantly deviate from expected values
- Can indicate errors, fraud, or anomalies
- Important for data quality and accuracy

**Common Detection Methods:**
- **Statistical (Z-Score)**: Values beyond mean ± 3*stddev
- **IQR Method**: Values outside Q1-1.5*IQR and Q3+1.5*IQR
- **Domain-based**: Known business rules (e.g., fare > $1000)
- **Isolation Forest**: Machine learning-based anomaly detection

**Use Cases:**
- Fraud detection
- Data quality validation
- Sensor anomaly detection
- Financial transaction monitoring
- Trip validity checks

In [121]:
# Example: Statistical Outlier Detection using Z-Score method
from pyspark.sql.functions import mean, stddev, col, abs

# Calculate mean and standard deviation for total_amount
fare_stats = df_2025_combined.select(
    mean('total_amount').alias('mean_fare'),
    stddev('total_amount').alias('stddev_fare')
).first()

mean_fare = fare_stats['mean_fare']
stddev_fare = fare_stats['stddev_fare']

print(f"Mean Fare: ${mean_fare:.2f}")
print(f"StdDev: ${stddev_fare:.2f}")

# Find outliers: beyond 3 standard deviations
outlier_threshold_low = mean_fare - 3 * stddev_fare
outlier_threshold_high = mean_fare + 3 * stddev_fare

outliers = df_2025_combined.filter(
    (col('total_amount') < outlier_threshold_low) |
    (col('total_amount') > outlier_threshold_high)
)

print(f"\nOutliers (beyond \u00b13 stddev):")
print(f"Low threshold: ${outlier_threshold_low:.2f}")
print(f"High threshold: ${outlier_threshold_high:.2f}")
print(f"Number of outliers: {outliers.count()}")

outliers.select('total_amount','trip_distance', 'passenger_count').show(10)

Mean Fare: $25.32
StdDev: $329.62

Outliers (beyond ±3 stddev):
Low threshold: $-963.53
High threshold: $1014.17
Number of outliers: 12
+------------+-------------+---------------+
|total_amount|trip_distance|passenger_count|
+------------+-------------+---------------+
|     2506.71|       255.33|              1|
|   863380.37|          1.6|              1|
|      1311.7|       188.88|              1|
|     1869.25|        270.2|              1|
|    -1832.85|       268.82|              1|
|     1832.85|       268.82|              1|
|     2235.17|       457.64|              1|
|     -987.94|       202.21|              1|
|     1183.59|       174.69|              1|
|   132555.41|          2.2|              1|
+------------+-------------+---------------+
only showing top 10 rows


In [122]:
# Example: IQR Method for Outlier Detection
from pyspark.sql.functions import percentile_approx

# Calculate Q1, Q3, and IQR for trip_distance
quartiles = df_2025_combined.select(
    percentile_approx('trip_distance', 0.25).alias('Q1'),
    percentile_approx('trip_distance', 0.75).alias('Q3')
).first()

Q1 = quartiles['Q1']
Q3 = quartiles['Q3']
IQR = Q3 - Q1

lower_bound = Q1 - 1.5 * IQR
upper_bound = Q3 + 1.5 * IQR

print(f"Q1: {Q1:.2f}, Q3: {Q3:.2f}, IQR: {IQR:.2f}")
print(f"Lower Bound: {lower_bound:.2f}, Upper Bound: {upper_bound:.2f}")

# Find distance outliers
distance_outliers = df_2025_combined.filter(
    (col('trip_distance') < lower_bound) |
    (col('trip_distance') > upper_bound)
)

print(f"\nTrip Distance Outliers (IQR method): {distance_outliers.count()}")
distance_outliers.select('trip_distance', 'total_amount', 'passenger_count').show(10)

Q1: 0.99, Q3: 3.14, IQR: 2.15
Lower Bound: -2.24, Upper Bound: 6.37

Trip Distance Outliers (IQR method): 832272
+-------------+------------+---------------+
|trip_distance|total_amount|passenger_count|
+-------------+------------+---------------+
|          7.2|        40.6|              1|
|        10.42|       74.28|              1|
|        14.84|       75.55|              1|
|        31.97|      111.32|              9|
|         7.73|        38.8|              2|
|         7.23|       44.04|              1|
|         10.9|        57.1|              1|
|         8.66|       48.89|              2|
|         9.17|       80.28|              2|
|          7.2|       46.56|              1|
+-------------+------------+---------------+
only showing top 10 rows


### User Defined Functions Challenge

**Challenge: Create a Fare Efficiency Score UDF**

Calculate a fare efficiency score: (total_fare / trip_distance) per mile.

**Steps:**
1. Define a Python function `calculate_fare_efficiency()` that returns fare per mile
2. Handle cases where distance is 0 or null (return 0 or -1)
3. Register it as a FloatType UDF
4. Apply it to create a new column `fare_per_mile`
5. Find trips with highest and lowest fare efficiency
6. Filter for trips with efficiency > $10 per mile (potentially overcharged trips)

**Bonus:** Create another UDF to flag suspicious trips (efficiency > mean + 2*stddev)

---

## 24. Advanced Topic: User Defined Functions (UDFs)

**What are UDFs?**
- Custom Python functions applied to DataFrames
- Extend PySpark's built-in functionality
- Allow complex business logic in transformations
- Can operate on single rows or aggregated data

**Types of UDFs:**
1. **Standard UDF**: Takes Python function, converts to PySpark UDF
2. **Pandas UDF**: Vectorized, uses Apache Arrow, faster performance
3. **SQL UDF**: Registered as SQL functions for use in SQL queries

**Performance Considerations:**
- UDFs are slower than built-in functions (serialization overhead)
- Pandas UDFs are faster for vectorized operations
- Always try to use built-in functions first

**Use Cases:**
- Complex business logic transformations
- Conditional categorization
- String manipulations
- Custom validation and categorization

In [123]:
# Example: UDF to categorize trip distance
from pyspark.sql.functions import udf
from pyspark.sql.types import StringType

# Define Python function
def categorize_distance(distance):
    """Categorize trip distance into Short, Medium, or Long"""
    if distance is None:
        return 'Unknown'
    elif distance < 2:
        return 'Short'
    elif distance < 10:
        return 'Medium'
    else:
        return 'Long'

# Register as UDF with return type StringType
distance_category_udf = udf(categorize_distance, StringType())

# Apply UDF to create new column
df_with_distance_category = df_2025_combined.withColumn(
    'distance_category',
    distance_category_udf(df_2025_combined['trip_distance'])
)

# Display results
df_with_distance_category.select('trip_distance', 'distance_category').show(10)

+-------------+-----------------+
|trip_distance|distance_category|
+-------------+-----------------+
|          1.6|            Short|
|          0.5|            Short|
|          0.6|            Short|
|         0.52|            Short|
|         0.66|            Short|
|         2.63|           Medium|
|          0.4|            Short|
|          1.6|            Short|
|          2.8|           Medium|
|         1.71|            Short|
+-------------+-----------------+
only showing top 10 rows


### User Defined Functions Challenge

**Challenge: Create a Fare Efficiency Score UDF**

Calculate a fare efficiency score: (total_fare / trip_distance) per mile.

**Steps:**
1. Define a Python function `calculate_fare_efficiency()` that returns fare per mile
2. Handle cases where distance is 0 or null (return 0 or -1)
3. Register it as a FloatType UDF
4. Apply it to create a new column `fare_per_mile`
5. Find trips with highest and lowest fare efficiency
6. Filter for trips with efficiency > $10 per mile (potentially overcharged trips)

**Bonus:** Create another UDF to flag suspicious trips (efficiency > mean + 2*stddev)

---

## 25. Advanced Topic: Pivot Tables

**What are Pivot Tables?**
- Reorganize and summarize data for reporting
- Convert rows into columns (transpose data)
- Aggregate data across multiple dimensions
- Similar to Excel pivot tables

**Syntax:**
```python
df.groupBy('dimension1').pivot('dimension2').agg(aggregation_function)
```


**Benefits:**

- Easy cross-tabulation of data
- Quick multi-dimensional summaries

- Better readability for certain analyses

- Useful for creating reports and dashboards



**Performance Note:**

- Can be expensive on large datasets

- Works best with limited pivot valuesCreate a pivot table showing the maximum trip distance for each combination of payment type and VendorID.

- Use filters to reduce data before pivoting



### Pivot Tables Challenge
**Challenge: Create Multi-Dimensional Pivot Table**
**Steps:**
1. Use groupBy() with VendorID
2. Pivot on payment_type
3. Aggregate using max('trip_distance')
4. Display the resulting pivot table
5. Interpret the results - which vendor-payment combination has the longest trips?


**Bonus:** Create another pivot showing min('total_amount') for the same dimensions

In [124]:
from pyspark.sql.functions import max, min

# 1. Use groupBy() with VendorID, 2. Pivot on payment_type, 3. Aggregate using max('trip_distance')
pivot_max_distance = df_2025_combined.groupBy('VendorID') \
    .pivot('payment_type') \
    .agg(max('trip_distance').alias('max_trip_distance'))

# 4. Display the resulting pivot table
pivot_max_distance.show()

# 5. Interpretation: The highest value in each row shows the longest trip for that vendor/payment type.

# Bonus: Create another pivot showing min('total_amount') for the same dimensions
pivot_min_total = df_2025_combined.groupBy('VendorID') \
    .pivot('payment_type') \
    .agg(min('total_amount').alias('min_total_amount'))

pivot_min_total.show()

+--------+---------+-------+-------+-----+------+----+
|VendorID|        0|      1|      2|    3|     4|   5|
+--------+---------+-------+-------+-----+------+----+
|       1|     92.4|44730.3|  265.9| 52.8|  71.3| 0.0|
|       6|    32.74|   NULL|   NULL| NULL|  NULL|NULL|
|       7|     NULL|  39.72|  62.75|11.96|  2.88|NULL|
|       2|276423.57|8002.41|2001.95|81.14|268.82|NULL|
+--------+---------+-------+-------+-----+------+----+

+--------+------+------+------+------+--------+----+
|VendorID|     0|     1|     2|     3|       4|   5|
+--------+------+------+------+------+--------+----+
|       1|   0.0|   0.0|   0.0|   0.0|     0.0| 0.0|
|       6|   0.0|  NULL|  NULL|  NULL|    NULL|NULL|
|       7|  NULL|   7.2|   5.9|   5.2|   10.85|NULL|
|       2|-35.73|-104.9|-960.0|-901.0|-1832.85|NULL|
+--------+------+------+------+------+--------+----+



---

## 26. Advanced Topic: Window Functions

**What are Window Functions?**
- Perform calculations across rows in a DataFrame
- Useful for running totals, rankings, and time-series analysis
- Apply functions to a "window" of rows defined by PARTITION BY and ORDER BY
- Return a new column without reducing the number of rows

**Key Concepts:**
- **PARTITION BY**: Divides data into groups (like GROUP BY but keeps all rows)
- **ORDER BY**: Specifies the order within each partition
- **Aggregation Functions**: sum, avg, min, max, row_number, rank, dense_rank, lag, lead


**Use Cases:**
- Running totals and cumulative sums

- Ranking within groups

- Finding differences from previous row (lag/lead)
- Moving averages

- Sequential numbering within groups


### Window Functions Challenge

**Challenge: Calculate Moving Average**
**Steps:**
1. Import the `lag` function from pyspark.sql.functions
2. Create a window that partitions by payment_type and orders by pu_datetime
3. Add a new column `prev_fare_1` using lag() to get the previous trip's fare
4. Create another column for the fare 2 trips back
5. Calculate the average of current and previous 2 trips
6. Display results for payment_type = 1 with at least 3 fares

Calculate a 3-trip moving average of fare amount for each payment type (use LAG and window functions).

In [125]:
# Example: Calculate running total of fare amount by payment type
from pyspark.sql.window import Window
from pyspark.sql.functions import sum as spark_sum, row_number

# Define window: partition by payment_type, order by pickup time
window_spec = Window.partitionBy('payment_type').orderBy('pu_datetime')\
                     .rowsBetween(Window.unboundedPreceding, 0)

# Add running total column
df_running_total = df_2025_combined.withColumn(
    'running_total_fare',
    spark_sum('total_amount').over(window_spec)
)

# Display first few rows with the running total
df_running_total.select('payment_type', 'pu_datetime', 'total_amount', 'running_total_fare').show(10)

+------------+-------------------+------------+------------------+
|payment_type|        pu_datetime|total_amount|running_total_fare|
+------------+-------------------+------------+------------------+
|           0|2025-01-01 00:00:00|        -3.6|              -3.6|
|           0|2025-01-01 00:03:52|        4.89|1.2899999999999996|
|           0|2025-01-01 00:04:35|       24.07|             25.36|
|           0|2025-01-01 00:04:37|       20.46|             45.82|
|           0|2025-01-01 00:04:45|       27.79|             73.61|
|           0|2025-01-01 00:04:51|       -4.22|             69.39|
|           0|2025-01-01 00:05:18|       22.35| 91.74000000000001|
|           0|2025-01-01 00:05:43|       34.52|126.26000000000002|
|           0|2025-01-01 00:05:44|       88.68|214.94000000000003|
|           0|2025-01-01 00:05:49|       -8.48|206.46000000000004|
+------------+-------------------+------------+------------------+
only showing top 10 rows


In [126]:
# Solution: 3-trip moving average of fare amount for each payment type
# 1. Import lag
from pyspark.sql.window import Window
from pyspark.sql.functions import lag, col, round as pyspark_round


# 2. Create a window partitioned by payment_type and ordered by pu_datetime
window_spec = Window.partitionBy('payment_type').orderBy('pu_datetime')

# 3. Add previous fare columns using lag
df_with_lags = df_2025_combined \
    .withColumn('prev_fare_1', lag('total_amount', 1).over(window_spec)) \
    .withColumn('prev_fare_2', lag('total_amount', 2).over(window_spec))

# 4 & 5. Calculate the 3-trip moving average (current + prev 2) / 3
df_with_ma = df_with_lags.withColumn(
    'moving_avg_fare',
    pyspark_round((
        col('total_amount') +
        col('prev_fare_1') +
        col('prev_fare_2')
    ) / 3, 2)
)

# 6. Show results for payment_type = 1 with at least 3 fares (i.e., not null in all lag columns)
df_with_ma.filter(
    (col('payment_type') == 1) &
    col('prev_fare_1').isNotNull() &
    col('prev_fare_2').isNotNull()
).select(
    'payment_type', 'pu_datetime', 'total_amount', 'prev_fare_1', 'prev_fare_2', 'moving_avg_fare'
).show(10)

+------------+-------------------+------------+-----------+-----------+---------------+
|payment_type|        pu_datetime|total_amount|prev_fare_1|prev_fare_2|moving_avg_fare|
+------------+-------------------+------------+-----------+-----------+---------------+
|           1|2024-12-31 20:54:50|       39.84|       32.3|      17.16|          29.77|
|           1|2024-12-31 21:15:22|        23.6|      39.84|       32.3|          31.91|
|           1|2024-12-31 21:20:05|        23.3|       23.6|      39.84|          28.91|
|           1|2024-12-31 23:25:38|        92.5|       23.3|       23.6|          46.47|
|           1|2024-12-31 23:27:13|        18.0|       92.5|       23.3|           44.6|
|           1|2024-12-31 23:30:03|       26.62|       18.0|       92.5|          45.71|
|           1|2024-12-31 23:37:42|       14.03|      26.62|       18.0|          19.55|
|           1|2024-12-31 23:48:48|        15.9|      14.03|      26.62|          18.85|
|           1|2024-12-31 23:49:2