Skip to content

Commit

Permalink
Make variant splitter work with Photon (#655)
Browse files Browse the repository at this point in the history
* update split expression

Signed-off-by: Henry Davidge <henry@davidge.me>

* Add test for assembly jar

Signed-off-by: Henry Davidge <henry@davidge.me>

* add header

Signed-off-by: Henry Davidge <henry@davidge.me>

* finish making splitter run in photon

Signed-off-by: Henry Davidge <henry@davidge.me>

* add comment

Signed-off-by: Henry Davidge <henry@davidge.me>

* add back doc

Signed-off-by: Henry Davidge <henry@davidge.me>

* less work when not splitting

Signed-off-by: Henry Davidge <henry@davidge.me>

* fmt

Signed-off-by: Henry Davidge <henry@davidge.me>

---------

Signed-off-by: Henry Davidge <henry@davidge.me>
Co-authored-by: Henry Davidge <henry@davidge.me>
  • Loading branch information
henrydavidge and Henry Davidge committed Apr 8, 2024
1 parent 6622f94 commit e08907f
Show file tree
Hide file tree
Showing 9 changed files with 224 additions and 166 deletions.
3 changes: 3 additions & 0 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,9 @@ jobs:
- name: Build artifacts
run: bin/build --scala --python

- name: Test assembly jar
run: java -cp core/target/**/glow*assembly*.jar io.projectglow.TestAssemblyJar

- name: Upload artifacts
uses: actions/upload-artifact@v4
if: success() || failure()
Expand Down
3 changes: 1 addition & 2 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,6 @@ lazy val commonSettings = Seq(
Test / test := ((Test / test) dependsOn (Test / headerCheck)).value,
assembly / test := {},
assembly / assemblyMergeStrategy := {
// Assembly jar is not executable
case p if p.toLowerCase.contains("manifest.mf") =>
MergeStrategy.discard
case _ =>
Expand Down Expand Up @@ -184,7 +183,7 @@ ThisBuild / coreDependencies := (providedSparkDependencies.value ++ testCoreDepe
"com.github.broadinstitute" % "picard" % "2.27.5",
"org.apache.commons" % "commons-lang3" % "3.14.0",
// Fix versions of libraries that are depended on multiple times
"org.apache.hadoop" % "hadoop-client" % "3.4.0",
"org.apache.hadoop" % "hadoop-client" % "3.3.6",
"io.netty" % "netty-all" % "4.1.96.Final",
"io.netty" % "netty-handler" % "4.1.96.Final",
"io.netty" % "netty-transport-native-epoll" % "4.1.96.Final",
Expand Down
23 changes: 23 additions & 0 deletions core/src/main/scala/io/projectglow/TestAssemblyJar.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
/*
* Copyright 2019 The Glow Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package io.projectglow

object TestAssemblyJar {
def main(args: Array[String]): Unit = {
println("Assembly jar works") // scalastyle:ignore
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,14 @@
package io.projectglow.sql.expressions

import io.projectglow.sql.util.{Rewrite, RewriteAfterResolution}

import org.apache.spark.ml.linalg.Vectors
import org.apache.spark.sql.SQLUtils
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.UnresolvedExtractValue
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodegenFallback, ExprCode}
import org.apache.spark.sql.catalyst.expressions.{Alias, CreateNamedStruct, ExpectsInputTypes, Expression, Generator, GenericInternalRow, GetStructField, ImplicitCastInputTypes, Literal, NamedExpression, UnaryExpression, Unevaluable}
import org.apache.spark.sql.catalyst.expressions.{Add, Alias, BinaryExpression, CaseWhen, Cast, CreateNamedStruct, Divide, EqualTo, Exp, ExpectsInputTypes, Expression, Factorial, Generator, GenericInternalRow, GetStructField, If, ImplicitCastInputTypes, LessThan, Literal, Log, Multiply, NamedExpression, Pi, Round, Subtract, UnaryExpression, Unevaluable}
import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData}
import org.apache.spark.sql.types._

import io.projectglow.SparkShim.newUnresolvedException

/**
Expand Down Expand Up @@ -231,3 +229,90 @@ object VectorToArray {
new GenericArrayData(vectorType.deserialize(input).toArray)
}
}

case class Comb(n: Expression, k: Expression) extends RewriteAfterResolution {
override def children: Seq[Expression] = Seq(n, k)

override def rewrite: Expression = {
Cast(
Round(
Exp(Subtract(Subtract(LogFactorial(n), LogFactorial(k)), LogFactorial(Subtract(n, k)))),
Literal(0)),
LongType)
}

override def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression = {
copy(n = newChildren(0), k = newChildren(1))
}
}

/**
* Note: not user facing, approximate for n > 47
*/
case class LogFactorial(n: Expression) extends RewriteAfterResolution {
override def children: Seq[Expression] = Seq(n)

override def rewrite: Expression = {
Add
CaseWhen(
Seq(
(EqualTo(n, Literal(0)), Literal(0.0)),
(EqualTo(n, Literal(1)), Literal(0.0)),
(EqualTo(n, Literal(2)), Literal(0.693147180559945)),
(EqualTo(n, Literal(3)), Literal(1.7917594692280554)),
(EqualTo(n, Literal(4)), Literal(3.178053830347945)),
(EqualTo(n, Literal(5)), Literal(4.787491742782047)),
(EqualTo(n, Literal(6)), Literal(6.579251212010102)),
(EqualTo(n, Literal(7)), Literal(8.525161361065415)),
(EqualTo(n, Literal(8)), Literal(10.604602902745249)),
(EqualTo(n, Literal(9)), Literal(12.801827480081467)),
(EqualTo(n, Literal(10)), Literal(15.104412573075514)),
(EqualTo(n, Literal(11)), Literal(17.502307845873887)),
(EqualTo(n, Literal(12)), Literal(19.987214495661885)),
(EqualTo(n, Literal(13)), Literal(22.55216385312342)),
(EqualTo(n, Literal(14)), Literal(25.191221182738683)),
(EqualTo(n, Literal(15)), Literal(27.89927138384089)),
(EqualTo(n, Literal(16)), Literal(30.671860106080672)),
(EqualTo(n, Literal(17)), Literal(33.50507345013689)),
(EqualTo(n, Literal(18)), Literal(36.39544520803305)),
(EqualTo(n, Literal(19)), Literal(39.339884187199495)),
(EqualTo(n, Literal(20)), Literal(42.335616460753485)),
(EqualTo(n, Literal(21)), Literal(45.38013889847691)),
(EqualTo(n, Literal(22)), Literal(48.47118135183522)),
(EqualTo(n, Literal(23)), Literal(51.60667556776438)),
(EqualTo(n, Literal(24)), Literal(54.78472939811232)),
(EqualTo(n, Literal(25)), Literal(58.00360522298052)),
(EqualTo(n, Literal(26)), Literal(61.26170176100201)),
(EqualTo(n, Literal(27)), Literal(64.55753862700634)),
(EqualTo(n, Literal(28)), Literal(67.88974313718153)),
(EqualTo(n, Literal(29)), Literal(71.257038967168)),
(EqualTo(n, Literal(30)), Literal(74.65823634883017)),
(EqualTo(n, Literal(31)), Literal(78.0922235533153)),
(EqualTo(n, Literal(32)), Literal(81.55795945611503)),
(EqualTo(n, Literal(33)), Literal(85.05446701758152)),
(EqualTo(n, Literal(34)), Literal(88.58082754219768)),
(EqualTo(n, Literal(35)), Literal(92.1361756036871)),
(EqualTo(n, Literal(36)), Literal(95.7196945421432)),
(EqualTo(n, Literal(37)), Literal(99.33061245478743)),
(EqualTo(n, Literal(38)), Literal(102.96819861451381)),
(EqualTo(n, Literal(39)), Literal(106.63176026064346)),
(EqualTo(n, Literal(40)), Literal(110.32063971475738)),
(EqualTo(n, Literal(41)), Literal(114.03421178146169)),
(EqualTo(n, Literal(42)), Literal(117.77188139974507)),
(EqualTo(n, Literal(43)), Literal(121.53308151543864)),
(EqualTo(n, Literal(44)), Literal(125.3172711493569)),
(EqualTo(n, Literal(45)), Literal(129.12393363912722)),
(EqualTo(n, Literal(46)), Literal(132.95257503561632)),
(EqualTo(n, Literal(47)), Literal(136.80272263732635))
),
Some(
Add(
Subtract(Multiply(Add(n, Literal(0.5)), Log(n)), n),
Multiply(Literal(0.5), Log(Multiply(Literal(2), Pi())))))
)
}
override def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression = {
copy(n = newChildren(0))
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import htsjdk.variant.vcf.VCFHeaderLineCount
import io.projectglow.common.GlowLogging
import io.projectglow.common.VariantSchemas._
import io.projectglow.vcf.{InternalRowToVariantContextConverter, VCFSchemaInferrer}
import org.apache.commons.math3.util.CombinatoricsUtils
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.SQLUtils.structFieldsEqualExceptNullability
import org.apache.spark.sql.functions._
Expand Down Expand Up @@ -92,7 +93,9 @@ private[projectglow] object VariantSplitter extends GlowLogging {
lit(":"),
col(startField.name) + 1,
lit(":"),
concat_ws("/", col(refAlleleField.name), col(alternateAllelesField.name))
array_join(
concat(array(col(refAlleleField.name)), col(alternateAllelesField.name)),
"/")
)
).otherwise(lit(null))
)
Expand Down Expand Up @@ -165,19 +168,11 @@ private[projectglow] object VariantSplitter extends GlowLogging {
.fields
.map(field =>
field.name ->
expr(s"transform(${genotypesFieldName}, g -> g.${field.name})"))
when(
col(splitFromMultiAllelicField.name),
expr(s"transform(${genotypesFieldName}, g -> g.${field.name})")).otherwise(array()))
val withExtractedFields = variantDf.withColumns(extractedFields.toMap)

// register the udf that genotypes splitter uses
withExtractedFields
.sqlContext
.udf
.register(
"likelihoodSplitUdf",
(numAlleles: Int, ploidy: Int, alleleIdx: Int) =>
refAltColexOrderIdxArray(numAlleles, ploidy, alleleIdx)
)

// update pulled-out genotypes columns, zip them back together as the new genotypes column,
// and drop the pulled-out columns
// Note: In performance tests, it was seen that nested transform sql functions used below work twice faster if
Expand All @@ -193,18 +188,15 @@ private[projectglow] object VariantSplitter extends GlowLogging {
structFieldsEqualExceptNullability(phredLikelihoodsField, f) |
structFieldsEqualExceptNullability(posteriorProbabilitiesField, f) =>
// update genotypes subfields that have colex order using the udf
f.name -> when(
col(splitFromMultiAllelicField.name),
expr(s"""transform(${f.name}, c ->
f.name ->
expr(s"""transform(${f.name}, c ->
| filter(
| transform(
| c, (x, idx) ->
| if (
| array_contains(
| likelihoodSplitUdf(
| size(${alternateAllelesField.name}) + 1,
| size(${callsField.name}[0]),
| $splitAlleleIdxFieldName + 1
| transform(array_repeat(0, size(${callsField.name}[0]) + 1), (el, i) ->
| comb(size(${callsField.name}[0]) + $splitAlleleIdxFieldName + 1, size(${callsField.name}[0])) - comb(size(${callsField.name}[0]) + $splitAlleleIdxFieldName + 1 - i, size(${callsField.name}[0]) - i)
| ),
| idx
| ), x, null
Expand All @@ -213,90 +205,33 @@ private[projectglow] object VariantSplitter extends GlowLogging {
| x -> !isnull(x)
| )
| )""".stripMargin)
).otherwise(col(f.name))

case f if structFieldsEqualExceptNullability(callsField, f) =>
// update GT calls subfield
f.name -> when(
col(splitFromMultiAllelicField.name),
expr(
s"transform(${f.name}, " +
s"c -> transform(c, x -> if(x == 0, x, if(x == $splitAlleleIdxFieldName + 1, 1, -1))))"
)
).otherwise(col(f.name))
f.name ->
expr(
s"transform(${f.name}, " +
s"c -> transform(c, x -> if(x == 0, x, if(x == $splitAlleleIdxFieldName + 1, 1, -1))))"
)

case f if f.dataType.isInstanceOf[ArrayType] =>
// update any ArrayType field with number of elements equal to number of alt alleles
f.name -> when(
col(splitFromMultiAllelicField.name),
expr(
s"transform(${f.name}, c -> if(size(c) == size(${alternateAllelesField.name}) + 1," +
s" array(c[0], c[$splitAlleleIdxFieldName + 1]), null))"
)
).otherwise(col(f.name))
f.name ->
expr(
s"transform(${f.name}, c -> if(size(c) == size(${alternateAllelesField.name}) + 1," +
s" array(c[0], c[$splitAlleleIdxFieldName + 1]), null))"
)
}

withExtractedFields
.withColumns(updatedColumns.toMap)
.withColumn(genotypesFieldName, arrays_zip(gSchema.get.fieldNames.map(col(_)): _*))
.withColumn(
genotypesFieldName,
when(
col(splitFromMultiAllelicField.name),
arrays_zip(gSchema.get.fieldNames.map(col(_)): _*)).otherwise(col(genotypesFieldName)))
.drop(gSchema.get.fieldNames: _*)
}

}

/**
* Given the total number of (ref and alt) alleles (numAlleles), ploidy, and the index an alt allele of interest
* (altAlleleIdx), generates an array of indices of genotypes that only include the ref allele and/or that alt allele
* of interest in the colex ordering of all possible genotypes. The function is general and correctly calculates
* the index array for any given set of values for its arguments.
*
* Example:
* Assume numAlleles = 3 (say A,B,C), ploidy = 2, and altAlleleIdx = 2 (i.e., C)
* Therefore, colex ordering of all possible genotypes is: AA, AB, BB, AC, BC, CC
* and for example refAltColexOrderIdxArray(3, 2, 2) = Array(0, 3, 5)
*
* @param numAlleles : total number of alleles (ref and alt)
* @param ploidy : ploidy
* @param altAlleleIdx : index of alt allele of interest
* @return array of indices of genotypes that only include the ref allele and alt allele
* of interest in the colex ordering of all possible genotypes.
*/
@VisibleForTesting
private[splitmultiallelics] def refAltColexOrderIdxArray(
numAlleles: Int,
ploidy: Int,
altAlleleIdx: Int): Array[Int] = {

if (ploidy < 1) {
throw new IllegalArgumentException("Ploidy must be at least 1.")
}
if (numAlleles < 2) {
throw new IllegalArgumentException(
"Number of alleles must be at least 2 (one REF and at least one ALT).")
}
if (altAlleleIdx > numAlleles - 1 || altAlleleIdx < 1) {
throw new IllegalArgumentException(
"Alternate allele index must be at least 1 and at most one less than number of alleles.")
}

val idxArray = new Array[Int](ploidy + 1)

// generate vector of elements at positions p+1,p,...,2 on the altAlleleIdx'th diagonal of Pascal's triangle
idxArray(0) = 0
var i = 1
idxArray(ploidy) = altAlleleIdx
while (i < ploidy) {
idxArray(ploidy - i) = idxArray(ploidy - i + 1) * (i + altAlleleIdx) / (i + 1)
i += 1
}

// calculate the cumulative vector
i = 1
while (i <= ploidy) {
idxArray(i) = idxArray(i) + idxArray(i - 1)
i += 1
}
idxArray
}

@VisibleForTesting
Expand Down

0 comments on commit e08907f

Please sign in to comment.