diff --git a/man/ttyd.1 b/man/ttyd.1 index 000cb45e..2a7ec97f 100644 --- a/man/ttyd.1 +++ b/man/ttyd.1 @@ -46,6 +46,10 @@ Cross platform: macOS, Linux, FreeBSD/OpenBSD, OpenWrt/LEDE, Windows \-c, \-\-credential USER[:PASSWORD] Credential for Basic Authentication (format: username:password) +.PP +\-H, \-\-auth\-header + HTTP Header name for auth proxy, this will configure ttyd to let a HTTP reverse proxy handle authentication + .PP \-u, \-\-uid User id to run with diff --git a/man/ttyd.man.md b/man/ttyd.man.md index b90cd97d..baf8fb45 100644 --- a/man/ttyd.man.md +++ b/man/ttyd.man.md @@ -28,6 +28,9 @@ ttyd 1 "September 2016" ttyd "User Manual" -c, --credential USER[:PASSWORD] Credential for Basic Authentication (format: username:password) + -H, --auth-header + HTTP Header name for auth proxy, this will configure ttyd to let a HTTP reverse proxy handle authentication + -u, --uid User id to run with diff --git a/scripts/cross-build.sh b/scripts/cross-build.sh index 2405b113..4070ca82 100755 --- a/scripts/cross-build.sh +++ b/scripts/cross-build.sh @@ -106,7 +106,6 @@ build_libwebsockets() { -DLWS_WITH_LEJP=OFF \ -DLWS_WITH_LEJP_CONF=OFF \ -DLWS_WITH_LWSAC=OFF \ - -DLWS_WITH_CUSTOM_HEADERS=OFF \ -DLWS_WITH_SEQUENCER=OFF \ .. make -j"$(nproc)" install diff --git a/src/http.c b/src/http.c index f1442970..eea62afd 100644 --- a/src/http.c +++ b/src/http.c @@ -11,35 +11,36 @@ enum { AUTH_OK, AUTH_FAIL, AUTH_ERROR }; static char *html_cache = NULL; static size_t html_cache_len = 0; -static int check_auth(struct lws *wsi, struct pss_http *pss) { - if (server->credential == NULL) return AUTH_OK; - - char buf[256]; - int len = lws_hdr_copy(wsi, buf, sizeof(buf), WSI_TOKEN_HTTP_AUTHORIZATION); - if (len >= 7 && strstr(buf, "Basic ")) { - if (!strcmp(buf + 6, server->credential)) return AUTH_OK; - } - +static int send_unauthorized(struct lws *wsi, unsigned int code, enum lws_token_indexes header) { unsigned char buffer[1024 + LWS_PRE], *p, *end; p = buffer + LWS_PRE; end = p + sizeof(buffer) - LWS_PRE; - char *body = strdup("401 Unauthorized\n"); - size_t n = strlen(body); - - if (lws_add_http_header_status(wsi, HTTP_STATUS_UNAUTHORIZED, &p, end) || - lws_add_http_header_by_token(wsi, WSI_TOKEN_HTTP_WWW_AUTHENTICATE, - (unsigned char *)"Basic realm=\"ttyd\"", 18, &p, end) || - lws_add_http_header_content_length(wsi, n, &p, end) || - lws_finalize_http_header(wsi, &p, end) || + if (lws_add_http_header_status(wsi, code, &p, end) || + lws_add_http_header_by_token(wsi, header, (unsigned char *)"Basic realm=\"ttyd\"", 18, &p, end) || + lws_add_http_header_content_length(wsi, 0, &p, end) || lws_finalize_http_header(wsi, &p, end) || lws_write(wsi, buffer + LWS_PRE, p - (buffer + LWS_PRE), LWS_WRITE_HTTP_HEADERS) < 0) - return AUTH_ERROR; + return AUTH_FAIL; - pss->buffer = pss->ptr = body; - pss->len = n; - lws_callback_on_writable(wsi); + return lws_http_transaction_completed(wsi) ? AUTH_FAIL : AUTH_ERROR; +} - return AUTH_FAIL; +static int check_auth(struct lws *wsi, struct pss_http *pss) { + if (server->auth_header != NULL) { + if (lws_hdr_custom_length(wsi, server->auth_header, strlen(server->auth_header)) > 0) return AUTH_OK; + return send_unauthorized(wsi, HTTP_STATUS_PROXY_AUTH_REQUIRED, WSI_TOKEN_HTTP_PROXY_AUTHENTICATE); + } + + if(server->credential != NULL) { + char buf[256]; + int len = lws_hdr_copy(wsi, buf, sizeof(buf), WSI_TOKEN_HTTP_AUTHORIZATION); + if (len >= 7 && strstr(buf, "Basic ")) { + if (!strcmp(buf + 6, server->credential)) return AUTH_OK; + } + return send_unauthorized(wsi, HTTP_STATUS_UNAUTHORIZED, WSI_TOKEN_HTTP_WWW_AUTHENTICATE); + } + + return AUTH_OK; } static bool accept_gzip(struct lws *wsi) { @@ -89,8 +90,7 @@ static void access_log(struct lws *wsi, const char *path) { lwsl_notice("HTTP %s - %s\n", path, rip); } -int callback_http(struct lws *wsi, enum lws_callback_reasons reason, void *user, void *in, - size_t len) { +int callback_http(struct lws *wsi, enum lws_callback_reasons reason, void *user, void *in, size_t len) { struct pss_http *pss = (struct pss_http *)user; unsigned char buffer[4096 + LWS_PRE], *p, *end; char buf[256]; @@ -118,8 +118,7 @@ int callback_http(struct lws *wsi, enum lws_callback_reasons reason, void *user, size_t n = sprintf(buf, "{\"token\": \"%s\"}", credential); if (lws_add_http_header_status(wsi, HTTP_STATUS_OK, &p, end) || lws_add_http_header_by_token(wsi, WSI_TOKEN_HTTP_CONTENT_TYPE, - (unsigned char *)"application/json;charset=utf-8", 30, &p, - end) || + (unsigned char *)"application/json;charset=utf-8", 30, &p, end) || lws_add_http_header_content_length(wsi, (unsigned long)n, &p, end) || lws_finalize_http_header(wsi, &p, end) || lws_write(wsi, buffer + LWS_PRE, p - (buffer + LWS_PRE), LWS_WRITE_HTTP_HEADERS) < 0) @@ -134,11 +133,9 @@ int callback_http(struct lws *wsi, enum lws_callback_reasons reason, void *user, // redirects `/base-path` to `/base-path/` if (strcmp(pss->path, endpoints.parent) == 0) { if (lws_add_http_header_status(wsi, HTTP_STATUS_FOUND, &p, end) || - lws_add_http_header_by_token(wsi, WSI_TOKEN_HTTP_LOCATION, - (unsigned char *)endpoints.index, + lws_add_http_header_by_token(wsi, WSI_TOKEN_HTTP_LOCATION, (unsigned char *)endpoints.index, (int)strlen(endpoints.index), &p, end) || - lws_add_http_header_content_length(wsi, 0, &p, end) || - lws_finalize_http_header(wsi, &p, end) || + lws_add_http_header_content_length(wsi, 0, &p, end) || lws_finalize_http_header(wsi, &p, end) || lws_write(wsi, buffer + LWS_PRE, p - (buffer + LWS_PRE), LWS_WRITE_HTTP_HEADERS) < 0) return 1; goto try_to_reuse; @@ -157,15 +154,14 @@ int callback_http(struct lws *wsi, enum lws_callback_reasons reason, void *user, char *output = (char *)index_html; size_t output_len = index_html_len; if (lws_add_http_header_status(wsi, HTTP_STATUS_OK, &p, end) || - lws_add_http_header_by_token(wsi, WSI_TOKEN_HTTP_CONTENT_TYPE, - (const unsigned char *)content_type, 9, &p, end)) + lws_add_http_header_by_token(wsi, WSI_TOKEN_HTTP_CONTENT_TYPE, (const unsigned char *)content_type, 9, &p, + end)) return 1; #ifdef LWS_WITH_HTTP_STREAM_COMPRESSION if (!uncompress_html(&output, &output_len)) return 1; #else if (accept_gzip(wsi)) { - if (lws_add_http_header_by_token(wsi, WSI_TOKEN_HTTP_CONTENT_ENCODING, - (unsigned char *)"gzip", 4, &p, end)) + if (lws_add_http_header_by_token(wsi, WSI_TOKEN_HTTP_CONTENT_ENCODING, (unsigned char *)"gzip", 4, &p, end)) return 1; } else { if (!uncompress_html(&output, &output_len)) return 1; diff --git a/src/protocol.c b/src/protocol.c index 5324655c..b6928a29 100644 --- a/src/protocol.c +++ b/src/protocol.c @@ -137,8 +137,21 @@ static void wsi_output(struct lws *wsi, pty_buf_t *buf) { free(message); } -int callback_tty(struct lws *wsi, enum lws_callback_reasons reason, void *user, void *in, - size_t len) { +static bool check_auth(struct lws *wsi) { + if (server->auth_header != NULL) { + return lws_hdr_custom_length(wsi, server->auth_header, strlen(server->auth_header)) > 0; + } + + if (server->credential != NULL) { + char buf[256]; + size_t n = lws_hdr_copy(wsi, buf, sizeof(buf), WSI_TOKEN_HTTP_AUTHORIZATION); + return n >= 7 && strstr(buf, "Basic ") && !strcmp(buf + 6, server->credential); + } + + return true; +} + +int callback_tty(struct lws *wsi, enum lws_callback_reasons reason, void *user, void *in, size_t len) { struct pss_tty *pss = (struct pss_tty *)user; char buf[256]; size_t n = 0; @@ -153,10 +166,7 @@ int callback_tty(struct lws *wsi, enum lws_callback_reasons reason, void *user, lwsl_warn("refuse to serve WS client due to the --max-clients option.\n"); return 1; } - if (server->credential != NULL) { - n = lws_hdr_copy(wsi, buf, sizeof(buf), WSI_TOKEN_HTTP_AUTHORIZATION); - if (n < 7 || !strstr(buf, "Basic ") || strcmp(buf + 6, server->credential)) return 1; - } + if (!check_auth(wsi)) return 1; n = lws_hdr_copy(wsi, pss->path, sizeof(pss->path), WSI_TOKEN_GET_URI); #if defined(LWS_ROLE_H2) @@ -261,8 +271,8 @@ int callback_tty(struct lws *wsi, enum lws_callback_reasons reason, void *user, } break; case RESIZE_TERMINAL: - json_object_put(parse_window_size(pss->buffer + 1, pss->len - 1, &pss->process->columns, - &pss->process->rows)); + json_object_put( + parse_window_size(pss->buffer + 1, pss->len - 1, &pss->process->columns, &pss->process->rows)); pty_resize(pss->process); break; case PAUSE: diff --git a/src/server.c b/src/server.c index f1af441c..b0a33751 100644 --- a/src/server.c +++ b/src/server.c @@ -54,6 +54,7 @@ static lws_retry_bo_t retry = { static const struct option options[] = {{"port", required_argument, NULL, 'p'}, {"interface", required_argument, NULL, 'i'}, {"credential", required_argument, NULL, 'c'}, + {"auth-header", required_argument, NULL, 'H'}, {"uid", required_argument, NULL, 'u'}, {"gid", required_argument, NULL, 'g'}, {"signal", required_argument, NULL, 's'}, @@ -79,13 +80,7 @@ static const struct option options[] = {{"port", required_argument, NULL, 'p'}, {"version", no_argument, NULL, 'v'}, {"help", no_argument, NULL, 'h'}, {NULL, 0, 0, 0}}; - -#if LWS_LIBRARY_VERSION_NUMBER < 4000000 -static const char *opt_string = "p:i:c:u:g:s:I:b:6aSC:K:A:Rt:T:Om:oBd:vh"; -#endif -#if LWS_LIBRARY_VERSION_NUMBER >= 4000000 -static const char *opt_string = "p:i:c:u:g:s:I:b:P:6aSC:K:A:Rt:T:Om:oBd:vh"; -#endif +static const char *opt_string = "p:i:c:H:u:g:s:I:b:P:6aSC:K:A:Rt:T:Om:oBd:vh"; static void print_help() { // clang-format off @@ -97,7 +92,8 @@ static void print_help() { "OPTIONS:\n" " -p, --port Port to listen (default: 7681, use `0` for random port)\n" " -i, --interface Network interface to bind (eg: eth0), or UNIX domain socket path (eg: /var/run/ttyd.sock)\n" - " -c, --credential Credential for Basic Authentication (format: username:password)\n" + " -c, --credential Credential for basic authentication (format: username:password)\n" + " -H, --auth-header HTTP Header name for auth proxy, this will configure ttyd to let a HTTP reverse proxy handle authentication\n" " -u, --uid User id to run with\n" " -g, --gid Group id to run with\n" " -s, --signal Signal to send to the command when exit it (default: 1, SIGHUP)\n" @@ -132,6 +128,28 @@ static void print_help() { // clang-format on } +static void print_config() { + lwsl_notice("tty configuration:\n"); + if (server->credential != NULL) lwsl_notice(" credential: %s\n", server->credential); + lwsl_notice(" start command: %s\n", server->command); + lwsl_notice(" close signal: %s (%d)\n", server->sig_name, server->sig_code); + lwsl_notice(" terminal type: %s\n", server->terminal_type); + if (endpoints.parent[0]) { + lwsl_notice("endpoints:\n"); + lwsl_notice(" base-path: %s\n", endpoints.parent); + lwsl_notice(" index : %s\n", endpoints.index); + lwsl_notice(" token : %s\n", endpoints.token); + lwsl_notice(" websocket: %s\n", endpoints.ws); + } + if (server->auth_header != NULL) lwsl_notice(" auth header: %s\n", server->auth_header); + if (server->check_origin) lwsl_notice(" check origin: true\n"); + if (server->url_arg) lwsl_notice(" allow url arg: true\n"); + if (server->readonly) lwsl_notice(" readonly: true\n"); + if (server->max_clients > 0) lwsl_notice(" max clients: %d\n", server->max_clients); + if (server->once) lwsl_notice(" once: true\n"); + if (server->index != NULL) lwsl_notice(" custom index.html: %s\n", server->index); +} + static struct server *server_new(int argc, char **argv, int start) { struct server *ts; size_t cmd_len = 0; @@ -178,6 +196,7 @@ static struct server *server_new(int argc, char **argv, int start) { static void server_free(struct server *ts) { if (ts == NULL) return; if (ts->credential != NULL) free(ts->credential); + if (ts->auth_header != NULL) free(ts->auth_header); if (ts->index != NULL) free(ts->index); free(ts->command); free(ts->prefs_json); @@ -362,6 +381,9 @@ int main(int argc, char **argv) { lws_b64_encode_string(optarg, strlen(optarg), b64_text, sizeof(b64_text)); server->credential = strdup(b64_text); break; + case 'H': + server->auth_header = strdup(optarg); + break; case 'u': info.uid = parse_int("uid", optarg); break; @@ -514,24 +536,15 @@ int main(int argc, char **argv) { #endif lwsl_notice("ttyd %s (libwebsockets %s)\n", TTYD_VERSION, LWS_LIBRARY_VERSION); - lwsl_notice("tty configuration:\n"); - if (server->credential != NULL) lwsl_notice(" credential: %s\n", server->credential); - lwsl_notice(" start command: %s\n", server->command); - lwsl_notice(" close signal: %s (%d)\n", server->sig_name, server->sig_code); - lwsl_notice(" terminal type: %s\n", server->terminal_type); - if (endpoints.parent[0]) { - lwsl_notice("endpoints:\n"); - lwsl_notice(" base-path: %s\n", endpoints.parent); - lwsl_notice(" index : %s\n", endpoints.index); - lwsl_notice(" token : %s\n", endpoints.token); - lwsl_notice(" websocket: %s\n", endpoints.ws); + print_config(); + + // lws custom header requires lower case name, and terminating : + if (server->auth_header != NULL) { + size_t auth_header_len = strlen(server->auth_header); + server->auth_header = xrealloc(server->auth_header, auth_header_len + 2); + strcat(server->auth_header + auth_header_len, ":"); + lowercase(server->auth_header); } - if (server->check_origin) lwsl_notice(" check origin: true\n"); - if (server->url_arg) lwsl_notice(" allow url arg: true\n"); - if (server->readonly) lwsl_notice(" readonly: true\n"); - if (server->max_clients > 0) lwsl_notice(" max clients: %d\n", server->max_clients); - if (server->once) lwsl_notice(" once: true\n"); - if (server->index != NULL) lwsl_notice(" custom index.html: %s\n", server->index); void *foreign_loops[1]; foreign_loops[0] = server->loop; diff --git a/src/server.h b/src/server.h index 5f369aee..d12bf6a6 100644 --- a/src/server.h +++ b/src/server.h @@ -1,3 +1,4 @@ +#include #include #include @@ -58,6 +59,7 @@ struct server { int client_count; // client count char *prefs_json; // client preferences char *credential; // encoded basic auth credential + char *auth_header; // header name used for auth proxy char *index; // custom index.html char *command; // full command line char **argv; // command with arguments diff --git a/src/utils.c b/src/utils.c index e3f9b395..f3e87d21 100644 --- a/src/utils.c +++ b/src/utils.c @@ -37,12 +37,20 @@ void *xrealloc(void *p, size_t size) { return p; } -char *uppercase(char *str) { - int i = 0; - do { - str[i] = (char)toupper(str[i]); - } while (str[i++] != '\0'); - return str; +char *uppercase(char *s) { + while(*s) { + *s = (char)toupper((int)*s); + s++; + } + return s; +} + +char *lowercase(char *s) { + while(*s) { + *s = (char)tolower((int)*s); + s++; + } + return s; } bool endswith(const char *str, const char *suffix) { diff --git a/src/utils.h b/src/utils.h index c5e985c3..826bb946 100644 --- a/src/utils.h +++ b/src/utils.h @@ -14,7 +14,10 @@ void *xmalloc(size_t size); void *xrealloc(void *p, size_t size); // Convert a string to upper case -char *uppercase(char *str); +char *uppercase(char *s); + +// Convert a string to lower case +char *lowercase(char *s); // Check whether str ends with suffix bool endswith(const char *str, const char *suffix);