Skip to content

Commit

Permalink
kill shuffle mode, support two-bit ref-files
Browse files Browse the repository at this point in the history
  • Loading branch information
ryan-williams committed Sep 8, 2015
1 parent 7095b54 commit 1f797ee
Show file tree
Hide file tree
Showing 8 changed files with 127 additions and 110 deletions.
Expand Up @@ -87,8 +87,6 @@ class TransformArgs extends Args4jBase with ADAMSaveAnyArgs with ParquetArgs {
var asSingleFile: Boolean = false
@Args4jOption(required = false, name = "-add_md_tags", usage = "Add MD Tags to reads based on the FASTA (or equivalent) file passed to this option.")
var mdTagsReferenceFile: String = null
@Args4jOption(required = false, name = "-md_tag_shuffle", usage = "When adding MD tags to reads, use a shuffle join (as opposed to default region join).")
var mdTagsShuffle: Boolean = false
@Args4jOption(required = false, name = "-md_tag_fragment_size", usage = "When adding MD tags to reads, load the reference in fragments of this size.")
var mdTagsFragmentSize: Long = 1000000L
@Args4jOption(required = false, name = "-md_tag_overwrite", usage = "When adding MD tags to reads, overwrite existing incorrect tags.")
Expand Down Expand Up @@ -158,7 +156,6 @@ class Transform(protected val args: TransformArgs) extends BDGSparkCommand[Trans
MDTagging(
adamRecords,
args.mdTagsReferenceFile,
shuffle = args.mdTagsShuffle,
fragmentLength = args.mdTagsFragmentSize,
overwriteExistingTags = args.mdTagsOverwrite,
validationStringency = stringencyOpt.getOrElse(ValidationStringency.STRICT)
Expand Down
12 changes: 12 additions & 0 deletions adam-core/src/main/scala/org/bdgenomics/adam/rdd/ADAMContext.scala
Expand Up @@ -17,6 +17,7 @@
*/
package org.bdgenomics.adam.rdd

import java.io.File
import java.io.FileNotFoundException
import java.util.regex.Pattern
import htsjdk.samtools.SAMFileHeader
Expand All @@ -39,8 +40,10 @@ import org.bdgenomics.adam.rdd.features._
import org.bdgenomics.adam.rdd.read.AlignmentRecordRDDFunctions
import org.bdgenomics.adam.rdd.variation._
import org.bdgenomics.adam.rich.RichAlignmentRecord
import org.bdgenomics.adam.util.{ TwoBitFile, ReferenceContigMap, ReferenceFile }
import org.bdgenomics.formats.avro._
import org.bdgenomics.utils.instrumentation.Metrics
import org.bdgenomics.utils.io.LocalFileByteAccess
import org.bdgenomics.utils.misc.HadoopUtil
import org.seqdoop.hadoop_bam.util.SAMHeaderReader
import org.seqdoop.hadoop_bam._
Expand Down Expand Up @@ -537,6 +540,15 @@ class ADAMContext(val sc: SparkContext) extends Serializable with Logging {
loadFeatures(filePath, projection).asGenes()
}

def loadReferenceFile(filePath: String, fragmentLength: Long): ReferenceFile = {
if (filePath.endsWith(".2bit")) {
//TODO(ryan): S3ByteAccess
new TwoBitFile(new LocalFileByteAccess(new File(filePath)))
} else {
ReferenceContigMap(loadSequence(filePath, fragmentLength = fragmentLength))
}
}

def loadSequence(
filePath: String,
projection: Option[Schema] = None,
Expand Down
107 changes: 7 additions & 100 deletions adam-core/src/main/scala/org/bdgenomics/adam/rdd/read/MDTagging.scala
Expand Up @@ -24,13 +24,11 @@ import org.apache.spark.Logging
import org.apache.spark.SparkContext._
import org.apache.spark.rdd.RDD
import org.bdgenomics.adam.models.ReferenceRegion
import org.bdgenomics.adam.rdd.ShuffleRegionJoin
import org.bdgenomics.adam.util.MdTag
import org.bdgenomics.formats.avro.{ AlignmentRecord, NucleotideContigFragment }
import org.bdgenomics.adam.util.{ ReferenceFile, MdTag }
import org.bdgenomics.formats.avro.AlignmentRecord

case class MDTagging(reads: RDD[AlignmentRecord],
referenceFragments: RDD[NucleotideContigFragment],
shuffle: Boolean = false,
@transient referenceFile: ReferenceFile,
partitionSize: Long = 1000000,
overwriteExistingTags: Boolean = false,
validationStringency: ValidationStringency = ValidationStringency.STRICT) extends Logging {
Expand All @@ -41,12 +39,7 @@ case class MDTagging(reads: RDD[AlignmentRecord],
val numUnmappedReads = sc.accumulator(0L, "Unmapped Reads")
val incorrectMDTags = sc.accumulator(0L, "Incorrect Extant MDTags")

val taggedReads =
(if (shuffle) {
addMDTagsShuffle
} else {
addMDTagsBroadcast
}).cache
val taggedReads = addMDTagsBroadcast.cache

def maybeMDTagRead(read: AlignmentRecord, refSeq: String): AlignmentRecord = {

Expand Down Expand Up @@ -75,118 +68,32 @@ case class MDTagging(reads: RDD[AlignmentRecord],
}

def addMDTagsBroadcast(): RDD[AlignmentRecord] = {
val collectedRefMap =
referenceFragments
.groupBy(_.getContig.getContigName)
.mapValues(_.toSeq.sortBy(_.getFragmentStartPosition))
.collectAsMap
.toMap

log.info(s"Found contigs named: ${collectedRefMap.keys.mkString(", ")}")

val refMapB = sc.broadcast(collectedRefMap)

def getRefSeq(contigName: String, read: AlignmentRecord): String = {
val readStart = read.getStart
val readEnd = readStart + read.referenceLength

val fragments =
refMapB
.value
.getOrElse(
contigName,
throw new Exception(
s"Contig $contigName not found in reference map with keys: ${refMapB.value.keys.mkString(", ")}"
)
)
.dropWhile(f => f.getFragmentStartPosition + f.getFragmentSequence.length < readStart)
.takeWhile(_.getFragmentStartPosition < readEnd)

getReferenceBasesForRead(read, fragments)
}
val referenceFileB = sc.broadcast(referenceFile)
reads.map(read => {
(for {
contig <- Option(read.getContig)
contigName <- Option(contig.getContigName)
if read.getReadMapped
} yield {
maybeMDTagRead(read, getRefSeq(contigName, read))
maybeMDTagRead(read, referenceFileB.value.extract(ReferenceRegion(read)))
}).getOrElse({
numUnmappedReads += 1
read
})
})
}

def addMDTagsShuffle(): RDD[AlignmentRecord] = {
val fragsWithRegions =
for {
fragment <- referenceFragments
region <- ReferenceRegion(fragment)
} yield {
region -> fragment
}

val unmappedReads = reads.filter(!_.getReadMapped)
numUnmappedReads += unmappedReads.count

val readsWithRegions =
for {
read <- reads
region <- ReferenceRegion.opt(read)
} yield region -> read

val sd = reads.adamGetSequenceDictionary()

val readsWithFragments =
ShuffleRegionJoin(sd, partitionSize)
.partitionAndJoin(readsWithRegions, fragsWithRegions)
.groupByKey
.mapValues(_.toSeq.sortBy(_.getFragmentStartPosition))

(for {
(read, fragments) <- readsWithFragments
} yield {
maybeMDTagRead(read, getReferenceBasesForRead(read, fragments))
}) ++ unmappedReads
}

private def getReferenceBasesForRead(read: AlignmentRecord, fragments: Seq[NucleotideContigFragment]): String = {
fragments.map(clipFragment(_, read)).mkString("")
}

private def clipFragment(fragment: NucleotideContigFragment, read: AlignmentRecord): String = {
clipFragment(fragment, read.getStart, read.getStart + read.referenceLength)
}
private def clipFragment(fragment: NucleotideContigFragment, start: Long, end: Long): String = {
val min =
math.max(
0L,
start - fragment.getFragmentStartPosition
).toInt

val max =
math.min(
fragment.getFragmentSequence.length,
end - fragment.getFragmentStartPosition
).toInt

fragment.getFragmentSequence.substring(min, max)
}
}

object MDTagging {
def apply(reads: RDD[AlignmentRecord],
referenceFile: String,
shuffle: Boolean,
fragmentLength: Long,
overwriteExistingTags: Boolean,
validationStringency: ValidationStringency): RDD[AlignmentRecord] = {
val sc = reads.sparkContext
new MDTagging(
reads,
sc.loadSequence(referenceFile, fragmentLength = fragmentLength),
shuffle = shuffle,
sc.loadReferenceFile(referenceFile, fragmentLength = fragmentLength),
partitionSize = fragmentLength,
overwriteExistingTags,
validationStringency
Expand Down
Expand Up @@ -23,6 +23,7 @@ import it.unimi.dsi.fastutil.io.{ FastByteArrayInputStream, FastByteArrayOutputS
import org.apache.avro.io.{ BinaryDecoder, DecoderFactory, BinaryEncoder, EncoderFactory }
import org.apache.avro.specific.{ SpecificDatumWriter, SpecificDatumReader, SpecificRecord }
import org.apache.spark.serializer.KryoRegistrator
import org.bdgenomics.adam.util.{ TwoBitFileSerializer, TwoBitFile }
import org.bdgenomics.formats.avro._
import org.bdgenomics.adam.models._
import org.bdgenomics.adam.rdd.read.realignment._
Expand Down Expand Up @@ -84,5 +85,6 @@ class ADAMKryoRegistrator extends KryoRegistrator {
kryo.register(classOf[IndelRealignmentTarget])
kryo.register(classOf[TargetSet], new TargetSetSerializer)
kryo.register(classOf[ZippedTargetSet], new ZippedTargetSetSerializer)
kryo.register(classOf[TwoBitFile], new TwoBitFileSerializer)
}
}
@@ -0,0 +1,75 @@
/**
* Licensed to Big Data Genomics (BDG) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The BDG licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.bdgenomics.adam.util

import org.apache.spark.rdd.RDD
// NOTE(ryan): this is necessary for Spark <= 1.2.1.
import org.apache.spark.SparkContext._
import org.bdgenomics.adam.models.ReferenceRegion
import org.bdgenomics.formats.avro.NucleotideContigFragment

case class ReferenceContigMap(contigMap: Map[String, Seq[NucleotideContigFragment]]) extends ReferenceFile {
/**
* Extract reference sequence from the file.
*
* @param region The desired ReferenceRegion to extract.
* @return The reference sequence at the desired locus.
*/
override def extract(region: ReferenceRegion): String = {
contigMap
.getOrElse(
region.referenceName,
throw new Exception(
s"Contig ${region.referenceName} not found in reference map with keys: ${contigMap.keys.toList.sortBy(x => x).mkString(", ")}"
)
)
.dropWhile(f => f.getFragmentStartPosition + f.getFragmentSequence.length < region.start)
.takeWhile(_.getFragmentStartPosition < region.end)
.map(
clipFragment(_, region.start, region.end)
)
.mkString("")
}

private def clipFragment(fragment: NucleotideContigFragment, start: Long, end: Long): String = {
val min =
math.max(
0L,
start - fragment.getFragmentStartPosition
).toInt

val max =
math.min(
fragment.getFragmentSequence.length,
end - fragment.getFragmentStartPosition
).toInt

fragment.getFragmentSequence.substring(min, max)
}
}

object ReferenceContigMap {
def apply(fragments: RDD[NucleotideContigFragment]): ReferenceContigMap =
ReferenceContigMap(
fragments
.groupBy(_.getContig.getContigName)
.mapValues(_.toSeq.sortBy(_.getFragmentStartPosition))
.collectAsMap
.toMap
)
}
26 changes: 24 additions & 2 deletions adam-core/src/main/scala/org/bdgenomics/adam/util/TwoBitFile.scala
Expand Up @@ -19,7 +19,9 @@
package org.bdgenomics.adam.util

import java.nio.{ ByteOrder, ByteBuffer }
import org.bdgenomics.utils.io.ByteAccess
import com.esotericsoftware.kryo.io.{ Output, Input }
import com.esotericsoftware.kryo.{ Kryo, Serializer }
import org.bdgenomics.utils.io.{ ByteArrayByteAccess, ByteAccess }
import org.bdgenomics.adam.models.ReferenceRegion

object TwoBitFile {
Expand Down Expand Up @@ -97,7 +99,13 @@ class TwoBitFile(byteAccess: ByteAccess) extends ReferenceFile {
* @return The reference sequence at the desired locus.
*/
def extract(region: ReferenceRegion): String = {
val record = seqRecords(region.referenceName)
val record =
seqRecords.getOrElse(
region.referenceName,
throw new Exception(
s"Contig ${region.referenceName} not found in reference map with keys: ${seqRecords.keys.toList.sortBy(x => x).mkString(", ")}"
)
)
val contigLength = record.dnaSize
assert(region.start >= 0)
assert(region.end <= contigLength.toLong)
Expand Down Expand Up @@ -125,6 +133,20 @@ class TwoBitFile(byteAccess: ByteAccess) extends ReferenceFile {
}
}

class TwoBitFileSerializer extends Serializer[TwoBitFile] {
override def write(kryo: Kryo, output: Output, obj: TwoBitFile): Unit = {
val arr = obj.bytes.array()
output.writeInt(arr.length)
output.write(arr)
}

override def read(kryo: Kryo, input: Input, klazz: Class[TwoBitFile]): TwoBitFile = {
val length = input.readInt()
val bytes = input.readBytes(length)
new TwoBitFile(new ByteArrayByteAccess(bytes))
}
}

object TwoBitRecord {
def apply(twoBitBytes: ByteBuffer, name: String, seqRecordStart: Int): TwoBitRecord = {
val dnaSize = twoBitBytes.getInt(seqRecordStart)
Expand Down
Expand Up @@ -18,8 +18,9 @@
package org.bdgenomics.adam.rdd.read

import org.apache.spark.rdd.RDD
import org.bdgenomics.adam.util.ADAMFunSuite
import org.bdgenomics.adam.util.{ ReferenceContigMap, ADAMFunSuite }
import org.bdgenomics.formats.avro.{ AlignmentRecord, NucleotideContigFragment, Contig }
import org.bdgenomics.utils.misc.SparkFunSuite

class MDTaggingSuite extends ADAMFunSuite {
val chr1 =
Expand Down Expand Up @@ -75,8 +76,7 @@ class MDTaggingSuite extends ADAMFunSuite {
}

for (i <- List(1, 10)) {
check(MDTagging(reads, makeFrags(fs: _*), partitionSize = i))
check(MDTagging(reads, makeFrags(fs: _*), partitionSize = i, shuffle = true))
check(MDTagging(reads, ReferenceContigMap(makeFrags(fs: _*)), partitionSize = i))
}
}

Expand Down
Expand Up @@ -22,8 +22,10 @@ import org.bdgenomics.utils.misc.SparkFunSuite
trait ADAMFunSuite extends SparkFunSuite {

override val appName: String = "adam"
override val properties: Map[String, String] = Map(("spark.serializer", "org.apache.spark.serializer.KryoSerializer"),
override val properties: Map[String, String] = Map(
("spark.serializer", "org.apache.spark.serializer.KryoSerializer"),
("spark.kryo.registrator", "org.bdgenomics.adam.serialization.ADAMKryoRegistrator"),
("spark.kryo.referenceTracking", "true"))
("spark.kryo.referenceTracking", "true")
)
}

0 comments on commit 1f797ee

Please sign in to comment.