Skip to content

Commit

Permalink
Merge pull request #2012 from natsukagami/async-curl-backends
Browse files Browse the repository at this point in the history
Make AbstractCurlBackend async friendly
  • Loading branch information
adamw committed Nov 30, 2023
2 parents 263ffb3 + 6f99ee5 commit 7c3b815
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 27 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import sttp.monad.MonadError
import sttp.monad.syntax._

import scala.collection.immutable.Seq
import scala.collection.mutable.ArrayBuffer
import scala.io.Source
import scala.scalanative.libc.stdio.{fclose, fopen, FILE}
import scala.scalanative.libc.stdlib._
Expand All @@ -25,33 +26,68 @@ import scala.scalanative.unsigned._
abstract class AbstractCurlBackend[F[_]](_monad: MonadError[F], verbose: Boolean) extends GenericBackend[F, Any] {
override implicit def monad: MonadError[F] = _monad

/** Given a [[CurlHandle]], perform the request and return a [[CurlCode]]. */
protected def performCurl(c: CurlHandle): F[CurlCode]

/** Same as [[performCurl]], but also checks and throws runtime exceptions on bad [[CurlCode]]s. */
private final def perform(c: CurlHandle) = performCurl(c).flatMap(lift)

type R = Any with Effect[F]

override def close(): F[Unit] = monad.unit(())

private var headers: CurlList = _
private var multiPartHeaders: Seq[CurlList] = Seq()
/** A request-specific context, with allocated zones and headers. */
private class Context() {
implicit val zone: Zone = Zone.open()
private val headers = ArrayBuffer[CurlList]()

/** Create a new Headers list that gets cleaned up when the context is destroyed. */
def transformHeaders(reqHeaders: Iterable[Header]): CurlList = {
val h = reqHeaders
.map(header => s"${header.name}: ${header.value}")
.foldLeft(new CurlList(null)) { case (acc, h) =>
new CurlList(acc.ptr.append(h))
}
headers += h
h
}

def close() = {
zone.close()
headers.foreach(l => if (l.ptr != null) l.ptr.free())
}
}

private object Context {

/** Create a new context and evaluates the body with it. Closes the context at the end. */
def evaluateUsing[T](body: Context => F[T]): F[T] = {
implicit val ctx = new Context()
body(ctx).ensure(monad.unit(ctx.close()))
}
}

override def send[T](request: GenericRequest[T, R]): F[Response[T]] =
adjustExceptions(request) {
unsafe.Zone { implicit z =>
def perform(implicit ctx: Context): F[Response[T]] = {
implicit val z = ctx.zone
val curl = CurlApi.init
if (verbose) {
curl.option(Verbose, parameter = true)
}
if (request.tags.nonEmpty) {
monad.error(new UnsupportedOperationException("Tags are not supported"))
return monad.error(new UnsupportedOperationException("Tags are not supported"))
}
val reqHeaders = request.headers
if (reqHeaders.nonEmpty) {
reqHeaders.find(_.name == "Accept-Encoding").foreach(h => curl.option(AcceptEncoding, h.value))
request.body match {
val headers = request.body match {
case _: MultipartBody[_] =>
headers = transformHeaders(
ctx.transformHeaders(
reqHeaders :+ Header.contentType(MediaType.MultipartFormData)
)
case _ =>
headers = transformHeaders(reqHeaders)
ctx.transformHeaders(reqHeaders)
}
curl.option(HttpHeader, headers.ptr)
}
Expand All @@ -62,6 +98,8 @@ abstract class AbstractCurlBackend[F[_]](_monad: MonadError[F], verbose: Boolean
case None => handleBase(request, curl, spaces)
}
}

Context.evaluateUsing(ctx => perform(ctx))
}

private def adjustExceptions[T](request: GenericRequest[_, _])(t: => F[T]): F[T] =
Expand All @@ -70,22 +108,21 @@ abstract class AbstractCurlBackend[F[_]](_monad: MonadError[F], verbose: Boolean
)

private def handleBase[T](request: GenericRequest[T, R], curl: CurlHandle, spaces: CurlSpaces)(implicit
z: unsafe.Zone
ctx: Context
) = {
implicit val z = ctx.zone
curl.option(WriteFunction, AbstractCurlBackend.wdFunc)
curl.option(WriteData, spaces.bodyResp)
curl.option(TimeoutMs, request.options.readTimeout.toMillis)
curl.option(HeaderData, spaces.headersResp)
curl.option(Url, request.uri.toString)
setMethod(curl, request.method)
setRequestBody(curl, request.body)
monad.flatMap(lift(curl.perform)) { _ =>
monad.flatMap(perform(curl)) { _ =>
curl.info(ResponseCode, spaces.httpCode)
val responseBody = fromCString((!spaces.bodyResp)._1)
val responseHeaders_ = parseHeaders(fromCString((!spaces.headersResp)._1))
val httpCode = StatusCode((!spaces.httpCode).toInt)
if (headers.ptr != null) headers.ptr.free()
multiPartHeaders.foreach(_.ptr.free())
free((!spaces.bodyResp)._1)
free((!spaces.headersResp)._1)
free(spaces.bodyResp.asInstanceOf[Ptr[CSignedChar]])
Expand All @@ -112,19 +149,18 @@ abstract class AbstractCurlBackend[F[_]](_monad: MonadError[F], verbose: Boolean
}

private def handleFile[T](request: GenericRequest[T, R], curl: CurlHandle, file: SttpFile, spaces: CurlSpaces)(
implicit z: unsafe.Zone
implicit ctx: Context
) = {
implicit val z = ctx.zone
val outputPath = file.toPath.toString
val outputFilePtr: Ptr[FILE] = fopen(toCString(outputPath), toCString("wb"))
curl.option(WriteData, outputFilePtr)
curl.option(Url, request.uri.toString)
setMethod(curl, request.method)
setRequestBody(curl, request.body)
monad.flatMap(lift(curl.perform)) { _ =>
monad.flatMap(perform(curl)) { _ =>
curl.info(ResponseCode, spaces.httpCode)
val httpCode = StatusCode((!spaces.httpCode).toInt)
if (headers.ptr != null) headers.ptr.free()
multiPartHeaders.foreach(_.ptr.free())
free(spaces.httpCode.asInstanceOf[Ptr[CSignedChar]])
fclose(outputFilePtr)
curl.cleanup()
Expand Down Expand Up @@ -159,7 +195,10 @@ abstract class AbstractCurlBackend[F[_]](_monad: MonadError[F], verbose: Boolean
lift(m)
}

private def setRequestBody(curl: CurlHandle, body: GenericRequestBody[R])(implicit zone: Zone): F[CurlCode] =
private def setRequestBody(curl: CurlHandle, body: GenericRequestBody[R])(implicit
ctx: Context
): F[CurlCode] = {
implicit val z = ctx.zone
body match { // todo: assign to monad object
case b: BasicBodyPart =>
val str = basicBodyToString(b)
Expand All @@ -176,9 +215,8 @@ abstract class AbstractCurlBackend[F[_]](_monad: MonadError[F], verbose: Boolean

val otherHeaders = headers.filterNot(_.is(HeaderNames.ContentType))
if (otherHeaders.nonEmpty) {
val curlList = transformHeaders(otherHeaders)
val curlList = ctx.transformHeaders(otherHeaders)
part.withHeaders(curlList.ptr)
multiPartHeaders = multiPartHeaders :+ curlList
}
}
lift(curl.option(Mimepost, mime))
Expand All @@ -187,6 +225,7 @@ abstract class AbstractCurlBackend[F[_]](_monad: MonadError[F], verbose: Boolean
case NoBody =>
monad.unit(CurlCode.Ok)
}
}

private def basicBodyToString(body: BodyPart[_]): String =
body match {
Expand Down Expand Up @@ -253,13 +292,6 @@ abstract class AbstractCurlBackend[F[_]](_monad: MonadError[F], verbose: Boolean
override protected def cleanupWhenGotWebSocket(response: Nothing, e: GotAWebSocketException): F[Unit] = response
}

private def transformHeaders(reqHeaders: Iterable[Header])(implicit z: Zone): CurlList =
reqHeaders
.map(header => s"${header.name}: ${header.value}")
.foldLeft(new CurlList(null)) { case (acc, h) =>
new CurlList(acc.ptr.append(h))
}

private def toByteArray(str: String): F[Array[Byte]] = monad.unit(str.getBytes)

private def lift(code: CurlCode): F[CurlCode] =
Expand All @@ -269,6 +301,12 @@ abstract class AbstractCurlBackend[F[_]](_monad: MonadError[F], verbose: Boolean
}
}

/** Curl backends that performs the curl operation with a simple `curl_easy_perform`. */
abstract class AbstractSyncCurlBackend[F[_]](_monad: MonadError[F], verbose: Boolean)
extends AbstractCurlBackend[F](_monad, verbose) {
override def performCurl(c: CurlHandle): F[CurlCode.CurlCode] = monad.unit(c.perform)
}

object AbstractCurlBackend {
val wdFunc: CFuncPtr4[Ptr[Byte], CSize, CSize, Ptr[CurlFetch], CSize] = {
(ptr: Ptr[CChar], size: CSize, nmemb: CSize, data: Ptr[CurlFetch]) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,13 @@ import scala.util.Try

// Curl supports redirects, but it doesn't store the history, so using FollowRedirectsBackend is more convenient

private class CurlBackend(verbose: Boolean) extends AbstractCurlBackend(IdMonad, verbose) with SyncBackend {}
private class CurlBackend(verbose: Boolean) extends AbstractSyncCurlBackend(IdMonad, verbose) with SyncBackend {}

object CurlBackend {
def apply(verbose: Boolean = false): SyncBackend = FollowRedirectsBackend(new CurlBackend(verbose))
}

private class CurlTryBackend(verbose: Boolean) extends AbstractCurlBackend(TryMonad, verbose) with Backend[Try] {}
private class CurlTryBackend(verbose: Boolean) extends AbstractSyncCurlBackend(TryMonad, verbose) with Backend[Try] {}

object CurlTryBackend {
def apply(verbose: Boolean = false): Backend[Try] = FollowRedirectsBackend(new CurlTryBackend(verbose))
Expand Down

0 comments on commit 7c3b815

Please sign in to comment.