<a href="https://colab.research.google.com/github/thesimonk/working_with_csv_in_pyspark/blob/master/processing_csv_pyspark.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [105]:
from pyspark.sql import SparkSession as ss
# import SparkSession, the main entry point to Spark functionality
# aliasing it as 'ss' is optional and just shortens later references

spark = ss.builder \
    .appName('CSV Processing') \
    .getOrCreate()
# ss.builder        - starts configuring a Spark application
# .appName(...)     - sets the application name (visible in Spark UI)
#                     (optional but recommended for tracking jobs)
# .getOrCreate()    - returns an existing SparkSession if one already exists,
#                     otherwise creates a new one

data = spark.read.csv(
    '/content/sample_data/california_housing_train.csv',
    header=True,        # header=True - first row contains column names
                        # header=False - Spark assigns column names like _c0, _c1, ...
    inferSchema=True    # inferSchema=True - Spark infers column data types (int, double, etc.)
                        # inferSchema=False - all columns are read as strings (default)
)
# spark.read        - DataFrameReader used for loading data
# .csv(...)         - reads a CSV file from local disk, HDFS, or cloud storage
# the result is a Spark DataFrame distributed across partitions

data.show()
# show()            - triggers a Spark action (lazy evaluation ends here)
#                     Displays the first 20 rows by default
# show(n)           - display the first n rows (e.g., show(5))
# show(truncate)   - truncate long columns (default=True)
# show(n, truncate)- control both row count and truncation

+---------+--------+------------------+-----------+--------------+----------+----------+-------------+------------------+
|longitude|latitude|housing_median_age|total_rooms|total_bedrooms|population|households|median_income|median_house_value|
+---------+--------+------------------+-----------+--------------+----------+----------+-------------+------------------+
|  -114.31|   34.19|              15.0|     5612.0|        1283.0|    1015.0|     472.0|       1.4936|           66900.0|
|  -114.47|    34.4|              19.0|     7650.0|        1901.0|    1129.0|     463.0|         1.82|           80100.0|
|  -114.56|   33.69|              17.0|      720.0|         174.0|     333.0|     117.0|       1.6509|           85700.0|
|  -114.57|   33.64|              14.0|     1501.0|         337.0|     515.0|     226.0|       3.1917|           73400.0|
|  -114.57|   33.57|              20.0|     1454.0|         326.0|     624.0|     262.0|        1.925|           65500.0|
|  -114.58|   33.63|    

In [106]:
# check column data types
data.printSchema()

root
 |-- longitude: double (nullable = true)
 |-- latitude: double (nullable = true)
 |-- housing_median_age: double (nullable = true)
 |-- total_rooms: double (nullable = true)
 |-- total_bedrooms: double (nullable = true)
 |-- population: double (nullable = true)
 |-- households: double (nullable = true)
 |-- median_income: double (nullable = true)
 |-- median_house_value: double (nullable = true)



In [107]:
# change a column from one data type to another
from pyspark.sql.functions import col # ideally imported at the top (first cell)

data = data.withColumn(
    "total_rooms", col("total_rooms").cast("int")
)

# common type conversions
    # .cast("int")
    # .cast("double")
    # .cast("float")
    # .cast("string")
    # .cast("boolean")
    # .cast("date")
    # .cast("timestamp")

# change multiple columns at once
data = data.select(
    col("total_bedrooms").cast("int"),
    col("median_house_value").cast("int"),
    *[c for c in data.columns if c not in ["total_bedrooms", "median_house_value"]]
)

data.show(5)
data.printSchema()



+--------------+------------------+---------+--------+------------------+-----------+----------+----------+-------------+
|total_bedrooms|median_house_value|longitude|latitude|housing_median_age|total_rooms|population|households|median_income|
+--------------+------------------+---------+--------+------------------+-----------+----------+----------+-------------+
|          1283|             66900|  -114.31|   34.19|              15.0|       5612|    1015.0|     472.0|       1.4936|
|          1901|             80100|  -114.47|    34.4|              19.0|       7650|    1129.0|     463.0|         1.82|
|           174|             85700|  -114.56|   33.69|              17.0|        720|     333.0|     117.0|       1.6509|
|           337|             73400|  -114.57|   33.64|              14.0|       1501|     515.0|     226.0|       3.1917|
|           326|             65500|  -114.57|   33.57|              20.0|       1454|     624.0|     262.0|        1.925|
+--------------+--------

In [108]:
# rename a column
data = data.withColumnRenamed("population", "total_population")

data = data.withColumn("total_population", col("total_population").cast("int")) # just changing the data type of the column

data.show(5)

+--------------+------------------+---------+--------+------------------+-----------+----------------+----------+-------------+
|total_bedrooms|median_house_value|longitude|latitude|housing_median_age|total_rooms|total_population|households|median_income|
+--------------+------------------+---------+--------+------------------+-----------+----------------+----------+-------------+
|          1283|             66900|  -114.31|   34.19|              15.0|       5612|            1015|     472.0|       1.4936|
|          1901|             80100|  -114.47|    34.4|              19.0|       7650|            1129|     463.0|         1.82|
|           174|             85700|  -114.56|   33.69|              17.0|        720|             333|     117.0|       1.6509|
|           337|             73400|  -114.57|   33.64|              14.0|       1501|             515|     226.0|       3.1917|
|           326|             65500|  -114.57|   33.57|              20.0|       1454|             624|  

In [109]:
# round a column to 2 decimal places
from pyspark.sql.functions import round
data = data.withColumn(
    "median_income",
    round(col("median_income"),2)
)

data.show(5)

+--------------+------------------+---------+--------+------------------+-----------+----------------+----------+-------------+
|total_bedrooms|median_house_value|longitude|latitude|housing_median_age|total_rooms|total_population|households|median_income|
+--------------+------------------+---------+--------+------------------+-----------+----------------+----------+-------------+
|          1283|             66900|  -114.31|   34.19|              15.0|       5612|            1015|     472.0|         1.49|
|          1901|             80100|  -114.47|    34.4|              19.0|       7650|            1129|     463.0|         1.82|
|           174|             85700|  -114.56|   33.69|              17.0|        720|             333|     117.0|         1.65|
|           337|             73400|  -114.57|   33.64|              14.0|       1501|             515|     226.0|         3.19|
|           326|             65500|  -114.57|   33.57|              20.0|       1454|             624|  

In [110]:
# filter by a column
data = data.filter(data.median_income > 2.00)

# using col() (recommended style)
    # from pyspark.sql.functions import col
    # df_filtered = data.filter(col("median_income") > 2.0)

# filter using SQL-style expressions
    # data.filter("median_income > 2.0 AND housing_median_age < 40")

# filter string columns
    # df.filter(col("ocean_proximity") == "NEAR BAY")
    # df.filter(col("ocean_proximity").like("%BAY%"))
    # df.filter(col("ocean_proximity").startswith("NEAR"))
    # df.filter(col("ocean_proximity").isin("NEAR BAY", "INLAND"))

# filter by range
    # data.filter(col("median_income").between(3.0, 6.0))

# filter NULL/NOT NULL values
    # data.filter(col("total_bedrooms").isNull())
    # data.filter(col("total_bedrooms").isNotNull())

# filter using lists (IN/NOT IN)
    # data.filter(col("housing_median_age").isin([10, 20, 30]))
    # negation:
      # data.filter(~col("housing_median_age").isin([10, 20, 30]))

data.show(5)

+--------------+------------------+---------+--------+------------------+-----------+----------------+----------+-------------+
|total_bedrooms|median_house_value|longitude|latitude|housing_median_age|total_rooms|total_population|households|median_income|
+--------------+------------------+---------+--------+------------------+-----------+----------------+----------+-------------+
|           337|             73400|  -114.57|   33.64|              14.0|       1501|             515|     226.0|         3.19|
|           236|             74000|  -114.58|   33.63|              29.0|       1387|             671|     239.0|         3.34|
|           680|             82400|  -114.58|   33.61|              25.0|       2907|            1841|     633.0|         2.68|
|          1175|             58400|  -114.59|   33.61|              34.0|       4789|            3134|    1056.0|         2.18|
|           309|             48100|   -114.6|   34.83|              46.0|       1497|             787|  

In [111]:
# order by a column
data = data.orderBy("total_bedrooms", ascending=False)
# or
    # df.orderBy(col("median_income").asc())
    # df.orderBy(col("median_income").desc())

data.show(5)

# order by multiple columns
    # df.orderBy(
    #     col("ocean_proximity").asc(),
    #     col("median_income").desc()
    # )

+--------------+------------------+---------+--------+------------------+-----------+----------------+----------+-------------+
|total_bedrooms|median_house_value|longitude|latitude|housing_median_age|total_rooms|total_population|households|median_income|
+--------------+------------------+---------+--------+------------------+-----------+----------------+----------+-------------+
|          6445|            118800|  -121.79|   36.64|              11.0|      32627|           28566|    6082.0|         2.31|
|          5471|            366300|  -117.74|   33.89|               4.0|      37937|           16122|    5189.0|         7.49|
|          5290|            253900|  -117.78|   34.03|               8.0|      32054|           15507|    5050.0|         6.02|
|          4957|            212300|  -117.12|   33.52|               4.0|      30401|           13251|    4339.0|         4.58|
|          4819|            134400|  -117.42|   33.35|              14.0|      25135|           35682|  

In [112]:
# group by on data frame
from pyspark.sql.functions import avg, round
data = (
    data
    .groupBy("latitude")
    .agg(round(avg("median_income"), 2).alias("avg_median_income"))
)

# common aggregations
    # from pyspark.sql.functions import (
    #     count, sum, avg, min, max, mean
    # )

    # df.groupBy("ocean_proximity").agg(
    #     count("*").alias("row_count"),
    #     avg("median_income").alias("avg_income"),
    #     min("median_income").alias("min_income"),
    #     max("median_income").alias("max_income"),
    #     sum("population").alias("total_population")
    # ).show()


# most common pattern
    # df.groupBy("column").agg(
    #     avg("value").alias("avg_value"),
    #     count("*").alias("count")
    # )


data.show(5)

+--------+-----------------+
|latitude|avg_median_income|
+--------+-----------------+
|   38.61|             3.36|
|   37.81|             4.88|
|   40.53|             2.35|
|   35.17|             3.49|
|   37.23|             6.07|
+--------+-----------------+
only showing top 5 rows


In [113]:
# writing data to another CSV
data.coalesce(1) \
    .write.mode("overwrite") \
    .option("header", True) \
    .csv("/content/sample_data/notebook_output")

data = spark.read.csv(
    "/content/sample_data/notebook_output",
    header=True,
    inferSchema=True
)

# data.coalesce(1)
#     Combines all partitions into one partition so that Spark writes one CSV file instead of multiple part-*.csv files.
#     Caution: For very large datasets, this can overwhelm a single executor and cause memory issues.
#     Alternative: omit coalesce(1) to write multiple files in parallel (more efficient for big data).

# .write.mode("overwrite")
# - Controls how Spark handles existing data at the output path.

# Options for mode:
#     "overwrite" - Delete existing files/folder and write new data.
#     "append" - Add new data to existing files/folder.
#     "ignore" - Do nothing if path exists.
#     "error" or "errorifexists" - Throw an error if the path exists (default behavior).
#     Use "overwrite" when you want to replace old output.

data.show(5)


+--------+-----------------+
|latitude|avg_median_income|
+--------+-----------------+
|   38.61|             3.36|
|   37.81|             4.88|
|   40.53|             2.35|
|   35.17|             3.49|
|   37.23|             6.07|
+--------+-----------------+
only showing top 5 rows
