diff --git a/dlls/http.sys/http.c b/dlls/http.sys/http.c index 6c71a5340695..4d2e10c90d9a 100644 --- a/dlls/http.sys/http.c +++ b/dlls/http.sys/http.c @@ -83,12 +83,21 @@ struct connection static struct list connections = LIST_INIT(connections); +struct listening_socket +{ + struct list entry; + unsigned short port; + SOCKET socket; +}; + +static struct list listening_sockets = LIST_INIT(listening_sockets); + struct url { struct list entry; char *url; HTTP_URL_CONTEXT context; - SOCKET socket; + struct listening_socket *listening_sock; }; struct request_queue @@ -326,30 +335,107 @@ static int parse_number(const char *str, const char **endptr, const char *end) return n; } -static struct url *host_matches(const struct connection *conn, const struct request_queue *queue) + +/* 0 means not a match, 1 and higher means a match, higher means more paths matched. */ +static unsigned int compare_paths(const char *queue_path, const char *conn_path, size_t conn_len) { - const char *conn_host = (conn->url[0] == '/') ? conn->host : conn->url + 7; - struct url *url; + const char *question_mark; + unsigned int i, cnt = 1; + size_t queue_len; + + queue_len = strlen(queue_path); + + if ((question_mark = memchr(conn_path, '?', conn_len))) + conn_len = question_mark - conn_path; + + if (queue_path[queue_len - 1] == '/') + queue_len--; + if (conn_path[conn_len - 1] == '/') + conn_len--; + + if (conn_len < queue_len) + return 0; + + for (i = 0; i < queue_len; ++i) + { + if (queue_path[i] != conn_path[i]) + return 0; + if (queue_path[i] == '/') + cnt++; + } + + if (queue_len == conn_len || conn_path[queue_len] == '/') + return cnt; + else + return 0; +} + +static BOOL host_matches(const struct url *url, const char *conn_host) +{ + size_t host_len; + + if (!url->url) + return FALSE; + + if (url->url[7] == '+') + { + const char *queue_port = strchr(url->url + 7, ':'); + host_len = strchr(queue_port, '/') - queue_port - 1; + if (!strncmp(queue_port, strchr(conn_host, ':'), host_len)) + return TRUE; + } + else + { + host_len = strchr(url->url + 7, '/') - url->url - 7; + if (!memicmp(url->url + 7, conn_host, host_len)) + return TRUE; + } + + return FALSE; +} + +static struct url *url_matches(const struct connection *conn, const struct request_queue *queue, + unsigned int *ret_slash_count) +{ + const char *queue_path, *conn_host, *conn_path; + unsigned int max_slash_count = 0, slash_count; + size_t conn_path_len; + struct url *url, *ret = NULL; + + if (conn->url[0] == '/') + { + conn_host = conn->host; + conn_path = conn->url; + conn_path_len = conn->url_len; + + } + else + { + conn_host = conn->url + 7; + conn_path = strchr(conn_host, '/'); + conn_path_len = (conn->url + conn->url_len) - conn_path; + } LIST_FOR_EACH_ENTRY(url, &queue->urls, struct url, entry) { - if (url->url) + if (host_matches(url, conn_host)) { - if (url->url && url->url[7] == '+') + queue_path = strchr(url->url + 7, '/'); + if (!queue_path) + continue; + slash_count = compare_paths(queue_path, conn_path, conn_path_len); + if (slash_count > max_slash_count) { - /* strip final slash */ - const char *queue_port = strchr(url->url + 7, ':'); - if (!strncmp(queue_port, strchr(conn_host, ':'), strlen(queue_port) - 1)) - return url; + max_slash_count = slash_count; + ret = url; } - - /* strip final slash */ - if (!memicmp(url->url + 7, conn_host, strlen(url->url) - 8)) - return url; } } - return NULL; + if (ret_slash_count) + *ret_slash_count = max_slash_count; + + return ret; } /* Upon receiving a request, parse it to ensure that it is a valid HTTP request, @@ -358,9 +444,10 @@ static struct url *host_matches(const struct connection *conn, const struct requ static int parse_request(struct connection *conn) { const char *const req = conn->buffer, *const end = conn->buffer + conn->len; - struct request_queue *queue; - struct url *conn_url; + struct request_queue *queue, *best_queue = NULL; + struct url *conn_url, *best_conn_url = NULL; const char *p = req, *q; + unsigned int slash_count, best_slash_count = 0; int len, ret; if (!conn->len) return 0; @@ -453,15 +540,24 @@ static int parse_request(struct connection *conn) /* Find a queue which can receive this request. */ LIST_FOR_EACH_ENTRY(queue, &request_queues, struct request_queue, entry) { - if ((conn_url = host_matches(conn, queue))) + if ((conn_url = url_matches(conn, queue, &slash_count))) { - TRACE("Assigning request to queue %p.\n", queue); - conn->queue = queue; - conn->context = conn_url->context; - break; + if (slash_count > best_slash_count) + { + best_slash_count = slash_count; + best_queue = queue; + best_conn_url = conn_url; + } } } + if (best_conn_url) + { + TRACE("Assigning request to queue %p.\n", best_queue); + conn->queue = best_queue; + conn->context = best_conn_url->context; + } + /* Stop selecting on incoming data until a response is queued. */ WSAEventSelect(conn->socket, request_event, FD_CLOSE); @@ -578,8 +674,8 @@ static DWORD WINAPI request_thread_proc(void *arg) { LIST_FOR_EACH_ENTRY(url, &queue->urls, struct url, entry) { - if (url->socket != -1) - accept_connection(url->socket); + if (url->listening_sock && url->listening_sock->socket != -1) + accept_connection(url->listening_sock->socket); } } @@ -596,17 +692,31 @@ static DWORD WINAPI request_thread_proc(void *arg) return 0; } +static struct listening_socket *get_listening_socket(unsigned short port) +{ + struct listening_socket *listening_sock; + + LIST_FOR_EACH_ENTRY(listening_sock, &listening_sockets, struct listening_socket, entry) + { + if (listening_sock->port == port) + return listening_sock; + } + + return NULL; +} + static NTSTATUS http_add_url(struct request_queue *queue, IRP *irp) { const struct http_add_url_params *params = irp->AssociatedIrp.SystemBuffer; + struct request_queue *queue_entry; struct sockaddr_in addr; struct connection *conn; struct url *url_entry, *new_entry; - unsigned int count = 0; + struct listening_socket *listening_sock; char *url, *endptr; + size_t queue_url_len, new_url_len; ULONG true = 1; - const char *p; - SOCKET s; + SOCKET s = INVALID_SOCKET; TRACE("host %s, context %s.\n", debugstr_a(params->url), wine_dbgstr_longlong(params->context)); @@ -615,8 +725,7 @@ static NTSTATUS http_add_url(struct request_queue *queue, IRP *irp) FIXME("HTTPS is not implemented.\n"); return STATUS_NOT_IMPLEMENTED; } - else if (strncmp(params->url, "http://", 7) || !strchr(params->url + 7, ':') - || params->url[strlen(params->url) - 1] != '/') + else if (strncmp(params->url, "http://", 7) || !strchr(params->url + 7, ':')) return STATUS_INVALID_PARAMETER; if (!(addr.sin_port = htons(strtol(strchr(params->url + 7, ':') + 1, &endptr, 10))) || *endptr != '/') return STATUS_INVALID_PARAMETER; @@ -625,82 +734,106 @@ static NTSTATUS http_add_url(struct request_queue *queue, IRP *irp) return STATUS_NO_MEMORY; strcpy(url, params->url); - for (p = url; *p; ++p) - if (*p == '/') ++count; - if (count > 3) - FIXME("Binding to relative URIs is not implemented; binding to all URIs instead.\n"); - if (!(new_entry = malloc(sizeof(struct url)))) { free(url); return STATUS_NO_MEMORY; } + new_url_len = strlen(url); + if (url[new_url_len - 1] == '/') + new_url_len--; + EnterCriticalSection(&http_cs); - LIST_FOR_EACH_ENTRY(url_entry, &queue->urls, struct url, entry) + LIST_FOR_EACH_ENTRY(queue_entry, &request_queues, struct request_queue, entry) { - if (url_entry->url && !strcmp(url_entry->url, url)) + LIST_FOR_EACH_ENTRY(url_entry, &queue_entry->urls, struct url, entry) { + queue_url_len = strlen(url_entry->url); + if (url_entry->url[queue_url_len - 1] == '/') + queue_url_len--; + + if (url_entry->url && queue_url_len == new_url_len && !memcmp(url_entry->url, url, queue_url_len)) + { + LeaveCriticalSection(&http_cs); + free(url); + free(new_entry); + return STATUS_OBJECT_NAME_COLLISION; + } + } + } + + listening_sock = get_listening_socket(addr.sin_port); + + if (!listening_sock) + { + if ((s = socket(AF_INET, SOCK_STREAM, 0)) == INVALID_SOCKET) + { + ERR("Failed to create socket, error %u.\n", WSAGetLastError()); LeaveCriticalSection(&http_cs); free(url); free(new_entry); - return STATUS_OBJECT_NAME_COLLISION; + return STATUS_UNSUCCESSFUL; } - } - if ((s = socket(AF_INET, SOCK_STREAM, 0)) == INVALID_SOCKET) - { - ERR("Failed to create socket, error %u.\n", WSAGetLastError()); - LeaveCriticalSection(&http_cs); - free(url); - free(new_entry); - return STATUS_UNSUCCESSFUL; - } + addr.sin_family = AF_INET; + addr.sin_addr.S_un.S_addr = INADDR_ANY; + if (bind(s, (struct sockaddr *)&addr, sizeof(addr)) == -1) + { + LeaveCriticalSection(&http_cs); + closesocket(s); + free(url); + free(new_entry); + if (WSAGetLastError() == WSAEADDRINUSE) + { + WARN("Address %s is already in use.\n", debugstr_a(params->url)); + return STATUS_SHARING_VIOLATION; + } + else if (WSAGetLastError() == WSAEACCES) + { + WARN("Not enough permissions to bind to address %s.\n", debugstr_a(params->url)); + return STATUS_ACCESS_DENIED; + } + ERR("Failed to bind socket, error %u.\n", WSAGetLastError()); + return STATUS_UNSUCCESSFUL; + } - addr.sin_family = AF_INET; - addr.sin_addr.S_un.S_addr = INADDR_ANY; - if (bind(s, (struct sockaddr *)&addr, sizeof(addr)) == -1) - { - LeaveCriticalSection(&http_cs); - closesocket(s); - free(url); - free(new_entry); - if (WSAGetLastError() == WSAEADDRINUSE) + if (listen(s, SOMAXCONN) == -1) { - WARN("Address %s is already in use.\n", debugstr_a(params->url)); - return STATUS_SHARING_VIOLATION; + ERR("Failed to listen to port %u, error %u.\n", addr.sin_port, WSAGetLastError()); + LeaveCriticalSection(&http_cs); + closesocket(s); + free(url); + free(new_entry); + return STATUS_OBJECT_NAME_COLLISION; } - else if (WSAGetLastError() == WSAEACCES) + + if (!(listening_sock = malloc(sizeof(struct listening_socket)))) { - WARN("Not enough permissions to bind to address %s.\n", debugstr_a(params->url)); - return STATUS_ACCESS_DENIED; + LeaveCriticalSection(&http_cs); + closesocket(s); + free(url); + free(new_entry); + return STATUS_NO_MEMORY; } - ERR("Failed to bind socket, error %u.\n", WSAGetLastError()); - return STATUS_UNSUCCESSFUL; - } + listening_sock->port = addr.sin_port; + listening_sock->socket = s; + list_add_head(&listening_sockets, &listening_sock->entry); - if (listen(s, SOMAXCONN) == -1) - { - ERR("Failed to listen to port %u, error %u.\n", addr.sin_port, WSAGetLastError()); - LeaveCriticalSection(&http_cs); - closesocket(s); - free(url); - free(new_entry); - return STATUS_OBJECT_NAME_COLLISION; + ioctlsocket(s, FIONBIO, &true); + WSAEventSelect(s, request_event, FD_ACCEPT); } - ioctlsocket(s, FIONBIO, &true); - WSAEventSelect(s, request_event, FD_ACCEPT); new_entry->url = url; new_entry->context = params->context; - new_entry->socket = s; + new_entry->listening_sock = listening_sock; list_add_head(&queue->urls, &new_entry->entry); /* See if any pending requests now match this queue. */ LIST_FOR_EACH_ENTRY(conn, &connections, struct connection, entry) { - if (conn->available && !conn->queue && host_matches(conn, queue)) + if (conn->available && !conn->queue && url_matches(conn, queue, NULL)) { conn->queue = queue; conn->context = params->context; @@ -713,6 +846,25 @@ static NTSTATUS http_add_url(struct request_queue *queue, IRP *irp) return STATUS_SUCCESS; } +static BOOL is_listening_socket_used(const struct listening_socket *listening_sock) +{ + struct request_queue *queue_entry; + struct url *url_entry; + + LIST_FOR_EACH_ENTRY(queue_entry, &request_queues, struct request_queue, entry) + { + LIST_FOR_EACH_ENTRY(url_entry, &queue_entry->urls, struct url, entry) + { + if (listening_sock == url_entry->listening_sock) + { + return TRUE; + } + } + } + + return FALSE; +} + static NTSTATUS http_remove_url(struct request_queue *queue, IRP *irp) { const char *url = irp->AssociatedIrp.SystemBuffer; @@ -729,9 +881,14 @@ static NTSTATUS http_remove_url(struct request_queue *queue, IRP *irp) free(url_entry->url); url_entry->url = NULL; - shutdown(url_entry->socket, SD_BOTH); - closesocket(url_entry->socket); - url_entry->socket = -1; + if (!is_listening_socket_used(url_entry->listening_sock)) + { + shutdown(url_entry->listening_sock->socket, SD_BOTH); + closesocket(url_entry->listening_sock->socket); + list_remove(&url_entry->listening_sock->entry); + free(url_entry->listening_sock); + } + url_entry->listening_sock = NULL; list_remove(&url_entry->entry); free(url_entry); @@ -959,6 +1116,7 @@ static NTSTATUS WINAPI dispatch_create(DEVICE_OBJECT *device, IRP *irp) static void close_queue(struct request_queue *queue) { struct url *url, *url_next; + struct listening_socket *listening_sock, *listening_sock_next; EnterCriticalSection(&http_cs); list_remove(&queue->entry); @@ -966,11 +1124,17 @@ static void close_queue(struct request_queue *queue) LIST_FOR_EACH_ENTRY_SAFE(url, url_next, &queue->urls, struct url, entry) { free(url->url); - shutdown(url->socket, SD_BOTH); - closesocket(url->socket); - list_remove(&url->entry); free(url); } + + LIST_FOR_EACH_ENTRY_SAFE(listening_sock, listening_sock_next, &listening_sockets, struct listening_socket, entry) + { + shutdown(listening_sock->socket, SD_BOTH); + closesocket(listening_sock->socket); + list_remove(&listening_sock->entry); + free(listening_sock); + } + free(queue); LeaveCriticalSection(&http_cs); diff --git a/dlls/httpapi/tests/httpapi.c b/dlls/httpapi/tests/httpapi.c index be7cb667ceaf..f52c584b93ba 100644 --- a/dlls/httpapi/tests/httpapi.c +++ b/dlls/httpapi/tests/httpapi.c @@ -1121,6 +1121,103 @@ static void test_v1_multiple_urls(void) ok(ret, "Failed to close queue handle, error %lu.\n", GetLastError()); } +static void test_v1_relative_urls(void) +{ + char DECLSPEC_ALIGN(8) req_buffer[2048]; + char relative_req[] = + "GET %s HTTP/1.1\r\n" + "Host: localhost:%u\r\n" + "Connection: keep-alive\r\n" + "User-Agent: WINE\r\n" + "\r\n"; + HTTP_REQUEST_V1 *req = (HTTP_REQUEST_V1 *)req_buffer; + unsigned short port; + WCHAR url[50], url2[50]; + char req_text[200]; + DWORD ret_size; + HANDLE queue, queue2; + SOCKET s; + int ret; + + ret = HttpCreateHttpHandle(&queue, 0); + ok(!ret, "Got error %u.\n", ret); + ret = HttpCreateHttpHandle(&queue2, 0); + ok(!ret, "Got error %u.\n", ret); + + port = add_url_v1(queue); + swprintf(url2, ARRAY_SIZE(url2), L"http://localhost:%u/foobar/foo/", port); + ret = HttpAddUrl(queue2, url2, NULL); + ok(!ret, "Got error %u.\n", ret); + + s = create_client_socket(port); + sprintf(req_text, simple_req, port); + ret = send(s, req_text, strlen(req_text), 0); + ok(ret == strlen(req_text), "send() returned %d.\n", ret); + + memset(req_buffer, 0xcc, sizeof(req_buffer)); + ret = HttpReceiveHttpRequest(queue, HTTP_NULL_ID, 0, (HTTP_REQUEST *)req, sizeof(req_buffer), &ret_size, NULL); + ok(!ret, "Got error %u.\n", ret); + ok(ret_size > sizeof(*req), "Got size %lu.\n", ret_size); + + send_response_v1(queue, req->RequestId, s); + + sprintf(req_text, relative_req, "/foobar/foo/bar", port); + ret = send(s, req_text, strlen(req_text), 0); + ok(ret == strlen(req_text), "send() returned %d.\n", ret); + + memset(req_buffer, 0xcc, sizeof(req_buffer)); + ret = HttpReceiveHttpRequest(queue2, HTTP_NULL_ID, 0, (HTTP_REQUEST *)req, sizeof(req_buffer), &ret_size, NULL); + ok(!ret, "Got error %u.\n", ret); + ok(ret_size > sizeof(*req), "Got size %lu.\n", ret_size); + + send_response_v1(queue2, req->RequestId, s); + + sprintf(req_text, relative_req, "/foobar/foo", port); + ret = send(s, req_text, strlen(req_text), 0); + ok(ret == strlen(req_text), "send() returned %d.\n", ret); + + memset(req_buffer, 0xcc, sizeof(req_buffer)); + ret = HttpReceiveHttpRequest(queue2, HTTP_NULL_ID, 0, (HTTP_REQUEST *)req, sizeof(req_buffer), &ret_size, NULL); + ok(!ret, "Got error %u.\n", ret); + ok(ret_size > sizeof(*req), "Got size %lu.\n", ret_size); + + send_response_v1(queue2, req->RequestId, s); + + sprintf(req_text, relative_req, "/foobar/foo?a=b", port); + ret = send(s, req_text, strlen(req_text), 0); + ok(ret == strlen(req_text), "send() returned %d.\n", ret); + + memset(req_buffer, 0xcc, sizeof(req_buffer)); + ret = HttpReceiveHttpRequest(queue2, HTTP_NULL_ID, 0, (HTTP_REQUEST *)req, sizeof(req_buffer), &ret_size, NULL); + ok(!ret, "Got error %u.\n", ret); + ok(ret_size > sizeof(*req), "Got size %lu.\n", ret_size); + + send_response_v1(queue2, req->RequestId, s); + + closesocket(s); + remove_url_v1(queue, port); + ret = HttpRemoveUrl(queue2, url2); + ok(!ret, "Got error %u.\n", ret); + + swprintf(url, ARRAY_SIZE(url), L"http://localhost:%u/barfoo/", port); + ret = HttpAddUrl(queue, url, NULL); + ok(!ret, "Got error %u.\n", ret); + swprintf(url2, ARRAY_SIZE(url2), L"http://localhost:%u/barfoo/", port); + ret = HttpAddUrl(queue2, url2, NULL); + ok(ret == ERROR_ALREADY_EXISTS, "Got error %u.\n", ret); + + swprintf(url2, ARRAY_SIZE(url2), L"http://localhost:%u/barfoo", port); + ret = HttpAddUrl(queue2, url2, NULL); + ok(ret == ERROR_ALREADY_EXISTS, "Got error %u.\n", ret); + + ret = CloseHandle(queue); + ok(ret, "Failed to close queue handle, error %lu.\n", GetLastError()); + ret = CloseHandle(queue2); + ok(ret, "Failed to close queue handle, error %lu.\n", GetLastError()); + + return; +} + static void test_v1_urls(void) { char DECLSPEC_ALIGN(8) req_buffer[2048]; @@ -1682,6 +1779,7 @@ START_TEST(httpapi) test_v1_cooked_url(); test_v1_unknown_tokens(); test_v1_multiple_urls(); + test_v1_relative_urls(); test_v1_urls(); ret = HttpTerminate(HTTP_INITIALIZE_SERVER, NULL);