# insert_into_hive experiments

In [1]:
import pprint
from pprint import pformat

import os
import datetime

from operator import and_
from collections import defaultdict

import six
import luigi
import pyspark.sql.functions as sqlfn

import json
import itertools as it

from pyspark.sql.types import MapType, ArrayType, FloatType, StringType, NumericType

if six.PY3:
    from functools import reduce  # make flake8 happy

In [None]:
# pprint.pprint(dict(os.environ), width=1)
def log(obj, msg=""):
    if msg: print(msg)
    print("type: {}\ndata: {}".format(type(obj), pformat(obj, indent=1, width=1)))

log(os.environ, "os.environ")
print()
log(dict(os.environ), "dict(os.environ)")

In [None]:
import os
import sys

from pyspark.sql import SparkSession, SQLContext

# Pack executable prj conda environment into zip
TMP_ENV_BASEDIR = "tmpenv"  # Reserved directory to store environment archive
env_dir = os.path.dirname(os.path.dirname(sys.executable))
env_name = os.path.basename(env_dir)
env_archive = "{basedir}/{env}.zip#{basedir}".format(basedir=TMP_ENV_BASEDIR, env=env_name)
os.environ["PYSPARK_PYTHON"] = "{}/{}/bin/python".format(TMP_ENV_BASEDIR, env_name)

# you need this only first time!
# !rm -rf {TMP_ENV_BASEDIR} && mkdir {TMP_ENV_BASEDIR} && cd {TMP_ENV_BASEDIR} && rsync -a {env_dir} . && zip -rq {env_name}.zip {env_name}

log(env_archive)

In [None]:
# Create Spark session with prj conda environment and JVM extensions
# `spark-submit ... --driver-java-options "-Dlog4j.configuration=file:/home/vlk/driver_log4j.properties"`
# spark.driver.extraJavaOptions
queue = "root.regular"

# "spark.driver.extraJavaOptions", "-Xss10M"
# catalyst SO while building parts. filter expression

# 4 TB of data
# sssp = 4 * 4 * 1024
sssp = 1 * 4 * 1024

spark = (
SparkSession.builder
    .master("yarn-client")
    .appName("TRG-75523-insert_into_hive-test-ipynb")
    .config("spark.yarn.queue", queue)
    .config("spark.executor.instances", "4")
    .config("spark.executor.memory", "8G")
    .config("spark.executor.cores", "6")
    .config("spark.executor.memoryOverhead", "2G")
    .config("spark.sql.shuffle.partitions", sssp)
    .config("spark.driver.memory", "4G")
    .config("spark.driver.maxResultSize", "1G")
    .config("spark.driver.extraJavaOptions", "-Dlog4j.configuration=file:/home/vlk/driver2_log4j.properties")
    .config("spark.speculation", "true")
    .config("spark.dynamicAllocation.enabled", "true")
    .config("spark.dynamicAllocation.minExecutors", "4")
    .config("spark.dynamicAllocation.maxExecutors", "512")
    .config("spark.dynamicAllocation.executorIdleTimeout", "300s")
    .config("spark.network.timeout", "800s")
    .config("spark.reducer.maxReqsInFlight", "10")
    .config("spark.shuffle.io.retryWait", "60s")
    .config("spark.shuffle.io.maxRetries", "10")
    .config("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
    .config("spark.kryoserializer.buffer.max", "1024m")
    .config("spark.hadoop.hive.exec.dynamic.partition", "true")
    .config("spark.hadoop.hive.exec.dynamic.partition.mode", "nonstrict")
    .config("spark.hadoop.hive.exec.max.dynamic.partitions", "1000000")
    .config("spark.hadoop.hive.exec.max.dynamic.partitions.pernode", "100000")
    .config("spark.hadoop.hive.metastore.client.socket.timeout", "3600s")
    .config("spark.ui.enabled", "true")
    .config("spark.sql.sources.partitionColumnTypeInference.enabled", "false")
    .config("spark.yarn.dist.archives", env_archive)
    .getOrCreate()
)
# .config("spark.driver.extraJavaOptions", "-Xss10M -Dlog4j.configuration=file:/home/vlk/driver_log4j.properties")
#     .config("spark.jars", "hdfs:/lib/prj-transformers-assembly-dev-1.5.1.jar")
sql_ctx = SQLContext(spark.sparkContext)
(spark, sql_ctx)

In [5]:
# end of env. setup

In [None]:
import os
import numpy as np

from pprint import pformat

import luigi
import pyspark.sql.functions as sqlfn

from pyspark.storagelevel import StorageLevel
from pyspark.sql import DataFrame, SQLContext
from pyspark.sql.types import (
    MapType, ArrayType, FloatType, DoubleType, StringType, StructType, IntegralType, IntegerType
)
from pyspark.sql.utils import CapturedException
from pyspark.ml.wrapper import JavaWrapper

from luigi.contrib.hdfs import HdfsTarget

from dmprj.apps.utils.common import add_days
from dmprj.apps.utils.common.hive import format_table, select_clause
from dmprj.apps.utils.common.luigix import HiveExternalTask
from dmprj.apps.utils.control.luigix.task import ControlApp
from dmprj.apps.utils.control.client.exception import FailedStatusException, MissingDepsStatusException

from dmprj.apps.utils.common import unfreeze_json_param
from dmprj.apps.utils.common.fs import HdfsClient
from dmprj.apps.utils.common.hive import FindPartitionsEngine
from dmprj.apps.utils.common.spark import prjUDFLibrary
from dmprj.apps.utils.common.luigix import HiveTableSchemaTarget

from dmprj.apps.utils.common.hive import select_clause
from dmprj.apps.utils.common.spark import prjUDFLibrary, insert_into_hive
from dmprj.apps.utils.common.luigix import HiveExternalTask, HiveGenericTarget
from dmprj.apps.utils.control.luigix import ControlApp, ControlDynamicOutputPySparkTask

In [None]:
# CustomUDFLibrary(spark, "hdfs:/lib/prj-transformers-assembly-dev-1.5.1.jar").register_all_udf()

In [6]:
def show(df, message="dataframe", nlines=20, truncate=False, heavy=True):
    if not heavy: print("\n{}, rows: {}:".format(message, df.count()))
    df.printSchema()
    if not heavy: df.show(nlines, truncate)
    return df

## reload table to `dt=yyyy-MM-dd` partitions

In [None]:
def reload():
    temp_path = "/user/vlk/test/ds_auditories/mob_app_audience"
    df = spark.read.parquet(os.path.join(temp_path, "source")).persist()
    (
        df.coalesce(1024 * 8)
        .write.mode("overwrite")
        .partitionBy("dt")
        .parquet(os.path.join(temp_path, "parts"))
    )
    df.unpersist()
    spark.stop()

## ds_auditories.mob_app_audience tests

In [7]:
drop_table_sql = "drop table if exists user_vlk.mob_app_audience_TRG75523 purge"

# TODO: check if this true: for some reason, spark create table is uncompatible with rest of the ETL pipeline.
# creating table in hive console will do nicely.
create_table_sql = """
CREATE TABLE user_vlk.mob_app_audience_TRG75523 (
  `uid` string COMMENT 'Uid',
  `score` double COMMENT 'Uid [0,1]-interval app install duration ranking score')
COMMENT 'Installed mobile apps user lists in audience format'
PARTITIONED BY (
  `audience_name` string COMMENT 'mobile app store id as in store_id from md_mobile.mobile_app',
  `category` string COMMENT 'mobile app aggregation type, positive stands for default last year aggregation period',
  `dt` string COMMENT 'Data date',
  `uid_type` string COMMENT 'Type of uid, e.g. GAID, IDFA')
ROW FORMAT SERDE
  'org.apache.hadoop.hive.ql.io.orc.OrcSerde'
STORED AS INPUTFORMAT
  'org.apache.hadoop.hive.ql.io.orc.OrcInputFormat'
OUTPUTFORMAT
  'org.apache.hadoop.hive.ql.io.orc.OrcOutputFormat'
"""

In [8]:
spark.sql(drop_table_sql)

DataFrame[]

In [9]:
spark.sql(create_table_sql)

DataFrame[]

In [None]:
# spark.sql("REFRESH TABLE ds_auditories.mob_app_audience")

In [None]:
# from table

# (2022-01-07,11972) # corrupted
# (2022-01-10,21558)
# (2022-01-01,31263) # good
# (2022-01-16,32050) # good
# (2022-01-17,32101) # good

# dts = "dt in ('2022-01-01', '2022-01-16', '2022-01-17')"
dts = "dt in ('2022-01-01')"

df = _show(
    spark.table("ds_auditories.mob_app_audience")
      .where(dts)
    .cache(),
    "ds_auditories.mob_app_audience, 1 dt",
    heavy=True
)

temp_path = "/user/vlk/test/ds_auditories/mob_app_audience/source"
# df.write.mode("overwrite").parquet(temp_path)
# df.unpersist()
# df = spark.read.parquet(temp_path).persist()

In [10]:
# from hdfs

base_path = "/user/vlk/test/ds_auditories/mob_app_audience/parts"
temp_path = base_path

# dts = "dt in ('2022-01-17', '2022-01-16', '2022-01-01')"
dts = "dt in ('2022-01-17')"

df = (
    spark.read.option("basePath", base_path)
    .parquet(temp_path)
    .where(dts)    
    .persist()
)
# .drop("dt").withColumn("dt", sqlfn.lit("2022-01-17")) 
# spark.sql.sources.partitionColumnTypeInference.enabled
show(df, "source")

root
 |-- uid: string (nullable = true)
 |-- score: double (nullable = true)
 |-- audience_name: string (nullable = true)
 |-- category: string (nullable = true)
 |-- uid_type: string (nullable = true)
 |-- dt: string (nullable = true)



DataFrame[uid: string, score: double, audience_name: string, category: string, uid_type: string, dt: string]

In [11]:
insert_into_hive(
    df,
    database="user_vlk",
    table="mob_app_audience_TRG75523",
    max_rows_per_bucket=2505505,
    overwrite=True,
    raise_on_missing_columns=True,
    check_parameter="markedAsDataLoaded",
    jar="hdfs:/lib/prj-transformers-assembly-dev-1.5.2.jar"
)

# 22/01/19 17:14:11 INFO Writer: Updating Hive table partitions parameters `markedAsDataLoaded -> None` for 32101 partitions ...
# java.lang.StackOverflowError
#	at org.apache.spark.sql.catalyst.expressions.BinaryOperator.sql(Expression.scala:592)
# "spark.driver.extraJavaOptions": "-Xss10M"

# java.lang.StackOverflowError
#	at org.apache.spark.sql.hive.client.Shim_v0_13.org$apache$spark$sql$hive$client$Shim_v0_13$$convert$1(HiveShim.scala:714)
   

In [None]:
# TODO: rewrite one dt, two dt, see how result changed

In [12]:
df.unpersist()
spark.stop()

## failed insert_into_hive, exploit

In [None]:
# /user/vlk/test/target_novlk/target
df = spark.read.parquet("/user/vlk/test/target_novlk/target")

In [None]:
show(df)

In [None]:
# dmprj_source.app_feature_extended
# show create table dmprj_source.app_feature_extended;

In [None]:
original_sql = """
CREATE TABLE `dmprj_source.app_feature_extended`(
  `store_id` string COMMENT 'Public store app_id', 
  `app_type` map<string,int> COMMENT 'Mobile app type: game/application/album', 
  `developer` map<string,int> COMMENT 'Mobile app developer', 
  `os_ver` map<string,int> COMMENT 'OS concatenated with version, e.g. iOS_11.0', 
  `all_categories` map<string,int> COMMENT 'All app categories from App Store or Google Play', 
  `categories` map<string,int> COMMENT 'Main app category from App Store or Google Play', 
  `age_limit` map<string,int> COMMENT 'Type of age restriction for mobile app, like 16+ etc.', 
  `desc_language` map<string,float> COMMENT 'App description language', 
  `app_topics100` map<string,float> COMMENT 'LDA topics based on mobile app description', 
  `desc_taxons` map<string,float> COMMENT 'App description taxons assigned by the trg.text.cls.bert.BertCategoryModel', 
  `desc_embedding` map<string,int> COMMENT 'App description embedding acquired by the trg.encoding.SparseAutoEncoder', 
  `events` map<string,float> COMMENT 'Scale of number of events, associated with the app, splitted by event type', 
  `misc_info` array<float> COMMENT 'The rest dense features in order: price, rating, votes_scale, weeks from creation')
COMMENT 'Mobile application extended features'
PARTITIONED BY ( 
  `dt` string COMMENT 'Data date', 
  `os` string COMMENT 'iOS/Android')
"""

In [None]:
hive_script = """
drop table if exists dmprj_dev_source.TRG_73352_test purge;

create table dmprj_dev_source.TRG_73352_test (
  `store_id` string COMMENT 'Public store app_id', 
  `app_type` map<string,int> COMMENT 'Mobile app type: game/application/album', 
  `developer` map<string,int> COMMENT 'Mobile app developer', 
  `os_ver` map<string,int> COMMENT 'OS concatenated with version, e.g. iOS_11.0', 
  `all_categories` map<string,int> COMMENT 'All app categories from App Store or Google Play', 
  `categories` map<string,int> COMMENT 'Main app category from App Store or Google Play', 
  `age_limit` map<string,int> COMMENT 'Type of age restriction for mobile app, like 16+ etc.', 
  `desc_language` map<string,float> COMMENT 'App description language', 
  `app_topics100` map<string,float> COMMENT 'LDA topics based on mobile app description', 
  `desc_taxons` map<string,float> COMMENT 'App description taxons assigned by the trg.text.cls.bert.BertCategoryModel', 
  `desc_embedding` map<string,int> COMMENT 'App description embedding acquired by the trg.encoding.SparseAutoEncoder', 
  `events` map<string,float> COMMENT 'Scale of number of events, associated with the app, splitted by event type', 
  `misc_info` array<float> COMMENT 'The rest dense features in order: price, rating, votes_scale, weeks from creation'
 )
COMMENT 'Mobile application extended features'
partitioned by (
  `dt` string COMMENT 'Data date', 
  `os` string COMMENT 'iOS/Android'
)
stored as orc;

gdfs ls -lah /data/dm/prj/dev/hive/dmprj_dev_source.db/trg_73352_test/

"""

In [None]:
def rename_columns(df, **mapping):
    show(df)
    return show(df.selectExpr(*["{} as {}".format(col, mapping.get(col, col)) for col in df.columns]))

In [None]:
def create_partition_buckets(df, max_rows_per_bucket, *partition_columns):
    log(partition_columns, "\npartition columns")
    if not partition_columns:
        num_buckets = int(np.ceil(df.count() / float(max_rows_per_bucket)))
        return df.repartition(num_buckets)

    prefixed_cols = [("a_" + col) for col in df.columns]
    prefixed_partition_cols = [("a_" + col) for col in partition_columns]
    add_prefix_map = dict(zip(df.columns, prefixed_cols))
    remove_prefix_map = dict(zip(prefixed_cols, df.columns))

    df = rename_columns(df, **add_prefix_map)
    show(df)

    part_df = df.groupBy(*prefixed_partition_cols).count()
    show(part_df)
    part_df = part_df.toPandas()
    part_df["num_buckets"] = np.ceil(part_df["count"] / float(max_rows_per_bucket)).astype(int)
    part_df["beg"] = part_df["num_buckets"].cumsum() - part_df["num_buckets"]
    log(part_df, "\npartitions counts")

    partition_map = {}
    for _, row in part_df.iterrows():
        partition_map.update({tuple(row[prefixed_partition_cols]): (row.beg, row.num_buckets)})
    log(partition_map, "\nrepartition data")
    partition_map_bc = df.sql_ctx.sparkSession.sparkContext.broadcast(partition_map)

    @sqlfn.udf(returnType=IntegerType())
    def _partition_index(*cols):
        beg, count = partition_map_bc.value[tuple(cols)]
        return int(beg + np.random.randint(count))

    indexed_df = show(df.withColumn("index", _partition_index(*prefixed_partition_cols)).cache())    
    bucket_df = show(
        indexed_df.repartitionByRange(
            max(part_df["num_buckets"].sum(), 1),
            "index"
        )
    ).drop("index")

    return rename_columns(bucket_df, **remove_prefix_map)


In [None]:
def insert_into_hive(df, database, table, max_rows_per_bucket, overwrite=True, raise_on_missing_columns=True):
    spark = df.sql_ctx.sparkSession
    columns = spark.catalog.listColumns(dbName=database, tableName=table)
    log(columns, "\ncatalog columns")

    if not raise_on_missing_columns:
        df = df.select(
            *[
                sqlfn.col(column.name) if column.name in df.columns else sqlfn.lit(None).alias(column.name)
                for column in columns
            ]
        )

    main_columns, partition_columns = [], []

    for column in columns:
        if column.isPartition:
            partition_columns.append(column.name)
        else:
            main_columns.append(column.name)

    old_mode = spark.conf.get("spark.sql.sources.partitionOverwriteMode")
    spark.conf.set("spark.sql.sources.partitionOverwriteMode", "dynamic")
    log(old_mode, "\nsaved spark.sql.sources.partitionOverwriteMode")

    try:
        create_partition_buckets(
            df.select(*(main_columns + partition_columns)), max_rows_per_bucket, *partition_columns
        ).write.insertInto("{}.{}".format(database, table), overwrite=overwrite)
    finally:
        spark.conf.set("spark.sql.sources.partitionOverwriteMode", old_mode)



In [None]:
db = "dmprj_dev_source"
table = "TRG_73352_test"

In [None]:
# insert_into_hive(df, database=db, table=table, max_rows_per_bucket=300000, raise_on_missing_columns=True)

In [None]:
from dmprj.apps.utils.common.spark import insert_into_hive
df = spark.read.parquet("/user/vlk/test/target_novlk/target")
insert_into_hive(
    df, database=db, table=table, max_rows_per_bucket=1000000, overwrite=True, raise_on_missing_columns=True
)

t = """
vlk@host> gdfs du -h /data/dm/prj/dev/hive/dmprj_dev_source.db/trg_73352_test/
9.8M    /data/dm/prj/dev/hive/dmprj_dev_source.db/trg_73352_test/dt=2021-12-12/os=Android/part-00000-44369aa8-30e8-4a36-91e6-b2d301dc0243.c000
9.8M    /data/dm/prj/dev/hive/dmprj_dev_source.db/trg_73352_test/dt=2021-12-12/os=Android
2.5M    /data/dm/prj/dev/hive/dmprj_dev_source.db/trg_73352_test/dt=2021-12-12/os=iOS/part-00001-44369aa8-30e8-4a36-91e6-b2d301dc0243.c000
2.5M    /data/dm/prj/dev/hive/dmprj_dev_source.db/trg_73352_test/dt=2021-12-12/os=iOS
12.3M   /data/dm/prj/dev/hive/dmprj_dev_source.db/trg_73352_test/dt=2021-12-12
12.3M   /data/dm/prj/dev/hive/dmprj_dev_source.db/trg_73352_test
"""

## read_orc_table test

In [None]:
def read_orc_table(what, partition_filter_expr, spark, jar="hdfs:/lib/dwh/common-1.21.21.jar"):
    spark.sql("ADD JAR {}".format(jar))
    _rtable = spark._sc._jvm.ru.mail.dwh.common.TableUtils.readTableAsUnionOrcFiles
    _jdf = _rtable(what, sqlfn.expr(partition_filter_expr)._jc, spark._jsparkSession)
    return DataFrame(_jdf, SQLContext(spark.sparkContext))

In [None]:
# /data/dm/prj/dev/hive/dmprj_dev_source.db/bits_user_vector/
# source_name=DM-8225-hid-v1/dt=2020-10-16/uid_type=HID
df = read_orc_table(
    what="dmprj_dev_source.bits_user_vector", 
    partition_filter_expr="source_name='DM-8225-hid-v1' and dt='2020-10-16' and uid_type='HID'",
    spark=spark
)

In [None]:
show(df, nlines=10, truncate=True)

In [None]:
def read_orc_table_jvm(what, partition_filter_expr, spark, jar="hdfs:/lib/dwh/common-1.21.21.jar"):
    from pyspark.ml.wrapper import JavaWrapper
    spark.sql("ADD JAR {}".format(jar))

    jdf = JavaWrapper._new_java_obj(
        "TableUtils.readTableAsUnionOrcFiles",
        what, sqlfn.expr(partition_filter_expr)._jc, spark._jsparkSession
    )
    return DataFrame(jdf, SQLContext(spark.sparkContext))

In [None]:
# /data/dm/prj/dev/hive/dmprj_dev_source.db/bits_user_vector/
# source_name=DM-8225-hid-v1/dt=2020-10-16/uid_type=HID
df = read_orc_table_jvm(
    what="dmprj_dev_source.bits_user_vector", 
    partition_filter_expr="source_name='DM-8225-hid-v1' and dt='2020-10-16' and uid_type='HID'",
    spark=spark
)

In [None]:
show(df, nlines=10, truncate=True)

## insert_into_hive, ~1TB to write

In [None]:
df = spark.sql("select * from ds_auditories.mob_app_audience").where("dt = '2021-12-01'")

In [None]:
show(df, "input DF")
# 41 059 901 473 rows

In [None]:
hive_script = """
drop table if exists dmprj_dev_source.TRG_73352_test purge;

CREATE TABLE IF NOT EXISTS dmprj_dev_source.TRG_73352_test (
  uid           string             comment 'user id',
  score         double             comment 'score value'
) comment 'user score'
partitioned by (
  category      string             comment 'category',
  dt            string             comment 'date as string',
  uid_type      string             comment 'uid type'
)
stored as orc;

gdfs ls -lah /data/dm/prj/dev/hive/dmprj_dev_source.db/trg_73352_test/

"""

In [None]:
spark.sql("DROP TABLE IF EXISTS dmprj_dev_source.trg_73352_test")

In [None]:
ddl = """
CREATE TABLE IF NOT EXISTS dmprj_dev_source.TRG_73352_test (
  uid           string             comment 'user id',
  score         double             comment 'score value'
) comment 'user score'
partitioned by (
  category      string             comment 'category',
  dt            string             comment 'date as string',
  uid_type      string             comment 'uid type'
)
stored as orc
""".strip()

In [None]:
spark.sql(ddl)

In [None]:
def insert_into_hive_jvm(
    df,
    database,
    table,
    max_rows_per_bucket,
    overwrite=True,
    raise_on_missing_columns=True,
    check_parameter=None,
    jar="hdfs:/lib/dm/prj-transformers-assembly-dev-1.5.0.jar",
):
    df.sql_ctx.sql("ADD JAR {}".format(jar))
    writer = JavaWrapper._create_from_java_class("prj.hive.Writer")  # SQL
#     writer = JavaWrapper._create_from_java_class("prj.hive.Writer", "RDD")

    writer._java_obj.insertIntoHive(
        df._jdf,
        database,
        table,
        max_rows_per_bucket,
        overwrite,
        raise_on_missing_columns,
        check_parameter,
    )


In [None]:
insert_into_hive_jvm(df, "dmprj_dev_source", "trg_73352_test", 6000000, True, False, "markedAsDataLoaded")
# 30 min SQL; 45 min Python; 45 min RDD; 35 min TRF

In [None]:
partitions_finder = FindPartitionsEngine()
partitions = partitions_finder.find(
            database="dmprj_dev_source",
            table="trg_73352_test",
            partition_conf={"dt": "2021-12-01"},
            min_dt="2021-12-01",
            max_dt="2021-12-01",
            check_parameter="markedAsDataLoaded",
)
log((len(partitions), partitions))

### Python

In [None]:
from dmprj.apps.utils.common.spark import CustomUDFLibrary, insert_into_hive
insert_into_hive(df, "dmprj_dev_source", "trg_73352_test", 6000000, True, False)

## experiments

In [None]:
df = show(spark.createDataFrame([
        # A 09
        {"uid": "a", "feature": 1.1, "uid_type": "A", "dt": "2021-11-09"},
        {"uid": "b", "feature": 1.2, "uid_type": "A", "dt": "2021-11-09"},
        {"uid": "c", "feature": 1.3, "uid_type": "A", "dt": "2021-11-09"},
        # A 10
        {"uid": "a", "feature": 1.1, "uid_type": "A", "dt": "2021-11-10"},
        {"uid": "b", "feature": 1.2, "uid_type": "A", "dt": "2021-11-10"},
        {"uid": "c", "feature": 1.3, "uid_type": "A", "dt": "2021-11-10"},
        # B 09
        {"uid": "a", "feature": 1.1, "uid_type": "B", "dt": "2021-11-09"},
        {"uid": "b", "feature": 1.2, "uid_type": "B", "dt": "2021-11-09"},
        {"uid": "c", "feature": 1.3, "uid_type": "B", "dt": "2021-11-09"},
        # B 10
        {"uid": "a", "feature": 1.1, "uid_type": "B", "dt": "2021-11-10"},
        {"uid": "b", "feature": 1.2, "uid_type": "B", "dt": "2021-11-10"},
        {"uid": "c", "feature": 1.3, "uid_type": "B", "dt": "2021-11-10"},
    ]).persist(StorageLevel.MEMORY_ONLY))

df.createGlobalTempView("test_features")

In [None]:
db = "dmprj_dev_source"
table = "float_feature_test"

In [None]:
log(np.random.randint(7), "np.random.randint(7), must be 0..6")

In [None]:
def rename_columns(df, **mapping):
    show(df)
    return show(df.selectExpr(*["{} as {}".format(col, mapping.get(col, col)) for col in df.columns]))

In [None]:
def create_partition_buckets(df, max_rows_per_bucket, *partition_columns):
    log(partition_columns, "\npartition columns")
    if not partition_columns:
        num_buckets = int(np.ceil(df.count() / float(max_rows_per_bucket)))
        return df.repartition(num_buckets)

    prefixed_cols = [("a_" + col) for col in df.columns]
    prefixed_partition_cols = [("a_" + col) for col in partition_columns]
    add_prefix_map = dict(zip(df.columns, prefixed_cols))
    remove_prefix_map = dict(zip(prefixed_cols, df.columns))

    df = rename_columns(df, **add_prefix_map)
    show(df)

    part_df = df.groupBy(*prefixed_partition_cols).count()
    show(part_df)
    part_df = part_df.toPandas()
    part_df["num_buckets"] = np.ceil(part_df["count"] / float(max_rows_per_bucket)).astype(int)
    part_df["beg"] = part_df["num_buckets"].cumsum() - part_df["num_buckets"]
    log(part_df, "\npartitions counts")

    partition_map = {}
    for _, row in part_df.iterrows():
        partition_map.update({tuple(row[prefixed_partition_cols]): (row.beg, row.num_buckets)})
    log(partition_map, "\nrepartition data")
    partition_map_bc = df.sql_ctx.sparkSession.sparkContext.broadcast(partition_map)

    @sqlfn.udf(returnType=IntegerType())
    def _partition_index(*cols):
        beg, count = partition_map_bc.value[tuple(cols)]
        return int(beg + np.random.randint(count))

    indexed_df = show(df.withColumn("index", _partition_index(*prefixed_partition_cols)).cache())    
    bucket_df = show(
        indexed_df.repartitionByRange(
            max(part_df["num_buckets"].sum(), 1),
            "index"
        )
    ).drop("index")

    return rename_columns(bucket_df, **remove_prefix_map)


In [None]:
def insert_into_hive(df, database, table, max_rows_per_bucket, overwrite=True, raise_on_missing_columns=True):
    spark = df.sql_ctx.sparkSession
    columns = spark.catalog.listColumns(dbName=database, tableName=table)
    log(columns, "\ncatalog columns")

    if not raise_on_missing_columns:
        df = df.select(
            *[
                sqlfn.col(column.name) if column.name in df.columns else sqlfn.lit(None).alias(column.name)
                for column in columns
            ]
        )

    main_columns, partition_columns = [], []

    for column in columns:
        if column.isPartition:
            partition_columns.append(column.name)
        else:
            main_columns.append(column.name)

    old_mode = spark.conf.get("spark.sql.sources.partitionOverwriteMode")
    spark.conf.set("spark.sql.sources.partitionOverwriteMode", "dynamic")
    log(old_mode, "\nsaved spark.sql.sources.partitionOverwriteMode")

    try:
        create_partition_buckets(
            df.select(*(main_columns + partition_columns)), max_rows_per_bucket, *partition_columns
        ).write.insertInto("{}.{}".format(database, table), overwrite=overwrite)
    finally:
        spark.conf.set("spark.sql.sources.partitionOverwriteMode", old_mode)


In [None]:
insert_into_hive(df, database=db, table=table, max_rows_per_bucket=1, raise_on_missing_columns=True)

In [None]:
# experiments with Spark SQL functions
# https://spark.apache.org/docs/latest/api/sql/index.html

In [None]:
# filter array values
df_0 = show(spark.sql(
    "select filter("
    "array("
    "cast(1 as float), cast(null as float), cast('NaN' as float), cast(4 as float)"
    "), _x -> not isnull(_x) and not isnan(_x)"
    ") as arrcol"
))

In [None]:
# create map
df_1 = show(
    spark.sql(
        "select map_from_arrays("
        "array('a', 'b', 'c', '101', '-1', '303'), "
        "array(cast(1 as double), cast(null as double), cast('NaN' as double), 1.01, 2.02, 3.03)"
        ") as mapcol"
    )
)

In [None]:
# convert map to array
df_2 = show(df_1.selectExpr(
    "user_dmdesc.map_key_values(mapcol) as arrtuples"
))

In [None]:
# drop invalid tuples
df_3 = show(df_2.selectExpr(
    "filter(arrtuples, _x -> "
    "not isnull(_x['value']) and not isnan(_x['value']) and is_uint32(_x['key'])"
    ") as arrtuples"
))

In [None]:
# convert array to map
df_4 = show(df_3.selectExpr(
    "cast("
    "map_from_entries(arrtuples)"
    "as map<string,float>) as mapcol"
))

In [None]:
spark.sql("select cast(0 as float)").show()

In [None]:
spark.catalog.dropGlobalTempView("test_features")
df.unpersist()

In [None]:
spark.stop()