Skip to content

Commit

Permalink
Add comments and doc. Move some classes to the right places.
Browse files Browse the repository at this point in the history
  • Loading branch information
yhuai committed Jul 17, 2015
1 parent a19fea6 commit 0a827b3
Show file tree
Hide file tree
Showing 5 changed files with 278 additions and 225 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -279,77 +279,5 @@ case class Average(child: Expression) extends AlgebraicAggregate {
override def children: Seq[Expression] = child :: Nil
}

abstract class AggregationBuffer(
toCatalystConverters: Array[Any => Any],
toScalaConverters: Array[Any => Any],
bufferOffset: Int)
extends Row {

override def length: Int = toCatalystConverters.length

protected val offsets: Array[Int] = {
val newOffsets = new Array[Int](length)
var i = 0
while (i < newOffsets.length) {
newOffsets(i) = bufferOffset + i
i += 1
}
newOffsets
}
}

class MutableAggregationBuffer(
toCatalystConverters: Array[Any => Any],
toScalaConverters: Array[Any => Any],
bufferOffset: Int,
var underlyingBuffer: MutableRow)
extends AggregationBuffer(toCatalystConverters, toScalaConverters, bufferOffset) {

override def apply(i: Int): Any = {
if (i >= length || i < 0) {
throw new IllegalArgumentException(
s"Could not access ${i}th value in this buffer because it only has $length values.")
}
toScalaConverters(i)(underlyingBuffer(offsets(i)))
}

def update(i: Int, value: Any): Unit = {
if (i >= length || i < 0) {
throw new IllegalArgumentException(
s"Could not update ${i}th value in this buffer because it only has $length values.")
}
underlyingBuffer.update(offsets(i), toCatalystConverters(i)(value))
}

override def copy(): MutableAggregationBuffer = {
new MutableAggregationBuffer(
toCatalystConverters,
toScalaConverters,
bufferOffset,
underlyingBuffer)
}
}

class InputAggregationBuffer(
toCatalystConverters: Array[Any => Any],
toScalaConverters: Array[Any => Any],
bufferOffset: Int,
var underlyingInputBuffer: Row)
extends AggregationBuffer(toCatalystConverters, toScalaConverters, bufferOffset) {

override def apply(i: Int): Any = {
if (i >= length || i < 0) {
throw new IllegalArgumentException(
s"Could not access ${i}th value in this buffer because it only has $length values.")
}
toScalaConverters(i)(underlyingInputBuffer(offsets(i)))
}

override def copy(): InputAggregationBuffer = {
new InputAggregationBuffer(
toCatalystConverters,
toScalaConverters,
bufferOffset,
underlyingInputBuffer)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package org.apache.spark.sql

import org.apache.spark.Logging
import org.apache.spark.sql.catalyst.expressions.{Expression}
import org.apache.spark.sql.execution.expressions.aggregate2.{ScalaUDAF, UserDefinedAggregateFunction}
import org.apache.spark.sql.expressions.aggregate2.{ScalaUDAF, UserDefinedAggregateFunction}


class UDAFRegistration private[sql] (sqlContext: SQLContext) extends Logging {
Expand Down

This file was deleted.

Loading

0 comments on commit 0a827b3

Please sign in to comment.