Skip to content

Commit

Permalink
Merge deaa220 into 5887022
Browse files Browse the repository at this point in the history
  • Loading branch information
gaborigloi committed Oct 6, 2017
2 parents 5887022 + deaa220 commit 6fdcd47
Show file tree
Hide file tree
Showing 3 changed files with 144 additions and 33 deletions.
2 changes: 1 addition & 1 deletion lib/s.ml
Expand Up @@ -62,7 +62,7 @@ module type SERVER = sig
application. If the name is invalid, the only option is to close the connection.
If the name is valid then use the [serve] function. *)

val serve : t -> (module V1_LWT.BLOCK with type t = 'b) -> 'b -> unit Lwt.t
val serve : t -> ?read_only:bool -> (module V1_LWT.BLOCK with type t = 'b) -> 'b -> unit Lwt.t
(** [serve t block b] runs forever processing requests from [t], using [block]
device type [b]. *)

Expand Down
74 changes: 43 additions & 31 deletions lib/server.ml
Expand Up @@ -184,19 +184,23 @@ let error t handle code =
t.channel.write t.reply
)

let serve t (type t) block (b:t) =
let serve t (type t) ?(read_only=false) block (b:t) =
let section = Lwt_log_core.Section.make("Server.serve") in
let module Block = (val block: V1_LWT.BLOCK with type t = t) in

Lwt_log_core.notice ~section "Serving new client" >>= fun () ->
Lwt_log_core.notice_f ~section "Serving new client, read_only = %b" read_only >>= fun () ->

Block.get_info b
>>= fun info ->
let size = Int64.(mul info.Block.size_sectors (of_int info.Block.sector_size)) in
(if info.Block.read_write then Lwt.return [] else
Lwt_log_core.warning ~section "Block is read-only, sending NBD_FLAG_READ_ONLY transmission flag" >>= fun () ->
Lwt.return [ PerExportFlag.Read_only ])
>>= fun flags ->
(match read_only, info.Block.read_write with
| true, _ -> Lwt.return true
| false, true -> Lwt.return false
| false, false ->
Lwt_log_core.error ~section "Read-write access was requested, but block is read-only, sending NBD_FLAG_READ_ONLY transmission flag" >>= fun () ->
Lwt.return true)
>>= fun read_only ->
let flags = if read_only then [ PerExportFlag.Read_only ] else [] in
negotiate_end t size flags
>>= fun t ->

Expand All @@ -208,27 +212,41 @@ let serve t (type t) block (b:t) =
let open Request in
match request with
| { ty = Command.Write; from; len; handle } ->
if Int64.(rem from (of_int info.Block.sector_size)) <> 0L || Int64.(rem (of_int32 len) (of_int info.Block.sector_size) <> 0L)
then error t handle `EINVAL
else begin
let rec copy offset remaining =
let n = min block_size remaining in
let subblock = Cstruct.sub block 0 n in
t.channel.Channel.read subblock
>>= fun () ->
Block.write b Int64.(div offset (of_int info.Block.sector_size)) [ subblock ]
>>= function
| `Error e ->
Lwt_log_core.debug_f ~section "Error while writing: %s; returning EIO error" (Block_error_printer.to_string e) >>= fun () ->
error t handle `EIO
| `Ok () ->
let remaining = remaining - n in
if remaining > 0
then copy Int64.(add offset (of_int n)) remaining
else ok t handle None >>= fun () -> loop () in
copy from (Int32.to_int request.Request.len)
begin
if read_only
then error t handle `EPERM
else if Int64.(rem from (of_int info.Block.sector_size)) <> 0L || Int64.(rem (of_int32 len) (of_int info.Block.sector_size) <> 0L)
then error t handle `EINVAL
else begin
let rec copy offset remaining =
let n = min block_size remaining in
let subblock = Cstruct.sub block 0 n in
t.channel.Channel.read subblock
>>= fun () ->
Block.write b Int64.(div offset (of_int info.Block.sector_size)) [ subblock ]
>>= function
| `Error e ->
Lwt_log_core.debug_f ~section "Error while writing: %s; returning EIO error" (Block_error_printer.to_string e) >>= fun () ->
error t handle `EIO
| `Ok () ->
let remaining = remaining - n in
if remaining > 0
then copy Int64.(add offset (of_int n)) remaining
else ok t handle None in
copy from (Int32.to_int request.Request.len)
end
end
>>= loop
| { ty = Command.Read; from; len; handle } ->
(* It is okay to disconnect here in case of errors. The NBD protocol
documentation says about NBD_CMD_READ:
"If an error occurs, the server SHOULD set the appropriate error code
in the error field. The server MAY then initiate a hard disconnect.
If it chooses not to, it MUST NOT send any payload for this request.
If an error occurs while reading after the server has already sent out
the reply header with an error field set to zero (i.e., signalling no
error), the server MUST immediately initiate a hard disconnect; it
MUST NOT send any further data to the client." *)
if Int64.(rem from (of_int info.Block.sector_size)) <> 0L || Int64.(rem (of_int32 len) (of_int info.Block.sector_size) <> 0L)
then error t handle `EINVAL
else begin
Expand All @@ -240,12 +258,6 @@ let serve t (type t) block (b:t) =
Block.read b Int64.(div offset (of_int info.Block.sector_size)) [ subblock ]
>>= function
| `Error e ->
(* The NBD protocol documentation says about NBD_CMD_READ:
"If an error occurs while reading after the server has already
sent out the reply header with an error field set to zero (i.e.,
signalling no error), the server MUST immediately initiate a
hard disconnect; it MUST NOT send any further data to the
client." *)
Lwt.fail_with (Printf.sprintf "Partial failure during a Block.read: %s; terminating the session" (Block_error_printer.to_string e))
| `Ok () ->
t.channel.write subblock
Expand Down
101 changes: 100 additions & 1 deletion lib_test/protocol_test.ml
Expand Up @@ -34,6 +34,7 @@ module TransmissionList = OUnitDiff.ListSimpleMake(TransmissionDiff)

let option_reply_magic_number = "\x00\x03\xe8\x89\x04\x55\x65\xa9"
let nbd_request_magic = "\x25\x60\x95\x13"
let nbd_reply_magic = "\x67\x44\x66\x98"

exception Failed_to_read_empty_stream

Expand Down Expand Up @@ -100,7 +101,7 @@ module V2_negotiation = struct

let v2_negotiation = v2_negotiation_start @ [
`Server, "\000\000\000\000\001\000\000\000"; (* size *)
`Server, "\000\000"; (* transmission flags *)
`Server, "\000\001"; (* transmission flags: NBD_FLAG_HAS_FLAGS (bit 0) *)
`Server, (String.make 124 '\000');
]

Expand Down Expand Up @@ -222,11 +223,109 @@ module V2_list_export_success = struct
)
end

module Cstruct_block : (V1_LWT.BLOCK with type t = Cstruct.t) = struct
type page_aligned_buffer = Cstruct.t
type error =
[ `Disconnected | `Is_read_only | `Unimplemented | `Unknown of string ]
type 'a io = 'a Lwt.t
type t = Cstruct.t
type id = Id
type info = { read_write : bool; sector_size : int; size_sectors : int64; }

let disconnect _ = Lwt.return_unit
let get_info contents = Lwt.return { read_write = true; sector_size = 1; size_sectors = (Cstruct.len contents |> Int64.of_int) }
let read contents sector_start buffers =
let sector_start = Int64.to_int sector_start in
List.fold_left
(fun contents buffer -> Cstruct.fillv [contents] buffer |> ignore; Cstruct.shift contents (Cstruct.len buffer))
(Cstruct.shift contents sector_start)
buffers
|> ignore; Lwt.return (`Ok ())
let write contents sector_start buffers =
let sector_start = Int64.to_int sector_start in
Cstruct.fillv buffers (Cstruct.shift contents sector_start)
|> ignore; Lwt.return (`Ok ())
end

module V2_read_only_test = struct

let test_block = (Cstruct.of_string "asdf")

let sequence = [
`Server, "NBDMAGIC";
`Server, "IHAVEOPT";
`Server, "\000\001"; (* handshake flags: NBD_FLAG_FIXED_NEWSTYLE *)
`Client, "\000\000\000\001"; (* client flags: NBD_FLAG_C_FIXED_NEWSTYLE *)

`Client, "IHAVEOPT";
`Client, "\000\000\000\001"; (* NBD_OPT_EXPORT_NAME *)
`Client, "\000\000\000\007"; (* length of export name *)
`Client, "export1";

`Server, "\000\000\000\000\000\000\000\004"; (* size: 4 bytes *)
`Server, "\000\003"; (* transmission flags: NBD_FLAG_READ_ONLY (bit 1) + NBD_FLAG_HAS_FLAGS (bit 0) *)
`Server, (String.make 124 '\000');
(* Now we've entered transmission mode *)

`Client, nbd_request_magic;
`Client, "\000\000"; (* command flags *)
`Client, "\000\000"; (* request type: NBD_CMD_READ *)
`Client, "\000\000\000\000\000\000\000\000"; (* handle: 4 bytes *)
`Client, "\000\000\000\000\000\000\000\001"; (* offset *)
`Client, "\000\000\000\002"; (* length *)

(* We're allowed to read from a read-only export *)
`Server, nbd_reply_magic;
`Server, "\000\000\000\000"; (* error: no error *)
`Server, "\000\000\000\000\000\000\000\000"; (* handle *)
`Server, "sd"; (* 2 bytes of data *)

`Client, nbd_request_magic;
`Client, "\000\000"; (* command flags *)
`Client, "\000\001"; (* request type: NBD_CMD_WRITE *)
`Client, "\000\000\000\000\000\000\000\001"; (* handle: 4 bytes *)
`Client, "\000\000\000\000\000\000\000\000"; (* offset *)
`Client, "\000\000\000\004"; (* length *)
(* The server should probably return the EPERM error immediately, and not
read any data associated with the write request, as the client should
recognize the error before transmitting the data, just like for EINVAL,
which is sent for unaligned requests. *)
(*`Client, "nope"; (* 4 bytes of data *)*)

(* However, we're not allowed to write to it *)
`Server, nbd_reply_magic;
`Server, "\000\000\000\001"; (* error: EPERM *)
`Server, "\000\000\000\000\000\000\000\001"; (* handle *)

`Client, nbd_request_magic;
`Client, "\000\000"; (* command flags *)
`Client, "\000\002"; (* request type: NBD_CMD_DISC *)
`Client, "\000\000\000\000\000\000\000\002"; (* handle: 4 bytes *)
`Client, "\000\000\000\000\000\000\000\000"; (* offset *)
`Client, "\000\000\000\000"; (* length *)
]

let server_test =
"Serve a read-only export and test that reads and writes are handled correctly."
>:: fun () ->
with_server_channel sequence (fun channel ->
let t =
Server.connect channel ()
>>= fun (export_name, svr) ->
OUnit.assert_equal ~msg:"The server did not receive the correct export name" "export1" export_name;
Server.serve svr ~read_only:true (module Cstruct_block) test_block
in
Lwt_main.run t
)

end

let tests =
"Nbd client tests" >:::
[ V2_negotiation.client_negotiation
; V2_negotiation.server_negotiation
; V2_list_export_disabled.client_list_disabled
; V2_list_export_disabled.server_list_disabled
; V2_list_export_success.client_list_success
; V2_read_only_test.server_test
]

0 comments on commit 6fdcd47

Please sign in to comment.