In [0]:
%run ./_resources/01-setup $reset_all_data=false

## The aggregation function to update session via applyInPandasWithState():

In [0]:
from typing import Tuple, Iterator
from pyspark.sql.streaming.state import GroupState, GroupStateTimeout

# If we don't have activity after 30sec, close the session:
max_session_duration = 30000

def func(key: Tuple[str], events: Iterator[pd.DataFrame], state: GroupState) -> Iterator[pd.DataFrame]:
    # Unpack the key tuple passed by Spark.
    # Even with a single group key like 'user_id', it's wrapped in a tuple.
    (user_id,) = key

    # Get current session's data if state exists.
    # If not, set default data for a new session.
    if state.exists:
        (user_id, click_count, start_time, end_time) = state.get
    else:
        click_count = 0
        start_time = sys.maxsize
        end_time = 0
    
    if state.hasTimedOut:
        # End of the session: Drop the session from the state.
        state.remove()
        # Emit a final offline session update.
        yield pd.DataFrame({"user_id": [user_id], "click_count": [click_count],
                            "start_time": [start_time], "end_time": [end_time], "status": ["offline"]})
    else:
        # For out-of-order events, we need to get the min/max date and the sum:
        for df in events:
            start_time = min(start_time, df["event_date"].min())
            end_time = max(end_time, df["event_date"].max())
            click_count += len(df)

        # Update the state with the new values:
        state.update((user_id, int(click_count), int(start_time), int(end_time)))

        # Set the timeout as max_session_duration seconds:
        state.setTimeoutDuration(max_session_duration)

        # Compute the status to flag offline session in case of restart:
        now = int(time.time())
        status = "offline" if end_time >= now - max_session_duration else "online"
    
        # Emit the change. 
        # We could also yield an empty dataframe if we only want to emit when the session is closed: yield pd.DataFrame()
        yield pd.DataFrame({"user_id": [user_id], "click_count": [click_count],
                            "start_time": [start_time], "end_time": [end_time], "status": [status]})

In [0]:
output_schema = "user_id STRING, click_count LONG, start_time LONG, end_time LONG, status STRING"
state_schema = "user_id STRING, click_count LONG, start_time LONG, end_time LONG"

# Enable processing-time-based timeouts for each group.
# This allows state.hasTimedOut and state.setTimeoutDuration(...) to work,
# so we can close sessions after a period of inactivity (e.g., 30 seconds).
sessions = (spark.readStream.table("events")
            .groupBy(F.col("user_id"))
            .applyInPandasWithState(
                func,
                output_schema,
                state_schema,
                "append",
                GroupStateTimeout.ProcessingTimeTimeout))

In [0]:
display(sessions, get_chkp_folder())

## Updating the session table with number of clicks and end/start time:

We want to have the session information in real time for each user. 

To do that, we'll create a Session table. Everytime we update the state, we'll UPSERT the session information with a MERGE operation using Delta and calling `foreachBatch`:

- If the session doesn't exist, we add it.
- If it exists, we update it with the new count and potential new status.

In [0]:
from delta.tables import DeltaTable

def upsert_sessions(df, epoch_id):
    #Create the table if it's the first time (we need it to be able to perform the merge)
    #limit(0) to create a DataFrame with the correct schema but zero rows.
    #spark._jsparkSession gives you access to the underlying Java/Scala SparkSession (used in internal checks like tableExists).
    if epoch_id == 0 and not spark._jsparkSession.catalog().tableExists("sessions"):
        (df.limit(0).write
                        .option('mergeSchema', 'true')
                        .mode('append')
                        .saveAsTable("sessions"))

    #Load Delta table by name (registered in metastore) for Delta-specific methods
    (DeltaTable.forName(spark, "sessions").alias("s")
                        .merge(
                            source = df.alias("u"),
                            condition = "s.user_id = u.user_id")
                        .whenMatchedUpdateAll()
                        .whenNotMatchedInsertAll()
                        .execute())
    
(sessions.writeStream
 .option("checkpointLocation", volume_folder + "/checkpoints/gold")
 .foreachBatch(upsert_sessions)
 .start())

Utils.wait_for_table("sessions")


In [0]:
%sql
SELECT *
FROM sessions;

In [0]:
%sql 
SELECT CAST(AVG(end_time - start_time) AS INT) average_session_duration 
FROM sessions;

In [0]:
Utils.stop_all_streams(sleep_time=120)