## Rewrite of Glow Tutorial w/ Sparse Data Structures

In [1]:
// Load samtools lib before glow as this is necessary to avoid this error on vcf writes:
// htsjdk.variant.variantcontext.VariantContextBuilder.getGenotypes() method not found
import $ivy.`com.github.samtools:htsjdk:2.21.1`
import $file.^.^.init.spark, spark._
import $file.^.^.init.paths, paths._
import $file.^.^.init.glow, glow._
import $file.^.^.init.benchmark, benchmark._
import $file.^.^.init.{plotly => init_plotly}, init_plotly._
import sys.process._
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.functions._
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.expressions.Window
import io.projectglow.Glow
import plotly._
import plotly.element._
import plotly.layout._
import plotly.Almond.{init => init_plotly_js, _}
import better.files.File
import org.apache.log4j.{Level, Logger}
import java.nio.file.Paths
Logger.getLogger("io.projectglow.plink").setLevel(Level.WARN)

val ss = getLocalSparkSession(shufflePartitions=1)
import ss.implicits._
Glow.register(ss)

def timeop[T](op: String)(block: => T) = optimer("glowmt", op, block)

init_plotly_js(offline=false)

val data_dir = GWAS_TUTORIAL_DATA_DIR / "1_QC_GWAS"

Compiling /home/eczech/repos/gwas-analysis/notebooks/init/paths.scCompiling /home/eczech/repos/gwas-analysis/notebooks/init/glow.scCompiling /home/eczech/repos/gwas-analysis/notebooks/init/benchmark.scCompiling /home/eczech/repos/gwas-analysis/notebooks/init/plotly.scLoading spark-stubs


SLF4J: Class path contains multiple SLF4J bindings.
SLF4J: Found binding in [jar:file:/home/eczech/.cache/coursier/v1/https/repo1.maven.org/maven2/org/slf4j/slf4j-log4j12/1.7.16/slf4j-log4j12-1.7.16.jar!/org/slf4j/impl/StaticLoggerBinder.class]
SLF4J: Found binding in [jar:file:/home/eczech/.cache/coursier/v1/https/repo1.maven.org/maven2/org/slf4j/slf4j-log4j12/1.7.25/slf4j-log4j12-1.7.25.jar!/org/slf4j/impl/StaticLoggerBinder.class]
SLF4J: See http://www.slf4j.org/codes.html#multiple_bindings for an explanation.
SLF4J: Actual binding is of type [org.slf4j.impl.Log4jLoggerFactory]


Creating SparkSession


Using Spark's default log4j profile: org/apache/spark/log4j-defaults.properties
20/01/16 04:58:48 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


[32mimport [39m[36m$ivy.$                                  
[39m
[32mimport [39m[36m$file.$           , spark._
[39m
[32mimport [39m[36m$file.$           , paths._
[39m
[32mimport [39m[36m$file.$          , glow._
[39m
[32mimport [39m[36m$file.$               , benchmark._
[39m
[32mimport [39m[36m$file.$                             , init_plotly._
[39m
[32mimport [39m[36msys.process._
[39m
[32mimport [39m[36morg.apache.spark.sql.DataFrame
[39m
[32mimport [39m[36morg.apache.spark.sql.functions._
[39m
[32mimport [39m[36morg.apache.spark.sql.SparkSession
[39m
[32mimport [39m[36morg.apache.spark.sql.expressions.Window
[39m
[32mimport [39m[36mio.projectglow.Glow
[39m
[32mimport [39m[36mplotly._
[39m
[32mimport [39m[36mplotly.element._
[39m
[32mimport [39m[36mplotly.layout._
[39m
[32mimport [39m[36mplotly.Almond.{init => init_plotly_js, _}
[39m
[32mimport [39m[36mbetter.files.File
[39m
[32mimport [39m[36morg.apache.log4j.

<h2><a id="load_raw_data">Load Raw Data</a></h2>

In [2]:
val path = data_dir / QC0_FILE + ".bed" toString
def df = ss.read.format("plink").load(path)
    .withColumn("variantId", $"names"(0))

defined [32mfunction[39m [36mdf[39m

In [3]:
def dp = ss.read.option("delimiter", " ")
    .csv(data_dir / QC0_FILE + ".fam" toString)
    .toDF("fid","iid","iidp","iidm", "gender", "phenotype")
    .withColumn("sampleId", concat($"fid", lit("_"), $"iid"))
dp.show(3)

+----+-------+----+----+------+---------+------------+
| fid|    iid|iidp|iidm|gender|phenotype|    sampleId|
+----+-------+----+----+------+---------+------------+
|1328|NA06989|   0|   0|     2|        2|1328_NA06989|
|1377|NA11891|   0|   0|     1|        2|1377_NA11891|
|1349|NA11843|   0|   0|     1|        1|1349_NA11843|
+----+-------+----+----+------+---------+------------+
only showing top 3 rows



defined [32mfunction[39m [36mdp[39m

In [4]:
class MatrixTable(val rows: DataFrame, val cols: DataFrame, val entries: DataFrame) {
    
    def save(path: String) = {
        if (!Paths.get(path).toFile.exists)
            Paths.get(path).toFile.mkdirs
        rows.write.format("parquet").mode("overwrite").save(Paths.get(path, "rows.parquet").toString)
        cols.write.format("parquet").mode("overwrite").save(Paths.get(path, "cols.parquet").toString)
        entries.repartition(16).write.format("parquet").mode("overwrite").save(Paths.get(path, "entries.parquet").toString)
    }
    
}

object MatrixTable {
    
    def load(path: String)(implicit ss: SparkSession) = {
        new MatrixTable(
            rows=ss.read.parquet(Paths.get(path, "rows.parquet").toString),
            cols=ss.read.parquet(Paths.get(path, "cols.parquet").toString),
            entries=ss.read.parquet(Paths.get(path, "entries.parquet").toString)
        )
    }
    
    def fromPLINKDataset(df: DataFrame, dp: DataFrame)(implicit ss: SparkSession) = {
        if (!df.schema.names.contains("variantId"))
            throw new IllegalArgumentException("Genotype data frame must contain 'variantId' field")
        if (!dp.schema.names.contains("sampleId"))
            throw new IllegalArgumentException("Pedigree data frame must contain 'sampleId' field")
        if (df.select("variantId").distinct.count != df.count)
            throw new IllegalArgumentException("Genotype field 'variantId' must be unique")
        if (dp.select("sampleId").distinct.count != dp.count)
            throw new IllegalArgumentException("Pedigree field 'sampleId' must be unique")
        
        val cols = df
            .withColumn("genotypes", explode($"genotypes"))
            .select("genotypes.sampleId")
            .dropDuplicates("sampleId")
            .join(dp, Seq("sampleId"), "left")
            .withColumn("colId", monotonically_increasing_id())
            .withColumn("colId", row_number.over(Window.orderBy($"colId")))

        val rows = df
            .drop("genotypes")
            .withColumn("rowId", monotonically_increasing_id())
            .withColumn("rowId", row_number.over(Window.orderBy($"rowId")))


        val entries = df
            .withColumn("state", expr("genotype_states(genotypes)"))
            .withColumn("sampleId", $"genotypes.sampleId")
            .selectExpr("variantId", "explode(arrays_zip(sampleId, state)) as gt")
            .select("variantId", "gt.*")
            .join(rows.select("rowId", "variantId"), Seq("variantId"), "inner")
            .join(cols.select("colId", "sampleId"), Seq("sampleId"), "inner")
            .select($"rowId", $"colId", $"state".cast("byte").as("state"))
            .filter($"state" =!= 0) // Ignore homozygous reference
        
        new MatrixTable(rows=rows, cols=cols, entries=entries)
    }
}

defined [32mclass[39m [36mMatrixTable[39m
defined [32mobject[39m [36mMatrixTable[39m

In [6]:
val mt = MatrixTable.fromPLINKDataset(df, dp)(ss)

[36mmt[39m: [32mMatrixTable[39m = ammonite.$sess.cmd3$Helper$MatrixTable@109f0e6c

In [7]:
(mt.rows.count, mt.cols.count)

[36mres6[39m: ([32mLong[39m, [32mLong[39m) = ([32m1457897L[39m, [32m165L[39m)

In [43]:
mt.rows.printSchema

root
 |-- contigName: string (nullable = true)
 |-- names: array (nullable = true)
 |    |-- element: string (containsNull = true)
 |-- position: double (nullable = true)
 |-- start: long (nullable = true)
 |-- end: long (nullable = true)
 |-- referenceAllele: string (nullable = true)
 |-- alternateAlleles: array (nullable = true)
 |    |-- element: string (containsNull = true)
 |-- variantId: string (nullable = true)
 |-- rowId: integer (nullable = true)



In [44]:
mt.cols.printSchema

root
 |-- sampleId: string (nullable = true)
 |-- fid: string (nullable = true)
 |-- iid: string (nullable = true)
 |-- iidp: string (nullable = true)
 |-- iidm: string (nullable = true)
 |-- gender: string (nullable = true)
 |-- phenotype: string (nullable = true)
 |-- colId: integer (nullable = true)



In [8]:
mt.entries.printSchema

root
 |-- rowId: integer (nullable = true)
 |-- colId: integer (nullable = true)
 |-- state: byte (nullable = true)



Save the processed PLINK data into this "Glow Matrix Table" (gmt) format:

In [10]:
mt.save(data_dir / QC0_FILE + ".gmt" toString)

20/01/15 23:14:14 WARN WindowExec: No Partition Defined for Window operation! Moving all data to a single partition, this can cause serious performance degradation.
20/01/15 23:14:25 WARN WindowExec: No Partition Defined for Window operation! Moving all data to a single partition, this can cause serious performance degradation.
20/01/15 23:14:36 WARN WindowExec: No Partition Defined for Window operation! Moving all data to a single partition, this can cause serious performance degradation.
20/01/15 23:14:36 WARN WindowExec: No Partition Defined for Window operation! Moving all data to a single partition, this can cause serious performance degradation.


Define a few simple calculations that can be run over a layout like this:

In [None]:
implicit class MatrixTableOps(mt: MatrixTable) {
    
    lazy val nCols: Long = mt.cols.count
    
    lazy val nRows: Long = mt.rows.count
    
    def count = (nRows, nCols)

    def filterCols(fn: DataFrame => DataFrame): MatrixTable = {
        val cols = fn(mt.cols)
        new MatrixTable(
            rows=mt.rows,
            cols=cols,
            entries = mt.entries.join(broadcast(cols.select("colId")), Seq("colId"), "leftsemi")
        )
    }
    
    def filterRows(fn: DataFrame => DataFrame): MatrixTable = {
        val rows = fn(mt.rows)
        new MatrixTable(
            rows=rows,
            cols=mt.cols,
            entries = mt.entries.join(broadcast(rows.select("rowId")), Seq("rowId"), "leftsemi")
        )
    }
    
    def getSampleStats() = {
        mt.entries
            .groupBy("colId").agg(
                sum(when($"state" === -1, 1).otherwise(0)).as("nUncalled"),
                sum(when($"state" === 1, 1).otherwise(0)).as("nHet")
            )
            .join(mt.cols.select("colId"), Seq("colId"), "right")
            .withColumn("nUncalled", coalesce($"nUncalled", lit(0)))
            .withColumn("nCalled", lit(nRows) - $"nUncalled")
            .withColumn("callRate", $"nCalled" / ($"nCalled" + $"nUncalled"))
    }
    
    def getVariantStats() = {
        mt.entries
            .groupBy("rowId").agg(
                sum(when($"state" === -1, 1).otherwise(0)).as("nUncalled"),
                sum(when($"state" === 1, 1).otherwise(0)).as("nHet")
            )
            .join(mt.rows.select("rowId"), Seq("rowId"), "right")
            .withColumn("nUncalled", coalesce($"nUncalled", lit(0)))
            .withColumn("nCalled", lit(nCols) - $"nUncalled")
            .withColumn("callRate", $"nCalled" / ($"nCalled" + $"nUncalled"))
    }
    
    def transform(fn: MatrixTable => MatrixTable): MatrixTable = {
        fn(mt)
    }
}

<h2><a id="step_1">Step 1: Sample/Variant Absence Filter</a></h2>

In [11]:
val mt_qc0 = MatrixTable.load(data_dir / QC0_FILE + ".gmt" toString)(ss)

[36mmt_qc0[39m: [32mMatrixTable[39m = ammonite.$sess.cmd3$Helper$MatrixTable@6333f3dc

In [13]:
timeop("qc1-count") {
    mt_qc0.count
}

Elapsed time: 0.1 seconds


[36mres12[39m: ([32mLong[39m, [32mLong[39m) = ([32m1457897L[39m, [32m165L[39m)

In [14]:
timeop("qc1-samplestats")(mt_qc0.getSampleStats())
    .fn(d => {
        Histogram(
            x=d.select("callRate").collect.map(_.getAs[Double]("callRate")).toList
        ).plot(title="Sample Call Rate Distribution", xaxis=Axis(title="Call Rate"))
    })

Elapsed time: 0.1 seconds


[36mres13[39m: [32mString[39m = [32m"plot-36d80bab-8a8d-4d19-807e-0e198774ee9a"[39m

In [15]:
timeop("qc1-variantstats")(mt_qc0.getVariantStats())
    .withColumn("bin", bround($"callRate"/.005)*.005)
    .groupBy("bin").count.sort($"bin".asc)
    .fn(d => {
        Bar(
            x=d.map(_.getAs[Double]("bin")).collect.toList,
            y=d.map(_.getAs[Long]("count")).collect.toList
        ).plot(
            title="Variant Call Rate Distribution", 
            xaxis=Axis(title="Call Rate"),
            yaxis=Axis(`type`=AxisType.Log, title="Num Variants")
        )
    })

Elapsed time: 0.1 seconds


[36mres14[39m: [32mString[39m = [32m"plot-bb3c3ea5-ae9a-46a0-9d61-e0fb85d89d25"[39m

In [16]:
def filterByVariantCallRate(threshold: Double)(mt: MatrixTable): MatrixTable = { 
    mt.filterRows(rows => {
        rows.join(
            mt.getVariantStats().filter($"callRate" >= threshold).select("rowId"),
            Seq("rowId"), "leftsemi"
        )
    })
}
def filterBySampleCallRate(threshold: Double)(mt: MatrixTable): MatrixTable = { 
    mt.filterCols(cols => {
        cols.join(
            mt.getSampleStats().filter($"callRate" >= threshold).select("colId"),
            Seq("colId"), "leftsemi"
        )
    })
}

defined [32mfunction[39m [36mfilterByVariantCallRate[39m
defined [32mfunction[39m [36mfilterBySampleCallRate[39m

In [18]:
val mt_qc1 = timeop("qc1") {
    mt_qc0
    .transform(filterByVariantCallRate(threshold=.8))
    .transform(filterBySampleCallRate(threshold=.8))
    .transform(filterByVariantCallRate(threshold=.98))
    .transform(filterBySampleCallRate(threshold=.98))
    .transform(m => {println(m.count); m})
}

(1430443,165)
Elapsed time: 63.2 seconds


[36mmt_qc1[39m: [32mMatrixTable[39m = ammonite.$sess.cmd3$Helper$MatrixTable@28cf04c1

In [19]:
mt_qc1.save(data_dir / QC1_FILE + ".gmt" toString)

<h2><a id="step_2">Step 2: Gender Discrepancy</a></h2>

In [22]:
val mt_qc1 = MatrixTable.load(data_dir / QC1_FILE + ".gmt" toString)(ss)

[36mmt_qc1[39m: [32mMatrixTable[39m = ammonite.$sess.cmd3$Helper$MatrixTable@191747fd

In [28]:
def mt_qc2_stat = mt_qc1
    .filterRows(df => {
        df.filter($"contigName" === "23")
    })
    .getSampleStats()
    .withColumn("hetRate", $"nHet" / $"nCalled")
    .withColumn("homRate", lit(1.0) - $"hetRate")
    .transform(d => d.join(mt_qc1.cols, Seq("colId")))

mt_qc2_stat.fn(d => {
        // Create histogram trace for each gender
        d.select("gender").dropDuplicates().map(_.getAs[String]("gender")).collect.map(g => 
            Histogram(
                x=d.filter($"gender" === g).map(_.getAs[Double]("homRate")).collect.toList, 
                // See https://www.cog-genomics.org/plink/1.9/formats#fam for encoding
                name=Map("1" -> "Male", "2" -> "Female")(g),
                xbins=Bins(0.7, 1.01, .01)
            )
        ).toSeq.plot(
            title="Sex Chromosome Homozygosity Rates",
            yaxis=Axis(`type`=AxisType.Log, title="Sample Count"),
            xaxis=Axis(title="Homozygosity Rate", range=(0.7, 1.01))
        )
    })

defined [32mfunction[39m [36mmt_qc2_stat[39m
[36mres27_1[39m: [32mString[39m = [32m"plot-3bd21485-df11-4830-b177-8b0de2d9a25a"[39m

In [29]:
val mt_qc2 = timeop("qc2") {
    mt_qc1
    .transform(mt => {
        new MatrixTable(
            rows=mt.rows,
            cols=mt.cols.drop("gender").join(
                mt_qc2_stat
                    .withColumn("gender", when($"gender" === "2" && $"homRate" > .9, "1").otherwise($"gender"))
                    .select("colId", "gender"),
                Seq("colId"), "inner"
            ),
            entries=mt.entries
        )
    })
    .transform(mt => {println(mt.count); mt})
}

(1430443,165)
Elapsed time: 0.9 seconds


[36mmt_qc2[39m: [32mMatrixTable[39m = ammonite.$sess.cmd3$Helper$MatrixTable@4d943b98

In [None]:
mt_qc2.save(data_dir / QC2_FILE + ".gmt" toString)

<h2><a id="step_3">Step 3: Autosomal Variants and MAF Filtering</a></h2>

<h2><a id="step_4">Step 4: Hardy-Weinberg Equilibrium Filtering</a></h2>

Just need nHet + nHomAlt + nHomRef for this:

- https://github.com/projectglow/glow/blob/master/core/src/main/scala/io/projectglow/sql/expressions/VariantQcExprs.scala#L54
- import io.projectglow.sql.util.LeveneHaldanea