# ETLFeatures OOM problem

Проекты джойнилки с количеством джойнов более 20 периодически падают с OOM на стадии джойна, перед записью результата в таблицу

Шаг 1: воспроизвести проблему, джойнить 30 датафреймов (доменов) и смотреть на метрики/логи в spark UI

Шаг 2: эксперементировать с решениями проблемы

После добавления чекпойнта на каждые 10 джойнов, падать на spill перестало, начало падать `Container killed by YARN for exceeding memory limits`

Увеличение памяти на экзекуторах и повышение количества партиций позволило успешно завершить джобу без единого task fail (если не считать container preemtion).

Использование `df.checkpoint(eager=true)` вполне работает, но генерирует раза в 3 больше нагрузку на hdfs.

In [1]:
import pprint
from pprint import pformat

import os
import datetime

from operator import and_
from collections import defaultdict

import six
import luigi
import pyspark.sql.functions as sqlfn

import json
import itertools as it

from pyspark.sql.types import MapType, ArrayType, FloatType, StringType, NumericType

if six.PY3:
    from functools import reduce  # make flake8 happy

In [None]:
# pprint.pprint(dict(os.environ), width=1)
def log(obj, msg=""):
    if msg: print(msg)
    print("type: {}\ndata: {}".format(type(obj), pformat(obj, indent=1, width=1)))

log(os.environ, "os.environ")
print()
log(dict(os.environ), "dict(os.environ)")

In [None]:
import os
import sys

from pyspark.sql import SparkSession, SQLContext

# Pack executable prj conda environment into zip
TMP_ENV_BASEDIR = "tmpenv"  # Reserved directory to store environment archive
env_dir = os.path.dirname(os.path.dirname(sys.executable))
env_name = os.path.basename(env_dir)
env_archive = "{basedir}/{env}.zip#{basedir}".format(basedir=TMP_ENV_BASEDIR, env=env_name)
os.environ["PYSPARK_PYTHON"] = "{}/{}/bin/python".format(TMP_ENV_BASEDIR, env_name)

# you need this only first time!
# !rm -rf {TMP_ENV_BASEDIR} && mkdir {TMP_ENV_BASEDIR} && cd {TMP_ENV_BASEDIR} && rsync -a {env_dir} . && zip -rq {env_name}.zip {env_name}

# b.config("spark.yarn.dist.archives", env_archive)
log(env_archive)

In [None]:
# Create Spark session with prj conda environment and JVM extensions
# `spark-submit ... --driver-java-options "-Dlog4j.configuration=file:/home/vlk/driver_log4j.properties"`
# spark.driver.extraJavaOptions

queue = "root.priority"

# "spark.driver.extraJavaOptions", "-Xss10M"
# catalyst SO while building parts. filter expression

# 200 GB of data
sssp = (200 * 4) * 2 * 2 * 4

spark = (
SparkSession.builder
    .master("yarn-client")
    .appName("TRG-77961-test-ipynb")
    .config("spark.yarn.queue", queue)
    .config("spark.sql.shuffle.partitions", sssp)
    .config("spark.executor.instances", "2")
    .config("spark.executor.cores", "4")
    .config("spark.executor.memory", "24G")
    .config("spark.executor.memoryOverhead", "8G")
    .config("spark.driver.memory", "4G")
    .config("spark.driver.maxResultSize", "1G")
    .config("spark.driver.extraJavaOptions", "-Dlog4j.configuration=file:/home/vlk/driver2_log4j.properties")
    .config("spark.speculation", "true")
    .config("spark.dynamicAllocation.enabled", "true")
    .config("spark.dynamicAllocation.minExecutors", "2")
    .config("spark.dynamicAllocation.maxExecutors", "512")
    .config("spark.dynamicAllocation.executorIdleTimeout", "300s")
    .config("spark.network.timeout", "800s")
    .config("spark.reducer.maxReqsInFlight", "10")
    .config("spark.shuffle.io.retryWait", "60s")
    .config("spark.shuffle.io.maxRetries", "10")
    .config("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
    .config("spark.kryoserializer.buffer.max", "1024m")
    .config("spark.hadoop.hive.exec.dynamic.partition", "true")
    .config("spark.hadoop.hive.exec.dynamic.partition.mode", "nonstrict")
    .config("spark.hadoop.hive.exec.max.dynamic.partitions", "1000000")
    .config("spark.hadoop.hive.exec.max.dynamic.partitions.pernode", "100000")
    .config("spark.hadoop.hive.metastore.client.socket.timeout", "60s")
    .config("spark.ui.enabled", "true")
    .config("spark.sql.sources.partitionColumnTypeInference.enabled", "false")
    .config("spark.yarn.dist.archives", env_archive)
    .getOrCreate()
)
# .config("spark.driver.extraJavaOptions", "-Xss10M -Dlog4j.configuration=file:/home/vlk/driver_log4j.properties")
# .config("spark.jars", "hdfs:/lib/dm/prj-transformers-assembly-dev-1.5.1.jar")

sql_ctx = SQLContext(spark.sparkContext)
(spark, sql_ctx)

In [5]:
# end of env. setup

In [12]:
import os
import numpy as np

from pprint import pformat

import luigi
import pyspark.sql.functions as sqlfn

from pyspark.storagelevel import StorageLevel
from pyspark.sql import DataFrame, SQLContext
from pyspark.sql.types import (
    MapType, ArrayType, FloatType, DoubleType, StringType, StructType, IntegralType, IntegerType
)
from pyspark.sql.utils import CapturedException
from pyspark.ml.wrapper import JavaWrapper

from luigi.contrib.hdfs import HdfsTarget

from dmprj.apps.utils.common import add_days
from dmprj.apps.utils.common.hive import format_table, select_clause
from dmprj.apps.utils.control.client.exception import FailedStatusException, MissingDepsStatusException
from dmprj.apps.utils.control.client.logs import ControlLoggingMixin
from dmprj.apps.utils.common.external_program import AvoidLuigiFlatTaskRunner

from dmprj.apps.utils.common import unfreeze_json_param
from dmprj.apps.utils.common.fs import HdfsClient
from dmprj.apps.utils.common.hive import FindPartitionsEngine
from dmprj.apps.utils.common.luigix import HiveTableSchemaTarget

from dmprj.apps.utils.common.hive import select_clause
from dmprj.apps.utils.common.spark import CustomUDFLibrary, insert_into_hive
from dmprj.apps.utils.common.luigix import HiveExternalTask, HiveGenericTarget
from dmprj.apps.utils.control.luigix import ControlApp, ControlDynamicOutputPySparkTask

from dmprj.common.hive import HiveMetastoreClient, HiveThriftSASLContext

In [7]:
# CustomUDFLibrary(spark, "hdfs:/lib/dm/prj-transformers-assembly-dev-1.5.1.jar").register_all_udf()

In [6]:
def show(df, message="dataframe", nlines=20, truncate=False, heavy=True):
    if not heavy: print("\n{}, rows: {}:".format(message, df.count()))
    df.printSchema()
    if not heavy: df.show(nlines, truncate)
    return df

In [7]:
db = "dmprj_source"
table = "hid_dataset_3_0"
dt = "2022-03-30"
uid_type = "HID"

In [None]:
source_expr = "select * from {db}.{table} where dt='{dt}' and uid_type='{uid_type}'".format(**globals())
log(source_expr)

In [9]:
source_df = spark.sql(source_expr).cache()

In [None]:
show(source_df, "source", heavy=False)

In [11]:
# 25 domains
columns_text = """
 |-- uid: string (nullable = true)
 |-- shows_activity: array (nullable = true)
 |-- banner_cats_ctr_wscore_30d: map (nullable = true)
 |-- hw_osm_ratios: map (nullable = true)
 |-- hw_osm_stats: array (nullable = true)
 |-- topics_motor200: map (nullable = true)
 |-- living_region_ids: map (nullable = true)
 |-- living_region_stats: array (nullable = true)
 |-- all_profs: array (nullable = true)
 |-- sn_epf: array (nullable = true)
 |-- sn_topics100: array (nullable = true)
 |-- app_stats: array (nullable = true)
 |-- app_topics100: map (nullable = true)
 |-- tp_os_counts: map (nullable = true)
 |-- tp_device_stats: array (nullable = true)
 |-- app_cats_activity: map (nullable = true)
 |-- mob_operators: map (nullable = true)
 |-- device_vendors: map (nullable = true)
 |-- app_events: map (nullable = true)
 |-- onelink_payments: map (nullable = true)
 |-- onelink_logins: map (nullable = true)
 |-- onelink_login_recencies: map (nullable = true)
 |-- topgoal_topics200: map (nullable = true)
 |-- mpop_senders: map (nullable = true)
 |-- app_cats_pc_cri_wscore_180d: map (nullable = true)
 |-- app_cats_installed: map (nullable = true)
 |-- dt: string (nullable = true)
 |-- uid_type: string (nullable = true)
"""

In [12]:
def strip(x):
    x = x.replace("|--", "").strip()
    for non_domain_prefix in ["uid", "dt", "uid_type"]:
        if x.startswith(non_domain_prefix):
            return ""
    return x.split(":")[0]

domains_names = [strip(x) for x in columns_text.split("\n") if strip(x)]
log(domains_names)

type: <type 'list'>
data: ['shows_activity',
 'banner_cats_ctr_wscore_30d',
 'hw_osm_ratios',
 'hw_osm_stats',
 'topics_motor200',
 'living_region_ids',
 'living_region_stats',
 'all_profs',
 'sn_epf',
 'sn_topics100',
 'app_stats',
 'app_topics100',
 'tp_os_counts',
 'tp_device_stats',
 'app_cats_activity',
 'mob_operators',
 'device_vendors',
 'app_events',
 'onelink_payments',
 'onelink_logins',
 'onelink_login_recencies',
 'topgoal_topics200',
 'mpop_senders',
 'app_cats_pc_cri_wscore_180d',
 'app_cats_installed']


In [None]:
domains_num = 30

def domain_name(i):
    return domains_names[i % len(domains_names)]

domains = [
    source_df.selectExpr("uid", "uid_type", "{} as d{}".format(domain_name(i), i+1))
    for i in range(domains_num)
]

log(domains)

In [14]:
keys = ["uid", "uid_type"]
join = "left_outer"
left = domains[0]

# OOM here, need another solution
def join_domains(left_df, keys, join, domains):
    for right_df in domains:
        left_df = left_df.join(right_df, keys, join)
    return left_df

# checkpoint each 10 step
hdfs_tmp_dir = "hdfs:/user/vlk/tmp/TRG-77961-etl_features-OOM"
spark.sparkContext.setCheckpointDir(os.path.join(hdfs_tmp_dir, "sc_checkpoint"))
CHECKPOINT_INTERVAL = 10

def checkpoint(df, step):
    def _checkpoint():
        return df.checkpoint(eager=True)
        # return df.checkpoint(eager=False)  # job fail almost immediately after join stage started, with 
    # SparkOutOfMemoryError: Unable to acquire 68 bytes of memory, got 0

    def _checkpoint_manual():
        # works fine, but slow
        checkpoint_dir = os.path.join(hdfs_tmp_dir, "checkpoint_step{}".format(step))
        _ = (
            df
            .write
            .mode("overwrite")
            .option("compression", "gzip")
            .option("mapreduce.fileoutputcommitter.algorithm.version", "2")
            .parquet(checkpoint_dir)
        )
        return spark.read.parquet(checkpoint_dir)

    if step > 0 and step % CHECKPOINT_INTERVAL == 0:
        return _checkpoint()
    return df

def join_domains_using_checkpoints(left_df, keys, join, domains):
    for step, right_df in enumerate(domains, 1):
        left_df = checkpoint(left_df.join(right_df, keys, join), step)
    return left_df

joined_df = join_domains(left, keys, join, domains[1:])
# joined_df = join_domains_using_checkpoints(left, keys, join, domains[1:])

log(joined_df)
# https://rm.adh.vk.team/proxy/application_1649019708481_15899/jobs/

type: <class 'pyspark.sql.dataframe.DataFrame'>
data: DataFrame[uid: string, uid_type: string, d1: array<float>, d2: map<string,float>, d3: map<string,float>, d4: array<float>, d5: map<string,float>, d6: map<string,float>, d7: array<float>, d8: array<float>, d9: array<float>, d10: array<float>, d11: array<float>, d12: map<string,float>, d13: map<string,float>, d14: array<float>, d15: map<string,float>, d16: map<string,float>, d17: map<string,float>, d18: map<string,float>, d19: map<string,float>, d20: map<string,float>, d21: map<string,float>, d22: map<string,float>, d23: map<string,float>, d24: map<string,float>, d25: map<string,float>, d26: array<float>, d27: map<string,float>, d28: map<string,float>, d29: array<float>, d30: map<string,float>]


In [None]:
show(joined_df, "joined")

In [16]:
result_path = os.path.join(hdfs_tmp_dir, "result")

_ = (
    joined_df
    .write
    .mode("overwrite")
    .option("compression", "gzip")
    .option("mapreduce.fileoutputcommitter.algorithm.version", "2")
    .parquet(result_path)
)

In [17]:
spark.stop()

# Check luigi tasks pool runner

In [18]:
from dmprj.apps.utils.common.external_program import BusyWaitPoolRunner
from pathos.multiprocessing import ProcessPool

class AvoidLuigiFlatTaskRunner(ControlLoggingMixin):
    """Run a batch of luigi tasks in parallel without using native luigi worker-scheduler interface.

    The main reason behind this is avoiding strange deadlocks within luigi framework, when spawning more tasks within a
    task run method. The only limitation here is that all input tasks for this kind of runner should have only external
    dependencies.

    :param tasks: list of luigi tasks to run.
    :param processes: size of processing pool or None (by default) for using len(tasks) as the pool size.
    :param log_url: Control logging endpoint.
    :param raise_on_task_failure: if True, any task that fails aborts all pool processes.
    """

    def __init__(self, tasks, processes=None, log_url=None, raise_on_task_failure=True):
        self.tasks = tasks
        self.processes = len(tasks) if processes is None else processes
        self.log_url = log_url
        self.raise_on_task_failure = raise_on_task_failure
        self._validate_tasks()

    def _validate_tasks(self):
        for task in self.tasks:
            if not isinstance(task, luigi.Task):
                raise TypeError("All tasks must be `luigi.Task`, got {}".format(task))
            for required_task in task.deps():
                if not isinstance(required_task, luigi.ExternalTask):
                    raise TypeError("Got non-external dependency {} for a task {}".format(required_task, task))
                if not required_task.complete():
                    raise ValueError("Incomplete external dependency {} for a task {}".format(required_task, task))

    def run(self, log_traceback=True):
        return BusyWaitPoolRunner(
            pool=ProcessPool(nodes=self.processes),
            log_traceback=log_traceback,
            log_url=self.log_url
        ).run(
            self._run_task,
            self.tasks
        )

    def _run_task(self, task):
        if not task.complete():
            self.info("Running a task {}".format(task))
            try:
                task.run()
            except Exception as e:
                task.on_failure(e)
                success_flag = False
                if self.raise_on_task_failure:
                    raise e
            else:
                task.on_success()
                success_flag = True
        else:
            self.info("Task {} is already complete".format(task))
            success_flag = True

        return task.task_id, success_flag


In [None]:
class T(luigi.Task):
    def __init__(self, *args, **kwargs):
        super(T, self).__init__()
    def complete(self):
        return False
    def run(self):
        raise ValueError("oops")
    def on_failure(self, exception):
        log(exception, msg="\n on_failure")
        raise ValueError("failed on_failure")
        return "failure: {}".format(exception)
    def on_success(self):
        pass

tasks = [T()]

results = None
try:
    results = AvoidLuigiFlatTaskRunner(
        tasks=tasks,
        raise_on_task_failure=False,
    ).run(
        log_traceback=True
    )
except BaseException as e:
        log(e, msg="\n boo-hoo! Exception:")

log(results, msg="\n results:")