# convert pandas dataframe to spark dataframe

In [None]:
from pyspark.sql.types import StructType, StructField, StringType, IntegerType, DoubleType, DateType, FloatType
from pyspark.sql.functions import col, to_date


df = None # pandas dataframe object
spark = None # spark instance

# convert pandas dataframe to spark df
schema = StructType([
    StructField("monthly", StringType(), True),
    StructField("OP_perct", StringType(), True),
    StructField("score", StringType(), True),
    StructField("Score_Value", DoubleType(), True),
    StructField("bad_tag", StringType(), True),
    StructField("dim_name", StringType(), True),
    StructField("dim_value", StringType(), True),
    StructField("Catch_Rate", DoubleType(), True),
    StructField("Bad", DoubleType(), True),
    StructField("Good", DoubleType(), True),
    StructField("Total", DoubleType(), True),
    StructField("FPR", DoubleType(), True),
    StructField("wgt_alias", StringType(), True),
])

spark_df = spark.createDataFrame(df, schema=schema)

# debug dataframe

In [None]:

data = [("2024-07-01", "2024-07"),
        ("2024-09-01", "2024-09"),
        ("2024-10-01", "2024-10"),
       ]

columns = ["driver_pmt_date","driver_monthly"]
driver_df = spark.createDataFrame(data = data, schema = columns)
driver_df.show()

# adding new column using case when

In [None]:
from pyspark.sql.functions import col, when, lit


df = df.withColumn('is_declined_bad',
                   when((col('decline_pattern').isNull()) | (col('decline_pattern') == lit('no pattern')), 0)\
                   .otherwise(1).cast('int'))

# rename multi columns in select method

In [None]:
import pyspark.sql.functions as fn

join_keys = ['TRANS_ID']
norm_postfix = '_normed'
candidate_vars = ['feat1', 'feat2']

ret = df.select(join_keys + [fn.col(c).alias(c + norm_postfix) for c in df.columns if c in candidate_vars])

# spark data IO

spark glob pattern
- posix glob pattern： https://man7.org/linux/man-pages/man7/glob.7.html
- hadoop glob pattern extension https://hail.is/docs/0.2/hadoop_glob_patterns.html

In [None]:
# load parquet data with data path glob
# 1
spark.read.format('parquet').load('hdfs://path/to/data/all_vars_[0-9]*/group*')

# 2
spark.read.option(option_key, option_value).parquet('hdfs://path/to/data/all_vars_[0-9]*/group*')

# read csv
df = spark.read.options(delimiter='\x07', header=True).csv("hdfs://path/to/directory/of/csv_file")


# write spark dataframe to file

In [None]:
# 1
df.write.mode('overwrite').parquet('hdfs://path/to/data')

# 2
df.write.option(option_key, option_vlaue).mode('overwrite').csv('hdfs://path/to/data')

# column data type

In [None]:
from pyspark.sql.types import StringType,BooleanType,DateType

for c in df.schema:
    if isinstance(c.dataType, StringType):
        print(c)

# trim string

In [None]:
from pyspark.sql.functions import col, when, lit, trim


df = df.withColumn(column_name, trim(col(column_name)))


# union dataframe with different schema

In [None]:
df1 = spark.createDataFrame([[1, 2, 3]], ["col0", "col1", "col2"])
df2 = spark.createDataFrame([[4, 5, 6]], ["col1", "col2", "col0"])

df1.unionByName(df2, allowMissingColumns=True).show()

# dataframe groupby and sort

In [None]:
from pyspark.sql.functions import col, desc

df.groupBy('pmt_start_date').count().sort(col('pmt_start_date')).show()

# coalesce dataframe

In [None]:
df.coalesce(32)

# parse date from integer

In [None]:
from pyspark.sql import functions as fn


new_df = df.withColumn('event_date', 
                       fn.to_date(
                           fn.from_unixtime(fn.col("time_event_published").cast('string'), 'yyyy-MM-dd HH:mm:ss'), 
                           'yyyy-MM-sortdd')
                      )

# load hive table into spark dataframe

config your notebook to access horton cluster
```shell
%url -c horton
```

In [None]:
spark = None # spark instance

df = spark.sql('select * from hive_db_name.hive_table_name')

# set column as null

In [None]:
from pyspark.sql.functions import lit
from pyspark.sql.types import StringType


v = 'some_column_name'

data_df = data_df.withColumn(f"{v}_null", lit(None).cast(StringType()))\
                .drop(v)\
                .withColumnRenamed(f"{v}_null", v)

# coalesce columns

In [None]:
from pyspark.sql.functions import col, when, lit, trim, coalesce, udf
from pyspark.sql.types import StringType, DoubleType, IntegerType


@udf(returnType=StringType())
def blank_as_null(x):
    if x is None:
        return None
    if isinstance(x, str) and x.strip() == '':
        return None
    return x


origin_df = origin_df.alias('origin_df')
backfill_df = backfill_df.alias('backfill_df')

ensembled_df = origin_df.join(backfill_df,
                              on='driver_trans_id_clean',
                              how='left'
                             )

origin_cols = set(origin_df.columns)
backfill_vars = [c for c in backfill_df.columns if c not in ['driver_trans_id_clean']]
for c in backfill_vars:
    if c not in origin_cols:
        print(f"ignore column {c} since it's not in origin columns")
        continue
        
    ensembled_df = ensembled_df.withColumn(f"{c}_", coalesce(blank_as_null(f"origin_df.{c}"), f"backfill_df.{c}"))\
                        .drop(c)\
                        .withColumnRenamed(f"{c}_", c)

# apply UDF to columns

In [None]:
from pyspark.sql.functions import udf, col
from pyspark.sql.types import StringType



@udf(returnType=StringType())
def clean_trans_id(trans_id):
    return trans_id.split('.')[0]

df = df.withColumn('driver_trans_id_clean', clean_trans_id(col('driver_trans_id')))



# check df join coverage

In [None]:
from pyspark.sql import functions as fn


df1 = None
df2 = None
df1_join_key = 'trans_id_clean'
df2_join_key = 'drvier_trans_id_clean'

In [None]:

joined_df = df1.select([join_key, 'monthly']).join(df2.select([join_key]), on=df1.trans_id_clean == df2.driver_trans_id_clean, how='left') 

joined_df = joined_df.withColumn('unit', fn.lit(1))
joined_df = joined_df.withColumn('has_join', fn.when(fn.col('driver_trans_id_clean').isNull(), 0).otherwise(1))


stats = joined_df.groupBy(['monthly']).agg({'has_join': 'sum', 'unit': 'sum'}).sort('monthly')
stats = stats.withColumnRenamed('sum(has_join)', 'has_join_sum')
stats = stats.withColumnRenamed('sum(unit)', 'unit_sum')
stats = stats.withColumn('join_coverage', fn.col('has_join_sum') / fn.col('unit_sum'))
stats.printSchema()

print('monthly join coverage')
stats.show(100)
    