Skip to content

Commit

Permalink
Merge pull request #121 from gaborigloi/test
Browse files Browse the repository at this point in the history
Add server write and client<-> server interaction tests
  • Loading branch information
gaborigloi committed Mar 9, 2018
2 parents 247e7c2 + ee7169c commit 189fc96
Show file tree
Hide file tree
Showing 4 changed files with 224 additions and 24 deletions.
117 changes: 117 additions & 0 deletions lib_test/client_server_test.ml
@@ -0,0 +1,117 @@
(** This module tests the NBD server and the NBD client provided by the core
NBD library, by connecting their inputs and outputs together so that they
communicate with each other through an in-memory pipe.
The server and client IO is run in concurrently by Lwt, in the same
process. *)

open Lwt.Infix

let with_channels f =
let section = Lwt_log_core.Section.make("with_channels") in
let make_channel name (ic, oc) =
let read c =
let len = Cstruct.len c in
Lwt_log.debug_f ~section "%s read: %d" name len >>= fun () ->
let b = Bytes.create len in
Lwt_io.read_into_exactly ic b 0 len >>= fun () ->
Cstruct.blit_from_bytes b 0 c 0 len;
Lwt_log.debug_f ~section "%s read: %d: %s finished" name len (String.escaped (Cstruct.to_string c))
in
let write c =
let len = Cstruct.len c in
Lwt_log.debug_f ~section "%s write: %d: %s" name len (String.escaped (Cstruct.to_string c)) >>= fun () ->
Lwt_io.write_from_string_exactly oc (Cstruct.to_string c) 0 len >>= fun () ->
Lwt_log.debug_f ~section "%s write: %d: %s finished" name len (String.escaped (Cstruct.to_string c))
in
(write, read)
in
let client_to_server = Lwt_io.pipe () in
let server_to_client = Lwt_io.pipe () in
let client_write, server_read = make_channel "client -> server" client_to_server in
let server_write, client_read = make_channel "server -> client" server_to_client in
let noop () = Lwt.return_unit in
let client_channel =
Nbd.Channel.{ read=client_read; write=client_write; close=noop; is_tls=false }
in
let server_channel =
Nbd.Channel.{ read_clear=server_read; write_clear=server_write; close_clear=noop; make_tls_channel=None }
in
Lwt_unix.with_timeout 0.5 (fun () -> f client_channel server_channel)

(** Run the given server and client test sequences concurrently with channels
connecting the server and the client together. *)
let test ~server ~client () =
Lwt_log.add_rule "*" Lwt_log.Debug;
let t =
with_channels (fun client_channel server_channel ->
let test_server = server server_channel in
let cancel, _ = Lwt.task () in
let test_server =
Lwt.catch
(fun () -> Lwt.pick [test_server; cancel])
(function Lwt.Canceled -> Lwt.return_unit | e -> Lwt.fail e)
in
let test_client =
client client_channel
(* TODO: because Client.disconnect does not send NBD_CMD_DISC,
the server loop will not stop - we have to stop it manually.
Once this is fixed, this cancel mechanism should be removed. *)
>|= fun () -> Lwt.cancel cancel
in
Lwt.join [test_server; test_client]
)
in
Lwt_main.run t

(** We fail the test if an error occurs *)
let check msg =
function
| Result.Ok a -> Lwt.return a
| Result.Error _ -> Lwt.fail_with msg

let test_connect_disconnect =
let test_block = (Cstruct.of_string "asdf") in
test
~server:(fun server_channel ->
Nbd.Server.connect server_channel () >>= fun (export_name, svr) ->
Alcotest.(check string) "export name received by server"
"export1" export_name;
Nbd.Server.serve svr ~read_only:false (module Cstruct_block.Block) test_block
)
~client:(fun client_channel ->
Nbd.Client.negotiate client_channel "export1" >>= fun (t, size, flags) ->
Alcotest.(check int64) "size received by client"
(Int64.of_int (Cstruct.len test_block))
size;
Nbd.Client.disconnect t
)

let test_read_write =
let test_block = (Cstruct.of_string "asdf") in
test
~server:(fun server_channel ->
Nbd.Server.connect server_channel () >>= fun (export_name, svr) ->
Nbd.Server.serve svr ~read_only:false (module Cstruct_block.Block) test_block
)
~client:(fun client_channel ->
Nbd.Client.negotiate client_channel "export1" >>= fun (t, size, flags) ->

let buf = Cstruct.create 2 in
Nbd.Client.read t 1L [buf] >>= check "1st read failed" >>= fun () ->
Alcotest.(check string) "2 bytes at offset 1" "sd" (Cstruct.to_string buf);

let buf = Cstruct.of_string "12" in
Nbd.Client.write t 2L [buf] >>= check "Write failed" >>= fun () ->

let buf = Cstruct.create 2 in
Nbd.Client.read t 2L [buf] >>= check "2nd read failed" >>= fun () ->
Alcotest.(check string) "2 modified bytes at offset 2" "12" (Cstruct.to_string buf);

Nbd.Client.disconnect t
)

let tests =
"Nbd client-server connection tests",
[ "test_connect_disconnect", `Quick, test_connect_disconnect
; "test_read_write", `Quick, test_read_write
]
24 changes: 24 additions & 0 deletions lib_test/cstruct_block.ml
@@ -0,0 +1,24 @@

(** A Mirage block module backed by a Cstruct for unit testing the NBD server *)
module Block : (Mirage_block_lwt.S with type t = Cstruct.t) = struct
type page_aligned_buffer = Cstruct.t
type error = Mirage_block.error
let pp_error = Mirage_block.pp_error
type write_error = Mirage_block.write_error
let pp_write_error = Mirage_block.pp_write_error
type 'a io = 'a Lwt.t
type t = Cstruct.t
let disconnect _ = Lwt.return_unit
let get_info contents = Lwt.return Mirage_block.{ read_write = true; sector_size = 1; size_sectors = (Cstruct.len contents |> Int64.of_int) }
let read contents sector_start buffers =
let sector_start = Int64.to_int sector_start in
List.fold_left
(fun contents buffer -> Cstruct.fillv [contents] buffer |> ignore; Cstruct.shift contents (Cstruct.len buffer))
(Cstruct.shift contents sector_start)
buffers
|> ignore; Lwt.return_ok ()
let write contents sector_start buffers =
let sector_start = Int64.to_int sector_start in
Cstruct.fillv buffers (Cstruct.shift contents sector_start)
|> ignore; Lwt.return_ok ()
end
106 changes: 82 additions & 24 deletions lib_test/protocol_test.ml
Expand Up @@ -12,9 +12,14 @@
* GNU Lesser General Public License for more details.
*)

(** This module tests the core NBD library by verifying that the communication
between the client and the server exactly matches the specified test
sequences. *)

open Nbd
open Lwt.Infix

(** An Alcotest TESTABLE for the data transmissions in the test sequences *)
let transmission =
let fmt =
Fmt.of_to_string
Expand All @@ -30,8 +35,18 @@ let option_reply_magic_number = "\x00\x03\xe8\x89\x04\x55\x65\xa9"
let nbd_request_magic = "\x25\x60\x95\x13"
let nbd_reply_magic = "\x67\x44\x66\x98"

(** The client or server wanted to read and there is no more data from the
other side. *)
exception Failed_to_read_empty_stream

(** [make_channel role test_sequence] creates a channel for use by the NBD library
from a test sequence containing the expected communication between the
client and the server. Reads and writes will verify that the communication
matches exactly what is in [test_sequence], which is a list of data
transmission tuples, each specifying whether the client or the server is
sending the data, and the actual data sent. [role] specifies whether the
client or the server will use the created channel, the other side will be
simulated by taking the responses from [test_sequence]. *)
let make_channel role test_sequence =
let next = ref test_sequence in
let rec read buf =
Expand Down Expand Up @@ -68,12 +83,18 @@ let make_channel role test_sequence =
let assert_processed_complete_sequence () = Alcotest.(check (list transmission)) "processed complete sequence" [] !next in
(assert_processed_complete_sequence, (read, write, close))

(** Passes a channel for use by the NBD client to the given function, verifying
that all communcation matches the given test sequence and that the complete
sequence has been processed after the function returns. *)
let with_client_channel s f =
fun () ->
let (assert_processed_complete_sequence, (read, write, close)) = make_channel `Client s in
f Channel.{read; write; close; is_tls=false};
assert_processed_complete_sequence ()

(** Passes a channel for use by the NBD server to the given function, verifying
that all communcation matches the given test sequence and that the complete
sequence has been processed after the function returns. *)
let with_server_channel s f =
fun () ->
let (assert_processed_complete_sequence, (read, write, close)) = make_channel `Server s in
Expand Down Expand Up @@ -218,29 +239,6 @@ module V2_list_export_success = struct
)
end

module Cstruct_block : (Mirage_block_lwt.S with type t = Cstruct.t) = struct
type page_aligned_buffer = Cstruct.t
type error = Mirage_block.error
let pp_error = Mirage_block.pp_error
type write_error = Mirage_block.write_error
let pp_write_error = Mirage_block.pp_write_error
type 'a io = 'a Lwt.t
type t = Cstruct.t
let disconnect _ = Lwt.return_unit
let get_info contents = Lwt.return Mirage_block.{ read_write = true; sector_size = 1; size_sectors = (Cstruct.len contents |> Int64.of_int) }
let read contents sector_start buffers =
let sector_start = Int64.to_int sector_start in
List.fold_left
(fun contents buffer -> Cstruct.fillv [contents] buffer |> ignore; Cstruct.shift contents (Cstruct.len buffer))
(Cstruct.shift contents sector_start)
buffers
|> ignore; Lwt.return_ok ()
let write contents sector_start buffers =
let sector_start = Int64.to_int sector_start in
Cstruct.fillv buffers (Cstruct.shift contents sector_start)
|> ignore; Lwt.return_ok ()
end

module V2_read_only_test = struct

let test_block = (Cstruct.of_string "asdf")
Expand Down Expand Up @@ -313,13 +311,72 @@ module V2_read_only_test = struct
Server.connect channel ()
>>= fun (export_name, svr) ->
Alcotest.(check string) "The server did not receive the correct export name" "export1" export_name;
Server.serve svr ~read_only:true (module Cstruct_block) test_block
Server.serve svr ~read_only:true (module Cstruct_block.Block) test_block
in
Lwt_main.run t
)

end

module V2_write_test = struct

let test_block = (Cstruct.of_string "asdf")

let sequence = [
`Server, "NBDMAGIC";
`Server, "IHAVEOPT";
`Server, "\000\001"; (* handshake flags: NBD_FLAG_FIXED_NEWSTYLE *)
`Client, "\000\000\000\001"; (* client flags: NBD_FLAG_C_FIXED_NEWSTYLE *)

`Client, "IHAVEOPT";
`Client, "\000\000\000\001"; (* NBD_OPT_EXPORT_NAME *)
`Client, "\000\000\000\007"; (* length of export name *)
`Client, "export1";

`Server, "\000\000\000\000\000\000\000\004"; (* size: 4 bytes *)
`Server, "\000\001"; (* transmission flags: NBD_FLAG_HAS_FLAGS (bit 0) *)
`Server, (String.make 124 '\000');
(* Now we've entered transmission mode *)

`Client, nbd_request_magic;
`Client, "\000\000"; (* command flags *)
`Client, "\000\001"; (* request type: NBD_CMD_WRITE *)
`Client, "\000\000\000\000\000\000\000\001"; (* handle: 4 bytes *)
`Client, "\000\000\000\000\000\000\000\002"; (* offset *)
`Client, "\000\000\000\002"; (* length *)
`Client, "12"; (* 2 bytes of data *)

(* We're allowed to read from a read-only export *)
`Server, nbd_reply_magic;
`Server, "\000\000\000\000"; (* error: no error *)
`Server, "\000\000\000\000\000\000\000\001"; (* handle *)

`Client, nbd_request_magic;
`Client, "\000\000"; (* command flags *)
`Client, "\000\002"; (* request type: NBD_CMD_DISC *)
`Client, "\000\000\000\000\000\000\000\002"; (* handle: 4 bytes *)
`Client, "\000\000\000\000\000\000\000\000"; (* offset *)
`Client, "\000\000\000\000"; (* length *)
]

let server_test =
"Serve a read-write export and test that writes are handled correctly.",
`Quick,
with_server_channel sequence (fun channel ->
let t =
Server.connect channel ()
>>= fun (export_name, svr) ->
Alcotest.(check string) "The server did not receive the correct export name" "export1" export_name;
Server.serve svr ~read_only:false (module Cstruct_block.Block) test_block
in
Lwt_main.run t;
Alcotest.(check string) "Data written by server"
"as12"
(Cstruct.to_string test_block)
)

end

let tests =
"Nbd client tests",
[ V2_negotiation.client_negotiation
Expand All @@ -328,4 +385,5 @@ let tests =
; V2_list_export_disabled.server_list_disabled
; V2_list_export_success.client_list_success
; V2_read_only_test.server_test
; V2_write_test.server_test
]
1 change: 1 addition & 0 deletions lib_test/suite.ml
Expand Up @@ -4,4 +4,5 @@ let () =
"Nbd library test suite"
[ Mux_test.tests
; Protocol_test.tests
; Client_server_test.tests
]

0 comments on commit 189fc96

Please sign in to comment.