Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement workaround for initialization order errors in JVM lambdas #448

Merged
merged 3 commits into from
Jan 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ lazy val lambda = crossProject(JSPlatform, JVMPlatform)
)
)
.jvmSettings(
Test / fork := true,
libraryDependencies ++= Seq(
"com.amazonaws" % "aws-lambda-java-core" % "1.2.3",
"co.fs2" %%% "fs2-io" % fs2Version
Expand Down
25 changes: 24 additions & 1 deletion lambda/jvm/src/main/scala/feral/lambda/IOLambdaPlatform.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,11 @@

package feral.lambda

import cats.effect.Async
import cats.effect.IO
import cats.effect.Resource
import cats.effect.std.Dispatcher
import cats.effect.syntax.all._
import cats.syntax.all._
import com.amazonaws.services.lambda.{runtime => lambdaRuntime}
import io.circe.Printer
Expand All @@ -29,11 +32,31 @@ import java.io.OutputStream
import java.io.OutputStreamWriter
import java.nio.channels.Channels
import scala.concurrent.duration._
import scala.util.control.NonFatal

private[lambda] abstract class IOLambdaPlatform[Event, Result]
extends lambdaRuntime.RequestStreamHandler { this: IOLambda[Event, Result] =>

private[this] val (dispatcher, handle) = {
val handler = {
val h =
try this.handler
catch { case ex if NonFatal(ex) => null }

if (h ne null) {
h.map(IO.pure(_))
} else {
val lambdaName = getClass().getSimpleName()
val msg =
s"""|There was an error initializing `$lambdaName` during startup.
|Falling back to initialize-during-first-invocation strategy.
|To fix, try replacing any `val`s in `$lambdaName` with `def`s.""".stripMargin
System.err.println(msg)

Async[Resource[IO, *]].defer(this.handler).memoize.map(_.allocated.map(_._1))
}
}

Dispatcher
.parallel[IO](await = false)
.product(handler)
Expand All @@ -50,7 +73,7 @@ private[lambda] abstract class IOLambdaPlatform[Event, Result]
val context = Context.fromJava[IO](runtimeContext)
dispatcher
.unsafeRunTimed(
handle(Invocation.pure(event, context)),
handle.flatMap(_(Invocation.pure(event, context))),
runtimeContext.getRemainingTimeInMillis().millis
)
.foreach { result =>
Expand Down
65 changes: 44 additions & 21 deletions lambda/jvm/src/test/scala/feral/lambda/IOLambdaJvmSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,19 @@ import java.util.concurrent.atomic.AtomicInteger

class IOLambdaJvmSuite extends FunSuite {

test("initializes handler once") {
implicit class HandleOps[A, B](lambda: IOLambda[A, B]) {
def handleRequestHelper(in: String): String = {
val os = new ByteArrayOutputStream
lambda.handleRequest(
new ByteArrayInputStream(in.getBytes()),
os,
DummyContext
)
new String(os.toByteArray())
}
}

test("initializes handler once during construction") {

val allocationCounter = new AtomicInteger
val invokeCounter = new AtomicInteger
Expand All @@ -41,18 +53,12 @@ class IOLambdaJvmSuite extends FunSuite {
.as(_.event.map(Some(_)) <* IO(invokeCounter.getAndIncrement()))
}

assertEquals(allocationCounter.get(), 1)

val chars = 'A' to 'Z'
chars.foreach { c =>
val os = new ByteArrayOutputStream

val json = s""""$c""""
lambda.handleRequest(
new ByteArrayInputStream(json.getBytes()),
os,
DummyContext
)

assertEquals(new String(os.toByteArray()), json)
assertEquals(lambda.handleRequestHelper(json), json)
}

assertEquals(allocationCounter.get(), 1)
Expand All @@ -68,21 +74,38 @@ class IOLambdaJvmSuite extends FunSuite {
def handler = Resource.pure(_ => IO(Some(output)))
}

val os = new ByteArrayOutputStream

lambda.handleRequest(
new ByteArrayInputStream(input.toString.getBytes()),
os,
DummyContext
)

assertEquals(
jawn.parseByteArray(os.toByteArray()),
Right(output),
new String(os.toByteArray())
jawn.parse(lambda.handleRequestHelper(input.noSpaces)),
Right(output)
)
}

test("gracefully handles broken initialization due to `val`") {

def go(mkLambda: AtomicInteger => IOLambda[Unit, Unit]): Unit = {
val counter = new AtomicInteger
val lambda = mkLambda(counter)
assertEquals(counter.get(), 0) // init failed
lambda.handleRequestHelper("{}")
assertEquals(counter.get(), 1) // inited
lambda.handleRequestHelper("{}")
assertEquals(counter.get(), 1) // did not re-init
}

go { counter =>
new IOLambda[Unit, Unit] {
val handler = Resource.eval(IO(counter.getAndIncrement())).as(_ => IO(None))
}
}

go { counter =>
new IOLambda[Unit, Unit] {
def handler = resource.as(_ => IO(None))
val resource = Resource.eval(IO(counter.getAndIncrement()))
}
}
}

object DummyContext extends runtime.Context {
override def getAwsRequestId(): String = ""
override def getLogGroupName(): String = ""
Expand Down