Skip to content

Commit

Permalink
Merge pull request #448 from armanbilge/fix/init-npe
Browse files Browse the repository at this point in the history
Implement workaround for initialization order errors in JVM lambdas
  • Loading branch information
armanbilge committed Jan 17, 2024
2 parents c88a094 + b6463d6 commit fa51e2f
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 22 deletions.
1 change: 1 addition & 0 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ lazy val lambda = crossProject(JSPlatform, JVMPlatform)
)
)
.jvmSettings(
Test / fork := true,
libraryDependencies ++= Seq(
"com.amazonaws" % "aws-lambda-java-core" % "1.2.3",
"co.fs2" %%% "fs2-io" % fs2Version
Expand Down
25 changes: 24 additions & 1 deletion lambda/jvm/src/main/scala/feral/lambda/IOLambdaPlatform.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,11 @@

package feral.lambda

import cats.effect.Async
import cats.effect.IO
import cats.effect.Resource
import cats.effect.std.Dispatcher
import cats.effect.syntax.all._
import cats.syntax.all._
import com.amazonaws.services.lambda.{runtime => lambdaRuntime}
import io.circe.Printer
Expand All @@ -29,11 +32,31 @@ import java.io.OutputStream
import java.io.OutputStreamWriter
import java.nio.channels.Channels
import scala.concurrent.duration._
import scala.util.control.NonFatal

private[lambda] abstract class IOLambdaPlatform[Event, Result]
extends lambdaRuntime.RequestStreamHandler { this: IOLambda[Event, Result] =>

private[this] val (dispatcher, handle) = {
val handler = {
val h =
try this.handler
catch { case ex if NonFatal(ex) => null }

if (h ne null) {
h.map(IO.pure(_))
} else {
val lambdaName = getClass().getSimpleName()
val msg =
s"""|There was an error initializing `$lambdaName` during startup.
|Falling back to initialize-during-first-invocation strategy.
|To fix, try replacing any `val`s in `$lambdaName` with `def`s.""".stripMargin
System.err.println(msg)

Async[Resource[IO, *]].defer(this.handler).memoize.map(_.allocated.map(_._1))
}
}

Dispatcher
.parallel[IO](await = false)
.product(handler)
Expand All @@ -50,7 +73,7 @@ private[lambda] abstract class IOLambdaPlatform[Event, Result]
val context = Context.fromJava[IO](runtimeContext)
dispatcher
.unsafeRunTimed(
handle(Invocation.pure(event, context)),
handle.flatMap(_(Invocation.pure(event, context))),
runtimeContext.getRemainingTimeInMillis().millis
)
.foreach { result =>
Expand Down
65 changes: 44 additions & 21 deletions lambda/jvm/src/test/scala/feral/lambda/IOLambdaJvmSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,19 @@ import java.util.concurrent.atomic.AtomicInteger

class IOLambdaJvmSuite extends FunSuite {

test("initializes handler once") {
implicit class HandleOps[A, B](lambda: IOLambda[A, B]) {
def handleRequestHelper(in: String): String = {
val os = new ByteArrayOutputStream
lambda.handleRequest(
new ByteArrayInputStream(in.getBytes()),
os,
DummyContext
)
new String(os.toByteArray())
}
}

test("initializes handler once during construction") {

val allocationCounter = new AtomicInteger
val invokeCounter = new AtomicInteger
Expand All @@ -41,18 +53,12 @@ class IOLambdaJvmSuite extends FunSuite {
.as(_.event.map(Some(_)) <* IO(invokeCounter.getAndIncrement()))
}

assertEquals(allocationCounter.get(), 1)

val chars = 'A' to 'Z'
chars.foreach { c =>
val os = new ByteArrayOutputStream

val json = s""""$c""""
lambda.handleRequest(
new ByteArrayInputStream(json.getBytes()),
os,
DummyContext
)

assertEquals(new String(os.toByteArray()), json)
assertEquals(lambda.handleRequestHelper(json), json)
}

assertEquals(allocationCounter.get(), 1)
Expand All @@ -68,21 +74,38 @@ class IOLambdaJvmSuite extends FunSuite {
def handler = Resource.pure(_ => IO(Some(output)))
}

val os = new ByteArrayOutputStream

lambda.handleRequest(
new ByteArrayInputStream(input.toString.getBytes()),
os,
DummyContext
)

assertEquals(
jawn.parseByteArray(os.toByteArray()),
Right(output),
new String(os.toByteArray())
jawn.parse(lambda.handleRequestHelper(input.noSpaces)),
Right(output)
)
}

test("gracefully handles broken initialization due to `val`") {

def go(mkLambda: AtomicInteger => IOLambda[Unit, Unit]): Unit = {
val counter = new AtomicInteger
val lambda = mkLambda(counter)
assertEquals(counter.get(), 0) // init failed
lambda.handleRequestHelper("{}")
assertEquals(counter.get(), 1) // inited
lambda.handleRequestHelper("{}")
assertEquals(counter.get(), 1) // did not re-init
}

go { counter =>
new IOLambda[Unit, Unit] {
val handler = Resource.eval(IO(counter.getAndIncrement())).as(_ => IO(None))
}
}

go { counter =>
new IOLambda[Unit, Unit] {
def handler = resource.as(_ => IO(None))
val resource = Resource.eval(IO(counter.getAndIncrement()))
}
}
}

object DummyContext extends runtime.Context {
override def getAwsRequestId(): String = ""
override def getLogGroupName(): String = ""
Expand Down

0 comments on commit fa51e2f

Please sign in to comment.