Skip to content

Commit

Permalink
Function to get quantiles of array (#656)
Browse files Browse the repository at this point in the history
* Quantile function

Signed-off-by: Henry Davidge <henry@davidge.me>

* python example

Signed-off-by: Henry Davidge <henry@davidge.me>

* fix python example

Signed-off-by: Henry Davidge <henry@davidge.me>

---------

Signed-off-by: Henry Davidge <henry@davidge.me>
Co-authored-by: Henry Davidge <henry@davidge.me>
  • Loading branch information
henrydavidge and Henry Davidge committed Apr 9, 2024
1 parent e08907f commit c075763
Show file tree
Hide file tree
Showing 5 changed files with 169 additions and 1 deletion.
18 changes: 18 additions & 0 deletions core/src/main/scala/io/projectglow/functions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,24 @@ object functions {
new io.projectglow.sql.expressions.SampleGqSummaryStatistics(genotypes.expr)
}

/**
* Array quantile
* @group quality_control
* @since 2.1.0
*
* @param arr An array of numeric values
* @param quantile The desired quantile
* @param is_sorted If true, the input array is assumed to already be sorted
* @return
*/
def array_quantile(arr: Column, quantile: Double, is_sorted: Column): Column = withExpr {
new io.projectglow.sql.expressions.ArrayQuantile(arr.expr, Literal(quantile), is_sorted.expr)
}

def array_quantile(arr: Column, quantile: Double): Column = withExpr {
new io.projectglow.sql.expressions.ArrayQuantile(arr.expr, Literal(quantile))
}

/**
* Performs a linear regression association test optimized for performance in a GWAS setting. See :ref:`linear-regression` for details.
* @group gwas_functions
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import org.apache.spark.sql.SQLUtils
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.UnresolvedExtractValue
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodegenFallback, ExprCode}
import org.apache.spark.sql.catalyst.expressions.{Add, Alias, BinaryExpression, CaseWhen, Cast, CreateNamedStruct, Divide, EqualTo, Exp, ExpectsInputTypes, Expression, Factorial, Generator, GenericInternalRow, GetStructField, If, ImplicitCastInputTypes, LessThan, Literal, Log, Multiply, NamedExpression, Pi, Round, Subtract, UnaryExpression, Unevaluable}
import org.apache.spark.sql.catalyst.expressions.{Add, Alias, ArraySort, BinaryExpression, CaseWhen, Cast, Ceil, CreateNamedStruct, Divide, EqualTo, Exp, ExpectsInputTypes, Expression, Factorial, Floor, Generator, GenericInternalRow, GetArrayItem, GetStructField, Greatest, If, ImplicitCastInputTypes, Least, LessThan, Literal, Log, Multiply, NamedExpression, Pi, Round, Size, Subtract, UnaryExpression, Unevaluable}
import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData}
import org.apache.spark.sql.types._
import io.projectglow.SparkShim.newUnresolvedException
Expand Down Expand Up @@ -314,5 +314,28 @@ case class LogFactorial(n: Expression) extends RewriteAfterResolution {
override def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression = {
copy(n = newChildren(0))
}
}

case class ArrayQuantile(arr: Expression, probability: Expression, isSorted: Expression)
extends RewriteAfterResolution {

def this(arr: Expression, probability: Expression) = this(arr, probability, Literal(false))
override def children: Seq[Expression] = Seq(arr, probability, isSorted)

private def getQuantile(arr: Expression): Expression = {
val trueIndex = Add(Multiply(probability, Subtract(Size(arr), Literal(1))), Literal(1))
val roundedIdx = Cast(trueIndex, IntegerType)
val below = GetArrayItem(arr, Greatest(Seq(Literal(0), Subtract(roundedIdx, Literal(1)))))
val above = GetArrayItem(arr, Least(Seq(Subtract(Size(arr), Literal(1)), roundedIdx)))
val frac = Subtract(trueIndex, roundedIdx)
Add(Multiply(frac, above), Multiply(Subtract(Literal(1), frac), below))
}

override def rewrite: Expression = {
If(isSorted, getQuantile(arr), getQuantile(new ArraySort(arr)))
}

override def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression = {
copy(arr = newChildren(0), probability = newChildren(1), isSorted = newChildren(2))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

package io.projectglow.tertiary

import com.google.common.math.Quantiles
import org.apache.spark.ml.linalg.{DenseMatrix, DenseVector, SparseVector, Vector}
import org.apache.spark.sql.{AnalysisException, DataFrame, Row}
import org.apache.spark.sql.functions._
Expand All @@ -24,6 +25,13 @@ import org.apache.spark.unsafe.types.UTF8String
import io.projectglow.sql.GlowBaseTest
import io.projectglow.sql.expressions.{VariantType, VariantUtilExprs}
import io.projectglow.functions._
import org.apache.commons.math3.stat.descriptive.rank.Percentile
import org.apache.commons.math3.stat.descriptive.rank.Percentile.EstimationType
import org.apache.commons.math3.stat.ranking.NaNStrategy
import org.apache.commons.math3.util.{KthSelector, MedianOf3PivotingStrategy}
import org.scalactic.TolerantNumerics

import scala.util.Random

class VariantUtilExprsSuite extends GlowBaseTest {
case class SimpleGenotypeFields(calls: Seq[Int])
Expand Down Expand Up @@ -334,6 +342,78 @@ class VariantUtilExprsSuite extends GlowBaseTest {
val df = spark.createDataFrame(Seq(Outer(Inner(1, "two"))))
assert(df.select(expand_struct(col("inner"))).as[Inner].head == Inner(1, "two"))
}

case class QuantileTest(
arr: Seq[Double],
p25: Double,
p50: Double,
p75: Double,
p90: Double,
p99: Double)
test("quantiles") {
def checkDf(df: DataFrame): Unit = {
val rows = df.collect()
rows.foreach { row =>
row.getAs[Double]("p25") ~== row.getAs[Double]("glow_25") relTol 0.02
row.getAs[Double]("p50") ~== row.getAs[Double]("glow_50") relTol 0.02
row.getAs[Double]("p75") ~== row.getAs[Double]("glow_75") relTol 0.02
row.getAs[Double]("p90") ~== row.getAs[Double]("glow_90") relTol 0.02
row.getAs[Double]("p99") ~== row.getAs[Double]("glow_99") relTol 0.02
}
}
val cases = Range(0, 50).map { n =>
val numbers = Range(0, (Random.nextDouble() * 1000).toInt).map(_ => Random.nextDouble())
val evaluator = new Percentile(1).withEstimationType(EstimationType.R_7)
val golden = Seq(25, 50, 75, 90, 99).map { d =>
evaluator.setQuantile(d)
d -> evaluator.evaluate(numbers.toArray)
}.toMap

QuantileTest(numbers, golden(25), golden(50), golden(75), golden(90), golden(99))

}

val df = spark.createDataFrame(cases)

// Unsorted
val glowUnsortedQuantiles = df.withColumns(
Map(
"glow_25" -> expr("array_quantile(arr, 0.25)"),
"glow_50" -> expr("array_quantile(arr, 0.50)"),
"glow_75" -> expr("array_quantile(arr, 0.75)"),
"glow_90" -> expr("array_quantile(arr, 0.90)"),
"glow_99" -> expr("array_quantile(arr, 0.99)")
))
checkDf(glowUnsortedQuantiles)

val sortedDf = df
.withColumn("sorted_arr", expr("array_sort(arr)"))
.withColumns(Map(
"glow_25" -> expr("array_quantile(sorted_arr, 0.25, true)"),
"glow_50" -> expr("array_quantile(sorted_arr, 0.50, true)"),
"glow_75" -> expr("array_quantile(sorted_arr, 0.75, true)"),
"glow_90" -> expr("array_quantile(sorted_arr, 0.90, true)"),
"glow_99" -> expr("array_quantile(sorted_arr, 0.99, true)")
))
checkDf(sortedDf)
}

test("quantiles respects sorted argument") {
val df = spark.createDataFrame(Seq(Tuple1(Seq(4, 3, 2, 1)))).withColumnRenamed("_1", "arr")
assert(df.selectExpr("array_quantile(arr, 1, true)").first().get(0) == 1)
assert(df.selectExpr("array_quantile(arr, 1, false)").first().get(0) == 4)
}

test("quantiles 0 length array") {
val df = spark.createDataFrame(Seq(Tuple1(Seq()))).withColumnRenamed("_1", "arr")
assert(df.selectExpr("array_quantile(arr, 1, true)").first().get(0) == null)
}

test("quantiles 1 length array") {
val df = spark.createDataFrame(Seq(Tuple1(Seq(5)))).withColumnRenamed("_1", "arr")
assert(df.selectExpr("cast(array_quantile(arr, 1, true) as int)").first().get(0) == 5)
assert(df.selectExpr("cast(array_quantile(arr, 0.001, true) as int)").first().get(0) == 5)
}
}

case class HCTestCase(
Expand Down
19 changes: 19 additions & 0 deletions functions.yml
Original file line number Diff line number Diff line change
Expand Up @@ -406,6 +406,25 @@ quality_control:
type: str
returns: Null if true, or throws an exception if not true

- name: array_quantile
doc: Array quantile
since: 2.1.0
expr_class: io.projectglow.sql.expressions.ArrayQuantile
args:
- name: arr
doc: An array of numeric values
- name: quantile
doc: The desired quantile
type: double
- name: is_sorted
doc: If true, the input array is assumed to already be sorted
is_optional: true
examples:
python: |
>>> df = spark.createDataFrame([Row(arr=[1, 2, 3, 4, 5])])
>>> df.select(glow.array_quantile(df.arr, 0.7).alias('p70')).collect()
[Row(p70=3.8)]
gwas_functions:
functions:
Expand Down
28 changes: 28 additions & 0 deletions python/glow/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -547,6 +547,34 @@ def sample_gq_summary_stats(genotypes: Union[Column, str]) -> Column:
output = Column(sc()._jvm.io.projectglow.functions.sample_gq_summary_stats(_to_java_column(genotypes)))
return output


__all__.append('array_quantile')
@typechecked
def array_quantile(arr: Union[Column, str], quantile: float, is_sorted: Union[Column, str] | None = None) -> Column:
"""
Array quantile
Added in version 2.1.0.
Examples:
>>> df = spark.createDataFrame([Row(arr=[1, 2, 3, 4, 5])])
>>> df.select(glow.array_quantile(df.arr, 0.7).alias('p70')).collect()
[Row(p70=3.8)]
Args:
arr : An array of numeric values
quantile : The desired quantile
is_sorted : If true, the input array is assumed to already be sorted
Returns:
"""
if is_sorted is None:
output = Column(sc()._jvm.io.projectglow.functions.array_quantile(_to_java_column(arr), quantile))
else:
output = Column(sc()._jvm.io.projectglow.functions.array_quantile(_to_java_column(arr), quantile, _to_java_column(is_sorted)))
return output

########### gwas_functions

__all__.append('linear_regression_gwas')
Expand Down

0 comments on commit c075763

Please sign in to comment.