diff --git a/src/spin_sdk/http/__init__.py b/src/spin_sdk/http/__init__.py index 3292d8e..19d5a0b 100644 --- a/src/spin_sdk/http/__init__.py +++ b/src/spin_sdk/http/__init__.py @@ -4,7 +4,6 @@ import traceback from spin_sdk.http import poll_loop from spin_sdk.http.poll_loop import PollLoop, Sink, Stream -from spin_sdk.wit import exports from spin_sdk.wit.types import Ok, Err from spin_sdk.wit.imports.types import ( IncomingResponse, Method, Method_Get, Method_Head, Method_Post, Method_Put, Method_Delete, Method_Connect, Method_Options, @@ -32,96 +31,105 @@ class Response: headers: MutableMapping[str, str] body: Optional[bytes] -class IncomingHandler(exports.IncomingHandler): - """Simplified handler for incoming HTTP requests using blocking, buffered I/O.""" +try: + from spin_sdk.wit import exports + from spin_sdk.wit.exports import IncomingHandler as Base - def handle_request(self, request: Request) -> Response: - """Handle an incoming HTTP request and return a response or raise an error""" - raise NotImplementedError + class IncomingHandler(Base): + """Simplified handler for incoming HTTP requests using blocking, buffered I/O.""" + + def handle_request(self, request: Request) -> Response: + """Handle an incoming HTTP request and return a response or raise an error""" + raise NotImplementedError + + def handle(self, request: IncomingRequest, response_out: ResponseOutparam): + method = request.method() + + if isinstance(method, Method_Get): + method_str = "GET" + elif isinstance(method, Method_Head): + method_str = "HEAD" + elif isinstance(method, Method_Post): + method_str = "POST" + elif isinstance(method, Method_Put): + method_str = "PUT" + elif isinstance(method, Method_Delete): + method_str = "DELETE" + elif isinstance(method, Method_Connect): + method_str = "CONNECT" + elif isinstance(method, Method_Options): + method_str = "OPTIONS" + elif isinstance(method, Method_Trace): + method_str = "TRACE" + elif isinstance(method, Method_Patch): + method_str = "PATCH" + elif isinstance(method, Method_Other): + method_str = method.value + else: + raise AssertionError + + request_body = request.consume() + request_stream = request_body.stream() + body = bytearray() + while True: + try: + body += request_stream.blocking_read(16 * 1024) + except Err as e: + if isinstance(e.value, StreamError_Closed): + request_stream.__exit__() + IncomingBody.finish(request_body) + break + else: + raise e + + request_uri = request.path_with_query() + if request_uri is None: + uri = "/" + else: + uri = request_uri - def handle(self, request: IncomingRequest, response_out: ResponseOutparam): - method = request.method() - - if isinstance(method, Method_Get): - method_str = "GET" - elif isinstance(method, Method_Head): - method_str = "HEAD" - elif isinstance(method, Method_Post): - method_str = "POST" - elif isinstance(method, Method_Put): - method_str = "PUT" - elif isinstance(method, Method_Delete): - method_str = "DELETE" - elif isinstance(method, Method_Connect): - method_str = "CONNECT" - elif isinstance(method, Method_Options): - method_str = "OPTIONS" - elif isinstance(method, Method_Trace): - method_str = "TRACE" - elif isinstance(method, Method_Patch): - method_str = "PATCH" - elif isinstance(method, Method_Other): - method_str = method.value - else: - raise AssertionError - - request_body = request.consume() - request_stream = request_body.stream() - body = bytearray() - while True: try: - body += request_stream.blocking_read(16 * 1024) - except Err as e: - if isinstance(e.value, StreamError_Closed): - request_stream.__exit__() - IncomingBody.finish(request_body) - break - else: - raise e - - request_uri = request.path_with_query() - if request_uri is None: - uri = "/" - else: - uri = request_uri - - try: - simple_response = self.handle_request(Request( - method_str, - uri, - dict(map(lambda pair: (pair[0], str(pair[1], "utf-8")), request.headers().entries())), - bytes(body) - )) - except: - traceback.print_exc() - - response = OutgoingResponse(Fields()) - response.set_status_code(500) - ResponseOutparam.set(response_out, Ok(response)) - return - - if simple_response.headers.get('content-length') is None: - content_length = len(simple_response.body) if simple_response.body is not None else 0 - simple_response.headers['content-length'] = str(content_length) - - response = OutgoingResponse(Fields.from_list(list(map( - lambda pair: (pair[0], bytes(pair[1], "utf-8")), - simple_response.headers.items() - )))) - response_body = response.body() - response.set_status_code(simple_response.status) - ResponseOutparam.set(response_out, Ok(response)) - response_stream = response_body.write() - if simple_response.body is not None: - MAX_BLOCKING_WRITE_SIZE = 4096 - offset = 0 - while offset < len(simple_response.body): - count = min(len(simple_response.body) - offset, MAX_BLOCKING_WRITE_SIZE) - response_stream.blocking_write_and_flush(simple_response.body[offset:offset+count]) - offset += count - response_stream.__exit__() - OutgoingBody.finish(response_body, None) + simple_response = self.handle_request(Request( + method_str, + uri, + dict(map(lambda pair: (pair[0], str(pair[1], "utf-8")), request.headers().entries())), + bytes(body) + )) + except: + traceback.print_exc() + + response = OutgoingResponse(Fields()) + response.set_status_code(500) + ResponseOutparam.set(response_out, Ok(response)) + return + if simple_response.headers.get('content-length') is None: + content_length = len(simple_response.body) if simple_response.body is not None else 0 + simple_response.headers['content-length'] = str(content_length) + + response = OutgoingResponse(Fields.from_list(list(map( + lambda pair: (pair[0], bytes(pair[1], "utf-8")), + simple_response.headers.items() + )))) + response_body = response.body() + response.set_status_code(simple_response.status) + ResponseOutparam.set(response_out, Ok(response)) + response_stream = response_body.write() + if simple_response.body is not None: + MAX_BLOCKING_WRITE_SIZE = 4096 + offset = 0 + while offset < len(simple_response.body): + count = min(len(simple_response.body) - offset, MAX_BLOCKING_WRITE_SIZE) + response_stream.blocking_write_and_flush(simple_response.body[offset:offset+count]) + offset += count + response_stream.__exit__() + OutgoingBody.finish(response_body, None) + +except ImportError: + # `spin_sdk.wit.exports` won't exist if the use is targeting `spin-imports`, + # so just skip this part + pass + def send(request: Request) -> Response: """Send an HTTP request and return a response or raise an error""" loop = PollLoop()