## source: https://docs.databricks.com/spark/latest/spark-sql/udaf-scala.html

In [1]:
spark.version

2.4.5

In [2]:
import java.sql.Date
import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.functions._
import org.apache.spark.sql.functions.udf

In [3]:
import org.apache.spark.sql.expressions.MutableAggregationBuffer
import org.apache.spark.sql.expressions.UserDefinedAggregateFunction
import org.apache.spark.sql.Row
import org.apache.spark.sql.types._

class DiffCount(val threshold:Double) extends UserDefinedAggregateFunction {
  // This is the input fields for your aggregate function.
  override def inputSchema: org.apache.spark.sql.types.StructType =
    StructType(StructField("sms", StringType) :: Nil)

  // This is the internal fields you keep for computing your aggregate.
  override def bufferSchema: StructType = StructType(
    StructField("count", LongType) ::
    StructField("base_string", StringType) :: 
    StructField("match_count", LongType) :: Nil
  )

  def similarity(str1: String, str2: String): Double = {
       if (str1 == str2) 1.0 else 0.0
  }
    
  // This is the output type of your aggregatation function.
  override def dataType: DataType = LongType

  // This is the initial value for your buffer schema.
  override def initialize(buffer: MutableAggregationBuffer): Unit = {
    buffer(0) = 0L
    buffer(1) = ""
    buffer(2) = 0L
  }

  override def deterministic: Boolean = true    
    
  // This is how to update your buffer schema given an input.
  override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
      
    val newString = input.getString(0)
            
    if(buffer.getAs[Long](0) == 0L){
        buffer(1) = newString
        buffer(0) = 1L
    }else{
        val baseString = buffer.getAs[String](1)
               
        val simil = similarity(baseString, newString)
        
        buffer(2) = buffer.getAs[Long](2) + simil.toLong
        buffer(0) = buffer.getAs[Long](0) + 1  
    }
  }

//   This is how to merge two objects with the bufferSchema type.
  override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {}
    
  // This is where you output the final value, given the final value of your bufferSchema.
  override def evaluate(buffer: Row): Any = buffer.getLong(2)
}

defined class DiffCount


In [4]:
val devicesDf = Seq(
    ("notebook"),
    ("notebook"),
    ("small phone"),
    ("camera"),
    ("small phone"),
    ("large phone"),
    ("camera"),
    ("small phone")
).toDF("sms")

devicesDf = [sms: string]


[sms: string]

In [5]:
val diff_count = new DiffCount(0.0)

devicesDf
    .withColumn("diff_count",diff_count(col("sms")) over Window.rowsBetween(0,Window.unboundedFollowing)).show()

+-----------+----------+
|        sms|diff_count|
+-----------+----------+
|   notebook|         1|
|   notebook|         0|
|small phone|         2|
|     camera|         1|
|small phone|         1|
|large phone|         0|
|     camera|         0|
|small phone|         0|
+-----------+----------+



diff_count = DiffCount@51b29ee9


DiffCount@51b29ee9