In [None]:
import os
import os.path as op
import psutil
from pathlib import Path
from IPython.display import display

import findspark

In [None]:
findspark.init()

In [None]:
import pyspark.sql.functions as F
from pyspark import SparkConf
from pyspark.sql import SparkSession
from pyspark.sql.functions import col

In [None]:
if os.getenv('SLURM_TMPDIR'):
    SPARK_TMPDIR = Path(os.getenv('SLURM_TMPDIR')).resolve(strict=True)
elif os.getenv("TMPDIR"):
    SPARK_TMPDIR = Path(os.getenv('TMPDIR'))
elif os.getenv('SCRATCH'):
    SPARK_TMPDIR = Path(os.getenv('SCRATCH')).joinpath('tmp')
else:
    raise Exception("Could not find a temporary directory for SPARK data!")
    
SPARK_TMPDIR.mkdir(parents=True, exist_ok=True)

In [None]:
vmem = psutil.virtual_memory().total // 1024**2

In [None]:
spark_conf = SparkConf()
spark_conf.set("spark.sql.execution.arrow.enabled", "true")

if "SPARK_MASTER_HOST" in os.environ:
    SPARK_MASTER = f"spark://{os.environ['SPARK_MASTER_HOST']}:7077"

    CORES_PER_WORKER = 16
    num_workers = max(1, psutil.cpu_count() // CORES_PER_WORKER)
    print(f"num_workers: {num_workers}")
    # Make sure we are not wasting any cores
    if num_workers != psutil.cpu_count() / CORES_PER_WORKER:
        print("WARNING!!! Not using all available CPUs!")

    spark_conf.set("spark.driver.memory", "65000M")
    spark_conf.set("spark.driver.maxResultSize", "65000M")

    spark_conf.set("spark.executor.cores", f"{CORES_PER_WORKER}")
    spark_conf.set("spark.executor.memory", f"{int((vmem - 1024) * 0.8 / num_workers)}M")

    spark_conf.set("spark.network.timeout", "600s")
    spark_conf.set("spark.sql.shuffle.partitions", "2001")

    # spark_conf.set("spark.local.dirs", SPARK_TMPDIR.as_posix())
else: 
    SPARK_MASTER = f"local[{psutil.cpu_count()}]"

    driver_memory = min(64000, int(vmem // 2))
    executor_memory = int(vmem - driver_memory)

    spark_conf.set("spark.driver.memory", f"{driver_memory}M")
    spark_conf.set("spark.driver.maxResultSize", f"{driver_memory}M")

    spark_conf.set("spark.executor.memory", f"{executor_memory}M")

    # spark_conf.set("spark.network.timeout", "600s")
    spark_conf.set("spark.sql.shuffle.partitions", "200")

    spark_conf.set("spark.local.dirs", SPARK_TMPDIR.as_posix())
    spark_conf.set("spark.driver.extraJavaOptions", f"-Djava.io.tmpdir={SPARK_TMPDIR.as_posix()}")
    spark_conf.set("spark.executor.extraJavaOptions", f"-Djava.io.tmpdir={SPARK_TMPDIR.as_posix()}")

In [None]:
try:
    SPARK_CONF_EXTRA
except NameError:
    pass
else:
    for key, value in SPARK_CONF_EXTRA.items():
        spark_conf.set(key, value)

In [None]:
spark = (
    SparkSession
    .builder
    .master(SPARK_MASTER)
    .appName(op.basename(op.dirname(os.getcwd())))
    .config(conf=spark_conf)
    .getOrCreate()
)

In [None]:
assert spark.conf.get("spark.sql.execution.arrow.enabled") == "true"

In [None]:
display(spark)