diff --git a/twirp/asgi.py b/twirp/asgi.py index 89a9c3d..de97317 100644 --- a/twirp/asgi.py +++ b/twirp/asgi.py @@ -72,6 +72,7 @@ async def __call__(self, scope, receive, send): self._hook.response_prepared(ctx=ctx) body_bytes, headers = encoder(response_data) + headers = dict(ctx.get_response_headers(), **headers) # Todo: middleware await self._respond( send=send, diff --git a/twirp/context.py b/twirp/context.py index fb38524..6596f52 100644 --- a/twirp/context.py +++ b/twirp/context.py @@ -19,6 +19,7 @@ def __init__(self, *args, logger = None, headers = None): if headers is None: headers = {} self._headers = headers + self._response_headers = {} def set(self, key, value): """Set a Context value @@ -50,7 +51,7 @@ def set_logger(self, logger): self._logger = logger def get_headers(self): - """Get headers that are currently stored.""" + """Get request headers that are currently stored.""" return self._headers def set_header(self, key, value): @@ -61,3 +62,16 @@ def set_header(self, key, value): value: Value for the header. """ self._headers[key] = value + + def get_response_headers(self): + """Get response headers that are currently stored.""" + return self._response_headers + + def set_response_header(self, key, value): + """Set a response header + + Arguments: + key: Key for the header. + value: Value for the header. + """ + self._response_headers[key] = value