In [None]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, from_json, explode
from pyspark.sql.types import StructType, StructField, StringType, IntegerType, DoubleType, LongType, ArrayType
from pyspark.sql import functions as F
import tensorflow as tf
from pyspark.sql.functions import udf  

spark = SparkSession.builder \
    .appName("Stream Processing") \
    .config("spark.jars.packages", 
            "org.apache.spark:spark-sql-kafka-0-10_2.12:3.5.3,"
            "org.apache.kafka:kafka-clients:3.5.0") \
    .getOrCreate()

def mse(y_true, y_pred):
    return tf.keras.losses.mean_squared_error(y_true, y_pred)
 
tf.keras.utils.get_custom_objects()["mse"] = mse
    
schema = StructType([
    StructField("data", StructType([
        StructField("plan", StructType([
            StructField("itineraries", ArrayType(StructType([
                StructField("walkDistance", DoubleType()),
                StructField("duration", IntegerType()),
                StructField("legs", ArrayType(StructType([
                    StructField("mode", StringType()),
                    StructField("startTime", LongType()),
                    StructField("endTime", LongType()),
                    StructField("from", StructType([
                        StructField("lat", DoubleType()),
                        StructField("lon", DoubleType()),
                        StructField("name", StringType()),
                        StructField("stop", StructType([
                            StructField("patterns", ArrayType(StructType([
                                StructField("code", StringType(), True)
                            ])))
                        ]))
                    ])),
                    StructField("to", StructType([
                        StructField("lat", DoubleType()),
                        StructField("lon", DoubleType()),
                        StructField("name", StringType()),
                        StructField("stop", StructType([
                            StructField("patterns", ArrayType(StructType([
                                StructField("code", StringType(), True)
                            ])))
                        ]))
                    ])),
                    StructField("trip", StructType([
                        StructField("gtfsId", StringType()),
                        StructField("pattern", StructType([
                            StructField("trip_pattern_code", StringType(), True)
                        ])),
                        StructField("tripHeadsign", StringType())
                    ]))
                ])))
            ])))
        ]))
    ]))
])
 
raw_stream = spark.readStream.format("kafka") \
    .option("kafka.bootstrap.servers", "localhost:9092") \
    .option("subscribe", "your_kafka_topic") \
    .load()
 
parsed_stream = raw_stream.selectExpr("CAST(value AS STRING) as json_value") \
    .select(from_json(col("json_value"), schema).alias("data"))

df_legs = parsed_stream.select(
    F.explode(col("data.data.plan.itineraries.legs")).alias("leg")
)

final_stream = df_legs.select(
    explode(col("leg.startTime")).alias("start_time"),              
    explode(col("leg.endTime")).alias("end_time"),              
    explode(col("leg.trip.gtfsId")).alias("trip_id"),         
    explode(col("leg.mode")).alias("mode"),   
    explode(col("leg.from.name")).alias("from_name"),
    explode(col("leg.to.name")).alias("to_name")
)
final_stream = final_stream.withColumn(
    "leg_duration",
    (col("end_time") - col("start_time")) / 60000.0
)
 
try:
    model = tf.keras.models.load_model("final_model.h5", custom_objects={"mse": mse})
except Exception as e:
    raise RuntimeError(f"Failed to load TensorFlow model: {e}")
 
@udf(DoubleType())
def predict_udf(walk_distance):
    try:
        features_array = np.array([[float(walk_distance)]])
        prediction = model.predict(features_array, verbose=0)
        return float(prediction[0][0])
    except Exception as e:
        print(f"Prediction error: {e}")
        return float("nan")
 
prediction_stream = final_stream.withColumn(
    "prediction", predict_udf(col("walk_distance"))
)
 
query = prediction_stream.select(
    "trip_id", "mode", "from_name", "to_name", 
    "leg_duration", "prediction"
).writeStream \
    .outputMode("append") \
    .format("console") \
    .option("truncate", False) \
    .start()
 
query.awaitTermination()