forked from hail-is/hail
-
Notifications
You must be signed in to change notification settings - Fork 1
/
AggOp.scala
90 lines (83 loc) · 3.57 KB
/
AggOp.scala
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
package is.hail.expr.ir
import is.hail.expr.ir.agg._
import is.hail.types.TypeWithRequiredness
import is.hail.types.physical._
import is.hail.types.virtual._
import is.hail.utils.FastSeq
object AggSignature {
def prune(agg: AggSignature, requestedType: Type): AggSignature = agg match {
case AggSignature(Collect(), Seq(), Seq(_)) =>
AggSignature(Collect(), FastSeq(), FastSeq(requestedType.asInstanceOf[TArray].elementType))
case AggSignature(Take(), Seq(n), Seq(_)) =>
AggSignature(Take(), FastSeq(n), FastSeq(requestedType.asInstanceOf[TArray].elementType))
case AggSignature(ReservoirSample(), Seq(n), Seq(_)) =>
AggSignature(ReservoirSample(), FastSeq(n), FastSeq(requestedType.asInstanceOf[TArray].elementType))
case AggSignature(TakeBy(reverse), Seq(n), Seq(_, k)) =>
AggSignature(TakeBy(reverse), FastSeq(n), FastSeq(requestedType.asInstanceOf[TArray].elementType, k))
case AggSignature(PrevNonnull(), Seq(), Seq(_)) =>
AggSignature(PrevNonnull(), FastSeq(), FastSeq(requestedType))
case AggSignature(Densify(), Seq(), Seq(_)) =>
AggSignature(Densify(), FastSeq(), FastSeq(requestedType))
case _ => agg
}
}
case class AggSignature(
op: AggOp,
initOpArgs: Seq[Type],
seqOpArgs: Seq[Type]) {
// only to be used with virtual non-nested signatures on ApplyAggOp and ApplyScanOp
lazy val returnType: Type = Extract.getResultType(this)
}
sealed trait AggOp {}
final case class ApproxCDF() extends AggOp
final case class CallStats() extends AggOp
final case class Collect() extends AggOp
final case class CollectAsSet() extends AggOp
final case class Count() extends AggOp
final case class Downsample() extends AggOp
final case class LinearRegression() extends AggOp
final case class Max() extends AggOp
final case class Min() extends AggOp
final case class Product() extends AggOp
final case class Sum() extends AggOp
final case class Take() extends AggOp
final case class ReservoirSample() extends AggOp
final case class Densify() extends AggOp
final case class TakeBy(so: SortOrder = Ascending) extends AggOp
final case class Group() extends AggOp
final case class AggElements() extends AggOp
final case class AggElementsLengthCheck() extends AggOp
final case class PrevNonnull() extends AggOp
final case class ImputeType() extends AggOp
final case class NDArraySum() extends AggOp
final case class NDArrayMultiplyAdd() extends AggOp
final case class Fold() extends AggOp
// exists === map(p).sum, needs short-circuiting aggs
// forall === map(p).product, needs short-circuiting aggs
object AggOp {
val fromString: PartialFunction[String, AggOp] = {
case "approxCDF" | "ApproxCDF" => ApproxCDF()
case "collect" | "Collect" => Collect()
case "collectAsSet" | "CollectAsSet" => CollectAsSet()
case "sum" | "Sum" => Sum()
case "product" | "Product" => Product()
case "max" | "Max" => Max()
case "min" | "Min" => Min()
case "count" | "Count" => Count()
case "take" | "Take" => Take()
case "ReservoirSample" | "Take" => ReservoirSample()
case "densify" | "Densify" => Densify()
case "takeBy" | "TakeBy" => TakeBy()
case "callStats" | "CallStats" => CallStats()
case "linreg" | "LinearRegression" => LinearRegression()
case "downsample" | "Downsample" => Downsample()
case "prevnonnull" | "PrevNonnull" => PrevNonnull()
case "Group" => Group()
case "AggElements" => AggElements()
case "AggElementsLengthCheck" => AggElementsLengthCheck()
case "ImputeType" => ImputeType()
case "NDArraySum" => NDArraySum()
case "NDArrayMutiplyAdd" => NDArrayMultiplyAdd()
case "Fold" => Fold()
}
}