In [36]:
from pyspark.sql import SparkSession,DataFrame
from pyspark.sql.types import StructField, StructType, StringType, DoubleType, IntegerType, LongType, DecimalType
import os
from pyspark.sql.functions import lit, rand

In [2]:
local = True

if local:
    spark = SparkSession.builder \
        .master("local[4]") \
        .appName("SparkDoubleTypeTest")\
        .getOrCreate()
else:
    spark = SparkSession.builder\
        .master("k8s://https://kubernetes.default.svc:443") \
        .appName("SparkDoubleTypeTest")\
        .config("spark.kubernetes.container.image", "inseefrlab/jupyter-datascience:py3.9.7-spark3.2.0")\
        .config("spark.kubernetes.authenticate.driver.serviceAccountName", os.environ['KUBERNETES_SERVICE_ACCOUNT'])\
        .config("spark.executor.instances", "4")\
        .config("spark.executor.memory", "8g")\
        .config("spark.kubernetes.namespace", os.environ['KUBERNETES_NAMESPACE'])\
        .getOrCreate()

22/03/20 11:24:27 WARN Utils: Your hostname, ubuntu resolves to a loopback address: 127.0.1.1; using 192.168.184.146 instead (on interface ens33)
22/03/20 11:24:27 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address
22/03/20 11:24:28 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
Using Spark's default log4j profile: org/apache/spark/log4j-defaults.properties
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).


In [3]:
data = [("James", "Sales", "NY", 900.01, 34, 10000),
        ("Michael", "Sales", "NY", 860.02, 56, 20000),
        ("Robert", "Sales", "CA", 810.03, 30, 23000),
        ("Maria", "Finance", "CA", 900.04, 24, 23000),
        ("Raman", "Finance", "CA", 990.05, 40, 24000),
        ("Scott", "Finance", "NY", 830.06, 36, 19000),
        ("Jen", "Finance", "NY", 790.07, 53, 15000),
        ("Jeff", "Marketing", "CA", 800.08, 25, 18000),
        ("Kumar", "Marketing", "NY", 910.09, 50, 21000)
        ]



schema = StructType([
    StructField("employee_name", StringType(), True),
    StructField("department", StringType(), True),
    StructField("state", StringType(), True),
    StructField("salary", DoubleType(), True),
    StructField("age", IntegerType(), True),
    StructField("bonus", LongType(), True)
])
df = spark.createDataFrame(data=data, schema=schema)

In [4]:
df.show(5)

+-------------+----------+-----+------+---+-----+
|employee_name|department|state|salary|age|bonus|
+-------------+----------+-----+------+---+-----+
|        James|     Sales|   NY|900.01| 34|10000|
|      Michael|     Sales|   NY|860.02| 56|20000|
|       Robert|     Sales|   CA|810.03| 30|23000|
|        Maria|   Finance|   CA|900.04| 24|23000|
|        Raman|   Finance|   CA|990.05| 40|24000|
+-------------+----------+-----+------+---+-----+
only showing top 5 rows



In [5]:
df.printSchema()

root
 |-- employee_name: string (nullable = true)
 |-- department: string (nullable = true)
 |-- state: string (nullable = true)
 |-- salary: double (nullable = true)
 |-- age: integer (nullable = true)
 |-- bonus: long (nullable = true)



In [6]:
from pyspark.sql.functions import col

df.withColumn("raised_salary", col("salary") + 0.1) \
    .withColumn("reduced_salary", col("salary") - 0.1).show()

+-------------+----------+-----+------+---+-----+-----------------+-----------------+
|employee_name|department|state|salary|age|bonus|    raised_salary|   reduced_salary|
+-------------+----------+-----+------+---+-----+-----------------+-----------------+
|        James|     Sales|   NY|900.01| 34|10000|           900.11|           899.91|
|      Michael|     Sales|   NY|860.02| 56|20000|           860.12|           859.92|
|       Robert|     Sales|   CA|810.03| 30|23000|           810.13|           809.93|
|        Maria|   Finance|   CA|900.04| 24|23000|           900.14|899.9399999999999|
|        Raman|   Finance|   CA|990.05| 40|24000|           990.15|989.9499999999999|
|        Scott|   Finance|   NY|830.06| 36|19000|           830.16|829.9599999999999|
|          Jen|   Finance|   NY|790.07| 53|15000|790.1700000000001|           789.97|
|         Jeff| Marketing|   CA|800.08| 25|18000|800.1800000000001|           799.98|
|        Kumar| Marketing|   NY|910.09| 50|21000|     

You can notice after the addition or soustraction with 0.1, many results are wrong. Because the underlying spark engine is implemented by using Scala.

**In scala, double is internally stored as a fraction in binary -- like 1/4=0.25 + 1/8=0.125 + 1/16=0.0625 + ...**

As a result, the value 829.94 -- or the value 789.96 -- cannot be stored as an exact fraction in binary, so double cannot store the exact value 0.94 or 0.96, and the subtracted value isn't quite exact.

If you want to get exact precise decimal arithmetic, convert column type from double/float to DecimalType. For all supported column type in spark: https://spark.apache.org/docs/latest/sql-ref-datatypes.html

In fact, this is a general problem for many programing language. For more detail about float calculation, you can visit this https://docs.oracle.com/cd/E19957-01/806-3568/ncg_goldberg.html

In [7]:
df_decimal=df.withColumn("salary_decimal",col("salary").cast(DecimalType(10,2))).drop("salary").withColumnRenamed("salary_decimal","salary")

In [8]:
df_decimal.show()

+-------------+----------+-----+---+-----+------+
|employee_name|department|state|age|bonus|salary|
+-------------+----------+-----+---+-----+------+
|        James|     Sales|   NY| 34|10000|900.01|
|      Michael|     Sales|   NY| 56|20000|860.02|
|       Robert|     Sales|   CA| 30|23000|810.03|
|        Maria|   Finance|   CA| 24|23000|900.04|
|        Raman|   Finance|   CA| 40|24000|990.05|
|        Scott|   Finance|   NY| 36|19000|830.06|
|          Jen|   Finance|   NY| 53|15000|790.07|
|         Jeff| Marketing|   CA| 25|18000|800.08|
|        Kumar| Marketing|   NY| 50|21000|910.09|
+-------------+----------+-----+---+-----+------+



In [9]:
df_decimal.printSchema()

root
 |-- employee_name: string (nullable = true)
 |-- department: string (nullable = true)
 |-- state: string (nullable = true)
 |-- age: integer (nullable = true)
 |-- bonus: long (nullable = true)
 |-- salary: decimal(10,2) (nullable = true)



In [10]:
df_wrong=df_decimal.withColumn("raised_salary", col("salary") + 0.1) \
    .withColumn("reduced_salary", col("salary") - 0.1)

In [11]:
df_wrong.show()

+-------------+----------+-----+---+-----+------+-----------------+-----------------+
|employee_name|department|state|age|bonus|salary|    raised_salary|   reduced_salary|
+-------------+----------+-----+---+-----+------+-----------------+-----------------+
|        James|     Sales|   NY| 34|10000|900.01|           900.11|           899.91|
|      Michael|     Sales|   NY| 56|20000|860.02|           860.12|           859.92|
|       Robert|     Sales|   CA| 30|23000|810.03|           810.13|           809.93|
|        Maria|   Finance|   CA| 24|23000|900.04|           900.14|899.9399999999999|
|        Raman|   Finance|   CA| 40|24000|990.05|           990.15|989.9499999999999|
|        Scott|   Finance|   NY| 36|19000|830.06|           830.16|829.9599999999999|
|          Jen|   Finance|   NY| 53|15000|790.07|790.1700000000001|           789.97|
|         Jeff| Marketing|   CA| 25|18000|800.08|800.1800000000001|           799.98|
|        Kumar| Marketing|   NY| 50|21000|910.09|     

In [12]:
df_wrong.printSchema()

root
 |-- employee_name: string (nullable = true)
 |-- department: string (nullable = true)
 |-- state: string (nullable = true)
 |-- age: integer (nullable = true)
 |-- bonus: long (nullable = true)
 |-- salary: decimal(10,2) (nullable = true)
 |-- raised_salary: double (nullable = true)
 |-- reduced_salary: double (nullable = true)



You can notice that if one parameter in the operation is double, the result will be double, you will not get the decimal precision.

In [13]:

df_right=df_decimal.withColumn("variant",lit(0.1).cast(DecimalType(10,2)))\
    .withColumn("raised_salary", col("salary") + col("variant")) \
    .withColumn("reduced_salary", col("salary") - col("variant"))

In [14]:
df_right.show()

+-------------+----------+-----+---+-----+------+-------+-------------+--------------+
|employee_name|department|state|age|bonus|salary|variant|raised_salary|reduced_salary|
+-------------+----------+-----+---+-----+------+-------+-------------+--------------+
|        James|     Sales|   NY| 34|10000|900.01|   0.10|       900.11|        899.91|
|      Michael|     Sales|   NY| 56|20000|860.02|   0.10|       860.12|        859.92|
|       Robert|     Sales|   CA| 30|23000|810.03|   0.10|       810.13|        809.93|
|        Maria|   Finance|   CA| 24|23000|900.04|   0.10|       900.14|        899.94|
|        Raman|   Finance|   CA| 40|24000|990.05|   0.10|       990.15|        989.95|
|        Scott|   Finance|   NY| 36|19000|830.06|   0.10|       830.16|        829.96|
|          Jen|   Finance|   NY| 53|15000|790.07|   0.10|       790.17|        789.97|
|         Jeff| Marketing|   CA| 25|18000|800.08|   0.10|       800.18|        799.98|
|        Kumar| Marketing|   NY| 50|21000|9

In [15]:
df_right.printSchema()

root
 |-- employee_name: string (nullable = true)
 |-- department: string (nullable = true)
 |-- state: string (nullable = true)
 |-- age: integer (nullable = true)
 |-- bonus: long (nullable = true)
 |-- salary: decimal(10,2) (nullable = true)
 |-- variant: decimal(10,2) (nullable = true)
 |-- raised_salary: decimal(11,2) (nullable = true)
 |-- reduced_salary: decimal(11,2) (nullable = true)



As the two operands are both decimal, so this time the precision of the result is correct. But you can still notice that the Data Type of the result is Decimal(11,2), not Decimal(10,2). This is spark do the auto cast to prevent arithmetic overflow.

## Performance test between decimal and double

As we explained before, double and float data type has precision problem for arithmetic operations. To avoid this problem, we need to use decimal data type. But decimal data type will suffer some performance lost. In below test, we will try to quantify the performance lost.

In [89]:
df_size=1000000000
def generate_decimal_sample(total:int,fraction:int)->DataFrame:
    return spark.range(0,df_size).withColumn("x",(rand(seed=48) * 10).cast(DecimalType(total,fraction))).withColumn("y",(rand(seed=46) * 10).cast(DecimalType(total,fraction)))

In [126]:
decimal_sample=generate_decimal_sample(25,15)
decimal_sample.show(5)

+---+-----------------+-----------------+
| id|                x|                y|
+---+-----------------+-----------------+
|  0|3.710369234532471|0.401060051090526|
|  1|7.975769828959885|0.620780005870033|
|  2|7.171375356646551|1.788344735940591|
|  3|3.636958873390777|3.283184244659680|
|  4|9.905156945373632|1.353439163856851|
+---+-----------------+-----------------+
only showing top 5 rows



In [99]:
decimal_sample.count()

1000000000

In [100]:
import time

def df_performance_test(input_df:DataFrame):
    start = time.time()
    res=input_df.withColumn("addition",col("x")+col("y"))\
        .withColumn("soustraction",col("x")-col("y")).\
        withColumn("multiple",col("x")*col("y")).\
        withColumn("division",col("x")/col("y"))
    res.show(5)
    end= time.time()
    print(f"operation duration is {end-start}")
    return res

In [135]:
decimal_res=df_performance_test(decimal_sample)


+---+-----------------+-----------------+------------------+-----------------+--------------------+----------------+
| id|                x|                y|          addition|     soustraction|            multiple|        division|
+---+-----------------+-----------------+------------------+-----------------+--------------------+----------------+
|  0|3.710369234532471|0.401060051090526| 4.111429285622997|3.309309183441945| 1.48808087476630867| 9.2514056796322|
|  1|7.975769828959885|0.620780005870033| 8.596549834829918|7.354989823089852| 4.95119844123974951|12.8479811745575|
|  2|7.171375356646551|1.788344735940591| 8.959720092587142|5.383030620705960|12.82489136851293785| 4.0100631676446|
|  3|3.636958873390777|3.283184244659680| 6.920143118050457|0.353774628731097|11.94080607159181893| 1.1077535107287|
|  4|9.905156945373632|1.353439163856851|11.258596109230483|8.551717781516781|13.40602733401736885| 7.3185091800855|
+---+-----------------+-----------------+------------------+----

In [118]:
decimal_res.printSchema()

root
 |-- id: long (nullable = false)
 |-- x: decimal(20,10) (nullable = true)
 |-- y: decimal(20,10) (nullable = true)
 |-- addition: decimal(21,10) (nullable = true)
 |-- soustraction: decimal(21,10) (nullable = true)
 |-- multiple: decimal(38,17) (nullable = true)
 |-- division: decimal(38,18) (nullable = true)



In [103]:
double_sample=spark.range(0,df_size).withColumn("x",(rand(seed=48) * 10).cast(DoubleType())).withColumn("y",(rand(seed=46) * 10).cast(DoubleType()))
double_sample.show(5)

+---+------------------+------------------+
| id|                 x|                 y|
+---+------------------+------------------+
|  0|3.7103692345324712| 0.401060051090526|
|  1| 7.975769828959885| 0.620780005870033|
|  2| 7.171375356646551|1.7883447359405913|
|  3|3.6369588733907765|3.2831842446596804|
|  4| 9.905156945373632| 1.353439163856851|
+---+------------------+------------------+
only showing top 5 rows



In [104]:
double_res=df_performance_test(double_sample)

+---+------------------+------------------+------------------+------------------+------------------+------------------+
| id|                 x|                 y|          addition|      soustraction|          multiple|          division|
+---+------------------+------------------+------------------+------------------+------------------+------------------+
|  0|3.7103692345324712| 0.401060051090526|4.1114292856229975| 3.309309183441945|1.4880808747663088| 9.251405679632196|
|  1| 7.975769828959885| 0.620780005870033| 8.596549834829919| 7.354989823089852| 4.951198441239749| 12.84798117455751|
|  2| 7.171375356646551|1.7883447359405913| 8.959720092587142| 5.383030620705959| 12.82489136851294| 4.010063167644644|
|  3|3.6369588733907765|3.2831842446596804| 6.920143118050457|0.3537746287310961|11.940806071591819|1.1077535107286576|
|  4| 9.905156945373632| 1.353439163856851|11.258596109230483|  8.55171778151678|13.406027334017368| 7.318509180085518|
+---+------------------+----------------

In [105]:
double_res.printSchema()

root
 |-- id: long (nullable = false)
 |-- x: double (nullable = false)
 |-- y: double (nullable = false)
 |-- addition: double (nullable = false)
 |-- soustraction: double (nullable = false)
 |-- multiple: double (nullable = false)
 |-- division: double (nullable = true)



With the above test, we can conclude that the performance of decimal is highly depends on the precision.
For example, with a dataframe of 1,000,000,000 rows, two columns x, y as parameter, to do the four operation (e.g. addition, ...)

- For type double, it takes 0.045064687728881836 s
- For type decimal,
      - with decimal(10,4), it takes 0.04919123649597168 (lose 9.16 percent in time compare to double)
      - with decimal(20,10), it takes 0.05225372314453125 (lose 15.95 percent in time compare to double)
      - with decimal(25,15), it takes 0.05529522895812988 (lose 22.70 percent in time compare to double)



In [138]:
def performance_percent(x,y):
    return ((x-y)/y)*100


In [139]:
case1=performance_percent(0.04919123649597168,0.045064687728881836)
print(case1)

9.156945215988149


In [137]:
case2=performance_percent(0.05225372314453125,0.045064687728881836)
print(case2)

15.952702166494722


In [141]:
case3=performance_percent(0.05529522895812988,0.045064687728881836)
print(case3)

22.70190196545248
