Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Allow global level Response headers #410

Merged
merged 2 commits into from
Feb 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 18 additions & 2 deletions docs/features.md
Original file line number Diff line number Diff line change
Expand Up @@ -171,17 +171,25 @@ async def hello(request):
return "Hello World"
```

## Global Headers
## Global Request Headers

You can also add global headers for every request.

```python
app.add_request_header("server", "robyn")
```

## Global Response Headers

You can also add global response headers for every request.

```python
app.add_response_header("content-type", "application/json")
```

## Per route headers

You can also add headers for every route.
You can also add request and response headers for every route.

```python
@app.get("/request_headers")
Expand All @@ -194,6 +202,14 @@ async def request_headers():
}
```

```python
@app.get("/response_headers")
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should have similar endpoint for response_header as for request_header?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

async def response_headers():
return {
"headers": {"Header": "header_value"},
}
```

## Query Params

You can access query params from every HTTP method.
Expand Down
2 changes: 1 addition & 1 deletion integration_tests/base_routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -511,7 +511,7 @@ async def post(request):


if __name__ == "__main__":
app.add_request_header("server", "robyn")
app.add_response_header("server", "robyn")
app.add_directory(
route="/test_dir",
directory_path=os.path.join(current_file_path, "build"),
Expand Down
6 changes: 6 additions & 0 deletions integration_tests/test_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,12 @@ def test_add_request_header():
assert app.request_headers == [Header(key="server", val="robyn")]


def test_add_response_header():
app = Robyn(__file__)
app.add_response_header("content-type", "application/json")
assert app.response_headers == [Header(key="content-type", val="application/json")]


def test_lifecycle_handlers():
def mock_startup_handler():
pass
Expand Down
6 changes: 6 additions & 0 deletions robyn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def __init__(self, file_object: str) -> None:
self.middleware_router = MiddlewareRouter()
self.web_socket_router = WebSocketRouter()
self.request_headers: List[Header] = [] # This needs a better type
self.response_headers: List[Header] = [] # This needs a better type
self.directories: List[Directory] = []
self.event_handlers = {}
load_vars(project_root=directory_path)
Expand Down Expand Up @@ -83,6 +84,9 @@ def add_directory(
def add_request_header(self, key: str, value: str) -> None:
self.request_headers.append(Header(key, value))

def add_response_header(self, key: str, value: str) -> None:
self.response_headers.append(Header(key, value))

def add_web_socket(self, endpoint: str, ws: WS) -> None:
self.web_socket_router.add_route(endpoint, ws)

Expand Down Expand Up @@ -126,6 +130,7 @@ def start(self, url: str = "127.0.0.1", port: int = 8080):
self.event_handlers,
self.config.workers,
self.config.processes,
self.response_headers,
)
else:
event_handler = EventHandler(
Expand All @@ -139,6 +144,7 @@ def start(self, url: str = "127.0.0.1", port: int = 8080):
self.event_handlers,
self.config.workers,
self.config.processes,
self.response_headers,
)
event_handler.start_server()
logger.info(
Expand Down
3 changes: 3 additions & 0 deletions robyn/dev_event_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,13 @@ def __init__(
event_handlers: Dict[Events, FunctionInfo],
workers: int,
processes: int,
response_headers: List[Header],
) -> None:
self.url = url
self.port = port
self.directories = directories
self.request_headers = request_headers
self.response_headers = response_headers
self.routes = routes
self.middlewares = middlewares
self.web_sockets = web_sockets
Expand All @@ -48,6 +50,7 @@ def start_server(self):
self.event_handlers,
self.n_workers,
self.n_processes,
self.response_headers,
True,
)

Expand Down
9 changes: 9 additions & 0 deletions robyn/processpool.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def run_processes(
event_handlers: Dict[Events, FunctionInfo],
workers: int,
processes: int,
response_headers: List[Header],
from_event_handler: bool = False,
) -> List[Process]:
socket = SocketHeld(url, port)
Expand All @@ -37,6 +38,7 @@ def run_processes(
socket,
workers,
processes,
response_headers,
)

if not from_event_handler:
Expand Down Expand Up @@ -66,6 +68,7 @@ def init_processpool(
socket: SocketHeld,
workers: int,
processes: int,
response_headers: List[Header],
) -> List[Process]:
process_pool = []
if sys.platform.startswith("win32"):
Expand All @@ -78,6 +81,7 @@ def init_processpool(
event_handlers,
socket,
workers,
response_headers,
)

return process_pool
Expand All @@ -95,6 +99,7 @@ def init_processpool(
event_handlers,
copied_socket,
workers,
response_headers,
),
)
process.start()
Expand Down Expand Up @@ -128,6 +133,7 @@ def spawn_process(
event_handlers: Dict[Events, FunctionInfo],
socket: SocketHeld,
workers: int,
response_headers: List[Header],
):
"""
This function is called by the main process handler to create a server runtime.
Expand Down Expand Up @@ -156,6 +162,9 @@ def spawn_process(
for header in request_headers:
server.add_request_header(*header.as_list())

for header in response_headers:
server.add_response_header(*header.as_list())

for route in routes:
route_type, endpoint, function, is_const = route
server.add_route(route_type, endpoint, function, is_const)
Expand Down
2 changes: 2 additions & 0 deletions robyn/robyn.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ class Server:
pass
def add_request_header(self, key: str, value: str) -> None:
pass
def add_response_header(self, key: str, value: str) -> None:
pass
def add_route(
self,
route_type: str,
Expand Down
27 changes: 24 additions & 3 deletions src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ pub struct Server {
websocket_router: Arc<WebSocketRouter>,
middleware_router: Arc<MiddlewareRouter>,
global_request_headers: Arc<DashMap<String, String>>,
global_response_headers: Arc<DashMap<String, String>>,
directories: Arc<RwLock<Vec<Directory>>>,
startup_handler: Option<Arc<FunctionInfo>>,
shutdown_handler: Option<Arc<FunctionInfo>>,
Expand All @@ -63,6 +64,7 @@ impl Server {
websocket_router: Arc::new(WebSocketRouter::new()),
middleware_router: Arc::new(MiddlewareRouter::new()),
global_request_headers: Arc::new(DashMap::new()),
global_response_headers: Arc::new(DashMap::new()),
directories: Arc::new(RwLock::new(Vec::new())),
startup_handler: None,
shutdown_handler: None,
Expand Down Expand Up @@ -92,6 +94,7 @@ impl Server {
let middleware_router = self.middleware_router.clone();
let web_socket_router = self.websocket_router.clone();
let global_request_headers = self.global_request_headers.clone();
let global_response_headers = self.global_response_headers.clone();
let directories = self.directories.clone();
let workers = Arc::new(workers);

Expand Down Expand Up @@ -145,7 +148,8 @@ impl Server {
.app_data(web::Data::new(router.clone()))
.app_data(web::Data::new(const_router.clone()))
.app_data(web::Data::new(middleware_router.clone()))
.app_data(web::Data::new(global_request_headers.clone()));
.app_data(web::Data::new(global_request_headers.clone()))
.app_data(web::Data::new(global_response_headers.clone()));

let web_socket_map = web_socket_router.get_web_socket_map();
for (elem, value) in (web_socket_map.read().unwrap()).iter() {
Expand All @@ -165,6 +169,7 @@ impl Server {
const_router: web::Data<Arc<ConstRouter>>,
middleware_router: web::Data<Arc<MiddlewareRouter>>,
global_request_headers,
global_response_headers,
body,
req| {
pyo3_asyncio::tokio::scope_local(task_locals.clone(), async move {
Expand All @@ -173,6 +178,7 @@ impl Server {
const_router,
middleware_router,
global_request_headers,
global_response_headers,
body,
req,
)
Expand Down Expand Up @@ -223,19 +229,32 @@ impl Server {
});
}

/// Adds a new header to our concurrent hashmap
/// Adds a new request header to our concurrent hashmap
/// this can be called after the server has started.
pub fn add_request_header(&self, key: &str, value: &str) {
self.global_request_headers
.insert(key.to_string(), value.to_string());
}

/// Removes a new header to our concurrent hashmap
/// Adds a new response header to our concurrent hashmap
/// this can be called after the server has started.
pub fn add_response_header(&self, key: &str, value: &str) {
self.global_response_headers
.insert(key.to_string(), value.to_string());
}

/// Removes a new request header to our concurrent hashmap
/// this can be called after the server has started.
pub fn remove_header(&self, key: &str) {
self.global_request_headers.remove(key);
}

/// Removes a new response header to our concurrent hashmap
/// this can be called after the server has started.
pub fn remove_response_header(&self, key: &str) {
self.global_response_headers.remove(key);
}

/// Add a new route to the routing tables
/// can be called after the server has been started
pub fn add_route(
Expand Down Expand Up @@ -345,6 +364,7 @@ async fn index(
const_router: web::Data<Arc<ConstRouter>>,
middleware_router: web::Data<Arc<MiddlewareRouter>>,
global_request_headers: web::Data<Arc<Headers>>,
global_response_headers: web::Data<Arc<Headers>>,
body: Bytes,
req: HttpRequest,
) -> impl Responder {
Expand All @@ -360,6 +380,7 @@ async fn index(

let mut response_builder = HttpResponse::Ok();
apply_dashmap_headers(&mut response_builder, &global_request_headers);
apply_dashmap_headers(&mut response_builder, &global_response_headers);
apply_hashmap_headers(&mut response_builder, &request.headers);

let response = if let Some(r) = const_router.get_route(req.method(), req.uri().path()) {
Expand Down
1 change: 1 addition & 0 deletions src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,7 @@ pub struct Response {

#[pymethods]
impl Response {
Copy link
Contributor Author

@ParthS007 ParthS007 Feb 19, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sansyrox We need to make change here to check for response headers and create the Response accordingly?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ParthS007 , in the future yes. But would we need that for now too?

I suppose you are suggesting for things like json content?

// To do: Add check for content-type in header and change response_type accordingly
#[new]
pub fn new(status_code: u16, headers: HashMap<String, String>, body: &PyAny) -> PyResult<Self> {
Ok(Self {
Expand Down