In [1]:
# Run this cell to import pyspark and to define start_spark() and stop_spark()

import findspark

findspark.init()

import getpass
import pandas
import pyspark
import random
import re

from IPython.display import display, HTML
from pyspark import SparkContext
from pyspark.sql import SparkSession
import pyspark.sql.functions as F


# Constants used to interact with Azure Blob Storage using the hdfs command or Spark

global username

username = re.sub('@.*', '', getpass.getuser())

global azure_account_name
global azure_data_container_name
global azure_user_container_name
global azure_user_token

azure_account_name = "madsstorage002"
azure_data_container_name = "campus-data"
azure_user_container_name = "campus-user"
azure_user_token = r"sp=racwdl&st=2025-08-01T09:41:33Z&se=2026-12-30T16:56:33Z&spr=https&sv=2024-11-04&sr=c&sig=GzR1hq7EJ0lRHj92oDO1MBNjkc602nrpfB5H8Cl7FFY%3D"


# Functions used below

def dict_to_html(d):
    """Convert a Python dictionary into a two column table for display.
    """

    html = []

    html.append(f'<table width="100%" style="width:100%; font-family: monospace;">')
    for k, v in d.items():
        html.append(f'<tr><td style="text-align:left;">{k}</td><td>{v}</td></tr>')
    html.append(f'</table>')

    return ''.join(html)


def show_as_html(df, n=20):
    """Leverage existing pandas jupyter integration to show a spark dataframe as html.
    
    Args:
        n (int): number of rows to show (default: 20)
    """

    display(df.limit(n).toPandas())

    
def display_spark():
    """Display the status of the active Spark session if one is currently running.
    """
    
    if 'spark' in globals() and 'sc' in globals():

        name = sc.getConf().get("spark.app.name")

        html = [
            f'<p><b>Spark</b></p>',
            f'<p>The spark session is <b><span style="color:green">active</span></b>, look for <code>{name}</code> under the running applications section in the Spark UI.</p>',
            f'<ul>',
            f'<li><a href="http://localhost:{sc.uiWebUrl.split(":")[-1]}" target="_blank">Spark Application UI</a></li>',
            f'</ul>',
            f'<p><b>Config</b></p>',
            dict_to_html(dict(sc.getConf().getAll())),
            f'<p><b>Notes</b></p>',
            f'<ul>',
            f'<li>The spark session <code>spark</code> and spark context <code>sc</code> global variables have been defined by <code>start_spark()</code>.</li>',
            f'<li>Please run <code>stop_spark()</code> before closing the notebook or restarting the kernel or kill <code>{name}</code> by hand using the link in the Spark UI.</li>',
            f'</ul>',
        ]
        display(HTML(''.join(html)))
        
    else:
        
        html = [
            f'<p><b>Spark</b></p>',
            f'<p>The spark session is <b><span style="color:red">stopped</span></b>, confirm that <code>{username} (notebook)</code> is under the completed applications section in the Spark UI.</p>',
            f'<ul>',
            f'<li><a href="http://mathmadslinux2p.canterbury.ac.nz:8080/" target="_blank">Spark UI</a></li>',
            f'</ul>',
        ]
        display(HTML(''.join(html)))


# Functions to start and stop spark

def start_spark(executor_instances=2, executor_cores=1, worker_memory=1, master_memory=1):
    """Start a new Spark session and define globals for SparkSession (spark) and SparkContext (sc).
    
    Args:
        executor_instances (int): number of executors (default: 2)
        executor_cores (int): number of cores per executor (default: 1)
        worker_memory (float): worker memory (default: 1)
        master_memory (float): master memory (default: 1)
    """

    global spark
    global sc

    cores = executor_instances * executor_cores
    partitions = cores * 4
    port = 4000 + random.randint(1, 999)

    spark = (
        SparkSession.builder
        .config("spark.driver.extraJavaOptions", f"-Dderby.system.home=/tmp/{username}/spark/")
        .config("spark.dynamicAllocation.enabled", "false")
        .config("spark.executor.instances", str(executor_instances))
        .config("spark.executor.cores", str(executor_cores))
        .config("spark.cores.max", str(cores))
        .config("spark.driver.memory", f'{master_memory}g')
        .config("spark.executor.memory", f'{worker_memory}g')
        .config("spark.driver.maxResultSize", "0")
        .config("spark.sql.shuffle.partitions", str(partitions))
        .config("spark.kubernetes.container.image", "madsregistry001.azurecr.io/hadoop-spark:v3.3.5-openjdk-8")
        .config("spark.kubernetes.container.image.pullPolicy", "IfNotPresent")
        .config("spark.kubernetes.memoryOverheadFactor", "0.3")
        .config("spark.memory.fraction", "0.1")
        .config(f"fs.azure.sas.{azure_user_container_name}.{azure_account_name}.blob.core.windows.net",  azure_user_token)
        .config("spark.app.name", f"{username} (notebook)")
        .getOrCreate()
    )
    sc = SparkContext.getOrCreate()
    
    display_spark()

    
def stop_spark():
    """Stop the active Spark session and delete globals for SparkSession (spark) and SparkContext (sc).
    """

    global spark
    global sc

    if 'spark' in globals() and 'sc' in globals():

        spark.stop()

        del spark
        del sc

    display_spark()


# Make css changes to improve spark output readability

html = [
    '<style>',
    'pre { white-space: pre !important; }',
    'table.dataframe td { white-space: nowrap !important; }',
    'table.dataframe thead th:first-child, table.dataframe tbody th { display: none; }',
    '</style>',
]
display(HTML(''.join(html)))

In [2]:
# Run this cell to start a spark session in this notebook

start_spark(executor_instances=8, executor_cores=4, worker_memory=8, master_memory=8)

25/09/03 15:11:06 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).


0,1
spark.dynamicAllocation.enabled,false
spark.fs.azure.sas.uco-user.madsstorage002.blob.core.windows.net,"""sp=racwdl&st=2024-09-19T08:00:18Z&se=2025-09-19T16:00:18Z&spr=https&sv=2022-11-02&sr=c&sig=qtg6fCdoFz6k3EJLw7dA8D3D8wN0neAYw8yG4z4Lw2o%3D"""
spark.kubernetes.driver.pod.name,spark-master-driver
spark.app.name,rsh224 (notebook)
spark.kubernetes.executor.podNamePrefix,rsh224-notebook-94d4fc990d8e835c
spark.fs.azure.sas.campus-user.madsstorage002.blob.core.windows.net,"""sp=racwdl&st=2024-09-19T08:03:31Z&se=2025-09-19T16:03:31Z&spr=https&sv=2022-11-02&sr=c&sig=kMP%2BsBsRzdVVR8rrg%2BNbDhkRBNs6Q98kYY695XMRFDU%3D"""
spark.kubernetes.container.image.pullPolicy,IfNotPresent
spark.kubernetes.namespace,rsh224
spark.executor.cores,4
spark.driver.memory,8g


In [3]:
# Write your imports here or insert cells below

from pyspark.sql import functions as F
from pyspark.sql.types import *
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

In [4]:
directory_path = f'wasbs://{azure_data_container_name}@{azure_account_name}.blob.core.windows.net/ghcnd'

In [5]:
# Load a subset of the last year in daily into Spark from Azure Blob Storage using spark.read.csv

schema = StructType([
    StructField("ID", StringType()),           # Character Station code
    StructField("DATE", StringType()),         # Date Observation date formatted as YYYYMMDD
    StructField("ELEMENT", StringType()),      # Character Element type indicator
    StructField("VALUE", DoubleType()),        # Real Data value for ELEMENT
    StructField("MEASUREMENT", StringType()),  # Character Measurement Flag
    StructField("QUALITY", StringType()),      # Character Quality Flag
    StructField("SOURCE", StringType()),       # Character Source Flag
    StructField("TIME", StringType()),         # Time Observation time formatted as HHMM
])

daily = spark.read.csv(
    path=f'wasbs://{azure_data_container_name}@{azure_account_name}.blob.core.windows.net/ghcnd/daily/',
    schema=schema
)

print(type(daily))
daily.printSchema()
print(daily)
daily.show(100, False)

<class 'pyspark.sql.dataframe.DataFrame'>
root
 |-- ID: string (nullable = true)
 |-- DATE: string (nullable = true)
 |-- ELEMENT: string (nullable = true)
 |-- VALUE: double (nullable = true)
 |-- MEASUREMENT: string (nullable = true)
 |-- QUALITY: string (nullable = true)
 |-- SOURCE: string (nullable = true)
 |-- TIME: string (nullable = true)

DataFrame[ID: string, DATE: string, ELEMENT: string, VALUE: double, MEASUREMENT: string, QUALITY: string, SOURCE: string, TIME: string]


                                                                                

+-----------+--------+-------+------+-----------+-------+------+----+
|ID         |DATE    |ELEMENT|VALUE |MEASUREMENT|QUALITY|SOURCE|TIME|
+-----------+--------+-------+------+-----------+-------+------+----+
|ASN00030019|20100101|PRCP   |24.0  |NULL       |NULL   |a     |NULL|
|ASN00030021|20100101|PRCP   |200.0 |NULL       |NULL   |a     |NULL|
|ASN00030022|20100101|TMAX   |294.0 |NULL       |NULL   |a     |NULL|
|ASN00030022|20100101|TMIN   |215.0 |NULL       |NULL   |a     |NULL|
|ASN00030022|20100101|PRCP   |408.0 |NULL       |NULL   |a     |NULL|
|ASN00029121|20100101|PRCP   |820.0 |NULL       |NULL   |a     |NULL|
|ASN00029126|20100101|TMAX   |371.0 |NULL       |NULL   |S     |NULL|
|ASN00029126|20100101|TMIN   |225.0 |NULL       |NULL   |S     |NULL|
|ASN00029126|20100101|PRCP   |0.0   |NULL       |NULL   |a     |NULL|
|ASN00029126|20100101|TAVG   |298.0 |H          |NULL   |S     |NULL|
|ASN00029127|20100101|TMAX   |371.0 |NULL       |NULL   |a     |NULL|
|ASN00029127|2010010

In [6]:
stations_enriched_path = f'wasbs://{azure_user_container_name}@{azure_account_name}.blob.core.windows.net/{username}/stations-enriched'

In [7]:
stations_enriched = spark.read.csv(stations_enriched_path, header=True, inferSchema=False)

                                                                                

In [8]:
daily.printSchema()

root
 |-- ID: string (nullable = true)
 |-- DATE: string (nullable = true)
 |-- ELEMENT: string (nullable = true)
 |-- VALUE: double (nullable = true)
 |-- MEASUREMENT: string (nullable = true)
 |-- QUALITY: string (nullable = true)
 |-- SOURCE: string (nullable = true)
 |-- TIME: string (nullable = true)



In [9]:
stations_enriched.printSchema()

root
 |-- ID: string (nullable = true)
 |-- STATE_CODE: string (nullable = true)
 |-- COUNTRY_CODE: string (nullable = true)
 |-- LATITUDE: string (nullable = true)
 |-- LONGITUDE: string (nullable = true)
 |-- ELEVATION: string (nullable = true)
 |-- STATION_NAME: string (nullable = true)
 |-- GSN: string (nullable = true)
 |-- HCN_CRN_FLAG: string (nullable = true)
 |-- WMO_ID: string (nullable = true)
 |-- COUNTRY_NAME: string (nullable = true)
 |-- STATE_NAME: string (nullable = true)
 |-- ELEMENTS: string (nullable = true)
 |-- NUM_CORE_ELEMENTS: string (nullable = true)
 |-- NUM_OTHER_ELEMENTS: string (nullable = true)



In [10]:
daily_prcp = daily.filter(F.col('ELEMENT') == 'PRCP')

In [11]:
daily_prcp.count()

                                                                                

1084610240

In [12]:
# daily_prcp.show(20, False)

+-----------+--------+-------+-----+-----------+-------+------+----+
|ID         |DATE    |ELEMENT|VALUE|MEASUREMENT|QUALITY|SOURCE|TIME|
+-----------+--------+-------+-----+-----------+-------+------+----+
|ASN00030019|20100101|PRCP   |24.0 |NULL       |NULL   |a     |NULL|
|ASN00030021|20100101|PRCP   |200.0|NULL       |NULL   |a     |NULL|
|ASN00030022|20100101|PRCP   |408.0|NULL       |NULL   |a     |NULL|
|ASN00029121|20100101|PRCP   |820.0|NULL       |NULL   |a     |NULL|
|ASN00029126|20100101|PRCP   |0.0  |NULL       |NULL   |a     |NULL|
|ASN00029127|20100101|PRCP   |8.0  |NULL       |NULL   |a     |NULL|
|ASN00029129|20100101|PRCP   |174.0|NULL       |NULL   |a     |NULL|
|ASN00029130|20100101|PRCP   |86.0 |NULL       |NULL   |a     |NULL|
|ASN00029131|20100101|PRCP   |56.0 |NULL       |NULL   |a     |NULL|
|ASN00029132|20100101|PRCP   |800.0|NULL       |NULL   |a     |NULL|
|ASN00029136|20100101|PRCP   |22.0 |NULL       |NULL   |a     |NULL|
|ASN00029137|20100101|PRCP   |0.0 

In [13]:
prcp_country = daily_prcp.join(
    stations_enriched,
    how='inner',
    on='ID'
)

In [14]:
# prcp_country.show(20, False)

[Stage 10:>                                                         (0 + 1) / 1]

+-----------+--------+-------+------+-----------+-------+------+----+----------+------------+--------+---------+---------+------------+----+------------+------+------------+----------+---------+-----------------+------------------+
|ID         |DATE    |ELEMENT|VALUE |MEASUREMENT|QUALITY|SOURCE|TIME|STATE_CODE|COUNTRY_CODE|LATITUDE|LONGITUDE|ELEVATION|STATION_NAME|GSN |HCN_CRN_FLAG|WMO_ID|COUNTRY_NAME|STATE_NAME|ELEMENTS |NUM_CORE_ELEMENTS|NUM_OTHER_ELEMENTS|
+-----------+--------+-------+------+-----------+-------+------+----+----------+------------+--------+---------+---------+------------+----+------------+------+------------+----------+---------+-----------------+------------------+
|ALE00100939|19640101|PRCP   |0.0   |NULL       |NULL   |E     |NULL|NULL      |AL          |41.3331 |19.7831  |89.0     |TIRANA      |NULL|NULL        |NULL  |Albania     |NULL      |TMAX;PRCP|2                |0                 |
|ALE00100939|19860101|PRCP   |1.0   |NULL       |NULL   |E     |NULL|NUL

                                                                                

In [15]:
prcp_country = prcp_country.select('ID', 'DATE', F.col('VALUE').alias('PRCP'), 'COUNTRY_CODE', 'COUNTRY_NAME', 'LATITUDE', 'LONGITUDE', 'ELEVATION', 'ID')

In [16]:
# prcp_country.show(20, False)

+-----------+--------+-----+------------+------------+--------+---------+---------+-----------+
|ID         |DATE    |PRCP |COUNTRY_CODE|COUNTRY_NAME|LATITUDE|LONGITUDE|ELEVATION|ID         |
+-----------+--------+-----+------------+------------+--------+---------+---------+-----------+
|ASN00030019|20100101|24.0 |AS          |Australia   |-19.2647|143.6741 |536.0    |ASN00030019|
|ASN00030021|20100101|200.0|AS          |Australia   |-20.7389|144.4853 |430.0    |ASN00030021|
|ASN00030022|20100101|408.0|AS          |Australia   |-20.8192|144.2333 |316.4    |ASN00030022|
|ASN00029121|20100101|820.0|AS          |Australia   |-20.5958|139.6939 |280.0    |ASN00029121|
|ASN00029126|20100101|0.0  |AS          |Australia   |-20.7361|139.4817 |381.0    |ASN00029126|
|ASN00029127|20100101|8.0  |AS          |Australia   |-20.6778|139.4875 |340.3    |ASN00029127|
|ASN00029129|20100101|174.0|AS          |Australia   |-21.215 |140.2333 |282.0    |ASN00029129|
|ASN00029130|20100101|86.0 |AS          

In [17]:
prcp_country = prcp_country.withColumn('DATE', F.to_date(F.col('DATE'), 'yyyyMMdd'))

In [18]:
prcp_country = prcp_country.withColumn('DATE', F.date_trunc('year', 'DATE').cast('date'))

In [19]:
# prcp_country.show(20, False)

+-----------+----------+-----+------------+------------+--------+---------+---------+-----------+
|ID         |DATE      |PRCP |COUNTRY_CODE|COUNTRY_NAME|LATITUDE|LONGITUDE|ELEVATION|ID         |
+-----------+----------+-----+------------+------------+--------+---------+---------+-----------+
|ASN00030019|2010-01-01|24.0 |AS          |Australia   |-19.2647|143.6741 |536.0    |ASN00030019|
|ASN00030021|2010-01-01|200.0|AS          |Australia   |-20.7389|144.4853 |430.0    |ASN00030021|
|ASN00030022|2010-01-01|408.0|AS          |Australia   |-20.8192|144.2333 |316.4    |ASN00030022|
|ASN00029121|2010-01-01|820.0|AS          |Australia   |-20.5958|139.6939 |280.0    |ASN00029121|
|ASN00029126|2010-01-01|0.0  |AS          |Australia   |-20.7361|139.4817 |381.0    |ASN00029126|
|ASN00029127|2010-01-01|8.0  |AS          |Australia   |-20.6778|139.4875 |340.3    |ASN00029127|
|ASN00029129|2010-01-01|174.0|AS          |Australia   |-21.215 |140.2333 |282.0    |ASN00029129|
|ASN00029130|2010-01

In [20]:
grouped_prcp = (prcp_country
                .where(F.col('PRCP') != -9999)
                .withColumn('PRCP', F.col('PRCP') / 10.0)
                .groupBy('DATE', 'COUNTRY_CODE', 'COUNTRY_NAME')
                .agg(
                    F.avg('PRCP').alias('AVG_PRCP'),
                    F.count('*').alias('NUM_OBSERVATIONS'),
                    F.countDistinct('ID').alias('NUM_STATIONS')
                )
               )

In [None]:
# grouped_prcp.show(20, False)

[Stage 16:>                                                      (0 + 32) / 107]

In [None]:
grouped_prcp.printSchema()

In [None]:
prcp_year_country = grouped_prcp.withColumn('YEAR', F.year('DATE')).select('YEAR', 'COUNTRY_CODE', 'COUNTRY_NAME', 'AVG_PRCP', 'NUM_OBSERVATIONS', 'NUM_STATIONS')

In [None]:
# prcp_year_country.show(20, False)

In [None]:
prcp_year_country.write.mode('overwrite').option('header', True).csv(f'{output_path}/rsh224/prcp_year_country')

In [None]:
# !hdfs dfs -ls {output_path}/rsh224/prcp_year_country

### Answer 2(a) second part

In [None]:
prcp_year_country.agg(
    F.count('*').alias('COUNT'),
    F.min('YEAR').alias('START_YEAR'),
    F.max('YEAR').alias('END_YEAR'),
    F.min('AVG_PRCP').alias('MIN_PRCP'),
    F.max('AVG_PRCP').alias('MAX_PRCP'),
    F.avg('AVG_PRCP').alias('AVG_PRCP_TOTAL'),
    F.stddev('AVG_PRCP').alias('STD_PRCP')
).show()

In [None]:
# grouped_prcp.select('AVG_PRCP').describe().show()

In [None]:
# prcp_year_country.printSchema()

In [None]:
prcp_year_country.filter(F.col('NUM_OBSERVATIONS') > 300).orderBy(F.desc('AVG_PRCP')).show(20, False)

In [None]:
prcp_2024 = prcp_year_country.filter(F.col('YEAR') == 2024)

In [None]:
prcp_2024.orderBy(F.desc('AVG_PRCP')).show(20, False)

In [None]:
pdf = prcp_2024.toPandas()

In [None]:
import geopandas as gpd
import numpy as np

In [None]:
# World geometries
world = gpd.read_file(gpd.datasets.get_path("naturalearth_lowres"))[["iso_a3","name","geometry"]]

In [None]:
world.head()

In [None]:
# Map 2-letter -> 3-letter country codes for a reliable join
# If pycountry isn't available, you can join by name as a fallback (see comment below).
try:    
    print('here')
    def a2_to_a3(a2):
        try:
            return pycountry.countries.get(alpha_2=a2).alpha_3
        except:
            return None
            
    pdf["iso_a3"] = pdf["COUNTRY_CODE"].apply(a2_to_a3)
    
    pdf["iso_a3"] = pdf.apply(lambda r: fixes.get(r["COUNTRY_CODE"], r["iso_a3"]), axis=1)
    world_join = world.merge(pdf, on="iso_a3", how="left")
except Exception:    
    print('WORLD')
    # Fallback: join by cleaned names (expect more mismatches)
    world_join = world.merge(pdf, left_on="name", right_on="COUNTRY_NAME", how="left")



In [None]:
# Robust color limits (avoid outliers skewing the map)
vmin = np.nanpercentile(world_join["AVG_PRCP"], 2)
vmax = np.nanpercentile(world_join["AVG_PRCP"], 98)

# Reproject and plot
world_robin = world_join.to_crs("ESRI:54030")  # Robinson projection
ax = world_robin.plot(
    column="AVG_PRCP",
    cmap="Blues",
    vmin=vmin, vmax=vmax,
    legend=True,
    linewidth=0.2, edgecolor="white",
    missing_kwds={"color": "lightgrey", "edgecolor": "white", "hatch": "///", "label": "No data"},
    figsize=(11.69, 8.27),  # A4 landscape
)

ax.set_title("Average daily rainfall (mm) in 2024 — country means (station-balanced)", fontsize=12, pad=8)
ax.axis("off")

In [None]:
output_path = f'wasbs://{azure_user_container_name}@{azure_account_name}.blob.core.windows.net'

In [None]:
output_states_count_path = f'{output_path}/rsh224/states_count_stations'
states_enriched.write.mode('overwrite').option('header', True).csv(output_states_count_path)

In [None]:
!hdfs dfs -ls {output_path}/rsh224/

In [None]:
stop_spark()