Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[main] To refactor BodyPartParser of Multipart in order to avoid StackOverflowError easily (backport #11360) by @yousuketto #11469

Merged
merged 1 commit into from
Oct 13, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
63 changes: 38 additions & 25 deletions core/play/src/main/scala/play/core/parsers/Multipart.scala
Original file line number Diff line number Diff line change
Expand Up @@ -326,8 +326,15 @@ object Multipart {
"'boundary' parameter of multipart Content-Type must not end with a space char"
)

// phantom type for ensuring soundness of our parsing method setup
sealed trait StateResult
sealed trait Done extends StateResult
case object Done extends Done
class ContinueParsing(parse: => StateResult) extends StateResult {
def apply(): StateResult = parse
}
object ContinueParsing {
def apply(parse: => StateResult): ContinueParsing = new ContinueParsing(parse)
}

private[this] val needle: Array[Byte] = {
val array = new Array[Byte](boundary.length + 4)
Expand All @@ -350,9 +357,9 @@ object Multipart {

override def createLogic(attributes: Attributes): GraphStageLogic =
new GraphStageLogic(shape) with InHandler with OutHandler {
private var output = collection.immutable.Queue.empty[RawPart]
private var state: ByteString => StateResult = tryParseInitialBoundary
private var terminated = false
private var output = collection.immutable.Queue.empty[RawPart]
private var state: ByteString => Done = tryParseInitialBoundary
private var terminated = false

override def onPush(): Unit = {
if (!terminated) {
Expand All @@ -378,13 +385,23 @@ object Multipart {

setHandlers(in, out, this)

def tryParseInitialBoundary(input: ByteString): StateResult = {
@tailrec
def handleParsingState(parse: ContinueParsing): Done = {
parse() match {
case Done => Done
case c: ContinueParsing => handleParsingState(c)
}
}

def handleParsingState(exec: => StateResult): Done = handleParsingState(ContinueParsing(exec))

def tryParseInitialBoundary(input: ByteString): Done = {
// we don't use boyerMoore here because we are testing for the boundary *without* a
// preceding CRLF and at a known location (the very beginning of the entity)
try {
if (boundary(input, 0)) {
val ix = boundaryLength
if (crlf(input, ix)) parseHeader(input, ix + 2, 0)
if (crlf(input, ix)) handleParsingState(parseHeader(input, ix + 2, 0))
else if (doubleDash(input, ix)) terminate()
else parsePreamble(input, 0)
} else parsePreamble(input, 0)
Expand All @@ -393,11 +410,11 @@ object Multipart {
}
}

def parsePreamble(input: ByteString, offset: Int): StateResult = {
def parsePreamble(input: ByteString, offset: Int): Done = {
try {
@tailrec def rec(index: Int): StateResult = {
@tailrec def rec(index: Int): Done = {
val needleEnd = boyerMoore.nextIndex(input, index) + needle.length
if (crlf(input, needleEnd)) parseHeader(input, needleEnd + 2, 0)
if (crlf(input, needleEnd)) handleParsingState(parseHeader(input, needleEnd + 2, 0))
else if (doubleDash(input, needleEnd)) terminate()
else rec(needleEnd)
}
Expand Down Expand Up @@ -488,7 +505,8 @@ object Multipart {
} else {
// There was not even enough space in the input to contain the needle. Only after we have enough data
// of at least the size of the needle we can decide if the body is empty or not.
state = more => checkEmptyBody(input ++ more, partStart, memoryBufferSize)(nonEmpty)(empty)
state = more =>
handleParsingState(checkEmptyBody(input ++ more, partStart, memoryBufferSize)(nonEmpty)(empty))
done()
}
}
Expand Down Expand Up @@ -517,7 +535,7 @@ object Multipart {
val needleEnd = currentPartEnd + needle.length
if (crlf(input, needleEnd)) {
emit(input.slice(offset, currentPartEnd))
parseHeader(input, needleEnd + 2, memoryBufferSize)
ContinueParsing(parseHeader(input, needleEnd + 2, memoryBufferSize))
} else if (doubleDash(input, needleEnd)) {
emit(input.slice(offset, currentPartEnd))
terminate()
Expand Down Expand Up @@ -546,7 +564,7 @@ object Multipart {
bufferExceeded("Memory buffer full on part " + partName)
} else if (crlf(input, needleEnd)) {
emit(DataPart(partName, input.slice(partStart, currentPartEnd).utf8String))
parseHeader(input, needleEnd + 2, newMemoryBufferSize)
ContinueParsing(parseHeader(input, needleEnd + 2, newMemoryBufferSize))
} else if (doubleDash(input, needleEnd)) {
emit(DataPart(partName, input.slice(partStart, currentPartEnd).utf8String))
terminate()
Expand All @@ -573,7 +591,7 @@ object Multipart {
val needleEnd = currentPartEnd + needle.length
if (crlf(input, needleEnd)) {
emit(BadPart(headers))
parseHeader(input, needleEnd + 2, memoryBufferSize)
ContinueParsing(parseHeader(input, needleEnd + 2, memoryBufferSize))
} else if (doubleDash(input, needleEnd)) {
emit(BadPart(headers))
terminate()
Expand All @@ -600,36 +618,31 @@ object Multipart {
head
}

def continue(input: ByteString, offset: Int)(next: (ByteString, Int) => StateResult): StateResult = {
def continue(input: ByteString, offset: Int)(next: (ByteString, Int) => StateResult): Done = {
state = math.signum(offset - input.length) match {
case -1 => more => next(input ++ more, offset)
case 0 => next(_, 0)
case -1 => more => handleParsingState(next(input ++ more, offset))
case 0 => more => handleParsingState(next(more, 0))
case 1 => throw new IllegalStateException
}
done()
}

def continue(next: (ByteString, Int) => StateResult): StateResult = {
state = next(_, 0)
done()
}

def bufferExceeded(message: String): StateResult = {
def bufferExceeded(message: String): Done = {
emit(MaxMemoryBufferExceeded(message))
terminate()
}

def fail(message: String): StateResult = {
def fail(message: String): Done = {
emit(ParseError(message))
terminate()
}

def terminate(): StateResult = {
def terminate(): Done = {
terminated = true
done()
}

def done(): StateResult = null // StateResult is a phantom type
def done(): Done = Done

// the length of the needle without the preceding CRLF
def boundaryLength: Int = needle.length - 2
Expand Down