In [0]:
from pyspark.sql import SparkSession
spark = SparkSession.builder.appName("Spark DataFrames").getOrCreate()

### Basics

In [0]:
my_range = spark.range(1000).toDF('number')
my_range

DataFrame[number: bigint]

In [0]:
even_number = my_range.where("number % 2 == 0")
even_number

DataFrame[number: bigint]

In [0]:
even_number.filter("number % 2 ==0").explain(True)

== Parsed Logical Plan ==
'Filter (('number % 2) = 0)
+- Filter ((number#4L % cast(2 as bigint)) = cast(0 as bigint))
   +- Project [id#2L AS number#4L]
      +- Range (0, 1000, step=1, splits=Some(4))

== Analyzed Logical Plan ==
number: bigint
Filter ((number#4L % cast(2 as bigint)) = cast(0 as bigint))
+- Filter ((number#4L % cast(2 as bigint)) = cast(0 as bigint))
   +- Project [id#2L AS number#4L]
      +- Range (0, 1000, step=1, splits=Some(4))

== Optimized Logical Plan ==
Project [id#2L AS number#4L]
+- Filter ((id#2L % 2) = 0)
   +- Range (0, 1000, step=1, splits=Some(4))

== Physical Plan ==
*(1) ColumnarToRow
+- PhotonResultStage
   +- PhotonProject [id#2L AS number#4L]
      +- PhotonFilter ((id#2L % 2) = 0)
         +- PhotonRange Range (0, 1000, step=1, splits=4)

== Photon Explanation ==
The query is fully supported by Photon.


In [0]:
display(dbutils.fs.ls("dbfs:/"))

path,name,size,modificationTime
dbfs:/Volume/,Volume/,0,0
dbfs:/Volumes/,Volumes/,0,0
dbfs:/databricks-datasets/,databricks-datasets/,0,0
dbfs:/databricks-results/,databricks-results/,0,0
dbfs:/volume/,volume/,0,0
dbfs:/volumes/,volumes/,0,0


### Working with csv data

In [0]:
import os
os.makedirs("/dbfs/tmp/", exist_ok=True)

In [0]:
import urllib.request

url = "https://raw.githubusercontent.com/databricks/Spark-The-Definitive-Guide/refs/heads/master/data/flight-data/csv/2015-summary.csv"

path = "/dbfs/tmp/2015-summary.csv"

urllib.request.urlretrieve(url, path)


('/dbfs/tmp/2015-summary.csv', <http.client.HTTPMessage at 0x7046c5844e50>)

In [0]:
flight_data = spark.read.option('header',True).csv("/tmp/2015-summary.csv")

In [0]:
flight_data.take(5)

[Row(DEST_COUNTRY_NAME='United States', ORIGIN_COUNTRY_NAME='Romania', count='15'),
 Row(DEST_COUNTRY_NAME='United States', ORIGIN_COUNTRY_NAME='Croatia', count='1'),
 Row(DEST_COUNTRY_NAME='United States', ORIGIN_COUNTRY_NAME='Ireland', count='344'),
 Row(DEST_COUNTRY_NAME='Egypt', ORIGIN_COUNTRY_NAME='United States', count='15'),
 Row(DEST_COUNTRY_NAME='United States', ORIGIN_COUNTRY_NAME='India', count='62')]

In [0]:
flight_data.sort('count').explain()

== Physical Plan ==
AdaptiveSparkPlan isFinalPlan=false
+- == Initial Plan ==
   ColumnarToRow
   +- PhotonResultStage
      +- PhotonSort [count#91 ASC NULLS FIRST]
         +- PhotonShuffleExchangeSource
            +- PhotonShuffleMapStage
               +- PhotonShuffleExchangeSink rangepartitioning(count#91 ASC NULLS FIRST, 200)
                  +- PhotonRowToColumnar
                     +- FileScan csv [DEST_COUNTRY_NAME#89,ORIGIN_COUNTRY_NAME#90,count#91] Batched: false, DataFilters: [], Format: CSV, Location: InMemoryFileIndex(1 paths)[dbfs:/tmp/2015-summary.csv], PartitionFilters: [], PushedFilters: [], ReadSchema: struct<DEST_COUNTRY_NAME:string,ORIGIN_COUNTRY_NAME:string,count:string>


== Photon Explanation ==
The query is fully supported by Photon.


In [0]:
flight_data.sort('count', ascending=False).take(3)

[Row(DEST_COUNTRY_NAME='United States', ORIGIN_COUNTRY_NAME='The Bahamas', count='986'),
 Row(DEST_COUNTRY_NAME='The Bahamas', ORIGIN_COUNTRY_NAME='United States', count='955'),
 Row(DEST_COUNTRY_NAME='United States', ORIGIN_COUNTRY_NAME='France', count='952')]

In [0]:
flight_data.sort('count').take(3)

[Row(DEST_COUNTRY_NAME='United States', ORIGIN_COUNTRY_NAME='Croatia', count='1'),
 Row(DEST_COUNTRY_NAME='United States', ORIGIN_COUNTRY_NAME='Singapore', count='1'),
 Row(DEST_COUNTRY_NAME='Moldova', ORIGIN_COUNTRY_NAME='United States', count='1')]

In [0]:
flight_data.createOrReplaceTempView("flight_data")

In [0]:
count_df = spark.sql("select count(*) from flight_data")
count_df.display()

count(1)
256


In [0]:
dest_country_5 = spark.sql("select DEST_COUNTRY_NAME, count(*) as dest_count from flight_data \
                        group by DEST_COUNTRY_NAME \
                        order by dest_count desc limit 5")
dest_country_5.display()

DEST_COUNTRY_NAME,dest_count
United States,125
Russia,1
Senegal,1
Anguilla,1
Paraguay,1


In [0]:
from pyspark.sql.functions import count

dest_country_5 = (
                flight_data
                .groupBy("DEST_COUNTRY_NAME")
                .agg(count("*").alias("dest_count"))
                .orderBy("dest_count", ascending=False)
                .limit(5) 
)
dest_country_5.display()

DEST_COUNTRY_NAME,dest_count
United States,125
Russia,1
Senegal,1
Anguilla,1
Paraguay,1


In [0]:
flight_data.take(3)

[Row(DEST_COUNTRY_NAME='United States', ORIGIN_COUNTRY_NAME='Romania', count='15'),
 Row(DEST_COUNTRY_NAME='United States', ORIGIN_COUNTRY_NAME='Croatia', count='1'),
 Row(DEST_COUNTRY_NAME='United States', ORIGIN_COUNTRY_NAME='Ireland', count='344')]

In [0]:
spark.sql("select sum(count) as total_count_of_each_dest, DEST_COUNTRY_NAME from flight_data \
          group by DEST_COUNTRY_NAME \
          order by total_count_of_each_dest desc"
).take(3)

[Row(total_count_of_each_dest=411352.0, DEST_COUNTRY_NAME='United States'),
 Row(total_count_of_each_dest=8399.0, DEST_COUNTRY_NAME='Canada'),
 Row(total_count_of_each_dest=7140.0, DEST_COUNTRY_NAME='Mexico')]

In [0]:
from pyspark.sql.functions import sum as _sum

top_3 = (
    flight_data
    .groupBy('DEST_COUNTRY_NAME')
    .agg(_sum("count").alias("total_count_of_each_dest"))
    .orderBy("total_count_of_each_dest", ascending=False)
    .limit(3).display()
)

DEST_COUNTRY_NAME,total_count_of_each_dest
United States,411352.0
Canada,8399.0
Mexico,7140.0


In [0]:
max_value = spark.sql("select max(count) from flight_data limit 1").display()

max(count)
986


In [0]:
from pyspark.sql.functions import max as _max

max_value = flight_data.agg(_max("count").alias("max_count")).display(1)

max_count
986


In [0]:
flight_data.limit(5).display()

DEST_COUNTRY_NAME,ORIGIN_COUNTRY_NAME,count
United States,Romania,15
United States,Croatia,1
United States,Ireland,344
Egypt,United States,15
United States,India,62


In [0]:
spark.sql("select sum(count) as top_dest_country_count, DEST_COUNTRY_NAME \
        from flight_data group by DEST_COUNTRY_NAME order by top_dest_country_count desc limit 5"
).display()

top_dest_country_count,DEST_COUNTRY_NAME
411352.0,United States
8399.0,Canada
7140.0,Mexico
2025.0,United Kingdom
1548.0,Japan


In [0]:
(
    flight_data
    .groupBy('DEST_COUNTRY_NAME')
    .agg(_sum("count").alias("top_dest_country_count"))
    .orderBy("top_dest_country_count", ascending=False)
    .limit(5).display()
)

DEST_COUNTRY_NAME,top_dest_country_count
United States,411352.0
Canada,8399.0
Mexico,7140.0
United Kingdom,2025.0
Japan,1548.0


In [0]:
(
    flight_data
    .groupBy('DEST_COUNTRY_NAME')
    .agg(_sum("count").alias("top_dest_country_count"))
    .sort("top_dest_country_count", ascending=False)
    .limit(5).display()
)

DEST_COUNTRY_NAME,top_dest_country_count
United States,411352.0
Canada,8399.0
Mexico,7140.0
United Kingdom,2025.0
Japan,1548.0


### Read from multiple csv files 

In [0]:
import os
import urllib.request

# List of file names
file_names = [
    "2010-12-01.csv", "2010-12-02.csv", "2010-12-03.csv",
     "2010-12-05.csv", "2010-12-06.csv",
    "2010-12-07.csv", "2010-12-08.csv", "2010-12-09.csv"
]

# Base GitHub URL
base_url = "https://raw.githubusercontent.com/databricks/Spark-The-Definitive-Guide/master/data/retail-data/by-day/"

# Directories
raw_dir = "dbfs/tmp/retail_data"
os.makedirs(raw_dir, exist_ok=True)
# Loop to download, read, and save
for file in file_names:
    local_file_path = os.path.join(raw_dir, file)

    # Download if not already present
    if not os.path.exists(local_file_path):
        print(f"Downloading {file}...")
        urllib.request.urlretrieve(base_url + file, local_file_path)


Downloading 2010-12-01.csv...
Downloading 2010-12-02.csv...
Downloading 2010-12-03.csv...
Downloading 2010-12-05.csv...
Downloading 2010-12-06.csv...
Downloading 2010-12-07.csv...
Downloading 2010-12-08.csv...
Downloading 2010-12-09.csv...


In [0]:
os.listdir(raw_dir)

['2010-12-01.csv',
 '2010-12-02.csv',
 '2010-12-03.csv',
 '2010-12-05.csv',
 '2010-12-06.csv',
 '2010-12-07.csv',
 '2010-12-08.csv',
 '2010-12-09.csv']

In [0]:
raw_dir

'dbfs/tmp/retail_data'

In [0]:
import requests

url = "https://api.github.com/repos/databricks/Spark-The-Definitive-Guide/contents/data/retail-data/by-day"
response = requests.get(url)
files = [item['name'] for item in response.json() if item['name'].endswith(".csv")]

In [0]:
import os
import urllib.request

# Base GitHub URL
base_url = "https://raw.githubusercontent.com/databricks/Spark-The-Definitive-Guide/master/data/retail-data/by-day/"

# Directories
raw_dir = "dbfs/by-day-raw"
output_dir = "dbfs/by-day-spark-output"
os.makedirs(raw_dir, exist_ok=True)
os.makedirs(output_dir, exist_ok=True)

# Loop to download, read, and save
for file in files:
    local_file_path = os.path.join(raw_dir, file)
    output_path = os.path.join(output_dir, file.replace(".csv", ""))

    # Download if not already present
    if not os.path.exists(local_file_path):
        print(f"Downloading {file}...")
        urllib.request.urlretrieve(base_url + file, local_file_path)

    # Read the CSV file with Spark
    df = spark.read.option("header", "true").csv("file://" + os.path.abspath(local_file_path))

    # Save to new output directory (no overwrite)
    if not os.path.exists(output_path):
        df.write.mode("overwrite").option("header", "true").csv(output_path)
        print(f"Saved: {output_path}")
    else:
        print(f"Already exists, skipped: {output_path}")

Saved: dbfs/by-day-spark-output/2010-12-01
Saved: dbfs/by-day-spark-output/2010-12-02
Saved: dbfs/by-day-spark-output/2010-12-03
Saved: dbfs/by-day-spark-output/2010-12-05
Saved: dbfs/by-day-spark-output/2010-12-06
Saved: dbfs/by-day-spark-output/2010-12-07
Saved: dbfs/by-day-spark-output/2010-12-08
Saved: dbfs/by-day-spark-output/2010-12-09
Downloading 2010-12-10.csv...
Saved: dbfs/by-day-spark-output/2010-12-10
Downloading 2010-12-12.csv...
Saved: dbfs/by-day-spark-output/2010-12-12
Downloading 2010-12-13.csv...
Saved: dbfs/by-day-spark-output/2010-12-13
Downloading 2010-12-14.csv...
Saved: dbfs/by-day-spark-output/2010-12-14
Downloading 2010-12-15.csv...
Saved: dbfs/by-day-spark-output/2010-12-15
Downloading 2010-12-16.csv...
Saved: dbfs/by-day-spark-output/2010-12-16
Downloading 2010-12-17.csv...
Saved: dbfs/by-day-spark-output/2010-12-17
Downloading 2010-12-19.csv...
Saved: dbfs/by-day-spark-output/2010-12-19
Downloading 2010-12-20.csv...
Saved: dbfs/by-day-spark-output/2010-12-20

In [0]:
static_df = spark.read \
    .option("header", "true") \
    .option("inferSchema", "true") \
    .csv("dbfs:/dbfs/by-day-spark-output/**/*.csv")

In [0]:
static_df.limit(5).display()

InvoiceNo,StockCode,Description,Quantity,InvoiceDate,UnitPrice,CustomerID,Country
580538,23084,RABBIT NIGHT LIGHT,48,2011-12-05T08:38:00Z,1.79,14075.0,United Kingdom
580538,23077,DOUGHNUT LIP GLOSS,20,2011-12-05T08:38:00Z,1.25,14075.0,United Kingdom
580538,22906,12 MESSAGE CARDS WITH ENVELOPES,24,2011-12-05T08:38:00Z,1.65,14075.0,United Kingdom
580538,21914,BLUE HARMONICA IN BOX,24,2011-12-05T08:38:00Z,1.25,14075.0,United Kingdom
580538,22467,GUMBALL COAT RACK,6,2011-12-05T08:38:00Z,2.55,14075.0,United Kingdom


In [0]:
static_df.createOrReplaceGlobalTempView("retail_table")
static_schema = static_df.schema

In [0]:
static_schema

StructType([StructField('InvoiceNo', StringType(), True), StructField('StockCode', StringType(), True), StructField('Description', StringType(), True), StructField('Quantity', IntegerType(), True), StructField('InvoiceDate', TimestampType(), True), StructField('UnitPrice', DoubleType(), True), StructField('CustomerID', DoubleType(), True), StructField('Country', StringType(), True)])

### To compute the total amount spent per customer per day, and list the highest-spending customers first.

In [0]:
%sql
select * from global_temp.retail_data limit 5;

InvoiceNo,StockCode,Description,Quantity,InvoiceDate,UnitPrice,CustomerID,Country
580538,23084,RABBIT NIGHT LIGHT,48,2011-12-05T08:38:00Z,1.79,14075.0,United Kingdom
580538,23077,DOUGHNUT LIP GLOSS,20,2011-12-05T08:38:00Z,1.25,14075.0,United Kingdom
580538,22906,12 MESSAGE CARDS WITH ENVELOPES,24,2011-12-05T08:38:00Z,1.65,14075.0,United Kingdom
580538,21914,BLUE HARMONICA IN BOX,24,2011-12-05T08:38:00Z,1.25,14075.0,United Kingdom
580538,22467,GUMBALL COAT RACK,6,2011-12-05T08:38:00Z,2.55,14075.0,United Kingdom


In [0]:
spark.sql(
    "select sum(total_cost), CustomerID, to_date(InvoiceDate) \
    from (select CustomerID, (UnitPrice*Quantity) as total_cost, InvoiceDate from global_temp.retail_data) \
    group by CustomerID, to_date(InvoiceDate) \
    order by sum(total_cost) desc \
    limit 5"
).display()

sum(total_cost),CustomerID,to_date(InvoiceDate)
71601.44,17450.0,2011-09-20
55316.08,,2011-11-14
42939.17,,2011-11-07
33521.39999999998,,2011-03-29
31975.590000000007,,2011-12-08


In [0]:
from pyspark.sql.functions import col, to_date, sum

(
    static_df
    .select('UnitPrice', 'Quantity', 'InvoiceDate', 'CustomerID')
    .withColumn('total_cost', col('UnitPrice') * col('Quantity'))
    .groupBy("CustomerID",to_date("InvoiceDate").alias("InvoiceDate"))
    .agg(sum('total_cost').alias('total_cost'))
    .orderBy(col('total_cost'), ascending=False)
    .limit(5)
    .display()
)

CustomerID,InvoiceDate,total_cost
17450.0,2011-09-20,71601.44
,2011-11-14,55316.08
,2011-11-07,42939.17
,2011-03-29,33521.39999999998
,2011-12-08,31975.590000000007


In [0]:
( static_df
    .withColumn('total_cost',col('UnitPrice') * col('Quantity'))
    .groupBy("CustomerID", to_date("InvoiceDate").alias("InvoiceDate"))
    .agg(sum("total_cost").alias("total_cost"))
    .orderBy("total_cost", ascending=False)
    .limit(5)
    .display()
)

CustomerID,InvoiceDate,total_cost
17450.0,2011-09-20,71601.44
,2011-11-14,55316.08
,2011-11-07,42939.17
,2011-03-29,33521.39999999998
,2011-12-08,31975.590000000007


### Machine Learning -> kmeans


**_Coalesce_** -> This reduces the number of data partitions in the DataFrame to 5.

📌 Why use this?
To control parallelism or optimize output when writing data to files (like in .write.csv()).

The data is stored across 5 partitions (useful if you'll write it to disk later)

Spark often creates many partitions by default (even hundreds) — if you're writing to disk, you'll get many tiny files. Coalescing fixes that.

In [0]:
from pyspark.sql.functions import date_format, coalesce
prepped_df = (
    static_df
    .na.fill(0)
    .withColumn("day_of_week",date_format(col("InvoiceDate"),"EEEE"))
    .coalesce(5)
)
prepped_df.limit(5).display()

InvoiceNo,StockCode,Description,Quantity,InvoiceDate,UnitPrice,CustomerID,Country,day_of_week
580538,23084,RABBIT NIGHT LIGHT,48,2011-12-05T08:38:00Z,1.79,14075.0,United Kingdom,Monday
580538,23077,DOUGHNUT LIP GLOSS,20,2011-12-05T08:38:00Z,1.25,14075.0,United Kingdom,Monday
580538,22906,12 MESSAGE CARDS WITH ENVELOPES,24,2011-12-05T08:38:00Z,1.65,14075.0,United Kingdom,Monday
580538,21914,BLUE HARMONICA IN BOX,24,2011-12-05T08:38:00Z,1.25,14075.0,United Kingdom,Monday
580538,22467,GUMBALL COAT RACK,6,2011-12-05T08:38:00Z,2.55,14075.0,United Kingdom,Monday


In [0]:
train_df = prepped_df.filter(to_date('InvoiceDate') < '2011-07-01')
test_df = prepped_df.filter(to_date('InvoiceDate') >= '2011-07-01')

In [0]:
train_df.count(), test_df.count()

(245903, 296006)

In [0]:
from pyspark.ml.feature import StringIndexer

indexer = (StringIndexer()
.setInputCol("day_of_week")
.setOutputCol("day_of_week_index"))

understand_string_indexer = (
                                indexer
                                .fit(train_df)
                                .transform(train_df)
                            )

understand_string_indexer.limit(5).display()

InvoiceNo,StockCode,Description,Quantity,InvoiceDate,UnitPrice,CustomerID,Country,day_of_week,day_of_week_index
537226,22811,SET OF 6 T-LIGHTS CACTI,6,2010-12-06T08:34:00Z,2.95,15987.0,United Kingdom,Monday,2.0
537226,21713,CITRONELLA CANDLE FLOWERPOT,8,2010-12-06T08:34:00Z,2.1,15987.0,United Kingdom,Monday,2.0
537226,22927,GREEN GIANT GARDEN THERMOMETER,2,2010-12-06T08:34:00Z,5.95,15987.0,United Kingdom,Monday,2.0
537226,20802,SMALL GLASS SUNDAE DISH CLEAR,6,2010-12-06T08:34:00Z,1.65,15987.0,United Kingdom,Monday,2.0
537226,22052,VINTAGE CARAVAN GIFT WRAP,25,2010-12-06T08:34:00Z,0.42,15987.0,United Kingdom,Monday,2.0


In [0]:
from pyspark.ml.feature import OneHotEncoder

encoder = (
    OneHotEncoder()
    .setInputCol("day_of_week_index")
    .setOutputCol("day_of_week_encoded")
)
    
understand_one_hot_encoder = (
    encoder
    .fit(understand_string_indexer)
    .transform(understand_string_indexer)
    )

understand_one_hot_encoder.limit(5).display()

InvoiceNo,StockCode,Description,Quantity,InvoiceDate,UnitPrice,CustomerID,Country,day_of_week,day_of_week_index,day_of_week_encoded
537226,22811,SET OF 6 T-LIGHTS CACTI,6,2010-12-06T08:34:00Z,2.95,15987.0,United Kingdom,Monday,2.0,"Map(vectorType -> sparse, length -> 5, indices -> List(2), values -> List(1.0))"
537226,21713,CITRONELLA CANDLE FLOWERPOT,8,2010-12-06T08:34:00Z,2.1,15987.0,United Kingdom,Monday,2.0,"Map(vectorType -> sparse, length -> 5, indices -> List(2), values -> List(1.0))"
537226,22927,GREEN GIANT GARDEN THERMOMETER,2,2010-12-06T08:34:00Z,5.95,15987.0,United Kingdom,Monday,2.0,"Map(vectorType -> sparse, length -> 5, indices -> List(2), values -> List(1.0))"
537226,20802,SMALL GLASS SUNDAE DISH CLEAR,6,2010-12-06T08:34:00Z,1.65,15987.0,United Kingdom,Monday,2.0,"Map(vectorType -> sparse, length -> 5, indices -> List(2), values -> List(1.0))"
537226,22052,VINTAGE CARAVAN GIFT WRAP,25,2010-12-06T08:34:00Z,0.42,15987.0,United Kingdom,Monday,2.0,"Map(vectorType -> sparse, length -> 5, indices -> List(2), values -> List(1.0))"


Since there is no Saturday. We have a vector of size 5.

In [0]:
understand_string_indexer.select('day_of_week', 'day_of_week_index').distinct().display()

day_of_week,day_of_week_index
Monday,2.0
Friday,4.0
Wednesday,3.0
Thursday,0.0
Sunday,5.0
Tuesday,1.0



features: [2.5, 6.0, 0.0, 0.0, 1.0, 0.0, 0.0]

Where:

2.5 is UnitPrice

6.0 is Quantity

Rest are the one-hot encoded values for the day of week

In [0]:
from pyspark.ml.feature import VectorAssembler

vector_assembler = (
    VectorAssembler()
    .setInputCols(["UnitPrice", "Quantity", "day_of_week_encoded"])
    .setOutputCol("features")
)

understand_vector_assembler = (
    vector_assembler
.transform(understand_one_hot_encoder)
    .select("features")
)

understand_vector_assembler.limit(5).display()

features
"Map(vectorType -> sparse, length -> 7, indices -> List(0, 1, 4), values -> List(2.95, 6.0, 1.0))"
"Map(vectorType -> sparse, length -> 7, indices -> List(0, 1, 4), values -> List(2.1, 8.0, 1.0))"
"Map(vectorType -> sparse, length -> 7, indices -> List(0, 1, 4), values -> List(5.95, 2.0, 1.0))"
"Map(vectorType -> sparse, length -> 7, indices -> List(0, 1, 4), values -> List(1.65, 6.0, 1.0))"
"Map(vectorType -> sparse, length -> 7, indices -> List(0, 1, 4), values -> List(0.42, 25.0, 1.0))"


### Combine using Pipelines

In [0]:
from pyspark.ml import Pipeline
transformation_pipeline = Pipeline(stages=[indexer, encoder, vector_assembler])

In [0]:
fitted_pipeline = transformation_pipeline.fit(train_df)

fitted_pipeline.transform(train_df)
train_transformed.limit(5).display()

InvoiceNo,StockCode,Description,Quantity,InvoiceDate,UnitPrice,CustomerID,Country,day_of_week,day_of_week_index,day_of_week_encoded,features
537226,22811,SET OF 6 T-LIGHTS CACTI,6,2010-12-06T08:34:00Z,2.95,15987.0,United Kingdom,Monday,2.0,"Map(vectorType -> sparse, length -> 5, indices -> List(2), values -> List(1.0))","Map(vectorType -> sparse, length -> 7, indices -> List(0, 1, 4), values -> List(2.95, 6.0, 1.0))"
537226,21713,CITRONELLA CANDLE FLOWERPOT,8,2010-12-06T08:34:00Z,2.1,15987.0,United Kingdom,Monday,2.0,"Map(vectorType -> sparse, length -> 5, indices -> List(2), values -> List(1.0))","Map(vectorType -> sparse, length -> 7, indices -> List(0, 1, 4), values -> List(2.1, 8.0, 1.0))"
537226,22927,GREEN GIANT GARDEN THERMOMETER,2,2010-12-06T08:34:00Z,5.95,15987.0,United Kingdom,Monday,2.0,"Map(vectorType -> sparse, length -> 5, indices -> List(2), values -> List(1.0))","Map(vectorType -> sparse, length -> 7, indices -> List(0, 1, 4), values -> List(5.95, 2.0, 1.0))"
537226,20802,SMALL GLASS SUNDAE DISH CLEAR,6,2010-12-06T08:34:00Z,1.65,15987.0,United Kingdom,Monday,2.0,"Map(vectorType -> sparse, length -> 5, indices -> List(2), values -> List(1.0))","Map(vectorType -> sparse, length -> 7, indices -> List(0, 1, 4), values -> List(1.65, 6.0, 1.0))"
537226,22052,VINTAGE CARAVAN GIFT WRAP,25,2010-12-06T08:34:00Z,0.42,15987.0,United Kingdom,Monday,2.0,"Map(vectorType -> sparse, length -> 5, indices -> List(2), values -> List(1.0))","Map(vectorType -> sparse, length -> 7, indices -> List(0, 1, 4), values -> List(0.42, 25.0, 1.0))"


In [0]:
train_transformed.cache()

DataFrame[InvoiceNo: string, StockCode: string, Description: string, Quantity: int, InvoiceDate: timestamp, UnitPrice: double, CustomerID: double, Country: string, day_of_week: string, day_of_week_index: double, day_of_week_encoded: vector, features: vector]

In [0]:
from pyspark.ml.clustering import KMeans

kmeans = KMeans().setK(20).setSeed(1)
kmeans_model = kmeans.fit(train_transformed)

In [0]:
kmeans_model.summary.trainingCost

80507237.08491144

In [0]:
test_transformed = fitted_pipeline.transform(test_df)
test_pred = kmeans_model.transform(test_transformed)

centers = kmeans_model.clusterCenters()

### For test data the cost has to be calculated manually

In [0]:
from pyspark.sql.functions import udf
from pyspark.sql.types import DoubleType
from pyspark.ml.linalg import Vectors

def squared_distance(v1, cluster_id):
    return float(v1.squared_distance(Vectors.dense(centers[cluster_id])))

distance_udf = udf(squared_distance, DoubleType())

# 7. Compute cost for test
test_pred = test_pred.withColumn("squared_error", distance_udf("features", "prediction"))
test_cost = test_pred.agg({"squared_error": "sum"}).collect()[0][0]
print("Test Cost:", test_cost)

Test Cost: 546096339.169738
