diff --git a/ocaml/libs/http-lib/http_svr.ml b/ocaml/libs/http-lib/http_svr.ml index 4db3df81d2a..404136a8b99 100644 --- a/ocaml/libs/http-lib/http_svr.ml +++ b/ocaml/libs/http-lib/http_svr.ml @@ -149,6 +149,10 @@ let response_error_html ?(version = "1.1") s code message hdrs body = D.debug "Response %s" (Http.Response.to_string res) ; Unixext.really_write_string s (Http.Response.to_wire_string res) +let response_custom_error ?req s error_code reason body = + let version = Option.map get_return_version req in + response_error_html ?version s error_code reason [] body + let response_unauthorised ?req label s = let version = Option.map get_return_version req in let body = diff --git a/ocaml/libs/http-lib/http_svr.mli b/ocaml/libs/http-lib/http_svr.mli index 101479d100d..7d3bae386d7 100644 --- a/ocaml/libs/http-lib/http_svr.mli +++ b/ocaml/libs/http-lib/http_svr.mli @@ -97,6 +97,9 @@ val response_unauthorised : val response_forbidden : ?req:Http.Request.t -> Unix.file_descr -> unit +val response_custom_error : + ?req:Http.Request.t -> Unix.file_descr -> string -> string -> string -> unit + val response_badrequest : ?req:Http.Request.t -> Unix.file_descr -> unit val response_internal_error : diff --git a/ocaml/xapi/console.ml b/ocaml/xapi/console.ml index 08f4c863e87..57072849bfc 100644 --- a/ocaml/xapi/console.ml +++ b/ocaml/xapi/console.ml @@ -40,39 +40,56 @@ type address = module Connection_limit = struct module VMMap = Map.Make (String) - let active_connections : int VMMap.t ref = ref VMMap.empty + let active_connections : (string * string) list VMMap.t ref = ref VMMap.empty let mutex = Mutex.create () let with_lock = Xapi_stdext_threads.Threadext.Mutex.execute - let drop vm_id = + let drop vm_id session_id = with_lock mutex (fun () -> match VMMap.find_opt vm_id !active_connections with - | Some n when n > 1 -> - active_connections := VMMap.add vm_id (n - 1) !active_connections - | Some _ | None -> - active_connections := VMMap.remove vm_id !active_connections + | Some connections -> + let updated_connections = + List.filter (fun (_, sid) -> sid <> session_id) connections + in + if updated_connections = [] then + active_connections := VMMap.remove vm_id !active_connections + else + active_connections := + VMMap.add vm_id updated_connections !active_connections + | None -> + (* Unlikely *) + () ) - (* When the limit is disabled (false), we must still track the connection count for each vm_id. + (* When the limit is disabled (false), we must still track the connections for each vm_id. This ensures that if the limit is later enabled (set to true), any existing connections are accounted for, and the limit can be correctly enforced for subsequent connection attempts. *) - let try_add vm_id is_limit_enabled = + let try_add vm_id user_name session_id is_limit_enabled = with_lock mutex (fun () -> - let count = - VMMap.find_opt vm_id !active_connections |> Option.value ~default:0 + let connections = + VMMap.find_opt vm_id !active_connections |> Option.value ~default:[] in + let count = List.length connections in if is_limit_enabled && count > 0 then ( debug "limit_console_sessions is true. Console connection is rejected \ for VM %s, active connections: %d" vm_id count ; false - ) else ( - active_connections := VMMap.add vm_id (count + 1) !active_connections ; + ) else + let updated_connections = (user_name, session_id) :: connections in + active_connections := + VMMap.add vm_id updated_connections !active_connections ; true - ) + ) + + let get_connected_users vm_id = + with_lock mutex (fun () -> + VMMap.find_opt vm_id !active_connections + |> Option.value ~default:[] + |> List.map fst ) end @@ -230,18 +247,57 @@ let real_proxy' ~__context ~vm vnc_port s = debug "Proxy exited" with exn -> debug "error: %s" (ExnHelper.string_of_exn exn) -let real_proxy __context vm _ _ vnc_port s = +let respond_console_limit_exceeded req s vm_id = + let html_escape s = + let escape_char = function + | '<' -> + "<" + | '>' -> + ">" + | '&' -> + "&" + | '"' -> + """ + | '\'' -> + "'" + | c -> + String.make 1 c + in + String.fold_left (fun acc c -> acc ^ escape_char c) "" s + in + let connected_users = Connection_limit.get_connected_users vm_id in + let users_text = + match connected_users with + | [user] -> + Printf.sprintf "User '%s' is" (html_escape user) + | users -> + let escaped_users = List.map html_escape users in + Printf.sprintf "Users '%s' are" (String.concat ", " escaped_users) + in + let body = + Printf.sprintf + "
%s currently connected \ + to this console. No new connections are allowed. Please check the \ + limit_console_sessions config and try again later.
" + users_text + in + Http_svr.response_custom_error ~req s "503" "Connection Limit Exceeded" body + +let real_proxy __context vm req _ vnc_port s = let vm_id = Ref.string_of vm in let pool = Helpers.get_pool ~__context in let is_limit_enabled = Db.Pool.get_limit_console_sessions ~__context ~self:pool in - if Connection_limit.try_add vm_id is_limit_enabled then + let session_id = Xapi_http.get_session_id req in + let user = Db.Session.get_auth_user_name ~__context ~self:session_id in + let session_id_str = Ref.string_of session_id in + if Connection_limit.try_add vm_id user session_id_str is_limit_enabled then finally (* Ensure we drop the vm connection count if exceptions occur *) (fun () -> real_proxy' ~__context ~vm vnc_port s) - (fun () -> Connection_limit.drop vm_id) + (fun () -> Connection_limit.drop vm_id session_id_str) else - Http_svr.headers s (Http.http_503_service_unavailable ()) + respond_console_limit_exceeded req s vm_id let go_if_no_limit __context s f = let pool = Helpers.get_pool ~__context in