In [None]:
from pyspark.sql import SparkSession
import socket
import os
import json

# Replace these with your actual values or fetch from environment variables
KUBERNETES_SERVICE_ACCOUNT = 'airflow-worker'
NUM_EXECUTORS = int(os.getenv('NUM_EXECUTORS', '1'))
EXECUTOR_CORES = int(os.getenv('EXECUTOR_CORES', '1'))
EXECUTOR_MEMORY = os.getenv('EXECUTOR_MEMORY', '1g')
DRIVER_MEMORY = os.getenv('DRIVER_MEMORY', '1g')
SPARK_IMAGE = os.getenv('SPARK_IMAGE', 'sergeygazaryan13/spark3.5.0-python3:v1.0.0')
KUBERNETES_NAMESPACE = 'airflow'
YUNIKORN_QUEUE = os.getenv('YUNIKORN_QUEUE', 'your-queue')
S3_ACCESS_KEY = os.getenv('S3_ACCESS_KEY', 'your-s3-access-key')
S3_SECRET_KEY = os.getenv('S3_SECRET_KEY', 'your-s3-secret-key')
S3_ENDPOINT = os.getenv('S3_ENDPOINT', 'your-s3-endpoint')
EXECUTOR_LABEL = os.getenv('EXECUTOR_LABEL', 'true')

try:
    driver_host = socket.gethostbyname(socket.gethostname())
except Exception as e:
    print(f"Error getting driver host: {e}")
    driver_host = "localhost"

spark = SparkSession.builder \
    .appName(f"{KUBERNETES_SERVICE_ACCOUNT}-{NUM_EXECUTORS}e-{EXECUTOR_CORES}c-{EXECUTOR_MEMORY}") \
    .master("k8s://https://kubernetes.default.svc.cluster.local:443") \
    .config("mapreduce.fileoutputcommitter.algorithm.version", 2) \
    .config("mapreduce.fileoutputcommitter.marksuccessfuljobs", "false") \
    .config("spark.kubernetes.container.image", SPARK_IMAGE) \
    .config("spark.kubernetes.container.image.pullPolicy", "Always") \
    .config("spark.kubernetes.authenticate.driver.serviceAccountName", "spark") \
    .config("spark.kubernetes.authenticate.executor.serviceAccountName", "spark") \
    .config("spark.kubernetes.namespace", "airflow") \
    .config("spark.kubernetes.executor.limit.cores", EXECUTOR_CORES) \
    .config("spark.hadoop.fs.s3a.access.key", S3_ACCESS_KEY) \
    .config("spark.hadoop.fs.s3a.secret.key", S3_SECRET_KEY) \
    .config("spark.hadoop.fs.s3a.endpoint", S3_ENDPOINT) \
    .config("spark.hadoop.fs.s3a.path.style.access", "true") \
    .config("spark.hadoop.fs.s3a.connection.ssl.enabled", "false") \
    .config("spark.hadoop.fs.s3a.impl", "org.apache.hadoop.fs.s3a.S3AFileSystem") \
    .config("spark.sql.extensions", "io.delta.sql.DeltaSparkSessionExtension") \
    .config("spark.sql.catalog.spark_catalog", "org.apache.spark.sql.delta.catalog.DeltaCatalog") \
    .config("spark.sql.catalogImplementation", "hive") \
    .config("spark.sql.warehouse.dir", "s3a://intent_tmp/warehouse") \
    .config("spark.executor.instances", NUM_EXECUTORS) \
    .config("spark.executor.cores", EXECUTOR_CORES) \
    .config("spark.executor.memory", EXECUTOR_MEMORY) \
    .config("spark.driver.memory", DRIVER_MEMORY) \
    .config("spark.kubernetes.executor.node.selector.general", EXECUTOR_LABEL) \
    .config("spark.driver.maxResultSize", "10g") \
    .config("spark.driver.host", driver_host) \
    .config("spark.rpc.askTimeout", 36000) \
    .getOrCreate()

# Your existing code...
data_array = ["Apache Spark", "is", "a unified analytics engine", "for large-scale data processing."]
rdd = spark.sparkContext.parallelize(data_array)
num_elements = rdd.count()
print(f"Number of elements in the RDD: {num_elements}")
num_words = rdd.flatMap(lambda line: line.split(" ")).count()
print(f"Number of words in the RDD: {num_words}")

# Push results to XCom
output = {'num_elements': num_elements, 'num_words': num_words}
print(json.dumps(output))

# Stop the SparkSession
spark.stop()
