From c5e1da3370652a2c6f5c4908983ca8b3d6c5583b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Igl=C3=B3i=20G=C3=A1bor?= Date: Wed, 13 Sep 2017 22:35:19 +0100 Subject: [PATCH 1/2] Make client use cleartext_channel MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This is required for the correctness of client types, because we always start with a cleartext channel, and then negotiate the upgrade to TLS. Signed-off-by: Iglói Gábor --- lib/client.ml | 34 +++++++++++++++++----------------- lib/s.ml | 4 ++-- lib_test/protocol_test.ml | 2 +- lwt/nbd_lwt_unix.ml | 2 +- lwt/nbd_lwt_unix.mli | 2 +- 5 files changed, 22 insertions(+), 22 deletions(-) diff --git a/lib/client.ml b/lib/client.ml index 798b43e..8a6060d 100644 --- a/lib/client.ml +++ b/lib/client.ml @@ -26,7 +26,7 @@ let get_handle = this module NbdRpc = struct - type transport = channel + type transport = cleartext_channel type id = int64 type request_hdr = Request.t type request_body = Cstruct.t option @@ -35,7 +35,7 @@ module NbdRpc = struct let recv_hdr sock = let buf = Cstruct.create 16 in - sock.read buf + sock.read_clear buf >>= fun () -> match Reply.unmarshal buf with | `Ok x -> Lwt.return (Some x.Reply.handle, x) @@ -48,7 +48,7 @@ module NbdRpc = struct begin match req_hdr.Request.ty with | Command.Read -> (* TODO: use a page-aligned memory allocator *) - Lwt_list.iter_s sock.read response_body + Lwt_list.iter_s sock.read_clear response_body >>= fun () -> Lwt.return (`Ok ()) | _ -> Lwt.return (`Ok ()) @@ -57,12 +57,12 @@ module NbdRpc = struct let send_one sock req_hdr req_body = let buf = Cstruct.create Request.sizeof in Request.marshal buf req_hdr; - sock.write buf + sock.write_clear buf >>= fun () -> match req_body with | None -> Lwt.return () | Some data -> - sock.write data + sock.write_clear data let id_of_request req = req.Request.handle @@ -98,13 +98,13 @@ let make channel size_bytes flags = let list channel = let buf = Cstruct.create Announcement.sizeof in - channel.read buf + channel.read_clear buf >>= fun () -> match Announcement.unmarshal buf with | `Error e -> Lwt.fail e | `Ok kind -> let buf = Cstruct.create (Negotiate.sizeof kind) in - channel.read buf + channel.read_clear buf >>= fun () -> begin match Negotiate.unmarshal buf kind with | `Error e -> Lwt.fail e @@ -114,15 +114,15 @@ let list channel = let buf = Cstruct.create NegotiateResponse.sizeof in let flags = if List.mem GlobalFlag.Fixed_newstyle x then [ ClientFlag.Fixed_newstyle ] else [] in NegotiateResponse.marshal buf flags; - channel.write buf + channel.write_clear buf >>= fun () -> let buf = Cstruct.create OptionRequestHeader.sizeof in OptionRequestHeader.(marshal buf { ty = Option.List; length = 0l }); - channel.write buf + channel.write_clear buf >>= fun () -> let buf = Cstruct.create OptionResponseHeader.sizeof in let rec loop acc = - channel.read buf + channel.read_clear buf >>= fun () -> match OptionResponseHeader.unmarshal buf with | `Error e -> Lwt.fail e @@ -131,7 +131,7 @@ let list channel = Lwt.return (`Error `Policy) | `Ok { OptionResponseHeader.response_type = OptionResponse.Server; length } -> let buf' = Cstruct.create (Int32.to_int length) in - channel.read buf' + channel.read_clear buf' >>= fun () -> begin match Server.unmarshal buf' with | `Ok server -> @@ -145,13 +145,13 @@ let list channel = let negotiate channel export = let buf = Cstruct.create Announcement.sizeof in - channel.read buf + channel.read_clear buf >>= fun () -> match Announcement.unmarshal buf with | `Error e -> Lwt.fail e | `Ok kind -> let buf = Cstruct.create (Negotiate.sizeof kind) in - channel.read buf + channel.read_clear buf >>= fun () -> begin match Negotiate.unmarshal buf kind with | `Error e -> Lwt.fail e @@ -163,18 +163,18 @@ let negotiate channel export = let buf = Cstruct.create NegotiateResponse.sizeof in let flags = if List.mem GlobalFlag.Fixed_newstyle x then [ ClientFlag.Fixed_newstyle ] else [] in NegotiateResponse.marshal buf flags; - channel.write buf + channel.write_clear buf >>= fun () -> let buf = Cstruct.create OptionRequestHeader.sizeof in OptionRequestHeader.(marshal buf { ty = Option.ExportName; length = Int32.of_int (String.length export) }); - channel.write buf + channel.write_clear buf >>= fun () -> let buf = Cstruct.create (ExportName.sizeof export) in ExportName.marshal buf export; - channel.write buf + channel.write_clear buf >>= fun () -> let buf = Cstruct.create DiskInfo.sizeof in - channel.read buf + channel.read_clear buf >>= fun () -> begin match DiskInfo.unmarshal buf with | `Error e -> Lwt.fail e diff --git a/lib/s.ml b/lib/s.ml index 7a036a3..846b0da 100644 --- a/lib/s.ml +++ b/lib/s.ml @@ -25,13 +25,13 @@ module type CLIENT = sig type size = int64 (** The size of a remote disk *) - val list: channel -> [ `Ok of string list | `Error of [ `Policy | `Unsupported ] ] Lwt.t + val list: cleartext_channel -> [ `Ok of string list | `Error of [ `Policy | `Unsupported ] ] Lwt.t (** [list channel] returns a list of exports known by the server. [`Error `Policy] means the server has this function disabled deliberately. [`Error `Unsupported] means the server is old and does not support the query function. *) - val negotiate: channel -> string -> (t * size * Protocol.PerExportFlag.t list) Lwt.t + val negotiate: cleartext_channel -> string -> (t * size * Protocol.PerExportFlag.t list) Lwt.t (** [negotiate channel export] takes an already-connected channel, performs the initial protocol negotiation and connects to the named export. Returns [disk * remote disk size * flags] *) diff --git a/lib_test/protocol_test.ml b/lib_test/protocol_test.ml index 25718d1..fbb8ecf 100644 --- a/lib_test/protocol_test.ml +++ b/lib_test/protocol_test.ml @@ -95,7 +95,7 @@ let make_client_channel test_sequence = else write buf | [] -> Lwt.fail_with "Client tried to write but the stream was empty" in let close () = Lwt.return () in - Channel.{ read; write; close; is_tls=false } + Channel.{ read_clear=read; write_clear=write; close_clear=close; make_tls_channel=None } let client_negotiation = "Perform a negotiation using the second version of the protocol from the diff --git a/lwt/nbd_lwt_unix.ml b/lwt/nbd_lwt_unix.ml index 98be659..06e8d02 100644 --- a/lwt/nbd_lwt_unix.ml +++ b/lwt/nbd_lwt_unix.ml @@ -99,7 +99,7 @@ let connect hostname port = let server_address = host_info.Lwt_unix.h_addr_list.(0) in Lwt_unix.connect socket (Lwt_unix.ADDR_INET (server_address, port)) >>= fun () -> - (generic_channel_of_fd socket None) + Lwt.return (cleartext_channel_of_fd socket None) let init_tls_get_ctx ~certfile ~ciphersuites = Ssl_threads.init (); diff --git a/lwt/nbd_lwt_unix.mli b/lwt/nbd_lwt_unix.mli index a5fc3fc..a35826e 100644 --- a/lwt/nbd_lwt_unix.mli +++ b/lwt/nbd_lwt_unix.mli @@ -20,7 +20,7 @@ type tls_role = | TlsClient of Ssl.context | TlsServer of Ssl.context -val connect: string -> int -> Channel.channel Lwt.t +val connect: string -> int -> Channel.cleartext_channel Lwt.t (** [connect hostname port] connects to host:port and returns a [generic_channel] with no TLS ability or potential. *) From d1b349ff7e25b2e4eb7775fc272f2b00916b165a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Igl=C3=B3i=20G=C3=A1bor?= Date: Wed, 13 Sep 2017 22:51:36 +0100 Subject: [PATCH 2/2] Client: use generic_channel for Mux MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit because in the future, the client may upgrade to TLS, and then Mux will use a TLS channel converted into a generic_channel type. Signed-off-by: Iglói Gábor --- lib/client.ml | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/lib/client.ml b/lib/client.ml index 8a6060d..3d1fdc4 100644 --- a/lib/client.ml +++ b/lib/client.ml @@ -26,7 +26,7 @@ let get_handle = this module NbdRpc = struct - type transport = cleartext_channel + type transport = generic_channel type id = int64 type request_hdr = Request.t type request_body = Cstruct.t option @@ -35,7 +35,7 @@ module NbdRpc = struct let recv_hdr sock = let buf = Cstruct.create 16 in - sock.read_clear buf + sock.read buf >>= fun () -> match Reply.unmarshal buf with | `Ok x -> Lwt.return (Some x.Reply.handle, x) @@ -48,7 +48,7 @@ module NbdRpc = struct begin match req_hdr.Request.ty with | Command.Read -> (* TODO: use a page-aligned memory allocator *) - Lwt_list.iter_s sock.read_clear response_body + Lwt_list.iter_s sock.read response_body >>= fun () -> Lwt.return (`Ok ()) | _ -> Lwt.return (`Ok ()) @@ -57,12 +57,12 @@ module NbdRpc = struct let send_one sock req_hdr req_body = let buf = Cstruct.create Request.sizeof in Request.marshal buf req_hdr; - sock.write_clear buf + sock.write buf >>= fun () -> match req_body with | None -> Lwt.return () | Some data -> - sock.write_clear data + sock.write data let id_of_request req = req.Request.handle @@ -87,6 +87,7 @@ type t = { type id = unit let make channel size_bytes flags = + let channel = generic_of_cleartext_channel channel in Rpc.create channel >>= fun client -> let read_write = not (List.mem PerExportFlag.Read_only flags) in