diff --git a/.github/CODE_OF_CONDUCT.md b/.github/CODE_OF_CONDUCT.md new file mode 100644 index 0000000..d02f050 --- /dev/null +++ b/.github/CODE_OF_CONDUCT.md @@ -0,0 +1,75 @@ +# Contributor Covenant Code of Conduct + +## Our Pledge + +In the interest of fostering an open and welcoming environment, we as +contributors and maintainers pledge to making participation in our project and +our community a harassment-free experience for everyone, regardless of age, body +size, disability, ethnicity, gender identity and expression, level of +experience, +education, socio-economic status, nationality, personal appearance, race, +religion, or sexual identity and orientation. + +## Our Standards + +Examples of behavior that contributes to creating a positive environment +include: + +- Using welcoming and inclusive language +- Being respectful of differing viewpoints and experiences +- Gracefully accepting constructive criticism +- Focusing on what is best for the community +- Showing empathy towards other community members + +Examples of unacceptable behavior by participants include: + +- The use of sexualized language or imagery and unwelcome sexual attention or +advances +- Trolling, insulting/derogatory comments, and personal or political attacks +- Public or private harassment +- Publishing others' private information, such as a physical or electronic +address, without explicit permission +- Other conduct which could reasonably be considered inappropriate in a +professional setting + + +## Our Responsibilities + +Project maintainers are responsible for clarifying the standards of acceptable +behavior and are expected to take appropriate and fair corrective action in +response to any instances of unacceptable behavior. + +Project maintainers have the right and responsibility to remove, edit, or +reject comments, commits, code, wiki edits, issues, and other contributions +that are not aligned to this Code of Conduct, or to ban temporarily or +permanently any contributor for other behaviors that they deem inappropriate, +threatening, offensive, or harmful. + +## Scope + +This Code of Conduct applies both within project spaces and in public spaces +when an individual is representing the project or its community. Examples of +representing a project or community include using an official project e-mail +address, posting via an official social media account, or acting as an appointed +representative at an online or offline event. Representation of a project may be +further defined and clarified by project maintainers. + +## Enforcement + +Instances of abusive, harassing, or otherwise unacceptable behavior may be +reported by contacting the project team at yoshuawuyts@gmail.com, or through +IRC. All complaints will be reviewed and investigated and will result in a +response that is deemed necessary and appropriate to the circumstances. The +project team is obligated to maintain confidentiality with regard to the +reporter of an incident. +Further details of specific enforcement policies may be posted separately. + +Project maintainers who do not follow or enforce the Code of Conduct in good +faith may face temporary or permanent repercussions as determined by other +members of the project's leadership. + +## Attribution + +This Code of Conduct is adapted from the Contributor Covenant, version 1.4, +available at +https://www.contributor-covenant.org/version/1/4/code-of-conduct.html diff --git a/.github/CONTRIBUTING.md b/.github/CONTRIBUTING.md new file mode 100644 index 0000000..96806b5 --- /dev/null +++ b/.github/CONTRIBUTING.md @@ -0,0 +1,63 @@ +# Contributing +Contributions include code, documentation, answering user questions, running the +project's infrastructure, and advocating for all types of users. + +The project welcomes all contributions from anyone willing to work in good faith +with other contributors and the community. No contribution is too small and all +contributions are valued. + +This guide explains the process for contributing to the project's GitHub +Repository. + +- [Code of Conduct](#code-of-conduct) +- [Bad Actors](#bad-actors) +- [Developer Certificate of Origin](#developer-certificate-of-origin) + +## Code of Conduct +The project has a [Code of Conduct][./CODE_OF_CONDUCT.md] that *all* +contributors are expected to follow. This code describes the *minimum* behavior +expectations for all contributors. + +As a contributor, how you choose to act and interact towards your +fellow contributors, as well as to the community, will reflect back not only +on yourself but on the project as a whole. The Code of Conduct is designed and +intended, above all else, to help establish a culture within the project that +allows anyone and everyone who wants to contribute to feel safe doing so. + +Should any individual act in any way that is considered in violation of the +[Code of Conduct][./CODE_OF_CONDUCT.md], corrective actions will be taken. It is +possible, however, for any individual to *act* in such a manner that is not in +violation of the strict letter of the Code of Conduct guidelines while still +going completely against the spirit of what that Code is intended to accomplish. + +Open, diverse, and inclusive communities live and die on the basis of trust. +Contributors can disagree with one another so long as they trust that those +disagreements are in good faith and everyone is working towards a common +goal. + +## Bad Actors +All contributors to tacitly agree to abide by both the letter and +spirit of the [Code of Conduct][./CODE_OF_CONDUCT.md]. Failure, or +unwillingness, to do so will result in contributions being respectfully +declined. + +A *bad actor* is someone who repeatedly violates the *spirit* of the Code of +Conduct through consistent failure to self-regulate the way in which they +interact with other contributors in the project. In doing so, bad actors +alienate other contributors, discourage collaboration, and generally reflect +poorly on the project as a whole. + +Being a bad actor may be intentional or unintentional. Typically, unintentional +bad behavior can be easily corrected by being quick to apologize and correct +course *even if you are not entirely convinced you need to*. Giving other +contributors the benefit of the doubt and having a sincere willingness to admit +that you *might* be wrong is critical for any successful open collaboration. + +Don't be a bad actor. + +## Developer Certificate of Origin +All contributors must read and agree to the [Developer Certificate of +Origin (DCO)](../CERTIFICATE). + +The DCO allows us to accept contributions from people to the project, similarly +to how a license allows us to distribute our code. diff --git a/.github/ISSUE_TEMPLATE.md b/.github/ISSUE_TEMPLATE.md new file mode 100644 index 0000000..ba20edd --- /dev/null +++ b/.github/ISSUE_TEMPLATE.md @@ -0,0 +1,9 @@ +## Summary +Explain what is going on. + +## Your Environment +| Software | Version(s) | +| ------------------ | ---------- | +| hypercore-protocol | +| Rustc | +| Operating System | diff --git a/.github/ISSUE_TEMPLATE/bug_report.md b/.github/ISSUE_TEMPLATE/bug_report.md new file mode 100644 index 0000000..47ff452 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/bug_report.md @@ -0,0 +1,23 @@ +--- +name: Bug Report +about: Did something not work as expected? +--- + +# Bug Report +## Your Environment +| Software | Version(s) | +| ------------------ | ---------- | +| hypercore-protocol | +| Rustc | +| Operating System | + +## Expected Behavior +Tell us what should have happened. + +## Current Behavior +Tell us what happens instead of the expected behavior. If you are seeing an +error, please include the full error message and stack trace. + +## Code Sample +Please provide a code repository, gist, code snippet or sample files to +reproduce the issue. diff --git a/.github/ISSUE_TEMPLATE/feature_request.md b/.github/ISSUE_TEMPLATE/feature_request.md new file mode 100644 index 0000000..2f86a30 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/feature_request.md @@ -0,0 +1,43 @@ +--- +name: Feature Request +about: Want us to add something to hypercore-protocol? +--- + +# Feature Request +## Summary +One paragraph explanation of the feature. + +## Motivation +Why are we doing this? What use cases does it support? What is the expected +outcome? + +## Guide-level explanation +Explain the proposal as if it was already included in the project and you +were teaching it to another programmer. That generally means: + +- Introducing new named concepts. +- Explaining the feature largely in terms of examples. +- If applicable, provide sample error messages, deprecation warnings, or + migration guidance. + +## Reference-level explanation +This is the technical portion of the feature request. Explain the design in +sufficient detail that: + +- Its interaction with other features is clear. +- It is reasonably clear how the feature would be implemented. +- Corner cases are dissected by example. + +## Drawbacks +Why should we _not_ do this? + +## Rationale and alternatives +- Why is this design the best in the space of possible designs? +- What other designs have been considered and what is the rationale for not + choosing them? +- What is the impact of not doing this? + +## Unresolved Questions +What related issues do you consider out of scope for this feature that could be +addressed in the future independently of the solution that comes out of this +feature? diff --git a/.github/ISSUE_TEMPLATE/question.md b/.github/ISSUE_TEMPLATE/question.md new file mode 100644 index 0000000..be188e6 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/question.md @@ -0,0 +1,18 @@ +--- +name: Question +about: Have any questions regarding how hypercore-protocol works? +--- + +# Question +## Your Environment +| Software | Version(s) | +| ------------------ | ---------- | +| hypercore-protocol | +| Rustc | +| Operating System | + +## Question +Provide your question here. + +## Context +How has this issue affected you? What are you trying to accomplish? diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md new file mode 100644 index 0000000..d820588 --- /dev/null +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -0,0 +1,21 @@ + + +**Choose one:** is this a 🐛 bug fix, a 🙋 feature, or a 🔦 documentation change? + + + +## Checklist + +- [ ] tests pass +- [ ] tests and/or benchmarks are included +- [ ] documentation is changed or added + +## Context + + +## Semver Changes + diff --git a/.github/stale.yml b/.github/stale.yml new file mode 100644 index 0000000..b8550d0 --- /dev/null +++ b/.github/stale.yml @@ -0,0 +1,17 @@ +# Configuration for probot-stale - https://github.com/probot/stale + +daysUntilStale: 90 +daysUntilClose: 7 +exemptLabels: + - pinned + - security +exemptProjects: false +exemptMilestones: false +staleLabel: wontfix +markComment: > + This issue has been automatically marked as stale because it has not had + recent activity. It will be closed if no further activity occurs. Thank you + for your contributions. +unmarkComment: false +closeComment: false +limitPerRun: 30 diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..07913c6 --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,72 @@ +name: 'CI' +on: + pull_request: + push: + branches: + - master + +env: + RUST_BACKTRACE: 1 + CARGO_TERM_COLOR: always + +jobs: + ci-pass: + name: CI is green + runs-on: ubuntu-latest + needs: + - test + - build-extra + - lint + steps: + - run: exit 0 + + test: + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v3 + - uses: dtolnay/rust-toolchain@stable + with: + components: rustfmt + - name: Run tests + run: | + cargo check --no-default-features --features tokio + cargo check --no-default-features --features async-std + cargo test --no-default-features --features js_interop_tests,tokio + cargo test --no-default-features --features js_interop_tests,async-std + cargo test --benches + + build-extra: + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v3 + - uses: dtolnay/rust-toolchain@stable + with: + targets: wasm32-unknown-unknown + - name: Build WASM + run: | + cargo build --target=wasm32-unknown-unknown --no-default-features --features wasm-bindgen,tokio + cargo build --target=wasm32-unknown-unknown --no-default-features --features wasm-bindgen,async-std + - name: Build release + run: | + cargo build --release --no-default-features --features tokio + cargo build --release --no-default-features --features async-std + - name: Build examples + run: | + cargo build --example replication + + lint: + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v3 + - uses: dtolnay/rust-toolchain@stable + with: + components: rustfmt, clippy + - uses: actions-rs/clippy-check@v1 + with: + token: ${{ secrets.GITHUB_TOKEN }} + - name: Format check + run: | + cargo fmt -- --check diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml deleted file mode 100644 index b3c3ef3..0000000 --- a/.github/workflows/rust.yml +++ /dev/null @@ -1,24 +0,0 @@ -name: Rust - -on: - push: - branches: [ master ] - pull_request: - branches: [ master ] - -env: - CARGO_TERM_COLOR: always - -jobs: - build: - - runs-on: ubuntu-latest - - steps: - - uses: actions/checkout@v2 - - name: Build - run: cargo build --verbose - - name: Run tests - run: cargo test --verbose - - name: Run clippy - run: cargo clippy diff --git a/.gitignore b/.gitignore index 3b2d033..e9c85c7 100644 --- a/.gitignore +++ b/.gitignore @@ -2,3 +2,6 @@ Cargo.lock SANDBOX flamegraph.svg +tests/js/work +package-lock.json +node_modules diff --git a/Cargo.toml b/Cargo.toml index f0e4eab..01f4c50 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,13 +1,16 @@ [package] name = "hypercore-protocol" -version = "0.3.1" +version = "0.4.0-alpha.7" license = "MIT OR Apache-2.0" description = "Replication protocol for Hypercore feeds" -authors = ["Franz Heinzmann (Frando) "] +authors = [ + "Franz Heinzmann (Frando) ", + "Timo Tiuraniemi " +] documentation = "https://docs.rs/hypercore-protocol" -repository = "https://github.com/Frando/hypercore-protocol-rs" +repository = "https://github.com/datrs/hypercore-protocol-rs" readme = "README.md" -edition = "2018" +edition = "2021" keywords = ["dat", "p2p", "replication", "hypercore", "protocol"] categories = [ "asynchronous", @@ -22,47 +25,54 @@ categories = [ bench = false [dependencies] -async-channel = "1.5" -snow = { version = "0.7.0-alpha5", features = ["risky-raw-split"] } -prost = "0.7" -bytes = "1.0.1" -varinteger = "1.0" -rand = "0.7" -blake2-rfc = "0.2" +async-channel = "1" +snow = { version = "0.9", features = ["risky-raw-split"] } +bytes = "1" +rand = "0.8" +blake2 = "0.10" hex = "0.4" async-trait = "0.1" -salsa20 = "0.6" -log = "0.4" +tracing = "0.1" pretty-hash = "0.4" -futures-timer = "3.0" -instant = "0.1" -getrandom = "0.1" -futures-lite = "1.11.3" +futures-timer = "3" +futures-lite = "1" +hypercore = { version = "0.12", default-features = false } +sha2 = "0.10" +curve25519-dalek = "4" +crypto_secretstream = "0.2" [dev-dependencies] -async-std = { version = "1.9.0", features = ["attributes", "unstable"] } +async-std = { version = "1.12.0", features = ["attributes", "unstable"] } +async-compat = "0.2.1" +tokio = { version = "1.27.0", features = ["macros", "net", "process", "rt", "rt-multi-thread", "sync", "time"] } env_logger = "0.7.1" -# hypercore from master branch as of 2021-03-03 -hypercore = { git = "https://github.com/datrs/hypercore", rev = "8d8cbef8a884a70e8d12d80968c1d97be2ceea0b" } -random-access-disk = "2.0.0" -random-access-memory = "2.0.0" -random-access-storage = "4.0.0" +random-access-storage = "5.0.0" +random-access-disk = { version = "3.0.0", default-features = false } +random-access-memory = "3.0.0" anyhow = "1.0.28" -criterion = "0.3.2" +instant = "0.1" +criterion = { version = "0.4", features = ["async_std"] } pretty-bytes = "0.2.2" duplexify = "1.1.0" sluice = "0.5.4" futures = "0.3.13" - -[build-dependencies] -prost-build = "0.6.1" +log = "0.4" +test-log = { version = "0.2.11", default-features = false, features = ["trace"] } +tracing-subscriber = { version = "0.3.16", features = ["env-filter", "fmt"] } [features] +default = ["async-std", "sparse"] wasm-bindgen = [ - "instant/wasm-bindgen", - "getrandom/wasm-bindgen", "futures-timer/wasm-bindgen" ] +sparse = ["hypercore/sparse"] +cache = ["hypercore/cache"] +tokio = ["hypercore/tokio", "random-access-disk/tokio"] +async-std = ["hypercore/async-std", "random-access-disk/async-std"] +# Used only in interoperability tests under tests/js-interop which use the javascript version of hypercore +# to verify that this crate works. To run them, use: +# cargo test --features js_interop_tests +js_interop_tests = [] [profile.bench] # debug = true diff --git a/README.md b/README.md index 2966743..b8ed180 100644 --- a/README.md +++ b/README.md @@ -1,105 +1,108 @@ -

hypercore-protocol

-
- - Rust implementation of the Hypercore wire protocol - -
- -
- -
- - - Crates.io version - - - - Download - - - - docs.rs docs - -
- -
-

- - API Docs - - | - - Contributing - -

-
- -**NOTE**: The master branch currently only works with the old hypercore version 9. -For ongoing work to support the latest version 10 of hypercore [see the v10 branch](https://github.com/datrs/hypercore-protocol-rs/tree/v10). - -This crate provides a low-level streaming API to hypercore-protocol and exposes an interface that should make it easy to implement actual protocol logic on top. This crate targets Hypercore 9 (Dat 2) only. - -It uses [async-std](https://async.rs) for async IO, and [snow](https://github.com/mcginty/snow) for the Noise handshake. - -Current features are: - -* Complete the Noise handshake and set up the transport encryption -* Open channels with a key -* Accept channels opened by the remote end if your end knows the key -* Create and verify capability hashes -* Send and receive all protocol messages -* Register and use protocol extensions - -*We're actively looking for contributors to the datrust development! If you're interested, say hi in the `#rust` channel on the [Hypercore Protocol Discord](https://chat.hypercore-protocol.org/) :-)* +# Hypercore Protocol +[![crates.io version][1]][2] [![build status][3]][4] +[![downloads][5]][6] [![docs.rs docs][7]][8] + +Hypercore protocol is a streaming, message based protocol. This is a Rust port of +the wire protocol implementation in +[the original Javascript version](https://github.com/holepunchto/hypercore). This +crate targets the Hypercore LTS version 10. + +This crate provides a low-level streaming API to hypercore-protocol and exposes an +interface that should make it easy to implement actual protocol logic on top. + +This crate uses either [async-std](https://async.rs) or [tokio](https://tokio.rs/) +for async IO, [snow](https://github.com/mcginty/snow) for the Noise handshake and +[RustCrypto's crypto_secretsteram](https://github.com/RustCrypto/nacl-compat/tree/master/crypto_secretstream) +for encryption. + +## Features + +- [x] Complete the Noise handshake +- [x] Establish libsodium's `crypto_secretstream`. +- [x] Open channels with a key +- [x] Accept channels opened by the remote end if your end knows the key +- [x] Create and verify capability hashes +- [x] Send and receive all protocol messages +- [x] Support `async-std` or `tokio` runtimes +- [x] Support WASM +- [x] Test Javascript interoperability +- [ ] Support the new [manifest](https://github.com/holepunchto/hypercore/blob/main/lib/manifest.js) in the wire protocol to remain compatible with upcoming v11 +- [ ] Finalize documentation and release v1.0.0 + +## Installation + +```bash +cargo add hypercore-protocol +``` ## Examples These examples sync data between Rust and NodeJS hypercore-protocol implementations. To prepare, run +```bash +cd examples-nodejs && npm install && cd .. ``` -cd examples-nodejs -npm install + +### [replication.rs](examples/replication.rs) + +Runs the `replication.rs` example by replicating a hypercore between Rust and Node hypercores and printing the result. + +* Node Server / Rust Client + +```bash +node examples-nodejs/run.js nodeServer ``` -### [hypercore.rs](examples/hypercore.rs) +* Rust Server / Node Client -`node examples-nodejs/run.js hypercore` +```bash +node examples-nodejs/run.js rustServer +``` -Runs the `hypercore.rs` example with a replication stream from NodeJS hypercore. The `hypercore.rs` example fetches all blocks of a Node.js hypercore and inserts them into a Rust in-memory hypercore. +* Rust Server / Rust Client -### [basic.rs](examples/basic.rs) +```bash +node examples-nodejs/run.js rust +``` -Accepts a hypercore-protocol stream and fetches all blocks of the first hypercore. +* Node Server / Node Client -`node examples-nodejs/run.js basic` +```bash +node examples-nodejs/run.js node +``` -Runs the `basic.rs` example with a replication stream from NodeJS hypercore. The `basic.rs` example fetches all blocks of a hypercore and prints them to STDOUT. +## Development -* Share a file over a hypercore on a local TCP server. Prints a hypercore key. - `node examples-nodejs/replicate.js server 8000 ./README.md` +To test interoperability with Javascript, enable the `js_interop_tests` feature: -* Use this key to connect from Rust and pipe the file content to stdout: - `cargo run --example basic -- server 8000 KEY` +```bash +cargo test --features js_interop_tests +``` +Run benches with: + +```bash +cargo bench +``` ## Contributing -We're actively looking for contributors to the datrust development! +We're actively looking for contributors to the datrust development! If you're interested, the +easiest is to say hi in the `#rust` channel on the +[Hypercore Protocol Discord](https://chat.hypercore-protocol.org/). -If you're interested, the easiest is to say hi in the `#rust` channel on the [Hypercore Protocol Discord](https://chat.hypercore-protocol.org/). +Want to help with Hypercore Protocol? Check out our +["Contributing" guide](https://github.com/datrs/hypercore-protocol-rs/blob/master/.github/CONTRIBUTING.md) +and take a look at the open [issues](https://github.com/datrs/hypercore-protocol-rs/issues). -Contributions include pull requests, issue reports, documentation, design -and other work that benefits this project. +## License -This project is welcoming contributions from anyone who acts in good faith! -We do not tolerate toxic behavior or discriminations against other contributors. -People who engage with this project in bad faith or fail to reflect and change -harmful behavior may be excluded from contributing. Should you feel that someone -acted in such a way, please reach out to the authors of this project. +[MIT](./LICENSE-MIT) OR [Apache-2.0](./LICENSE-APACHE) -Open, diverse, and inclusive communities live and die on the basis of trust. -Contributors can disagree with one another so long as they trust that those -disagreements are in good faith and everyone is working towards a common -goal. +[1]: https://img.shields.io/crates/v/hypercore-protocol.svg?style=flat-square +[2]: https://crates.io/crates/hypercore-protocol +[3]: https://github.com/datrs/hypercore-protocol-rs/actions/workflows/ci.yml/badge.svg +[4]: https://github.com/datrs/hypercore-protocol-rs/actions +[5]: https://img.shields.io/crates/d/hypercore-protocol.svg?style=flat-square +[6]: https://crates.io/crates/hypercore-protocol +[7]: https://img.shields.io/badge/docs-latest-blue.svg?style=flat-square +[8]: https://docs.rs/hypercore-protocol diff --git a/benches/pipe.rs b/benches/pipe.rs index c793313..630146c 100644 --- a/benches/pipe.rs +++ b/benches/pipe.rs @@ -2,7 +2,7 @@ use async_std::task; use criterion::{criterion_group, criterion_main, Criterion, Throughput}; use futures::io::{AsyncRead, AsyncWrite}; use futures::stream::StreamExt; -use hypercore_protocol::schema::*; +use hypercore_protocol::{schema::*, Duplex}; use hypercore_protocol::{Channel, Event, Message, Protocol, ProtocolBuilder}; use log::*; use pretty_bytes::converter::convert as pretty_bytes; @@ -44,10 +44,10 @@ async fn run_echo(i: u64) -> Result<()> { let encrypted = true; let a = ProtocolBuilder::new(true) - .set_encrypted(encrypted) + .encrypted(encrypted) .connect_rw(ar, aw); let b = ProtocolBuilder::new(false) - .set_encrypted(encrypted) + .encrypted(encrypted) .connect_rw(br, bw); let ta = task::spawn(async move { onconnection(i, a).await }); let tb = task::spawn(async move { onconnection(i, b).await }); @@ -58,7 +58,7 @@ async fn run_echo(i: u64) -> Result<()> { // The onconnection handler is called for each incoming connection (if server) // or once when connected (if client). -async fn onconnection(i: u64, mut protocol: Protocol) -> Result +async fn onconnection(i: u64, mut protocol: Protocol>) -> Result where R: AsyncRead + Send + Unpin + 'static, W: AsyncWrite + Send + Unpin + 'static, @@ -104,8 +104,8 @@ async fn on_channel_resp(_i: u64, mut channel: Channel) -> Result { while let Some(message) = channel.next().await { match message { Message::Data(ref data) => { - len += data.value.as_ref().map_or(0, |v| v.len() as u64); - debug!("[b] echo {}", data.index); + len += value_len(data); + debug!("[b] echo {}", index(data)); channel.send(message).await?; } Message::Close(_) => { @@ -129,18 +129,14 @@ async fn on_channel_init(i: u64, mut channel: Channel) -> Result { while let Some(message) = channel.next().await { match message { Message::Data(mut data) => { - len += data.value.as_ref().map_or(0, |v| v.len() as u64); - debug!("[a] recv {}", data.index); - if data.index >= COUNT { - debug!("close at {}", data.index); - channel - .send(Message::Close(Close { - discovery_key: None, - })) - .await?; + len += value_len(&data); + debug!("[a] recv {}", index(&data)); + if index(&data) >= COUNT { + debug!("close at {}", index(&data)); + channel.close().await?; break; } else { - data.index += 1; + increment_index(&mut data); channel.send(Message::Data(data)).await?; } } @@ -153,14 +149,34 @@ async fn on_channel_init(i: u64, mut channel: Channel) -> Result { } fn msg_data(index: u64, value: Vec) -> Message { + use hypercore::DataBlock; + Message::Data(Data { - index, - value: Some(value), - nodes: vec![], - signature: None, + request: index, + fork: 0, + block: Some(DataBlock { + index, + value, + nodes: vec![], + }), + hash: None, + seek: None, + upgrade: None, }) } +fn index(data: &Data) -> u64 { + data.request +} + +fn increment_index(data: &mut Data) { + data.request += 1; +} + +fn value_len(data: &Data) -> u64 { + data.block.as_ref().map_or(0, |b| b.value.len() as u64) +} + fn print_stats(msg: impl ToString, instant: Instant, bytes: f64) { let msg = msg.to_string(); let time = instant.elapsed(); diff --git a/benches/throughput.rs b/benches/throughput.rs index ad5d07e..76d6874 100644 --- a/benches/throughput.rs +++ b/benches/throughput.rs @@ -4,7 +4,7 @@ use criterion::{criterion_group, criterion_main, Criterion, Throughput}; use futures::future::Either; use futures::io::{AsyncRead, AsyncWrite}; use futures::stream::{FuturesUnordered, StreamExt}; -use hypercore_protocol::schema::*; +use hypercore_protocol::{schema::*, Duplex}; use hypercore_protocol::{Channel, Event, Message, ProtocolBuilder}; use log::*; use std::time::Instant; @@ -88,14 +88,14 @@ async fn start_server(address: &str) -> futures::channel::oneshot::Sender<()> { kill_tx } -async fn onconnection(reader: R, writer: W, is_initiator: bool) -> (R, W) +async fn onconnection(reader: R, writer: W, is_initiator: bool) -> Duplex where R: AsyncRead + Send + Unpin + 'static, W: AsyncWrite + Send + Unpin + 'static, { let key = [0u8; 32]; let mut protocol = ProtocolBuilder::new(is_initiator) - .set_encrypted(false) + .encrypted(false) .connect_rw(reader, writer); while let Some(Ok(event)) = protocol.next().await { // eprintln!("RECV EVENT [{}] {:?}", protocol.is_initiator(), event); @@ -122,11 +122,7 @@ async fn onchannel(mut channel: Channel, is_initiator: bool) { } else { channel_server(&mut channel).await } - let _res = channel - .send(Message::Close(Close { - discovery_key: None, - })) - .await; + let _res = channel.close().await; } async fn channel_server(channel: &mut Channel) { @@ -141,30 +137,20 @@ async fn channel_server(channel: &mut Channel) { async fn channel_client(channel: &mut Channel) { let data = vec![0u8; SIZE as usize]; let start = Instant::now(); - let message = Message::Data(Data { - index: 0, - value: Some(data.clone()), - nodes: vec![], - signature: None, - }); + let message = msg_data(0, data.clone()); channel.send(message).await.unwrap(); while let Some(message) = channel.next().await { match message { - Message::Data(msg) => { - if msg.index < COUNT { - let message = Message::Data(Data { - index: msg.index + 1, - value: Some(data.clone()), - nodes: vec![], - signature: None, - }); + Message::Data(ref msg) => { + if index(msg) < COUNT { + let message = msg_data(index(msg) + 1, data.clone()); channel.send(message).await.unwrap(); } else { let time = start.elapsed(); let bytes = COUNT * SIZE; trace!( "client completed. {} blocks, {} bytes, {:?}", - msg.index, + index(msg), bytes, time ); @@ -175,3 +161,24 @@ async fn channel_client(channel: &mut Channel) { } } } + +fn msg_data(index: u64, value: Vec) -> Message { + use hypercore::DataBlock; + + Message::Data(Data { + request: index, + fork: 0, + block: Some(DataBlock { + index, + value, + nodes: vec![], + }), + hash: None, + seek: None, + upgrade: None, + }) +} + +fn index(msg: &Data) -> u64 { + msg.request +} diff --git a/build.rs b/build.rs deleted file mode 100644 index 8b00481..0000000 --- a/build.rs +++ /dev/null @@ -1,3 +0,0 @@ -fn main() { - prost_build::compile_protos(&["src/schema.proto"], &["src/"]).unwrap(); -} diff --git a/examples-nodejs/.gitignore b/examples-nodejs/.gitignore deleted file mode 100644 index f846c68..0000000 --- a/examples-nodejs/.gitignore +++ /dev/null @@ -1,3 +0,0 @@ -node_modules -package-lock.json -yarn.lock diff --git a/examples-nodejs/bench-echo-client.js b/examples-nodejs/bench-echo-client.js deleted file mode 100644 index d6deb40..0000000 --- a/examples-nodejs/bench-echo-client.js +++ /dev/null @@ -1,76 +0,0 @@ -const PORT = 11011 -const KEY = Buffer.alloc(32, 0) - -const net = require('net') -const Protocol = require('hypercore-protocol') -const pretty = require('pretty-bytes') - -const COUNT = 100000 -const SIZE = 1000 -const data = Buffer.alloc(SIZE, 0) -const conns = 1 - -for (let i = 0; i < conns; i++) { - const socket = net.connect(PORT) - onconnection(socket, i) -} - -function onconnection (socket, i) { - const proto = new Protocol(true) - socket.pipe(proto).pipe(socket) - const timer = clock() - const channel = proto.open(KEY, { - onopen () { - channel.data({ - value: data, - index: 0 - }) - }, - ondata (msg) { - if (msg.index < COUNT) { - channel.data({ - value: data, - index: msg.index + 1 - }) - // if (msg.index % 10000 === 0) { - // console.log('done', msg.index) - // } - } else { - const bytes = COUNT * SIZE - const time = timer() - console.log('bytes:', pretty(bytes)) - console.log('time:', formatTime(time)) - console.log('throughput', throughput(time, bytes)) - process.exit(0) - } - } - }) -} - -function clock () { - const [ss, sn] = process.hrtime() - return () => { - const [ds, dn] = process.hrtime([ss, sn]) - const ns = (ds * 1e9) + dn - return ns - } -} - -function formatTime (ns) { - const ms = round(ns / 1e6) - const s = round(ms / 1e3) - let time - if (s >= 1) time = s + 's' - else if (ms >= 0.01) time = ms + 'ms' - else if (ns) time = ns + 'ns' - return time -} - -function throughput (ns, bytes) { - const bytespers = pretty(bytes / (ns / 1e9)) - return `${bytespers}/s` -} - -function round (num, decimals = 2) { - return Math.round(num * Math.pow(10, decimals)) / Math.pow(10, decimals) -} diff --git a/examples-nodejs/bench-echo-server.js b/examples-nodejs/bench-echo-server.js deleted file mode 100644 index b81570e..0000000 --- a/examples-nodejs/bench-echo-server.js +++ /dev/null @@ -1,22 +0,0 @@ -const PORT = 11011 -const KEY = Buffer.alloc(32, 0) - -const net = require('net') -const Protocol = require('hypercore-protocol') - -net.createServer(onconnection).listen(PORT, () => { - console.log('listening on localhost:%s', PORT) -}) -function onconnection (socket) { - console.log('new connection from %s:%s', socket.remoteAddress, socket.remotePort) - socket.on('end', () => { - console.log('connection closed from %s:%s', socket.remoteAddress, socket.remotePort) - }) - const proto = new Protocol(false) - socket.pipe(proto).pipe(socket) - const channel = proto.open(KEY, { - ondata (msg) { - channel.data(msg) - } - }) -} diff --git a/examples-nodejs/bench-tcp.js b/examples-nodejs/bench-tcp.js deleted file mode 100644 index b460aae..0000000 --- a/examples-nodejs/bench-tcp.js +++ /dev/null @@ -1,77 +0,0 @@ -const net = require('net') -const pretty = require('pretty-bytes') - -const SIZE = 1000 -const COUNT = 1000 -const PORT = 12345 -const ITERS = 100 - -net.createServer(socket => { - socket.pipe(socket) -}).listen(PORT, () => { - const timer = clock() - let total = 0 - let i = 0 - next() - function next (time) { - if (++i <= ITERS) process.nextTick(echobench, i, next) - else done() - } - function done () { - console.log(`finish ${ITERS} iterations, each ${COUNT} * ${pretty(SIZE)}`) - console.log(formatTime(timer(), SIZE * COUNT * ITERS)) - process.exit(0) - } -}) - -function echobench (j, cb) { - const timer = clock() - const socket = net.connect(PORT) - const data = Buffer.alloc(SIZE, 1) - // let result = Buffer.alloc(COUNT * SIZE, 0) - let offset = 0 - let i = 0 - socket.on('data', ondata) - write() - function ondata (buf) { - // result.copy(buf, offset) - // console.log(j, offset, buf.length, buf.slice(buf.length - 2)) - offset += buf.length - // console.log(COUNT * SIZE - offset) - if (offset >= COUNT * SIZE) { - // console.log('done') - // socket.removeListener('data', ondata) - cb(timer()) - } - } - function write () { - socket.write(data) - // console.log(j, 'written', i * data.length) - if (++i < COUNT) process.nextTick(write) - // else console.log(j, 'written', data.length * i) - } -} - -function clock () { - const [ss, sn] = process.hrtime() - return () => { - const [ds, dn] = process.hrtime([ss, sn]) - const ns = (ds * 1e9) + dn - return ns - } -} - -function formatTime (ns, bytes) { - const ms = round(ns / 1e6) - const s = round(ms / 1e3) - const bytespers = pretty(bytes / (ns / 1e9)) - let time - if (s >= 1) time = s + 's' - else if (ms >= 0.01) time = ms + 'ms' - else if (ns) time = ns + 'ns' - return `${time} ${bytespers}/s` -} - -function round (num, decimals = 2) { - return Math.round(num * Math.pow(10, decimals)) / Math.pow(10, decimals) -} diff --git a/examples-nodejs/debug-message.js b/examples-nodejs/debug-message.js deleted file mode 100644 index dffee4d..0000000 --- a/examples-nodejs/debug-message.js +++ /dev/null @@ -1,22 +0,0 @@ -const fs = require('fs') -const { Request } = require('simple-hypercore-protocol/messages.js') - -encode('n') -decode('r') - -function encode (n) { - const msg1 = Request.encode({ - index: 127 - }) - const msg2 = Request.encode({ - index: 128 - }) - fs.writeFileSync('msg1' + n, msg1) - fs.writeFileSync('msg2' + n, msg2) -} -function decode (n) { - const buf1 = fs.readFileSync('msg1' + n) - const buf2 = fs.readFileSync('msg2' + n) - console.log(Request.decode(buf1)) - console.log(Request.decode(buf2)) -} diff --git a/examples-nodejs/extension.js b/examples-nodejs/extension.js deleted file mode 100644 index a7cb5a7..0000000 --- a/examples-nodejs/extension.js +++ /dev/null @@ -1,131 +0,0 @@ -const RAM = require('random-access-memory') -const net = require('net') -const pretty = require('pretty-bytes') -const hypercore = require('hypercore') -const { Duplex } = require('streamx') - -let n = 5 -for (let i = 0; i < n; i++) { - main(9000 + i).catch(console.error) -} - -async function main (port) { - const feedA = hypercore(RAM) - await new Promise(resolve => feedA.ready(resolve)) - const feedB = hypercore(RAM, feedA.key) - await new Promise(resolve => feedB.ready(resolve)) - - const server = net.createServer(socket => { - const proto = feedA.replicate(false, { live: true }) - socket.pipe(proto).pipe(socket) - }) - - await new Promise((resolve, reject) => { - server.listen(port, err => { - if (err) return reject(err) - const socket = net.connect(port, resolve) - socket.once('error', reject) - const proto = feedB.replicate(true, { live: true }) - socket.pipe(proto).pipe(socket) - }) - }) - console.log('connected') - - - // const protoA = feedA.replicate(true, { live: true }) - // const protoB = feedB.replicate(false, { live: true }) - // protoA.pipe(protoB).pipe(protoA) - - - - const extA = streamExtension(feedA, 'ext') - const extB = streamExtension(feedB, 'ext') - - let limit = 1024 * 1024 * 64 - // let limit = 1024 * 64 * 9 - const timer = clock() - - process.nextTick(() => { - let len = 0 - extB.on('data', buf => { - len += buf.length - // console.log('B recv', buf.length, len) - extB.write(buf) - }) - }) - - process.nextTick(() => { - let buf = Buffer.alloc(1024 * 64, 0) - let len = 0 - - next() - function next () { - extA.write(buf) - len += buf.length - if (len < limit + 1) { - setImmediate(next) - // setTimeout(next, 0) - } - } - }) - - let printed = false - await new Promise(resolve => { - let len = 0 - extA.on('data', buf => { - len += buf.length - // console.log('A recv', buf.length, len, limit) - if (len >= limit) { - if (!printed) done() - resolve() - } - }) - }) - - function done () { - printed = true - // console.log('written: ' + pretty(limit)) - const time = timer() - const formatted = formatTime(time, limit) - console.log(pretty(limit), formatted) - } -} - -function streamExtension (feed, name) { - const ext = feed.registerExtension(name, { - onmessage (message, peer) { - stream.push(message) - } - }) - const stream = new Duplex({ - write (data, cb) { - ext.broadcast(data) - cb() - } - }) - return stream -} - -function clock () { - const [ss, sn] = process.hrtime() - return () => { - const [ds, dn] = process.hrtime([ss, sn]) - const ns = (ds * 1e9) + dn - return ns - } -} - -function formatTime (ns, bytes) { - const ms = round(ns / 1e6) - const s = round(ms / 1e3) - const bytespers = pretty(bytes / (ns / 1e9)) - let time - if (s >= 1) time = s + 's' - else if (ms >= 0.01) time = ms + 'ms' - else if (ns) time = ns + 'ns' - return `${time} ${bytespers}/s` -} - -function round (num, decimals = 2) { - return Math.round(num * Math.pow(10, decimals)) / Math.pow(10, decimals) -} diff --git a/examples-nodejs/handshake.js b/examples-nodejs/handshake.js deleted file mode 100644 index bbc1c9f..0000000 --- a/examples-nodejs/handshake.js +++ /dev/null @@ -1,75 +0,0 @@ -const net = require('net') -const Protocol = require('hypercore-protocol') - -const KEY = Buffer.from('01234567890123456789012345678901') -const hostname = 'localhost' -let [mode, port] = process.argv.slice(2) -if (['client', 'server'].indexOf(mode) === -1 || !port) { - exit('usage: node index.js [client|server] PORT') -} - -start({ port, hostname, mode }) - -function start ({ port, hostname, mode }) { - const isInitiator = mode === 'client' - if (mode === 'client') { - const socket = net.connect(port, hostname) - onconnection({ socket, isInitiator }) - } else { - const server = net.createServer(socket => onconnection({ socket, isInitiator })) - server.listen(port, hostname, () => { - const { address, port } = server.address() - console.error(`server listening on ${address}:${port}`) - }) - } -} - -function onconnection (opts) { - const { socket, isInitiator } = opts - const { remoteAddress, remotePort } = socket - if (!isInitiator) { - console.error(`new connection from ${remoteAddress}:${remotePort}`) - } - socket.on('close', () => { - if (!isInitiator) { - console.error(`connection closed from ${remoteAddress}:${remotePort}`) - } else { - console.error('connection closed from server') - } - }) - - const proto = new Protocol(isInitiator, { noise: true, encrypted: false }) - - console.error('init protocol') - console.error('local public key: ', proto.publicKey) - - proto.pipe(socket).pipe(proto) - - proto.on('error', err => { - console.error('protocol error', err) - socket.destroy() - }) - proto.on('handshake', () => { - console.error('handshake finished') - console.error('remote public key:', proto.remotePublicKey) - console.error('noise handshake nonces', { - local: proto.state._payload.nonce, - remote: proto.state.remotePayload.nonce - }) - console.error('noise handshake split lengths:', { rx: proto.state._split.rx.length, tx: proto.state._split.tx.length }) - console.error('noise handshake split:', proto.state._split) - setTimeout(() => { - console.error('now open channel') - proto.open(KEY, { - onopen () { - console.error('channel opened!') - } - }) - }, 0) - }) -} - -function exit (msg) { - console.error(msg) - process.exit(1) -} diff --git a/examples-nodejs/package.json b/examples-nodejs/package.json index d29d43a..721786a 100644 --- a/examples-nodejs/package.json +++ b/examples-nodejs/package.json @@ -1,20 +1,14 @@ { - "name": "hypercore-protocol-nodejs", + "name": "hypercore-protocol-rs-nodejs", "version": "1.0.0", - "description": "", + "description": "hypercore-protocol-rs NodeJS example", "main": "index.js", - "scripts": { - "test": "echo \"Error: no test specified\" && exit 1" - }, - "keywords": [], - "author": "", - "license": "GPL-3.0", + "license": "MIT", "dependencies": { "chalk": "^4.0.0", - "hypercore": "^9.0", - "hypercore-protocol": "^8.0", + "hypercore": "^10", "pretty-bytes": "^5.3.0", - "random-access-memory": "^3.1.1", + "random-access-memory": "^6.0.0", "split2": "^3.1.1" } } diff --git a/examples-nodejs/replicate.js b/examples-nodejs/replicate.js index 4317d4a..6425b19 100644 --- a/examples-nodejs/replicate.js +++ b/examples-nodejs/replicate.js @@ -1,53 +1,29 @@ const net = require('net') -const Protocol = require('hypercore-protocol') -const hypercore = require('hypercore') -const ram = require('random-access-memory') -const { pipeline } = require('stream') -const fs = require('fs') -const p = require('path') -const os = require('os') -const split = require('split2') +const Hypercore = require('hypercore'); +const RAM = require('random-access-memory') -const hostname = 'localhost' -let [mode, port, keyOrFilename] = process.argv.slice(2) -if (['client', 'server'].indexOf(mode) === -1 || !port || !keyOrFilename) { - exit('usage: node index.js [client|server] PORT [KEY|FILENAME]') +const hostname = '127.0.0.1' +let [mode, port, key] = process.argv.slice(2) +if (['client', 'server'].indexOf(mode) === -1 || !port) { + exit('usage: node replicate.js [client|server] PORT (KEY)') } - -const KEY_REGEX = /^[d0-9a-f]{64}$/i -let key, filename -if (keyOrFilename.match(KEY_REGEX)) { - key = keyOrFilename -} else { - filename = keyOrFilename -} - -const feed = hypercore(ram, key) -feed.ready(() => { - console.log('KEY=' + feed.key.toString('hex')) +const hypercore = new Hypercore((_) => new RAM(), key) +hypercore.info().then((_info) => { + console.log('KEY=' + hypercore.key.toString('hex')) console.log() - if (feed.writable && filename) { - feed.append(['hi\n', 'ola\n', 'hello\n', 'mundo\n']) - // pipeline( - // fs.createReadStream(filename), - // split(), - // feed.createWriteStream(), - // err => { - // if (err) console.error('error importing file', err) - // else console.error('import done, new len %o, bytes %o', feed.length, feed.byteLength) - // } - // ) + if (hypercore.writable && !key) { + hypercore.append(['hi\n', 'ola\n', 'hello\n', 'mundo\n']) } }) const opts = { - feed, filename, mode, port, hostname + hypercore, mode, port, hostname } start(opts) function start (opts) { - const { port, hostname, mode, feed, filename } = opts + const { port, hostname, mode } = opts const isInitiator = mode === 'client' opts.isInitiator = isInitiator @@ -64,7 +40,7 @@ function start (opts) { } function onconnection (opts) { - const { socket, isInitiator, feed } = opts + const { socket, isInitiator, mode, hypercore } = opts const { remoteAddress, remotePort } = socket if (!isInitiator) { console.error(`new connection from ${remoteAddress}:${remotePort}`) @@ -77,36 +53,20 @@ function onconnection (opts) { } }) - // const proto = new Protocol(isInitiator, { noise: true, encrypted: false }) - feed.ready(() => { - let mode = feed.writable ? 'write' : 'read' - const proto = feed.replicate(isInitiator, { encrypted: true, live: true }) - - console.error('init protocol') - console.error('key', feed.key.toString('hex')) - - proto.pipe(socket).pipe(proto) - - proto.on('error', err => { - console.error('protocol error', err) - socket.destroy() - }) - - if (mode === 'write') { - // feed.append(feed.length) - // feed.append('hello') - // setTimeout(() => feed.append('world'), 500) - - // const filepath = p.join(os.homedir(), 'Musik', 'foo.mp3') - // const rs = fs.createReadStream(filepath) - // rs.pipe(feed.createWriteStream()) - } - if (mode === 'read') { - feed.createReadStream({ live: true }).pipe(process.stdout) - } - - // setTimeout(() => proto.destroy(), 1000) + hypercore.on('append', _ => { + console.log(`${mode} got append, new length ${hypercore.length} and byte length ${hypercore.byteLength}, replaying:`) + console.log(""); + console.log("### Results (Press Ctrl-C to exit)"); + console.log(""); + console.log("Replication succeeded if you see '0: hi', '1: ola', '2: hello' and '3: mundo' (not necessarily in that order)") + console.log(""); + for (let i = 0; i < hypercore.length; i++) { + hypercore.get(i).then(value => { + console.log(`${i}: ${value}`); + }); + } }) + socket.pipe(hypercore.replicate(isInitiator)).pipe(socket) } function exit (msg) { diff --git a/examples-nodejs/run.js b/examples-nodejs/run.js index 3bb4375..c96541f 100644 --- a/examples-nodejs/run.js +++ b/examples-nodejs/run.js @@ -4,19 +4,17 @@ const chalk = require('chalk') const split = require('split2') const PORT = 8000 -const FILE = p.join(__dirname, '..', 'README.md') const EXAMPLE_NODE = p.join(__dirname, 'replicate.js') -const EXAMPLE_RUST = process.argv[2] -if (!EXAMPLE_RUST) { +const EXAMPLE_RUST = 'replication' +const MODE = process.argv[2] +if (!MODE) { usage() } -const SERVER = process.argv[3] || 'node' function startNode (mode, key, color, name) { const args = [EXAMPLE_NODE, mode, PORT] if (key) args.push(key) - else args.push(FILE) const node = start({ bin: 'node', args, @@ -46,12 +44,18 @@ function startRust (mode, key, color, name) { } let client, server -if (SERVER === 'node') { +if (MODE === 'nodeServer') { server = startNode client = startRust -} else { +} else if (MODE === 'rustServer') { server = startRust client = startNode +} else if (MODE === 'node') { + server = startNode + client = startNode +} else if (MODE === 'rust') { + server = startRust + client = startRust } const procs = [] @@ -89,6 +93,6 @@ function start ({ bin, args, name, color, env = {} }) { } function usage () { - console.error('USAGE: node run.js [basic|hypercore]') + console.error('USAGE: node run.js [node|rust|nodeServer|rustServer]') process.exit(1) } diff --git a/examples/basic.rs b/examples/basic.rs deleted file mode 100644 index 372a488..0000000 --- a/examples/basic.rs +++ /dev/null @@ -1,230 +0,0 @@ -use anyhow::Result; -use async_std::net::TcpStream; -use async_std::sync::Arc; -use async_std::task; -use futures_lite::stream::StreamExt; -use log::*; -use std::collections::HashMap; -use std::convert::TryInto; -use std::env; - -use hypercore_protocol::schema::*; -use hypercore_protocol::{ - discovery_key, Channel, DiscoveryKey, Event, Key, Message, ProtocolBuilder, -}; - -mod util; -use util::{tcp_client, tcp_server}; - -/// Print usage and exit. -fn usage() { - println!("usage: cargo run --example basic -- [client|server] [port] [key]"); - std::process::exit(1); -} - -fn main() { - util::init_logger(); - if env::args().count() < 3 { - usage(); - } - let mode = env::args().nth(1).unwrap(); - let port = env::args().nth(2).unwrap(); - let address = format!("127.0.0.1:{}", port); - - let key = env::args().nth(3); - let key = key.map_or(None, |key| hex::decode(key).ok()); - let key = key.map(|key| key.try_into().expect("Key must be a 32 byte hex string")); - - let mut feedstore = FeedStore::new(); - if let Some(key) = key { - feedstore.add(Feed::new(key)); - } else { - let key = [9u8; 32]; - feedstore.add(Feed::new(key.clone())); - println!("KEY={}", hex::encode(&key)); - } - let feedstore = Arc::new(feedstore); - - task::block_on(async move { - let result = match mode.as_ref() { - "server" => tcp_server(address, onconnection, feedstore).await, - "client" => tcp_client(address, onconnection, feedstore).await, - _ => panic!(usage()), - }; - util::log_if_error(&result); - }); -} - -// The onconnection handler is called for each incoming connection (if server) -// or once when connected (if client). -async fn onconnection( - stream: TcpStream, - is_initiator: bool, - feedstore: Arc, -) -> Result<()> { - let mut protocol = ProtocolBuilder::new(is_initiator).connect(stream); - while let Some(event) = protocol.next().await { - let event = event?; - debug!("EVENT {:?}", event); - match event { - Event::Handshake(_) => { - if is_initiator { - for feed in feedstore.feeds.values() { - protocol.open(feed.key.clone()).await?; - } - } - } - Event::DiscoveryKey(dkey) => { - if let Some(feed) = feedstore.get(&dkey) { - protocol.open(feed.key.clone()).await?; - } - } - Event::Channel(mut channel) => { - if let Some(feed) = feedstore.get(channel.discovery_key()) { - let feed = feed.clone(); - let mut state = FeedState::default(); - task::spawn(async move { - while let Some(message) = channel.next().await { - onmessage(&*feed, &mut state, &mut channel, message).await; - } - }); - } - } - _ => {} - } - } - Ok(()) -} - -struct FeedStore { - pub feeds: HashMap>, -} -impl FeedStore { - pub fn new() -> Self { - let feeds = HashMap::new(); - Self { feeds } - } - - pub fn add(&mut self, feed: Feed) { - let hdkey = hex::encode(&feed.discovery_key); - self.feeds.insert(hdkey, Arc::new(feed)); - } - - pub fn get(&self, discovery_key: &[u8; 32]) -> Option<&Arc> { - let hdkey = hex::encode(discovery_key); - self.feeds.get(&hdkey) - } -} - -/// A Feed is a single unit of replication, an append-only log. -/// This toy feed can only read sequentially and does not save or buffer anything. -#[derive(Debug)] -struct Feed { - key: Key, - discovery_key: DiscoveryKey, -} -impl Feed { - pub fn new(key: Key) -> Self { - Feed { - discovery_key: discovery_key(&key), - key, - } - } -} - -/// A FeedState stores the head seq of the remote. -/// This would have a bitfield to support sparse sync in the actual impl. -#[derive(Debug)] -struct FeedState { - remote_head: Option, -} -impl Default for FeedState { - fn default() -> Self { - FeedState { remote_head: None } - } -} - -async fn onmessage(_feed: &Feed, state: &mut FeedState, channel: &mut Channel, message: Message) { - match message { - Message::Open(_) => { - let msg = Want { - start: 0, - length: None, - }; - channel - .send(Message::Want(msg)) - .await - .expect("failed to send"); - } - Message::Want(_) => { - let msg = Have { - start: 0, - length: Some(3), - bitfield: None, - ack: None, - }; - channel - .send(Message::Have(msg)) - .await - .expect("failed to send"); - } - Message::Have(msg) => { - let new_remote_head = msg.start + msg.length.unwrap_or(0); - if state.remote_head == None { - state.remote_head = Some(new_remote_head); - let msg = Request { - index: 0, - bytes: None, - hash: None, - nodes: None, - }; - channel.send(Message::Request(msg)).await.unwrap(); - } else if let Some(remote_head) = state.remote_head { - if remote_head < new_remote_head { - state.remote_head = Some(new_remote_head); - } - } - } - Message::Request(msg) => { - channel - .send(Message::Data(Data { - index: msg.index, - value: Some("Hello world".as_bytes().to_vec()), - nodes: vec![], - signature: None, - })) - .await - .unwrap(); - } - Message::Data(msg) => { - debug!( - "receive data: idx {}, {} bytes (remote_head {:?})", - msg.index, - msg.value.as_ref().map_or(0, |v| v.len()), - state.remote_head - ); - - if let Some(value) = msg.value { - eprintln!("{} {}", msg.index, String::from_utf8(value).unwrap()); - // let mut stdout = io::stdout(); - // stdout.write_all(&value).await.unwrap(); - // stdout.flush().await.unwrap(); - } - - let next = msg.index + 1; - if let Some(remote_head) = state.remote_head { - if remote_head >= next { - // Request next data block. - let msg = Request { - index: next, - bytes: None, - hash: None, - nodes: None, - }; - channel.send(Message::Request(msg)).await.unwrap(); - } - } - } - _ => {} - } -} diff --git a/examples/extension.rs b/examples/extension.rs deleted file mode 100644 index eac8d6c..0000000 --- a/examples/extension.rs +++ /dev/null @@ -1,191 +0,0 @@ -use async_std::net::TcpStream; -use async_std::prelude::*; -use async_std::task::{self, JoinHandle}; -use futures_lite::io::{AsyncRead, AsyncWrite}; -use hypercore_protocol::{Channel, Event, Protocol, ProtocolBuilder}; -use log::*; -use pretty_bytes::converter::convert as pretty_bytes; -use std::io; -use std::time::Instant; - -// This example sets up a pair of protocols connected over TCP, -// registers the same extension on each, and then uses the AsyncWrite -// and AsyncRead implementations on the extension to pipe binary data -// from A to B, where B will echo it back, and then A will read the -// echoed stream. Then, the throughput is calculated and reported. -// -// Note: Run in release mode, otherwise it will be *very* slow. -// -// Adjust the consts below for different "settings". - -const BYTES: usize = 1024 * 1024 * 64; -const PARALLEL: usize = 5; -const ENCRYPT: bool = false; - -#[async_std::main] -pub async fn main() -> anyhow::Result<()> { - env_logger::init(); - - let port = 9000u16; - let mut tasks = vec![]; - let instant = Instant::now(); - for n in 0..PARALLEL { - let port = port + n as u16; - let task = task::spawn(async move { - let _ = channel_extension_async_read_write(BYTES, port, ENCRYPT).await; - }); - tasks.push(task); - } - for task in tasks.iter_mut() { - task.await; - } - print_stats("total", instant, (BYTES * PARALLEL) as f64); - Ok(()) -} - -async fn channel_extension_async_read_write( - limit: usize, - port: u16, - encrypted: bool, -) -> anyhow::Result<()> { - // let (mut proto_a, mut proto_b) = create_pair_memory().await?; - let (mut proto_a, mut proto_b) = create_pair_tcp(port, encrypted).await?; - let key = [1u8; 32]; - proto_a.open(key).await?; - proto_b.open(key).await?; - - let next_a = drive_until_channel(proto_a); - let next_b = drive_until_channel(proto_b); - let (proto_a, mut channel_a) = next_a.await?; - let (proto_b, mut channel_b) = next_b.await?; - - let mut ext_a = channel_a.register_extension("ext").await; - let mut ext_b = channel_b.register_extension("ext").await; - - // Drive the protocols and channels. - drive(proto_a); - drive(proto_b); - drive(channel_a); - drive(channel_b); - - let instant = Instant::now(); - - // On B, run an echo loop. - task::spawn(async move { - let mut read_buf = vec![0u8; 1024 * 64]; - let mut len = 0; - loop { - let n = ext_b.read(&mut read_buf).await.unwrap(); - len += n; - debug!("B READ: {}", len); - ext_b.write_all(&read_buf[..n]).await.unwrap(); - } - }); - - // On A, write BYTES bytes. - let mut len = 0; - task::spawn({ - let mut ext_a = ext_a.clone(); - let buf = vec![0u8; 1024 * 64]; - async move { - while len < limit + 10 { - ext_a.write_all(&buf).await.unwrap(); - len += buf.len(); - debug!("A WRITE: {}", len); - } - } - }); - - // On A, read BYTES bytes back (from the echo on B). - let mut read_buf = vec![0u8; 1024 * 64]; - let mut len = 0; - while len < limit { - let n = ext_a.read(&mut read_buf).await.unwrap(); - len += n; - debug!("A READ {}", len); - } - - // Now report how long it all took. - print_stats("done", instant, limit as f64); - Ok(()) -} - -fn print_stats(msg: impl ToString, instant: Instant, bytes: f64) { - let msg = msg.to_string(); - let time = instant.elapsed(); - let secs = time.as_secs_f64(); - let bs = bytes / secs; - eprintln!( - "[{}] time {:.3?} bytes {} throughput {}/s", - msg, - time, - pretty_bytes(bytes), - pretty_bytes(bs) - ); -} - -pub type TcpProtocol = Protocol; -pub async fn create_pair_tcp( - port: u16, - encrypted: bool, -) -> std::io::Result<(TcpProtocol, TcpProtocol)> { - let (stream_a, stream_b) = tcp::pair(port).await?; - let a = ProtocolBuilder::new(true) - .set_encrypted(encrypted) - .connect(stream_a); - let b = ProtocolBuilder::new(false) - .set_encrypted(encrypted) - .connect(stream_b); - Ok((a, b)) -} - -/// Drive a stream to completion in a task. -fn drive(mut proto: S) -> JoinHandle<()> -where - S: Stream + Send + Unpin + 'static, -{ - task::spawn(async move { while let Some(_event) = proto.next().await {} }) -} - -// Drive a protocol stream until the first channel arrives. -fn drive_until_channel( - mut proto: Protocol, -) -> JoinHandle, Channel)>> -where - IO: AsyncRead + AsyncWrite + Send + Unpin + 'static, -{ - task::spawn(async move { - while let Some(event) = proto.next().await { - let event = event?; - match event { - Event::Channel(channel) => return Ok((proto, channel)), - _ => {} - } - } - Err(io::Error::new( - io::ErrorKind::Interrupted, - "Protocol closed before a channel was opened", - )) - }) -} - -pub mod tcp { - use async_std::net::{TcpListener, TcpStream}; - use async_std::prelude::*; - use async_std::task; - use std::io::{Error, ErrorKind, Result}; - pub async fn pair(port: u16) -> Result<(TcpStream, TcpStream)> { - let address = format!("localhost:{}", port); - let listener = TcpListener::bind(&address).await?; - let mut incoming = listener.incoming(); - - let connect_task = task::spawn(async move { TcpStream::connect(&address).await }); - - let server_stream = incoming.next().await; - let server_stream = - server_stream.ok_or_else(|| Error::new(ErrorKind::Other, "Stream closed"))?; - let server_stream = server_stream?; - let client_stream = connect_task.await?; - Ok((server_stream, client_stream)) - } -} diff --git a/examples/hypercore.rs b/examples/hypercore.rs deleted file mode 100644 index 18c205a..0000000 --- a/examples/hypercore.rs +++ /dev/null @@ -1,330 +0,0 @@ -use anyhow::Result; -use async_std::net::TcpStream; -use async_std::sync::{Arc, Mutex}; -use async_std::task; -use futures_lite::stream::StreamExt; -use hypercore::{Feed, Node, NodeTrait, Proof, PublicKey, Signature, Storage}; -use log::*; -use random_access_memory::RandomAccessMemory; -use random_access_storage::RandomAccess; -use std::collections::HashMap; -use std::convert::{TryFrom, TryInto}; -use std::env; -use std::fmt::Debug; - -use hypercore_protocol::schema::*; -use hypercore_protocol::{discovery_key, Channel, Event, Message, ProtocolBuilder}; - -mod util; -use util::{tcp_client, tcp_server}; - -fn main() { - util::init_logger(); - if env::args().count() < 3 { - usage(); - } - let mode = env::args().nth(1).unwrap(); - let port = env::args().nth(2).unwrap(); - let address = format!("127.0.0.1:{}", port); - - let key = env::args().nth(3); - let key: Option<[u8; 32]> = key.map_or(None, |key| { - Some( - hex::decode(key) - .expect("Key has to be a hex string") - .try_into() - .expect("Key has to be a 32 byte hex string"), - ) - }); - - task::block_on(async move { - let mut feedstore: FeedStore = FeedStore::new(); - let storage = Storage::new_memory().await.unwrap(); - // Create a hypercore. - let feed = if let Some(key) = key { - let public_key = PublicKey::from_bytes(&key).unwrap(); - Feed::builder(public_key, storage).build().await.unwrap() - } else { - let mut feed = Feed::default(); - feed.append(b"hello").await.unwrap(); - feed.append(b"world").await.unwrap(); - feed - }; - info!("Opened feed: {}", hex::encode(feed.public_key().as_bytes())); - // Wrap it and add to the feed store. - let feed_wrapper = FeedWrapper::from_memory_feed(feed); - feedstore.add(feed_wrapper); - let feedstore = Arc::new(feedstore); - - let result = match mode.as_ref() { - "server" => tcp_server(address, onconnection, feedstore).await, - "client" => tcp_client(address, onconnection, feedstore).await, - _ => panic!(usage()), - }; - util::log_if_error(&result); - }); -} - -/// Print usage and exit. -fn usage() { - println!("usage: cargo run --example hypercore -- [client|server] [port] [key]"); - std::process::exit(1); -} - -// The onconnection handler is called for each incoming connection (if server) -// or once when connected (if client). -// Unfortunately, everything that touches the feedstore or a feed has to be generic -// at the moment. -async fn onconnection( - stream: TcpStream, - is_initiator: bool, - feedstore: Arc>, -) -> Result<()> -where - T: RandomAccess> + Debug + Send, -{ - let mut protocol = ProtocolBuilder::new(is_initiator).connect(stream); - - while let Some(event) = protocol.next().await { - let event = event?; - debug!("protocol event {:?}", event); - match event { - Event::Handshake(_) => { - if is_initiator { - for feed in feedstore.feeds.values() { - protocol.open(feed.key().clone()).await?; - } - } - } - Event::DiscoveryKey(dkey) => { - if let Some(feed) = feedstore.get(&dkey) { - protocol.open(feed.key().clone()).await?; - } - } - Event::Channel(channel) => { - if let Some(feed) = feedstore.get(channel.discovery_key()) { - feed.onpeer(channel); - } - } - Event::Close(_dkey) => {} - _ => {} - } - } - Ok(()) -} - -/// A container for hypercores. -#[derive(Debug)] -struct FeedStore -where - T: RandomAccess> + Debug + Send, -{ - feeds: HashMap>>, -} -impl FeedStore -where - T: RandomAccess> + Debug + Send, -{ - pub fn new() -> Self { - let feeds = HashMap::new(); - Self { feeds } - } - - pub fn add(&mut self, feed: FeedWrapper) { - let hdkey = hex::encode(&feed.discovery_key); - self.feeds.insert(hdkey, Arc::new(feed)); - } - - pub fn get(&self, discovery_key: &[u8; 32]) -> Option<&Arc>> { - let hdkey = hex::encode(discovery_key); - self.feeds.get(&hdkey) - } -} - -/// A Feed is a single unit of replication, an append-only log. -#[derive(Debug, Clone)] -struct FeedWrapper -where - T: RandomAccess> + Debug + Send, -{ - discovery_key: [u8; 32], - key: [u8; 32], - feed: Arc>>, -} - -impl FeedWrapper { - pub fn from_memory_feed(feed: Feed) -> Self { - let key = feed.public_key().to_bytes(); - FeedWrapper { - key, - discovery_key: discovery_key(&key), - feed: Arc::new(Mutex::new(feed)), - } - } -} - -impl FeedWrapper -where - T: RandomAccess> + Debug + Send + 'static, -{ - pub fn key(&self) -> &[u8; 32] { - &self.key - } - - pub fn onpeer(&self, mut channel: Channel) { - let mut state = PeerState::default(); - let mut feed = self.feed.clone(); - task::spawn(async move { - let msg = Want { - start: 0, - length: None, - }; - channel.send(Message::Want(msg)).await.unwrap(); - while let Some(message) = channel.next().await { - let result = onmessage(&mut feed, &mut state, &mut channel, message).await; - if let Err(e) = result { - error!("protocol error: {}", e); - break; - } - } - }); - } -} - -/// A PeerState stores the head seq of the remote. -/// This would have a bitfield to support sparse sync in the actual impl. -#[derive(Debug)] -struct PeerState { - remote_head: Option, -} -impl Default for PeerState { - fn default() -> Self { - PeerState { remote_head: None } - } -} - -async fn onmessage( - feed: &mut Arc>>, - state: &mut PeerState, - channel: &mut Channel, - message: Message, -) -> Result<()> -where - T: RandomAccess> + Debug + Send, -{ - match message { - Message::Open(_) => { - let msg = Want { - start: 0, - length: None, - }; - channel.send(Message::Want(msg)).await?; - } - Message::Want(msg) => { - let mut feed = feed.lock().await; - if feed.has(msg.start) { - channel - .have(Have { - start: msg.start, - ack: None, - bitfield: None, - length: None, - }) - .await?; - } - } - Message::Have(msg) => { - if state.remote_head == None { - state.remote_head = Some(msg.start); - let msg = Request { - index: 0, - bytes: None, - hash: None, - nodes: None, - }; - channel.send(Message::Request(msg)).await?; - } else if let Some(remote_head) = state.remote_head { - if remote_head < msg.start { - state.remote_head = Some(msg.start) - } - } - } - Message::Request(request) => { - let mut feed = feed.lock().await; - let index = request.index; - let value = feed.get(index).await?; - let proof = feed.proof(index, false).await?; - let nodes = proof - .nodes - .iter() - .map(|node| data::Node { - index: NodeTrait::index(node), - hash: NodeTrait::hash(node).to_vec(), - size: NodeTrait::len(node), - }) - .collect(); - let message = Data { - index, - value: value.clone(), - nodes, - signature: proof.signature.map(|s| s.to_bytes().to_vec()), - }; - channel.data(message).await?; - } - Message::Data(msg) => { - let mut feed = feed.lock().await; - let value: Option<&[u8]> = match msg.value.as_ref() { - None => None, - Some(value) => { - // eprintln!( - // "recv idx {}: {:?}", - // msg.index, - // String::from_utf8(value.clone()).unwrap() - // ); - Some(value) - } - }; - - let signature = match msg.signature { - Some(bytes) => Some(Signature::try_from(&bytes[..])?), - None => None, - }; - let nodes = msg - .nodes - .iter() - .map(|n| Node::new(n.index, n.hash.clone(), n.size)) - .collect(); - let proof = Proof { - index: msg.index, - nodes, - signature, - }; - - feed.put(msg.index, value, proof.clone()).await?; - - let i = msg.index; - let node = feed.get(i).await?; - if let Some(value) = node { - println!("feed idx {}: {:?}", i, String::from_utf8(value).unwrap()); - } else { - println!("feed idx {}: {:?}", i, "NONE"); - } - - let next = msg.index + 1; - if let Some(remote_head) = state.remote_head { - if remote_head >= next { - // Request next data block. - let msg = Request { - index: next, - bytes: None, - hash: None, - nodes: None, - }; - channel.send(Message::Request(msg)).await?; - } - }; - } - _ => {} - }; - Ok(()) -} diff --git a/examples/pipe.rs b/examples/pipe.rs deleted file mode 100644 index 4eec4a2..0000000 --- a/examples/pipe.rs +++ /dev/null @@ -1,201 +0,0 @@ -use anyhow::Result; -use async_std::task; -use futures_lite::prelude::*; -use futures_lite::stream::StreamExt; -use log::*; -use pretty_bytes::converter::convert as pretty_bytes; -use sluice::pipe::pipe; -use std::env; -use std::time::Instant; - -use hypercore_protocol::schema::*; -use hypercore_protocol::{Channel, Event, Message, Protocol, ProtocolBuilder}; - -fn main() { - env_logger::from_env(env_logger::Env::default().default_filter_or("info")).init(); - let config = Config::from_env(); - task::block_on(run_echo_pipes(config)).unwrap(); -} - -#[derive(Clone)] -struct Config { - pub connections: u64, - pub blocksize: u64, - pub length: u64, - pub no_encrypt: bool, -} - -impl Config { - pub fn total_bytes(&self) -> u64 { - self.connections * self.blocksize * self.length * 2 - } - - pub fn from_env() -> Self { - Config { - connections: parse_env_u64("CONNECTIONS", 10), - blocksize: parse_env_u64("BLOCKSIZE", 100), - length: parse_env_u64("LENGTH", 1000), - no_encrypt: env::var("NO_ENCRYPT").is_ok(), - } - } -} - -async fn run_echo_pipes(config: Config) -> Result<()> { - let start = std::time::Instant::now(); - let mut tasks = vec![]; - for i in 0..config.connections { - tasks.push(task::spawn(run_echo(config.clone(), i))); - } - for task in tasks { - task.await?; - } - // futures::future::join_all(futs).await; - print_stats("total", start, config.total_bytes() as f64); - Ok(()) -} - -async fn run_echo(config: Config, i: u64) -> Result<()> { - // let cap: usize = config.blocksize as usize * 10; - let (ar, bw) = pipe(); - let (br, aw) = pipe(); - - let mut a = ProtocolBuilder::new(true); - let mut b = ProtocolBuilder::new(false); - if config.no_encrypt { - a = a.set_encrypted(false); - b = b.set_encrypted(false); - } - let a = a.connect_rw(ar, aw); - let b = b.connect_rw(br, bw); - let c = config.clone(); - let ta = task::spawn(async move { onconnection(c, i, a).await }); - let c = config.clone(); - let tb = task::spawn(async move { onconnection(c, i, b).await }); - let _lena = ta.await?; - let _lenb = tb.await?; - Ok(()) -} - -// The onconnection handler is called for each incoming connection (if server) -// or once when connected (if client). -async fn onconnection(config: Config, i: u64, mut protocol: Protocol) -> Result -where - IO: AsyncRead + AsyncWrite + Send + Unpin + 'static, -{ - let key = [0u8; 32]; - let is_initiator = protocol.is_initiator(); - // let mut len: u64 = 0; - while let Some(event) = protocol.next().await { - match event { - Ok(event) => { - debug!("[init {}] EVENT {:?}", is_initiator, event); - match event { - Event::Handshake(_) => { - protocol.open(key.clone()).await?; - } - Event::DiscoveryKey(_dkey) => {} - Event::Channel(channel) => { - let config = config.clone(); - task::spawn(async move { - if is_initiator { - on_channel_init(config, i, channel).await - } else { - on_channel_resp(config, i, channel).await - } - }); - } - Event::Close(_) => { - return Ok(0); - } - _ => {} - } - } - Err(err) => { - error!("ERROR {:?}", err); - return Err(err.into()); - } - } - } - Ok(0) -} - -async fn on_channel_resp(_config: Config, _i: u64, mut channel: Channel) -> Result { - let mut len: u64 = 0; - while let Some(message) = channel.next().await { - match message { - Message::Data(ref data) => { - len += data.value.as_ref().map_or(0, |v| v.len() as u64); - debug!("[b] echo {}", data.index); - channel.send(message).await?; - } - Message::Close(_) => { - break; - } - _ => {} - } - } - debug!("[b] ch close"); - Ok(len) -} - -async fn on_channel_init(config: Config, i: u64, mut channel: Channel) -> Result { - let data = vec![1u8; config.blocksize as usize]; - let mut len: u64 = 0; - let message = msg_data(0, data); - channel.send(message).await?; - - let start = std::time::Instant::now(); - - while let Some(message) = channel.next().await { - match message { - Message::Data(mut data) => { - len += data.value.as_ref().map_or(0, |v| v.len() as u64); - debug!("[a] recv {}", data.index); - if data.index >= config.length { - debug!("close at {}", data.index); - channel - .send(Message::Close(Close { - discovery_key: None, - })) - .await?; - break; - } else { - data.index += 1; - channel.send(Message::Data(data)).await?; - } - } - _ => {} - } - } - print_stats(i, start, len as f64); - Ok(len) -} - -fn msg_data(index: u64, value: Vec) -> Message { - Message::Data(Data { - index, - value: Some(value), - nodes: vec![], - signature: None, - }) -} - -fn print_stats(msg: impl ToString, instant: Instant, bytes: f64) { - let msg = msg.to_string(); - let time = instant.elapsed(); - let secs = time.as_secs_f64(); - let bs = bytes / secs; - eprintln!( - "[{}] time {:?} bytes {} throughput {}/s", - msg, - time, - pretty_bytes(bytes), - pretty_bytes(bs) - ); -} - -fn parse_env_u64(name: &str, default: u64) -> u64 { - env::var(name) - .map(|v| v.parse().unwrap()) - .unwrap_or(default) -} diff --git a/examples/replication.rs b/examples/replication.rs new file mode 100644 index 0000000..fca2bd9 --- /dev/null +++ b/examples/replication.rs @@ -0,0 +1,491 @@ +use anyhow::Result; +use async_std::net::{TcpListener, TcpStream}; +use async_std::prelude::*; +use async_std::sync::{Arc, Mutex}; +use async_std::task; +use env_logger::Env; +use futures_lite::stream::StreamExt; +use hypercore::{ + Hypercore, HypercoreBuilder, PartialKeypair, RequestBlock, RequestUpgrade, Storage, + VerifyingKey, +}; +use log::*; +use random_access_memory::RandomAccessMemory; +use random_access_storage::RandomAccess; +use std::collections::HashMap; +use std::convert::TryInto; +use std::env; +use std::fmt::Debug; + +use hypercore_protocol::schema::*; +use hypercore_protocol::{discovery_key, Channel, Event, Message, ProtocolBuilder}; + +fn main() { + init_logger(); + if env::args().count() < 3 { + usage(); + } + let mode = env::args().nth(1).unwrap(); + let port = env::args().nth(2).unwrap(); + let address = format!("127.0.0.1:{port}"); + + let key = env::args().nth(3); + let key: Option<[u8; 32]> = key.map(|key| { + hex::decode(key) + .expect("Key has to be a hex string") + .try_into() + .expect("Key has to be a 32 byte hex string") + }); + + task::block_on(async move { + let mut hypercore_store: HypercoreStore = HypercoreStore::new(); + let storage = Storage::new_memory().await.unwrap(); + // Create a hypercore. + let hypercore = if let Some(key) = key { + let public_key = VerifyingKey::from_bytes(&key).unwrap(); + HypercoreBuilder::new(storage) + .key_pair(PartialKeypair { + public: public_key, + secret: None, + }) + .build() + .await + .unwrap() + } else { + let mut hypercore = HypercoreBuilder::new(storage).build().await.unwrap(); + let batch: &[&[u8]] = &[b"hi\n", b"ola\n", b"hello\n", b"mundo\n"]; + hypercore.append_batch(batch).await.unwrap(); + hypercore + }; + println!( + "KEY={}", + hex::encode(hypercore.key_pair().public.as_bytes()) + ); + info!("{} opened hypercore", mode); + // Wrap it and add to the hypercore store. + let hypercore_wrapper = HypercoreWrapper::from_memory_hypercore(hypercore); + hypercore_store.add(hypercore_wrapper); + let hypercore_store = Arc::new(hypercore_store); + + let result = match mode.as_ref() { + "server" => tcp_server(address, onconnection, hypercore_store).await, + "client" => tcp_client(address, onconnection, hypercore_store).await, + _ => panic!("{:?}", usage()), + }; + log_if_error(&result); + }); +} + +/// Print usage and exit. +fn usage() { + println!("usage: cargo run --example hypercore -- [client|server] [port] [key]"); + std::process::exit(1); +} + +// The onconnection handler is called for each incoming connection (if server) +// or once when connected (if client). +// Unfortunately, everything that touches the hypercore_store or a hypercore has to be generic +// at the moment. +async fn onconnection( + stream: TcpStream, + is_initiator: bool, + hypercore_store: Arc>, +) -> Result<()> +where + T: RandomAccess + Debug + Send, +{ + info!("onconnection, initiator: {}", is_initiator); + let mut protocol = ProtocolBuilder::new(is_initiator).connect(stream); + info!("protocol created, polling for next()"); + while let Some(event) = protocol.next().await { + let event = event?; + info!("protocol event {:?}", event); + match event { + Event::Handshake(_) => { + if is_initiator { + for hypercore in hypercore_store.hypercores.values() { + protocol.open(*hypercore.key()).await?; + } + } + } + Event::DiscoveryKey(dkey) => { + if let Some(hypercore) = hypercore_store.get(&dkey) { + protocol.open(*hypercore.key()).await?; + } + } + Event::Channel(channel) => { + if let Some(hypercore) = hypercore_store.get(channel.discovery_key()) { + hypercore.onpeer(channel); + } + } + Event::Close(_dkey) => {} + _ => {} + } + } + Ok(()) +} + +/// A container for hypercores. +#[derive(Debug)] +struct HypercoreStore +where + T: RandomAccess + Debug + Send, +{ + hypercores: HashMap>>, +} +impl HypercoreStore +where + T: RandomAccess + Debug + Send, +{ + pub fn new() -> Self { + let hypercores = HashMap::new(); + Self { hypercores } + } + + pub fn add(&mut self, hypercore: HypercoreWrapper) { + let hdkey = hex::encode(hypercore.discovery_key); + self.hypercores.insert(hdkey, Arc::new(hypercore)); + } + + pub fn get(&self, discovery_key: &[u8; 32]) -> Option<&Arc>> { + let hdkey = hex::encode(discovery_key); + self.hypercores.get(&hdkey) + } +} + +/// A Hypercore is a single unit of replication, an append-only log. +#[derive(Debug, Clone)] +struct HypercoreWrapper +where + T: RandomAccess + Debug + Send, +{ + discovery_key: [u8; 32], + key: [u8; 32], + hypercore: Arc>>, +} + +impl HypercoreWrapper { + pub fn from_memory_hypercore(hypercore: Hypercore) -> Self { + let key = hypercore.key_pair().public.to_bytes(); + HypercoreWrapper { + key, + discovery_key: discovery_key(&key), + hypercore: Arc::new(Mutex::new(hypercore)), + } + } +} + +impl HypercoreWrapper +where + T: RandomAccess + Debug + Send + 'static, +{ + pub fn key(&self) -> &[u8; 32] { + &self.key + } + + pub fn onpeer(&self, mut channel: Channel) { + let mut peer_state = PeerState::default(); + let mut hypercore = self.hypercore.clone(); + task::spawn(async move { + let info = { + let hypercore = hypercore.lock().await; + hypercore.info() + }; + + if info.fork != peer_state.remote_fork { + peer_state.can_upgrade = false; + } + let remote_length = if info.fork == peer_state.remote_fork { + peer_state.remote_length + } else { + 0 + }; + + let sync_msg = Synchronize { + fork: info.fork, + length: info.length, + remote_length, + can_upgrade: peer_state.can_upgrade, + uploading: true, + downloading: true, + }; + + if info.contiguous_length > 0 { + let range_msg = Range { + drop: false, + start: 0, + length: info.contiguous_length, + }; + channel + .send_batch(&[Message::Synchronize(sync_msg), Message::Range(range_msg)]) + .await + .unwrap(); + } else { + channel.send(Message::Synchronize(sync_msg)).await.unwrap(); + } + while let Some(message) = channel.next().await { + let result = + onmessage(&mut hypercore, &mut peer_state, &mut channel, message).await; + if let Err(e) = result { + error!("protocol error: {}", e); + break; + } + } + }); + } +} + +/// A PeerState stores the head seq of the remote. +/// This would have a bitfield to support sparse sync in the actual impl. +#[derive(Debug)] +struct PeerState { + can_upgrade: bool, + remote_fork: u64, + remote_length: u64, + remote_can_upgrade: bool, + remote_uploading: bool, + remote_downloading: bool, + remote_synced: bool, + length_acked: u64, +} +impl Default for PeerState { + fn default() -> Self { + PeerState { + can_upgrade: true, + remote_fork: 0, + remote_length: 0, + remote_can_upgrade: false, + remote_uploading: true, + remote_downloading: true, + remote_synced: false, + length_acked: 0, + } + } +} + +async fn onmessage( + hypercore: &mut Arc>>, + peer_state: &mut PeerState, + channel: &mut Channel, + message: Message, +) -> Result<()> +where + T: RandomAccess + Debug + Send, +{ + match message { + Message::Synchronize(message) => { + println!("Got Synchronize message {message:?}"); + let length_changed = message.length != peer_state.remote_length; + let first_sync = !peer_state.remote_synced; + let info = { + let hypercore = hypercore.lock().await; + hypercore.info() + }; + let same_fork = message.fork == info.fork; + + peer_state.remote_fork = message.fork; + peer_state.remote_length = message.length; + peer_state.remote_can_upgrade = message.can_upgrade; + peer_state.remote_uploading = message.uploading; + peer_state.remote_downloading = message.downloading; + peer_state.remote_synced = true; + + peer_state.length_acked = if same_fork { message.remote_length } else { 0 }; + + let mut messages = vec![]; + + if first_sync { + // Need to send another sync back that acknowledges the received sync + let msg = Synchronize { + fork: info.fork, + length: info.length, + remote_length: peer_state.remote_length, + can_upgrade: peer_state.can_upgrade, + uploading: true, + downloading: true, + }; + messages.push(Message::Synchronize(msg)); + } + + if peer_state.remote_length > info.length + && peer_state.length_acked == info.length + && length_changed + { + let msg = Request { + id: 1, // There should be proper handling for in-flight request ids + fork: info.fork, + hash: None, + block: None, + seek: None, + upgrade: Some(RequestUpgrade { + start: info.length, + length: peer_state.remote_length - info.length, + }), + }; + messages.push(Message::Request(msg)); + } + channel.send_batch(&messages).await?; + } + Message::Request(message) => { + println!("Got Request message {message:?}"); + let (info, proof) = { + let mut hypercore = hypercore.lock().await; + let proof = hypercore + .create_proof(message.block, message.hash, message.seek, message.upgrade) + .await?; + (hypercore.info(), proof) + }; + if let Some(proof) = proof { + let msg = Data { + request: message.id, + fork: info.fork, + hash: proof.hash, + block: proof.block, + seek: proof.seek, + upgrade: proof.upgrade, + }; + channel.send(Message::Data(msg)).await?; + } + } + Message::Data(message) => { + println!("Got Data message {message:?}"); + let (_old_info, _applied, new_info, request_block) = { + let mut hypercore = hypercore.lock().await; + let old_info = hypercore.info(); + let proof = message.clone().into_proof(); + let applied = hypercore.verify_and_apply_proof(&proof).await?; + let new_info = hypercore.info(); + let request_block: Option = if let Some(upgrade) = &message.upgrade { + // When getting the initial upgrade, send a request for the first missing block + if old_info.length < upgrade.length { + let request_index = old_info.length; + let nodes = hypercore.missing_nodes(request_index).await?; + Some(RequestBlock { + index: request_index, + nodes, + }) + } else { + None + } + } else if let Some(block) = &message.block { + // When receiving a block, ask for the next, if there are still some missing + if block.index < peer_state.remote_length - 1 { + let request_index = block.index + 1; + let nodes = hypercore.missing_nodes(request_index).await?; + Some(RequestBlock { + index: request_index, + nodes, + }) + } else { + None + } + } else { + None + }; + + // If all have been replicated, print the result + if new_info.contiguous_length == new_info.length { + println!(); + println!("### Results"); + println!(); + println!("Replication succeeded if this prints '0: hi', '1: ola', '2: hello' and '3: mundo':"); + println!(); + for i in 0..new_info.contiguous_length { + println!( + "{}: {}", + i, + String::from_utf8(hypercore.get(i).await?.unwrap()).unwrap() + ); + } + println!("Press Ctrl-C to exit"); + } + (old_info, applied, new_info, request_block) + }; + + let mut messages: Vec = vec![]; + if let Some(upgrade) = &message.upgrade { + let new_length = upgrade.length; + let remote_length = if new_info.fork == peer_state.remote_fork { + peer_state.remote_length + } else { + 0 + }; + messages.push(Message::Synchronize(Synchronize { + fork: new_info.fork, + length: new_length, + remote_length, + can_upgrade: false, + uploading: true, + downloading: true, + })); + } + if let Some(request_block) = request_block { + messages.push(Message::Request(Request { + id: request_block.index + 1, + fork: new_info.fork, + hash: None, + block: Some(request_block), + seek: None, + upgrade: None, + })); + } + channel.send_batch(&messages).await.unwrap(); + } + _ => {} + }; + Ok(()) +} + +/// Init EnvLogger, logging info, warn and error messages to stdout. +pub fn init_logger() { + env_logger::from_env(Env::default().default_filter_or("info")).init(); +} + +/// Log a result if it's an error. +pub fn log_if_error(result: &Result<()>) { + if let Err(err) = result.as_ref() { + log::error!("error: {}", err); + } +} + +/// A simple async TCP server that calls an async function for each incoming connection. +pub async fn tcp_server( + address: String, + onconnection: impl Fn(TcpStream, bool, C) -> F + Send + Sync + Copy + 'static, + context: C, +) -> Result<()> +where + F: Future> + Send, + C: Clone + Send + 'static, +{ + let listener = TcpListener::bind(&address).await?; + log::info!("listening on {}", listener.local_addr()?); + let mut incoming = listener.incoming(); + while let Some(Ok(stream)) = incoming.next().await { + let context = context.clone(); + let peer_addr = stream.peer_addr().unwrap(); + log::info!("new connection from {}", peer_addr); + task::spawn(async move { + let result = onconnection(stream, false, context).await; + log_if_error(&result); + log::info!("connection closed from {}", peer_addr); + }); + } + Ok(()) +} + +/// A simple async TCP client that calls an async function when connected. +pub async fn tcp_client( + address: String, + onconnection: impl Fn(TcpStream, bool, C) -> F + Send + Sync + Copy + 'static, + context: C, +) -> Result<()> +where + F: Future> + Send, + C: Clone + Send + 'static, +{ + log::info!("attempting connection to {address}"); + let stream = TcpStream::connect(&address).await?; + log::info!("connected to {address}"); + onconnection(stream, true, context).await +} diff --git a/examples/util.rs b/examples/util.rs deleted file mode 100644 index 05042d1..0000000 --- a/examples/util.rs +++ /dev/null @@ -1,62 +0,0 @@ -use anyhow::Result; -use async_std::net::{TcpListener, TcpStream}; -use async_std::prelude::*; -use async_std::task; -use env_logger::Env; - -// We use this example as a module for the other examples. -#[allow(dead_code)] -fn main() {} - -/// Init EnvLogger, logging info, warn and error messages to stdout. -pub fn init_logger() { - env_logger::from_env(Env::default().default_filter_or("info")).init(); -} - -/// Log a result if it's an error. -pub fn log_if_error(result: &Result<()>) { - if let Err(err) = result.as_ref() { - log::error!("error: {}", err); - } -} - -/// A simple async TCP server that calls an async function for each incoming connection. -pub async fn tcp_server( - address: String, - onconnection: impl Fn(TcpStream, bool, C) -> F + Send + Sync + Copy + 'static, - context: C, -) -> Result<()> -where - F: Future> + Send, - C: Clone + Send + 'static, -{ - let listener = TcpListener::bind(&address).await?; - log::info!("listening on {}", listener.local_addr()?); - let mut incoming = listener.incoming(); - while let Some(Ok(stream)) = incoming.next().await { - let context = context.clone(); - let peer_addr = stream.peer_addr().unwrap(); - log::info!("new connection from {}", peer_addr); - task::spawn(async move { - let result = onconnection(stream, false, context).await; - log_if_error(&result); - log::info!("connection closed from {}", peer_addr); - }); - } - Ok(()) -} - -/// A simple async TCP client that calls an async function when connected. -pub async fn tcp_client( - address: String, - onconnection: impl Fn(TcpStream, bool, C) -> F + Send + Sync + Copy + 'static, - context: C, -) -> Result<()> -where - F: Future> + Send, - C: Clone + Send + 'static, -{ - let stream = TcpStream::connect(&address).await?; - log::info!("connected to {}", &address); - onconnection(stream, true, context).await -} diff --git a/src/builder.rs b/src/builder.rs index 02ede9f..d797654 100644 --- a/src/builder.rs +++ b/src/builder.rs @@ -1,64 +1,26 @@ -use crate::duplex::Duplex; use crate::Protocol; +use crate::{duplex::Duplex, protocol::Options}; use futures_lite::io::{AsyncRead, AsyncWrite}; -/// Options for a Protocol instance. -#[derive(Debug)] -pub struct Options { - /// Whether this peer initiated the IO connection for this protoccol - pub is_initiator: bool, - /// Enable or disable the handshake. - /// Disabling the handshake will also disable capabilitity verification. - /// Don't disable this if you're not 100% sure you want this. - pub noise: bool, - /// Enable or disable transport encryption. - pub encrypted: bool, -} - -impl Options { - /// Create with default options. - pub fn new(is_initiator: bool) -> Self { - Self { - is_initiator, - noise: true, - encrypted: true, - } - } -} - /// Build a Protocol instance with options. #[derive(Debug)] pub struct Builder(Options); impl Builder { - /// Create a protocol builder. - pub fn new(is_initiator: bool) -> Self { - Self(Options { - is_initiator, - noise: true, - encrypted: true, - }) - } - - /// Default options for an initiating endpoint. - pub fn initiator() -> Self { - Self::new(true) - } - - /// Default options for a responding endpoint. - pub fn responder() -> Self { - Self::new(false) + /// Create a protocol builder as initiator (true) or responder (false). + pub fn new(initiator: bool) -> Self { + Self(Options::new(initiator)) } - /// Set encrypted option. - pub fn set_encrypted(mut self, encrypted: bool) -> Self { + /// Set encrypted option. Defaults to true. + pub fn encrypted(mut self, encrypted: bool) -> Self { self.0.encrypted = encrypted; self } - /// Set handshake option. - pub fn set_noise(mut self, noise: bool) -> Self { - self.0.noise = noise; + /// Set handshake option. Defaults to true. + pub fn handshake(mut self, handshake: bool) -> Self { + self.0.noise = handshake; self } diff --git a/src/channels.rs b/src/channels.rs index da69afc..dea48d7 100644 --- a/src/channels.rs +++ b/src/channels.rs @@ -1,10 +1,9 @@ -use crate::extension::{Extension, Extensions}; use crate::message::ChannelMessage; use crate::schema::*; use crate::util::{map_channel_err, pretty_hash}; use crate::Message; use crate::{discovery_key, DiscoveryKey, Key}; -use async_channel::{Receiver, Sender}; +use async_channel::{Receiver, Sender, TrySendError}; use futures_lite::ready; use futures_lite::stream::Stream; use std::collections::HashMap; @@ -20,11 +19,11 @@ use std::task::Poll; /// This is the handle that can be sent to other threads. pub struct Channel { inbound_rx: Option>, - outbound_tx: Sender, + direct_inbound_tx: Sender, + outbound_tx: Sender>, key: Key, discovery_key: DiscoveryKey, local_id: usize, - extensions: Extensions, closed: Arc, } @@ -75,14 +74,37 @@ impl Channel { } let message = ChannelMessage::new(self.local_id as u64, message); self.outbound_tx - .send(message) + .send(vec![message]) .await .map_err(map_channel_err) } - /// Register a protocol extension. - pub async fn register_extension(&mut self, name: impl ToString) -> Extension { - self.extensions.register(name.to_string()).await + /// Send a batch of messages over the channel. + pub async fn send_batch(&mut self, messages: &[Message]) -> Result<()> { + // In javascript this is cork()/uncork(), e.g.: + // + // https://github.com/holepunchto/hypercore/blob/c338b9aaa4442d35bc9d283d2c242b86a46de6d4/lib/replicator.js#L402-L418 + // + // at the protomux level, where there can be messages from multiple channels in a single + // stream write: + // + // https://github.com/holepunchto/protomux/blob/d3d6f8f55e52c2fbe5cd56f5d067ac43ca13c27d/index.js#L368-L389 + // + // Batching messages across channels like protomux is capable of doing is not (yet) implemented. + if self.closed() { + return Err(Error::new( + ErrorKind::ConnectionAborted, + "Channel is closed", + )); + } + let messages = messages + .iter() + .map(|message| ChannelMessage::new(self.local_id as u64, message.clone())) + .collect(); + self.outbound_tx + .send(messages) + .await + .map_err(map_channel_err) } /// Take the receiving part out of the channel. @@ -93,49 +115,13 @@ impl Channel { self.inbound_rx.take() } - /// Send a status message. - pub async fn status(&mut self, msg: Status) -> Result<()> { - self.send(Message::Status(msg)).await - } - - /// Send a options message. - pub async fn options(&mut self, msg: Options) -> Result<()> { - self.send(Message::Options(msg)).await - } - - /// Send a have message. - pub async fn have(&mut self, msg: Have) -> Result<()> { - self.send(Message::Have(msg)).await - } - - /// Send a unhave message. - pub async fn unhave(&mut self, msg: Unhave) -> Result<()> { - self.send(Message::Unhave(msg)).await - } - - /// Send a want message. - pub async fn want(&mut self, msg: Want) -> Result<()> { - self.send(Message::Want(msg)).await - } - - /// Send a unwant message. - pub async fn unwant(&mut self, msg: Unwant) -> Result<()> { - self.send(Message::Unwant(msg)).await - } - - /// Send a request message. - pub async fn request(&mut self, msg: Request) -> Result<()> { - self.send(Message::Request(msg)).await - } - - /// Send a cancel message. - pub async fn cancel(&mut self, msg: Cancel) -> Result<()> { - self.send(Message::Cancel(msg)).await - } - - /// Send a data message. - pub async fn data(&mut self, msg: Data) -> Result<()> { - self.send(Message::Data(msg)).await + /// Clone the local sending part of the channel receiver. Useful + /// for direct local communication to the channel listener. Typically + /// you will only want to send a LocalSignal message with this sender to make + /// it clear what event came from the remote peer and what was local + /// signaling. + pub fn local_sender(&mut self) -> Sender { + self.direct_inbound_tx.clone() } /// Send a close message and close this channel. @@ -144,12 +130,20 @@ impl Channel { return Ok(()); } let close = Close { - discovery_key: None, + channel: self.local_id as u64, }; self.send(Message::Close(close)).await?; self.closed.store(true, Ordering::SeqCst); Ok(()) } + + /// Signal the protocol to produce Event::LocalSignal. If you want to send a message + /// to the channel level, see take_receiver() and local_sender(). + pub async fn signal_local_protocol(&mut self, name: &str, data: Vec) -> Result<()> { + self.send(Message::LocalSignal((name.to_string(), data))) + .await?; + Ok(()) + } } impl Stream for Channel { @@ -159,24 +153,11 @@ impl Stream for Channel { cx: &mut std::task::Context<'_>, ) -> std::task::Poll> { let this = self.get_mut(); - loop { - match this.inbound_rx.as_mut() { - None => { - return Poll::Ready(None); - } - Some(ref mut inbound_rx) => { - let message = ready!(Pin::new(inbound_rx).poll_next(cx)); - match message { - Some(Message::Extension(msg)) => { - this.extensions.on_message(msg); - } - Some(Message::Options(ref msg)) => { - this.extensions.on_remote_update(msg.extensions.clone()); - return Poll::Ready(message); - } - _ => return Poll::Ready(message), - } - } + match this.inbound_rx.as_mut() { + None => Poll::Ready(None), + Some(ref mut inbound_rx) => { + let message = ready!(Pin::new(inbound_rx).poll_next(cx)); + Poll::Ready(message) } } } @@ -230,24 +211,24 @@ impl ChannelHandle { this } - pub fn discovery_key(&self) -> &[u8; 32] { + pub(crate) fn discovery_key(&self) -> &[u8; 32] { &self.discovery_key } - pub fn local_id(&self) -> Option { + pub(crate) fn local_id(&self) -> Option { self.local_state.as_ref().map(|s| s.local_id) } - pub fn remote_id(&self) -> Option { + pub(crate) fn remote_id(&self) -> Option { self.remote_state.as_ref().map(|s| s.remote_id) } - pub fn attach_local(&mut self, local_id: usize, key: Key) { + pub(crate) fn attach_local(&mut self, local_id: usize, key: Key) { let local_state = LocalState { local_id, key }; self.local_state = Some(local_state); } - pub fn attach_remote(&mut self, remote_id: usize, remote_capability: Option>) { + pub(crate) fn attach_remote(&mut self, remote_id: usize, remote_capability: Option>) { let remote_state = RemoteState { remote_id, remote_capability, @@ -255,11 +236,11 @@ impl ChannelHandle { self.remote_state = Some(remote_state); } - pub fn is_connected(&self) -> bool { + pub(crate) fn is_connected(&self) -> bool { self.local_state.is_some() && self.remote_state.is_some() } - pub fn prepare_to_verify(&self) -> Result<(&Key, Option<&Vec>)> { + pub(crate) fn prepare_to_verify(&self) -> Result<(&Key, Option<&Vec>)> { if !self.is_connected() { return Err(error("Channel is not opened from both local and remote")); } @@ -269,7 +250,7 @@ impl ChannelHandle { Ok((&local_state.key, remote_state.remote_capability.as_ref())) } - pub fn open(&mut self, outbound_tx: Sender) -> Channel { + pub(crate) fn open(&mut self, outbound_tx: Sender>) -> Channel { let local_state = self .local_state .as_ref() @@ -278,26 +259,43 @@ impl ChannelHandle { let (inbound_tx, inbound_rx) = async_channel::unbounded(); let channel = Channel { inbound_rx: Some(inbound_rx), - outbound_tx: outbound_tx.clone(), + direct_inbound_tx: inbound_tx.clone(), + outbound_tx, discovery_key: self.discovery_key, key: local_state.key, local_id: local_state.local_id, - extensions: Extensions::new(outbound_tx, local_state.local_id as u64), closed: self.closed.clone(), }; self.inbound_tx = Some(inbound_tx); channel } - pub fn try_send_inbound(&mut self, message: Message) -> std::io::Result<()> { + pub(crate) fn try_send_inbound(&mut self, message: Message) -> std::io::Result<()> { if let Some(inbound_tx) = self.inbound_tx.as_mut() { inbound_tx .try_send(message) - .map_err(|_e| error("Channel is full")) + .map_err(|e| error(format!("Sending to channel failed: {e}").as_str())) } else { Err(error("Channel is not open")) } } + + pub(crate) fn try_send_inbound_tolerate_closed( + &mut self, + message: Message, + ) -> std::io::Result<()> { + if let Some(inbound_tx) = self.inbound_tx.as_mut() { + if let Err(err) = inbound_tx.try_send(message) { + match err { + TrySendError::Full(e) => { + return Err(error(format!("Sending to channel failed: {e}").as_str())) + } + TrySendError::Closed(_) => {} + } + } + } + Ok(()) + } } impl Drop for ChannelHandle { @@ -315,7 +313,7 @@ pub(crate) struct ChannelMap { } impl ChannelMap { - pub fn new() -> Self { + pub(crate) fn new() -> Self { Self { channels: HashMap::new(), // Add a first None value to local_id to start ids at 1. @@ -325,9 +323,9 @@ impl ChannelMap { } } - pub fn attach_local(&mut self, key: Key) -> &ChannelHandle { + pub(crate) fn attach_local(&mut self, key: Key) -> &ChannelHandle { let discovery_key = discovery_key(&key); - let hdkey = hex::encode(&discovery_key); + let hdkey = hex::encode(discovery_key); let local_id = self.alloc_local(); self.channels @@ -339,13 +337,13 @@ impl ChannelMap { self.channels.get(&hdkey).unwrap() } - pub fn attach_remote( + pub(crate) fn attach_remote( &mut self, discovery_key: DiscoveryKey, remote_id: usize, remote_capability: Option>, ) -> &ChannelHandle { - let hdkey = hex::encode(&discovery_key); + let hdkey = hex::encode(discovery_key); self.alloc_remote(remote_id); self.channels .entry(hdkey.clone()) @@ -357,7 +355,7 @@ impl ChannelMap { self.channels.get(&hdkey).unwrap() } - pub fn get_remote_mut(&mut self, remote_id: usize) -> Option<&mut ChannelHandle> { + pub(crate) fn get_remote_mut(&mut self, remote_id: usize) -> Option<&mut ChannelHandle> { if let Some(Some(hdkey)) = self.remote_id.get(remote_id).as_ref() { self.channels.get_mut(hdkey) } else { @@ -365,7 +363,7 @@ impl ChannelMap { } } - pub fn get_remote(&self, remote_id: usize) -> Option<&ChannelHandle> { + pub(crate) fn get_remote(&self, remote_id: usize) -> Option<&ChannelHandle> { if let Some(Some(hdkey)) = self.remote_id.get(remote_id).as_ref() { self.channels.get(hdkey) } else { @@ -373,7 +371,7 @@ impl ChannelMap { } } - pub fn get_local_mut(&mut self, local_id: usize) -> Option<&mut ChannelHandle> { + pub(crate) fn get_local_mut(&mut self, local_id: usize) -> Option<&mut ChannelHandle> { if let Some(Some(hdkey)) = self.local_id.get(local_id).as_ref() { self.channels.get_mut(hdkey) } else { @@ -381,7 +379,7 @@ impl ChannelMap { } } - pub fn get_local(&self, local_id: usize) -> Option<&ChannelHandle> { + pub(crate) fn get_local(&self, local_id: usize) -> Option<&ChannelHandle> { if let Some(Some(hdkey)) = self.local_id.get(local_id).as_ref() { self.channels.get(hdkey) } else { @@ -389,7 +387,7 @@ impl ChannelMap { } } - pub fn remove(&mut self, discovery_key: &[u8]) { + pub(crate) fn remove(&mut self, discovery_key: &[u8]) { let hdkey = hex::encode(discovery_key); let channel = self.channels.get(&hdkey); if let Some(channel) = channel { @@ -403,17 +401,17 @@ impl ChannelMap { self.channels.remove(&hdkey); } - pub fn prepare_to_verify(&self, local_id: usize) -> Result<(&Key, Option<&Vec>)> { + pub(crate) fn prepare_to_verify(&self, local_id: usize) -> Result<(&Key, Option<&Vec>)> { let channel_handle = self .get_local(local_id) .ok_or_else(|| error("Channel not found"))?; channel_handle.prepare_to_verify() } - pub fn accept( + pub(crate) fn accept( &mut self, local_id: usize, - outbound_tx: Sender, + outbound_tx: Sender>, ) -> Result { let channel_handle = self .get_local_mut(local_id) @@ -425,15 +423,35 @@ impl ChannelMap { Ok(channel) } - pub fn forward_inbound_message(&mut self, remote_id: usize, message: Message) -> Result<()> { + pub(crate) fn forward_inbound_message( + &mut self, + remote_id: usize, + message: Message, + ) -> Result<()> { if let Some(channel_handle) = self.get_remote_mut(remote_id) { channel_handle.try_send_inbound(message)?; } Ok(()) } + pub(crate) fn forward_inbound_message_tolerate_closed( + &mut self, + remote_id: usize, + message: Message, + ) -> Result<()> { + if let Some(channel_handle) = self.get_remote_mut(remote_id) { + channel_handle.try_send_inbound_tolerate_closed(message)?; + } + Ok(()) + } + fn alloc_local(&mut self) -> usize { - let empty_id = self.local_id.iter().skip(1).position(|x| x.is_none()); + let empty_id = self + .local_id + .iter() + .skip(1) + .position(|x| x.is_none()) + .map(|position| position + 1); match empty_id { Some(empty_id) => empty_id, None => { diff --git a/src/constants.rs b/src/constants.rs index 01e19ea..77285ee 100644 --- a/src/constants.rs +++ b/src/constants.rs @@ -1,14 +1,15 @@ -/// Seed for the capability hash -pub const CAP_NS_BUF: &[u8] = b"hypercore capability"; - /// Seed for the discovery key hash -pub const DISCOVERY_NS_BUF: &[u8] = b"hypercore"; +pub(crate) const DISCOVERY_NS_BUF: &[u8] = b"hypercore"; /// Default timeout (in seconds) -pub const DEFAULT_TIMEOUT: u32 = 20; +pub(crate) const DEFAULT_TIMEOUT: u32 = 20; /// Default keepalive interval (in seconds) -pub const DEFAULT_KEEPALIVE: u32 = 10; +pub(crate) const DEFAULT_KEEPALIVE: u32 = 10; + +// 16,78MB is the max encrypted wire message size (will be much smaller usually). +// This limitation stems from the 24bit header. +pub(crate) const MAX_MESSAGE_SIZE: u64 = 0xFFFFFF; -// 4MB is the max wire message size (will be much smaller usually). -pub const MAX_MESSAGE_SIZE: u64 = 1024 * 1024 * 4; +/// v10: Protocol name +pub(crate) const PROTOCOL_NAME: &str = "hypercore/alpha"; diff --git a/src/crypto/cipher.rs b/src/crypto/cipher.rs new file mode 100644 index 0000000..c0e54a9 --- /dev/null +++ b/src/crypto/cipher.rs @@ -0,0 +1,186 @@ +use super::HandshakeResult; +use crate::util::{stat_uint24_le, write_uint24_le, UINT_24_LENGTH}; +use blake2::{ + digest::{typenum::U32, FixedOutput, Update}, + Blake2bMac, +}; +use crypto_secretstream::{Header, Key, PullStream, PushStream, Tag}; +use rand::rngs::OsRng; +use std::convert::TryInto; +use std::io; + +const STREAM_ID_LENGTH: usize = 32; +const KEY_LENGTH: usize = 32; +const HEADER_MSG_LEN: usize = UINT_24_LENGTH + STREAM_ID_LENGTH + Header::BYTES; + +pub(crate) struct DecryptCipher { + pull_stream: PullStream, +} + +pub(crate) struct EncryptCipher { + push_stream: PushStream, +} + +impl std::fmt::Debug for DecryptCipher { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "DecryptCipher(crypto_secretstream)") + } +} + +impl std::fmt::Debug for EncryptCipher { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "EncryptCipher(crypto_secretstream)") + } +} + +impl DecryptCipher { + pub(crate) fn from_handshake_rx_and_init_msg( + handshake_result: &HandshakeResult, + init_msg: &[u8], + ) -> io::Result { + if init_msg.len() < 32 + 24 { + return Err(io::Error::new( + io::ErrorKind::UnexpectedEof, + format!( + "Received too short init message, {} < {}.", + init_msg.len(), + 32 + 24 + ), + )); + } + + let key: [u8; KEY_LENGTH] = handshake_result.split_rx[..KEY_LENGTH] + .try_into() + .expect("split_rx with incorrect length"); + let key = Key::from(key); + let handshake_hash = handshake_result.handshake_hash.clone(); + let is_initiator = handshake_result.is_initiator; + + // Read the received message from the other peer + let mut expected_stream_id: [u8; 32] = [0; 32]; + write_stream_id(&handshake_hash, !is_initiator, &mut expected_stream_id); + let remote_stream_id: [u8; 32] = init_msg[0..32] + .try_into() + .expect("stream id slice with incorrect length"); + if expected_stream_id != remote_stream_id { + return Err(io::Error::new( + io::ErrorKind::PermissionDenied, + "Received stream id does not match expected".to_string(), + )); + } + + let header: [u8; 24] = init_msg[32..] + .try_into() + .expect("header slice with incorrect length"); + let pull_stream = PullStream::init(Header::from(header), &key); + Ok(Self { pull_stream }) + } + + pub(crate) fn decrypt( + &mut self, + buf: &mut [u8], + header_len: usize, + body_len: usize, + ) -> io::Result { + let (to_decrypt, _tag) = self.decrypt_buf(&buf[header_len..header_len + body_len])?; + let decrypted_len = to_decrypt.len(); + write_uint24_le(decrypted_len, buf); + let decrypted_end = 3 + to_decrypt.len(); + buf[3..decrypted_end].copy_from_slice(to_decrypt.as_slice()); + // Set extra bytes in the buffer to 0 + let encrypted_end = header_len + body_len; + buf[decrypted_end..encrypted_end].fill(0x00); + Ok(decrypted_end) + } + + pub(crate) fn decrypt_buf(&mut self, buf: &[u8]) -> io::Result<(Vec, Tag)> { + let mut to_decrypt = buf.to_vec(); + let tag = &self.pull_stream.pull(&mut to_decrypt, &[]).map_err(|err| { + io::Error::new(io::ErrorKind::Other, format!("Decrypt failed: {err}")) + })?; + Ok((to_decrypt, *tag)) + } +} + +impl EncryptCipher { + pub(crate) fn from_handshake_tx( + handshake_result: &HandshakeResult, + ) -> std::io::Result<(Self, Vec)> { + let key: [u8; KEY_LENGTH] = handshake_result.split_tx[..KEY_LENGTH] + .try_into() + .expect("split_tx with incorrect length"); + let key = Key::from(key); + + let mut header_message: [u8; HEADER_MSG_LEN] = [0; HEADER_MSG_LEN]; + write_uint24_le(STREAM_ID_LENGTH + Header::BYTES, &mut header_message); + write_stream_id( + &handshake_result.handshake_hash, + handshake_result.is_initiator, + &mut header_message[UINT_24_LENGTH..UINT_24_LENGTH + STREAM_ID_LENGTH], + ); + + let (header, push_stream) = PushStream::init(OsRng, &key); + let header = header.as_ref(); + header_message[UINT_24_LENGTH + STREAM_ID_LENGTH..].copy_from_slice(header); + let msg = header_message.to_vec(); + Ok((Self { push_stream }, msg)) + } + + /// Get the length needed for encryption, that includes padding. + pub(crate) fn safe_encrypted_len(&self, plaintext_len: usize) -> usize { + // ChaCha20-Poly1305 uses padding in two places, use two 15 bytes as a safe + // extra room. + // https://mailarchive.ietf.org/arch/msg/cfrg/u734TEOSDDWyQgE0pmhxjdncwvw/ + plaintext_len + 2 * 15 + } + + /// Encrypts message in the given buffer to the same buffer, returns number of bytes + /// of total message. + pub(crate) fn encrypt(&mut self, buf: &mut [u8]) -> io::Result { + let stat = stat_uint24_le(buf); + if let Some((header_len, body_len)) = stat { + let mut to_encrypt = buf[header_len..header_len + body_len as usize].to_vec(); + self.push_stream + .push(&mut to_encrypt, &[], Tag::Message) + .map_err(|err| { + io::Error::new(io::ErrorKind::Other, format!("Encrypt failed: {err}")) + })?; + let encrypted_len = to_encrypt.len(); + write_uint24_le(encrypted_len, buf); + buf[header_len..header_len + encrypted_len].copy_from_slice(to_encrypt.as_slice()); + Ok(3 + encrypted_len) + } else { + Err(io::Error::new( + io::ErrorKind::InvalidData, + format!("Could not encrypt invalid data, len: {}", buf.len()), + )) + } + } +} + +// NB: These values come from Javascript-side +// +// const [NS_INITIATOR, NS_RESPONDER] = crypto.namespace('hyperswarm/secret-stream', 2) +// +// at https://github.com/hyperswarm/secret-stream/blob/master/index.js +const NS_INITIATOR: [u8; 32] = [ + 0xa9, 0x31, 0xa0, 0x15, 0x5b, 0x5c, 0x09, 0xe6, 0xd2, 0x86, 0x28, 0x23, 0x6a, 0xf8, 0x3c, 0x4b, + 0x8a, 0x6a, 0xf9, 0xaf, 0x60, 0x98, 0x6e, 0xde, 0xed, 0xe9, 0xdc, 0x5d, 0x63, 0x19, 0x2b, 0xf7, +]; +const NS_RESPONDER: [u8; 32] = [ + 0x74, 0x2c, 0x9d, 0x83, 0x3d, 0x43, 0x0a, 0xf4, 0xc4, 0x8a, 0x87, 0x05, 0xe9, 0x16, 0x31, 0xee, + 0xcf, 0x29, 0x54, 0x42, 0xbb, 0xca, 0x18, 0x99, 0x6e, 0x59, 0x70, 0x97, 0x72, 0x3b, 0x10, 0x61, +]; + +fn write_stream_id(handshake_hash: &[u8], is_initiator: bool, out: &mut [u8]) { + let mut hasher = + Blake2bMac::::new_with_salt_and_personal(handshake_hash, &[], &[]).unwrap(); + if is_initiator { + hasher.update(&NS_INITIATOR); + } else { + hasher.update(&NS_RESPONDER); + } + let result = hasher.finalize_fixed(); + let result = result.as_slice(); + out.copy_from_slice(result); +} diff --git a/src/crypto/curve.rs b/src/crypto/curve.rs new file mode 100644 index 0000000..48ed841 --- /dev/null +++ b/src/crypto/curve.rs @@ -0,0 +1,101 @@ +use hypercore::{generate_signing_key, SecretKey, SigningKey, VerifyingKey}; +use sha2::Digest; +use snow::{ + params::{CipherChoice, DHChoice, HashChoice}, + resolvers::CryptoResolver, + types::{Cipher, Dh, Hash, Random}, +}; +use std::convert::TryInto; + +/// Wraps ed25519-dalek compatible keypair +#[derive(Default)] +struct Ed25519 { + privkey: [u8; 32], + pubkey: [u8; 32], +} + +impl Dh for Ed25519 { + fn name(&self) -> &'static str { + "Ed25519" + } + + fn pub_len(&self) -> usize { + 32 + } + + fn priv_len(&self) -> usize { + 32 + } + + fn set(&mut self, privkey: &[u8]) { + let secret: SecretKey = privkey + .try_into() + .expect("Can't use given bytes as SecretKey"); + let public: VerifyingKey = SigningKey::from(&secret).verifying_key(); + self.privkey[..privkey.len()].copy_from_slice(privkey); + let public_key_bytes = public.as_bytes(); + self.pubkey[..public_key_bytes.len()].copy_from_slice(public_key_bytes); + } + + fn generate(&mut self, _: &mut dyn Random) { + // NB: Given Random can't be used with ed25519_dalek's SigningKey::generate(), + // use OS's random here from hypercore. + let signing_key = generate_signing_key(); + let secret_key_bytes = signing_key.to_bytes(); + self.privkey[..secret_key_bytes.len()].copy_from_slice(&secret_key_bytes); + let verifying_key = signing_key.verifying_key(); + let public_key_bytes = verifying_key.as_bytes(); + self.pubkey[..public_key_bytes.len()].copy_from_slice(public_key_bytes); + } + + fn pubkey(&self) -> &[u8] { + &self.pubkey + } + + fn privkey(&self) -> &[u8] { + &self.privkey + } + + fn dh(&self, pubkey: &[u8], out: &mut [u8]) -> Result<(), snow::Error> { + let sk: [u8; 32] = sha2::Sha512::digest(self.privkey).as_slice()[..32] + .try_into() + .unwrap(); + // PublicKey is a CompressedEdwardsY in dalek. So we decompress it to get the + // EdwardsPoint and use variable base multiplication. + let cey = + curve25519_dalek::edwards::CompressedEdwardsY::from_slice(&pubkey[..self.pub_len()]) + .map_err(|_| snow::Error::Dh)?; + let pubkey: curve25519_dalek::edwards::EdwardsPoint = match cey.decompress() { + Some(ep) => Ok(ep), + None => Err(snow::Error::Dh), + }?; + let result = pubkey.mul_clamped(sk); + let result: [u8; 32] = *result.compress().as_bytes(); + out[..result.len()].copy_from_slice(result.as_slice()); + Ok(()) + } +} + +#[derive(Default)] +pub(super) struct CurveResolver; + +impl CryptoResolver for CurveResolver { + fn resolve_dh(&self, choice: &DHChoice) -> Option> { + match *choice { + DHChoice::Curve25519 => Some(Box::::default()), + _ => None, + } + } + + fn resolve_rng(&self) -> Option> { + None + } + + fn resolve_hash(&self, _choice: &HashChoice) -> Option> { + None + } + + fn resolve_cipher(&self, _choice: &CipherChoice) -> Option> { + None + } +} diff --git a/src/crypto/handshake.rs b/src/crypto/handshake.rs new file mode 100644 index 0000000..64db407 --- /dev/null +++ b/src/crypto/handshake.rs @@ -0,0 +1,241 @@ +use super::curve::CurveResolver; +use crate::util::wrap_uint24_le; +use blake2::{ + digest::{typenum::U32, FixedOutput, Update}, + Blake2bMac, +}; +use snow::resolvers::{DefaultResolver, FallbackResolver}; +use snow::{Builder, Error as SnowError, HandshakeState}; +use std::io::{Error, ErrorKind, Result}; + +const CIPHERKEYLEN: usize = 32; +const HANDSHAKE_PATTERN: &str = "Noise_XX_Ed25519_ChaChaPoly_BLAKE2b"; + +// These the output of, see `hash_namespace` test below for how they are produced +// https://github.com/hypercore-protocol/hypercore/blob/70b271643c4e4b1e5ecae5bb579966dfe6361ff3/lib/caps.js#L9 +const REPLICATE_INITIATOR: [u8; 32] = [ + 0x51, 0x81, 0x2A, 0x2A, 0x35, 0x9B, 0x50, 0x36, 0x95, 0x36, 0x77, 0x5D, 0xF8, 0x9E, 0x18, 0xE4, + 0x77, 0x40, 0xF3, 0xDB, 0x72, 0xAC, 0xA, 0xE7, 0xB, 0x29, 0x59, 0x4C, 0x19, 0x4D, 0xC3, 0x16, +]; +const REPLICATE_RESPONDER: [u8; 32] = [ + 0x4, 0x38, 0x49, 0x2D, 0x2, 0x97, 0xC, 0xC1, 0x35, 0x28, 0xAC, 0x2, 0x62, 0xBC, 0xA0, 0x7, + 0x4E, 0x9, 0x26, 0x26, 0x2, 0x56, 0x86, 0x5A, 0xCC, 0xC0, 0xBF, 0x15, 0xBD, 0x79, 0x12, 0x7D, +]; + +#[derive(Debug, Clone, Default)] +pub(crate) struct HandshakeResult { + pub(crate) is_initiator: bool, + pub(crate) local_pubkey: Vec, + pub(crate) remote_pubkey: Vec, + pub(crate) handshake_hash: Vec, + pub(crate) split_tx: [u8; CIPHERKEYLEN], + pub(crate) split_rx: [u8; CIPHERKEYLEN], +} + +impl HandshakeResult { + pub(crate) fn capability(&self, key: &[u8]) -> Option> { + Some(replicate_capability( + self.is_initiator, + key, + &self.handshake_hash, + )) + } + + pub(crate) fn remote_capability(&self, key: &[u8]) -> Option> { + Some(replicate_capability( + !self.is_initiator, + key, + &self.handshake_hash, + )) + } + + pub(crate) fn verify_remote_capability( + &self, + capability: Option>, + key: &[u8], + ) -> Result<()> { + let expected_capability = self.remote_capability(key); + match (capability, expected_capability) { + (Some(c1), Some(c2)) if c1 == c2 => Ok(()), + (None, None) => Err(Error::new( + ErrorKind::PermissionDenied, + "Missing capabilities for verification", + )), + _ => Err(Error::new( + ErrorKind::PermissionDenied, + "Invalid remote channel capability", + )), + } + } +} + +pub(crate) struct Handshake { + result: HandshakeResult, + state: HandshakeState, + payload: Vec, + tx_buf: Vec, + rx_buf: Vec, + complete: bool, + did_receive: bool, +} + +impl Handshake { + pub(crate) fn new(is_initiator: bool) -> Result { + let (state, local_pubkey) = build_handshake_state(is_initiator).map_err(map_err)?; + + let payload = vec![]; + let result = HandshakeResult { + is_initiator, + local_pubkey, + ..Default::default() + }; + Ok(Self { + state, + result, + payload, + tx_buf: vec![0u8; 512], + rx_buf: vec![0u8; 512], + complete: false, + did_receive: false, + }) + } + + pub(crate) fn start(&mut self) -> Result>> { + if self.is_initiator() { + let tx_len = self.send()?; + let wrapped = wrap_uint24_le(&self.tx_buf[..tx_len].to_vec()); + Ok(Some(wrapped)) + } else { + Ok(None) + } + } + + pub(crate) fn complete(&self) -> bool { + self.complete + } + + pub(crate) fn is_initiator(&self) -> bool { + self.result.is_initiator + } + + fn recv(&mut self, msg: &[u8]) -> Result { + self.state + .read_message(msg, &mut self.rx_buf) + .map_err(map_err) + } + fn send(&mut self) -> Result { + self.state + .write_message(&self.payload, &mut self.tx_buf) + .map_err(map_err) + } + + pub(crate) fn read(&mut self, msg: &[u8]) -> Result>> { + // eprintln!("hs read len {}", msg.len()); + if self.complete() { + return Err(Error::new(ErrorKind::Other, "Handshake read after finish")); + } + + let _rx_len = self.recv(msg)?; + + if !self.is_initiator() && !self.did_receive { + self.did_receive = true; + let tx_len = self.send()?; + let wrapped = wrap_uint24_le(&self.tx_buf[..tx_len].to_vec()); + return Ok(Some(wrapped)); + } + + let tx_buf = if self.is_initiator() { + let tx_len = self.send()?; + let wrapped = wrap_uint24_le(&self.tx_buf[..tx_len].to_vec()); + Some(wrapped) + } else { + None + }; + + let split = self.state.dangerously_get_raw_split(); + if self.is_initiator() { + self.result.split_tx = split.0; + self.result.split_rx = split.1; + } else { + self.result.split_tx = split.1; + self.result.split_rx = split.0; + } + self.result.remote_pubkey = self + .state + .get_remote_static() + .expect("Could not read remote static key after handshake") + .to_vec(); + self.result.handshake_hash = self.state.get_handshake_hash().to_vec(); + self.complete = true; + Ok(tx_buf) + } + + pub(crate) fn into_result(self) -> Result { + if !self.complete() { + Err(Error::new(ErrorKind::Other, "Handshake is not complete")) + } else { + Ok(self.result) + } + } +} + +fn build_handshake_state( + is_initiator: bool, +) -> std::result::Result<(HandshakeState, Vec), SnowError> { + use snow::params::{ + BaseChoice, CipherChoice, DHChoice, HandshakeChoice, HandshakeModifierList, + HandshakePattern, HashChoice, NoiseParams, + }; + // NB: HANDSHAKE_PATTERN.parse() doesn't work because the pattern has "Ed25519" + // instead of "25519". + let noise_params = NoiseParams::new( + HANDSHAKE_PATTERN.to_string(), + BaseChoice::Noise, + HandshakeChoice { + pattern: HandshakePattern::XX, + modifiers: HandshakeModifierList { list: vec![] }, + }, + DHChoice::Curve25519, + CipherChoice::ChaChaPoly, + HashChoice::Blake2b, + ); + let builder: Builder<'_> = Builder::with_resolver( + noise_params, + Box::new(FallbackResolver::new( + Box::::default(), + Box::::default(), + )), + ); + let key_pair = builder.generate_keypair().unwrap(); + let builder = builder.local_private_key(&key_pair.private); + let handshake_state = if is_initiator { + tracing::debug!("building initiator"); + builder.build_initiator()? + } else { + tracing::debug!("building responder"); + builder.build_responder()? + }; + Ok((handshake_state, key_pair.public)) +} + +fn map_err(e: SnowError) -> Error { + Error::new(ErrorKind::PermissionDenied, format!("Handshake error: {e}")) +} + +/// Create a hash used to indicate replication capability. +/// See https://github.com/hypercore-protocol/hypercore/blob/70b271643c4e4b1e5ecae5bb579966dfe6361ff3/lib/caps.js#L11 +fn replicate_capability(is_initiator: bool, key: &[u8], handshake_hash: &[u8]) -> Vec { + let seed = if is_initiator { + REPLICATE_INITIATOR + } else { + REPLICATE_RESPONDER + }; + + let mut hasher = + Blake2bMac::::new_with_salt_and_personal(handshake_hash, &[], &[]).unwrap(); + hasher.update(&seed); + hasher.update(key); + let hash = hasher.finalize_fixed(); + let capability = hash.as_slice().to_vec(); + capability +} diff --git a/src/crypto/mod.rs b/src/crypto/mod.rs new file mode 100644 index 0000000..66bb62d --- /dev/null +++ b/src/crypto/mod.rs @@ -0,0 +1,5 @@ +mod cipher; +mod curve; +mod handshake; +pub(crate) use cipher::{DecryptCipher, EncryptCipher}; +pub(crate) use handshake::{Handshake, HandshakeResult}; diff --git a/src/extension.rs b/src/extension.rs deleted file mode 100644 index e93d8e5..0000000 --- a/src/extension.rs +++ /dev/null @@ -1,253 +0,0 @@ -use crate::constants::MAX_MESSAGE_SIZE; -use crate::message::{ChannelMessage, ExtensionMessage, Message}; -use crate::schema::*; -use async_channel::{Receiver, Sender}; -use futures_lite::{ready, AsyncRead, AsyncWrite, FutureExt, Stream}; -use std::collections::HashMap; -use std::future::Future; -use std::io; -use std::pin::Pin; -use std::task::{Context, Poll}; - -const MAX_BODY_SIZE: usize = MAX_MESSAGE_SIZE as usize - 16; - -#[derive(Debug)] -pub struct Extensions { - extensions: HashMap, - channel: u64, - local_ids: Vec, - remote_ids: Vec, - outbound_tx: Sender, -} - -impl Extensions { - pub fn new(outbound_tx: Sender, channel: u64) -> Self { - Self { - channel, - extensions: HashMap::new(), - local_ids: vec![], - remote_ids: vec![], - outbound_tx, - } - } - - pub fn add_local_name(&mut self, name: String) -> u64 { - self.local_ids.push(name.clone()); - self.local_ids.sort(); - let local_id = self.local_ids.iter().position(|x| x == &name).unwrap(); - local_id as u64 - } - - pub async fn register(&mut self, name: String) -> Extension { - let local_id = self.add_local_name(name.clone()); - let (inbound_tx, inbound_rx) = async_channel::unbounded(); - let handle = ExtensionHandle { - name: name.clone(), - channel: self.channel, - local_id, - inbound_tx, - }; - let extension = Extension { - name: name.clone(), - channel: self.channel, - local_id, - outbound_tx: self.outbound_tx.clone(), - inbound_rx, - write_state: WriteState::Idle, - read_state: None, - }; - self.extensions.insert(name, handle); - - let message = Options { - extensions: self.local_ids.clone(), - ack: None, - }; - let message = ChannelMessage::new(self.channel, Message::Options(message)); - self.outbound_tx.send(message).await.unwrap(); - - extension - } - - pub fn on_remote_update(&mut self, names: Vec) { - self.remote_ids = names; - } - - pub fn on_message(&mut self, message: ExtensionMessage) { - let ExtensionMessage { id, message } = message; - if let Some(name) = self.remote_ids.get(id as usize) { - if let Some(handle) = self.extensions.get_mut(name) { - handle.inbound_send(message); - } - } - } -} - -#[derive(Debug)] -pub struct ExtensionHandle { - name: String, - channel: u64, - local_id: u64, - inbound_tx: Sender>, -} - -impl ExtensionHandle { - fn inbound_send(&mut self, message: Vec) { - // This should be safe because inbound_tx is an unbounded channel, - // and is only dropped when the whole channel is dropped. - let _ = self.inbound_tx.try_send(message); - } -} - -/// A protocol extension. -/// -/// An extension can be registered on either the [`Protocol` stream] or on -/// any [`Channel`]. An extension is identified by a string. When both peers -/// open an extension with the same name, the extensions are connected. Then, they function as a -/// binary duplex stream. The stream is fully encrypted, but there's no authentication -/// performed on individual messages. -/// -/// The Extension struct implements both [`AsyncRead`] and [`AsyncWrite`] -/// and is also a [`Stream`]. You should use the extension either as a stream or as -/// an async reader; if being used as both, the messages would appear in either poll randomly. -/// -/// [`Channel`]: crate::Channel -/// [`Stream`]: futures_lite::Stream -/// [`AsyncRead`]: futures_lite::AsyncRead -/// [`AsyncWrite`]: futures_lite::AsyncWrite -/// [`Protocol` stream]: crate::Protocol -#[derive(Debug)] -pub struct Extension { - name: String, - channel: u64, - local_id: u64, - outbound_tx: Sender, - inbound_rx: Receiver>, - write_state: WriteState, - read_state: Option>, -} - -impl std::clone::Clone for Extension { - fn clone(&self) -> Self { - Self { - name: self.name.clone(), - channel: self.channel, - local_id: self.local_id, - outbound_tx: self.outbound_tx.clone(), - inbound_rx: self.inbound_rx.clone(), - write_state: WriteState::Idle, - read_state: None, - } - } -} - -type SendFuture = Pin> + Send + Sync + 'static>>; - -enum WriteState { - Sending(SendFuture, usize), - Idle, -} - -impl std::fmt::Debug for WriteState { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - WriteState::Sending(_, len) => { - write!(f, "Sending(len={})", len) - } - WriteState::Idle => write!(f, "Idle"), - } - } -} - -impl Extension { - /// Send a message - pub async fn send(&self, message: Vec) { - let message = ExtensionMessage::new(self.local_id, message); - let message = ChannelMessage::new(self.channel, Message::Extension(message)); - self.outbound_tx.send(message).await.unwrap() - } - - fn send_pinned(&self, message: Vec) -> SendFuture { - let message = ExtensionMessage::new(self.local_id, message); - let message = ChannelMessage::new(self.channel, Message::Extension(message)); - // TODO: It would be nice to do this without cloning, but I didn't find a way so far. - let fut = send_message(self.outbound_tx.clone(), message); - Box::pin(fut) - } -} - -pub async fn send_message( - sender: Sender, - message: ChannelMessage, -) -> io::Result<()> { - sender - .send(message) - .await - .map_err(|e| io::Error::new(io::ErrorKind::Interrupted, format!("Channel error: {}", e))) -} - -impl Stream for Extension { - type Item = Vec; - fn poll_next( - mut self: Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> std::task::Poll> { - Pin::new(&mut self.inbound_rx).poll_next(cx) - } -} - -impl AsyncRead for Extension { - fn poll_read( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &mut [u8], - ) -> Poll> { - let mut this = self.get_mut(); - let message = if let Some(message) = this.read_state.take() { - message - } else { - let message = ready!(Pin::new(&mut this).poll_next(cx)); - message.ok_or_else(|| io::Error::new(io::ErrorKind::Interrupted, "Channel closed"))? - }; - let len = message.len().min(buf.len()); - buf[..len].copy_from_slice(&message[..len]); - if message.len() > len { - this.read_state = Some(message[len..].to_vec()); - } else { - this.read_state = None - } - Poll::Ready(Ok(len)) - } -} - -impl AsyncWrite for Extension { - fn poll_write( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &[u8], - ) -> Poll> { - let this = self.get_mut(); - loop { - match this.write_state { - WriteState::Idle => { - let len = buf.len().min(MAX_BODY_SIZE); - let fut = this.send_pinned(buf.to_vec()); - this.write_state = WriteState::Sending(fut, len); - } - WriteState::Sending(ref mut fut, len) => { - let res = ready!(fut.poll(cx)); - let res = res.map(|_| len); - this.write_state = WriteState::Idle; - return Poll::Ready(res); - } - } - } - } - - fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { - Poll::Ready(Ok(())) - } - - fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { - Poll::Ready(Ok(())) - } -} diff --git a/src/lib.rs b/src/lib.rs index 9b68d1e..f0d7866 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,24 +1,49 @@ -//! Hypercore protocol is a streaming, message based protocol +//! ## Introduction +//! +//! Hypercore protocol is a streaming, message based protocol. This is a rust port of the wire +//! protocol implementation in [the original Javascript version][holepunch-hypercore] aiming +//! for interoperability with LTS version. +//! +//! This crate is built on top of the [hypercore] crate, which defines some structs used here. +//! +//! ## Design //! //! This crate does not include any IO related code, it is up to the user to supply a streaming IO //! handler that implements the [AsyncRead] and [AsyncWrite] traits. //! //! When opening a Hypercore protocol stream on an IO handler, the protocol will perform a Noise -//! handshake to setup a secure and authenticated connection. After that, each side can request any -//! number of channels on the protocol. A channel is opened with a [Key], a 32 byte buffer. -//! Channels are only opened if both peers opened a channel for the same key. It is automatically -//! verified that both parties know the key without transmitting the key itself. -//! -//! On a channel, the predefined messages of the Hypercore protocol can be sent and received. -//! Additionally, Hypercore protocol supports protocol extensions that can be registered both on an -//! individual channel and on the main protocol stream. Extensions are registered with a string -//! name and are only established if both peers register an extension with the same name. Each -//! extension then can be used as a duplex stream. Note that individual messages on an extension -//! stream are enrypted but not authenticated. +//! handshake followed by libsodium's [crypto_secretstream] to setup a secure and authenticated +//! connection. After that, each side can request any number of channels on the protocol. A +//! channel is opened with a [Key], a 32 byte buffer. Channels are only opened if both peers +//! opened a channel for the same key. It is automatically verified that both parties know the +//! key without transmitting the key itself. //! -//! [AsyncRead]: futures_lite::AsyncRead -//! [AsyncWrite]: futures_lite::AsyncWrite -//! [TcpStream]: async_std::net::TcpStream +//! On a channel, the predefined messages, including a custom Extension message, of the Hypercore +//! protocol can be sent and received. +//! +//! ## Features +//! +//! ### `sparse` (default) +//! +//! When using disk storage for hypercore, clearing values may create sparse files. On by default. +//! +//! ### `async-std` (default) +//! +//! Use the async-std runtime, on by default. Either this or `tokio` is mandatory. +//! +//! ### `tokio` +//! +//! Use the tokio runtime. Either this or `async_std` is mandatory. +//! +//! ### `wasm-bindgen` +//! +//! Enable for WASM runtime support. +//! +//! ### `cache` +//! +//! Use a moka cache for hypercore's merkle tree nodes to speed-up reading. +//! +//! ## Example //! //! The following example opens a TCP server on localhost and connects to that server. Both ends //! then open a channel with the same key and exchange a message. @@ -28,7 +53,6 @@ //! use hypercore_protocol::{ProtocolBuilder, Event, Message}; //! use hypercore_protocol::schema::*; //! use async_std::prelude::*; -//! //! // Start a tcp server. //! let listener = async_std::net::TcpListener::bind("localhost:8000").await.unwrap(); //! async_std::task::spawn(async move { @@ -55,6 +79,7 @@ //! let mut protocol = ProtocolBuilder::new(is_initiator).connect(stream); //! //! // Iterate over the protocol events. This is required to "drive" the protocol. +//! //! while let Some(Ok(event)) = protocol.next().await { //! eprintln!("{} received event {:?}", name, event); //! match event { @@ -67,7 +92,7 @@ //! // A Channel can be sent to other tasks. //! async_std::task::spawn(async move { //! // A Channel can both send messages and is a stream of incoming messages. -//! channel.want(Want { start: 0, length: None }).await; +//! channel.send(Message::Want(Want { start: 0, length: 1 })).await; //! while let Some(message) = channel.next().await { //! eprintln!("{} received message: {:?}", name, message); //! } @@ -79,35 +104,41 @@ //! } //! # }) //! ``` +//! +//! Find more examples in the [Github repository][examples]. +//! +//! [holepunch-hypercore]: https://github.com/holepunchto/hypercore +//! [datrs-hypercore]: https://github.com/datrs/hypercore +//! [AsyncRead]: futures_lite::AsyncRead +//! [AsyncWrite]: futures_lite::AsyncWrite +//! [examples]: https://github.com/datrs/hypercore-protocol-rs#examples #![forbid(unsafe_code, future_incompatible, rust_2018_idioms)] #![deny(missing_debug_implementations, nonstandard_style)] -// #![warn(missing_docs, missing_doc_code_examples, unreachable_pub)] -#![warn(missing_docs, missing_doc_code_examples)] +#![warn(missing_docs, unreachable_pub)] mod builder; mod channels; mod constants; +mod crypto; mod duplex; -mod extension; mod message; -mod noise; mod protocol; mod reader; mod util; mod writer; /// The wire messages used by the protocol. -#[allow(missing_docs)] -pub mod schema { - include!(concat!(env!("OUT_DIR"), "/hypercore.schema.rs")); - pub use crate::message::ExtensionMessage; -} +pub mod schema; -pub use builder::{Builder as ProtocolBuilder, Options}; +pub use builder::Builder as ProtocolBuilder; pub use channels::Channel; +// Export the needed types for Channel::take_receiver, and Channel::local_sender() +pub use async_channel::{ + Receiver as ChannelReceiver, SendError as ChannelSendError, Sender as ChannelSender, +}; pub use duplex::Duplex; -pub use extension::Extension; +pub use hypercore; // Re-export hypercore pub use message::Message; pub use protocol::{DiscoveryKey, Event, Key, Protocol}; pub use util::discovery_key; diff --git a/src/message.rs b/src/message.rs index f269d15..30344d6 100644 --- a/src/message.rs +++ b/src/message.rs @@ -1,142 +1,344 @@ use crate::schema::*; +use crate::util::{stat_uint24_le, write_uint24_le}; +use hypercore::encoding::{ + CompactEncoding, EncodingError, EncodingErrorKind, HypercoreState, State, +}; use pretty_hash::fmt as pretty_fmt; -use prost::Message as _; use std::fmt; use std::io; -use crate::constants::MAX_MESSAGE_SIZE; - -/// Error if the buffer has insufficient size to encode a message. -#[derive(Debug)] -pub struct EncodeError { - required: usize, -} - -impl fmt::Display for EncodeError { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "Cannot encode message: Write buffer is full") - } -} - -impl EncodeError { - fn new(required: usize) -> Self { - Self { required } - } -} - -impl From for EncodeError { - fn from(e: prost::EncodeError) -> Self { - Self::new(e.required_capacity()) - } -} - -impl From for io::Error { - fn from(e: EncodeError) -> Self { - io::Error::new(io::ErrorKind::Other, format!("{}", e)) - } +/// The type of a data frame. +#[derive(Debug, Clone, PartialEq)] +pub(crate) enum FrameType { + Raw, + Message, } /// Encode data into a buffer. /// /// This trait is implemented on data frames and their components /// (channel messages, messages, and individual message types through prost). -pub trait Encoder: Sized + fmt::Debug { +pub(crate) trait Encoder: Sized + fmt::Debug { /// Calculates the length that the encoded message needs. - fn encoded_len(&self) -> usize; + fn encoded_len(&mut self) -> Result; /// Encodes the message to a buffer. /// /// An error will be returned if the buffer does not have sufficient capacity. - fn encode(&self, buf: &mut [u8]) -> Result; + fn encode(&mut self, buf: &mut [u8]) -> Result; } impl Encoder for &[u8] { - fn encoded_len(&self) -> usize { - self.len() + fn encoded_len(&mut self) -> Result { + Ok(self.len()) } - fn encode(&self, buf: &mut [u8]) -> Result { - let len = self.encoded_len(); + fn encode(&mut self, buf: &mut [u8]) -> Result { + let len = self.encoded_len()?; if len > buf.len() { - return Err(EncodeError::new(len)); + return Err(EncodingError::new( + EncodingErrorKind::Overflow, + &format!("Length does not fit buffer, {} > {}", len, buf.len()), + )); } buf[..len].copy_from_slice(&self[..]); Ok(len) } } -/// The type of a data frame. -#[derive(Debug, Clone, PartialEq)] -pub enum FrameType { - Raw, - Message, -} - /// A frame of data, either a buffer or a message. #[derive(Clone, PartialEq)] -pub enum Frame { - /// A raw binary buffer. Used in the handshaking phase. - Raw(Vec), - /// A message. Used for everything after the handshake. - Message(ChannelMessage), +pub(crate) enum Frame { + /// A raw batch binary buffer. Used in the handshaking phase. + RawBatch(Vec>), + /// Message batch, containing one or more channel messsages. Used for everything after the handshake. + MessageBatch(Vec), } impl fmt::Debug for Frame { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { - Frame::Raw(buf) => write!(f, "Frame(Raw <{}>)", buf.len()), - Frame::Message(message) => write!(f, "Frame({:?})", message), + Frame::RawBatch(batch) => write!(f, "Frame(RawBatch <{}>)", batch.len()), + Frame::MessageBatch(messages) => write!(f, "Frame({messages:?})"), } } } impl From for Frame { fn from(m: ChannelMessage) -> Self { - Self::Message(m) + Self::MessageBatch(vec![m]) } } impl From> for Frame { fn from(m: Vec) -> Self { - Self::Raw(m) + Self::RawBatch(vec![m]) } } impl Frame { + /// Decodes a frame from a buffer containing multiple concurrent messages. + pub(crate) fn decode_multiple(buf: &[u8], frame_type: &FrameType) -> Result { + match frame_type { + FrameType::Raw => { + let mut index = 0; + let mut raw_batch: Vec> = vec![]; + while index < buf.len() { + // There might be zero bytes in between, and with LE, the next message will + // start with a non-zero + if buf[index] == 0 { + index += 1; + continue; + } + let stat = stat_uint24_le(&buf[index..]); + if let Some((header_len, body_len)) = stat { + raw_batch.push( + buf[index + header_len..index + header_len + body_len as usize] + .to_vec(), + ); + index += header_len + body_len as usize; + } else { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + "received invalid data in raw batch", + )); + } + } + Ok(Frame::RawBatch(raw_batch)) + } + FrameType::Message => { + let mut index = 0; + let mut combined_messages: Vec = vec![]; + while index < buf.len() { + // There might be zero bytes in between, and with LE, the next message will + // start with a non-zero + if buf[index] == 0 { + index += 1; + continue; + } + + let stat = stat_uint24_le(&buf[index..]); + if let Some((header_len, body_len)) = stat { + let (frame, length) = Self::decode_message( + &buf[index + header_len..index + header_len + body_len as usize], + )?; + if length != body_len as usize { + tracing::warn!( + "Did not know what to do with all the bytes, got {} but decoded {}", + body_len, + length + ); + } + if let Frame::MessageBatch(messages) = frame { + for message in messages { + combined_messages.push(message); + } + } else { + unreachable!("Can not get Raw messages"); + } + index += header_len + body_len as usize; + } else { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + "received invalid data in multi-message chunk", + )); + } + } + Ok(Frame::MessageBatch(combined_messages)) + } + } + } + /// Decode a frame from a buffer. - pub fn decode(buf: &[u8], frame_type: &FrameType) -> Result { + pub(crate) fn decode(buf: &[u8], frame_type: &FrameType) -> Result { match frame_type { - FrameType::Raw => Ok(Frame::Raw(buf.to_vec())), - FrameType::Message => Ok(Frame::Message(ChannelMessage::decode(buf)?)), + FrameType::Raw => Ok(Frame::RawBatch(vec![buf.to_vec()])), + FrameType::Message => { + let (frame, _) = Self::decode_message(buf)?; + Ok(frame) + } + } + } + + fn decode_message(buf: &[u8]) -> Result<(Self, usize), io::Error> { + if buf.len() >= 3 && buf[0] == 0x00 { + if buf[1] == 0x00 { + // Batch of messages + let mut messages: Vec = vec![]; + let mut state = State::new_with_start_and_end(2, buf.len()); + + // First, there is the original channel + let mut current_channel: u64 = state.decode(buf)?; + while state.start() < state.end() { + // Length of the message is inbetween here + let channel_message_length: usize = state.decode(buf)?; + if state.start() + channel_message_length > state.end() { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + format!( + "received invalid message length, {} + {} > {}", + state.start(), + channel_message_length, + state.end() + ), + )); + } + // Then the actual message + let (channel_message, _) = ChannelMessage::decode( + &buf[state.start()..state.start() + channel_message_length], + current_channel, + )?; + messages.push(channel_message); + state.add_start(channel_message_length)?; + // After that, if there is an extra 0x00, that means the channel + // changed. This works because of LE encoding, and channels starting + // from the index 1. + if state.start() < state.end() && buf[state.start()] == 0x00 { + state.add_start(1)?; + current_channel = state.decode(buf)?; + } + } + Ok((Frame::MessageBatch(messages), state.start())) + } else if buf[1] == 0x01 { + // Open message + let (channel_message, length) = ChannelMessage::decode_open_message(&buf[2..])?; + Ok((Frame::MessageBatch(vec![channel_message]), length)) + } else if buf[1] == 0x03 { + // Close message + let (channel_message, length) = ChannelMessage::decode_close_message(&buf[2..])?; + Ok((Frame::MessageBatch(vec![channel_message]), length)) + } else { + Err(io::Error::new( + io::ErrorKind::InvalidData, + "received invalid special message", + )) + } + } else if buf.len() >= 2 { + // Single message + let mut state = State::from_buffer(buf); + let channel: u64 = state.decode(buf)?; + let (channel_message, length) = ChannelMessage::decode(&buf[state.start()..], channel)?; + Ok(( + Frame::MessageBatch(vec![channel_message]), + state.start() + length, + )) + } else { + Err(io::Error::new( + io::ErrorKind::InvalidData, + format!("received too short message, {buf:02X?}"), + )) } } - fn body_len(&self) -> usize { + fn preencode(&mut self, state: &mut State) -> Result { match self { - Self::Raw(message) => message.as_slice().encoded_len(), - Self::Message(message) => message.encoded_len(), + Self::RawBatch(raw_batch) => { + for raw in raw_batch { + state.add_end(raw.as_slice().encoded_len()?)?; + } + } + #[allow(clippy::comparison_chain)] + Self::MessageBatch(messages) => { + if messages.len() == 1 { + if let Message::Open(_) = &messages[0].message { + // This is a special case with 0x00, 0x01 intro bytes + state.add_end(2 + &messages[0].encoded_len()?)?; + } else if let Message::Close(_) = &messages[0].message { + // This is a special case with 0x00, 0x03 intro bytes + state.add_end(2 + &messages[0].encoded_len()?)?; + } else { + (*state).preencode(&messages[0].channel)?; + state.add_end(messages[0].encoded_len()?)?; + } + } else if messages.len() > 1 { + // Two intro bytes 0x00 0x00, then channel id, then lengths + state.add_end(2)?; + let mut current_channel: u64 = messages[0].channel; + state.preencode(¤t_channel)?; + for message in messages.iter_mut() { + if message.channel != current_channel { + // Channel changed, need to add a 0x00 in between and then the new + // channel + state.add_end(1)?; + state.preencode(&message.channel)?; + current_channel = message.channel; + } + let message_length = message.encoded_len()?; + state.preencode(&message_length)?; + state.add_end(message_length)?; + } + } + } } + Ok(state.end()) } } impl Encoder for Frame { - fn encoded_len(&self) -> usize { - let body_len = self.body_len(); - body_len + varinteger::length(body_len as u64) + fn encoded_len(&mut self) -> Result { + let body_len = self.preencode(&mut State::new())?; + match self { + Self::RawBatch(_) => Ok(body_len), + Self::MessageBatch(_) => Ok(3 + body_len), + } } - fn encode(&self, buf: &mut [u8]) -> Result { - let len = self.encoded_len(); + fn encode(&mut self, buf: &mut [u8]) -> Result { + let mut state = State::new(); + let header_len = if let Self::RawBatch(_) = self { 0 } else { 3 }; + let body_len = self.preencode(&mut state)?; + let len = body_len + header_len; if buf.len() < len { - return Err(EncodeError::new(len)); + return Err(EncodingError::new( + EncodingErrorKind::Overflow, + &format!("Length does not fit buffer, {} > {}", len, buf.len()), + )); } - let body_len = self.body_len(); - let header_len = len - body_len; - varinteger::encode(body_len as u64, &mut buf[..header_len]); match self { - Self::Raw(ref message) => message.as_slice().encode(&mut buf[header_len..]), - Self::Message(ref message) => message.encode(&mut buf[header_len..]), - }?; + Self::RawBatch(ref raw_batch) => { + for raw in raw_batch { + raw.as_slice().encode(buf)?; + } + } + #[allow(clippy::comparison_chain)] + Self::MessageBatch(ref mut messages) => { + write_uint24_le(body_len, buf); + let buf = buf.get_mut(3..).expect("Buffer should be over 3 bytes"); + if messages.len() == 1 { + if let Message::Open(_) = &messages[0].message { + // This is a special case with 0x00, 0x01 intro bytes + state.encode(&(0_u8), buf)?; + state.encode(&(1_u8), buf)?; + state.add_start(messages[0].encode(&mut buf[state.start()..])?)?; + } else if let Message::Close(_) = &messages[0].message { + // This is a special case with 0x00, 0x03 intro bytes + state.encode(&(0_u8), buf)?; + state.encode(&(3_u8), buf)?; + state.add_start(messages[0].encode(&mut buf[state.start()..])?)?; + } else { + state.encode(&messages[0].channel, buf)?; + state.add_start(messages[0].encode(&mut buf[state.start()..])?)?; + } + } else if messages.len() > 1 { + // Two intro bytes 0x00 0x00, then channel id, then lengths + state.set_slice_to_buffer(&[0_u8, 0_u8], buf)?; + let mut current_channel: u64 = messages[0].channel; + state.encode(¤t_channel, buf)?; + for message in messages.iter_mut() { + if message.channel != current_channel { + // Channel changed, need to add a 0x00 in between and then the new + // channel + state.encode(&(0_u8), buf)?; + state.encode(&message.channel, buf)?; + current_channel = message.channel; + } + let message_length = message.encoded_len()?; + state.encode(&message_length, buf)?; + state.add_start(message.encode(&mut buf[state.start()..])?)?; + } + } + } + }; Ok(len) } } @@ -146,104 +348,104 @@ impl Encoder for Frame { #[allow(missing_docs)] pub enum Message { Open(Open), - Options(Options), - Status(Status), - Have(Have), - Unhave(Unhave), - Want(Want), - Unwant(Unwant), + Close(Close), + Synchronize(Synchronize), Request(Request), Cancel(Cancel), Data(Data), - Close(Close), - Extension(ExtensionMessage), + NoData(NoData), + Want(Want), + Unwant(Unwant), + Bitfield(Bitfield), + Range(Range), + Extension(Extension), + /// A local signalling message never sent over the wire + LocalSignal((String, Vec)), } impl Message { - /// Decode a message from a buffer. - pub fn decode(buf: &[u8], typ: u64) -> io::Result { - match typ { - 0 => Ok(Self::Open(Open::decode(buf)?)), - 1 => Ok(Self::Options(Options::decode(buf)?)), - 2 => Ok(Self::Status(Status::decode(buf)?)), - 3 => Ok(Self::Have(Have::decode(buf)?)), - 4 => Ok(Self::Unhave(Unhave::decode(buf)?)), - 5 => Ok(Self::Want(Want::decode(buf)?)), - 6 => Ok(Self::Unwant(Unwant::decode(buf)?)), - 7 => Ok(Self::Request(Request::decode(buf)?)), - 8 => Ok(Self::Cancel(Cancel::decode(buf)?)), - 9 => Ok(Self::Data(Data::decode(buf)?)), - 10 => Ok(Self::Close(Close::decode(buf)?)), - 15 => Ok(Self::Extension(ExtensionMessage::decode(buf)?)), - _ => Err(io::Error::new( - io::ErrorKind::InvalidData, - "Invalid message type", - )), - } - } - /// Wire type of this message. - pub fn typ(&self) -> u64 { + pub(crate) fn typ(&self) -> u64 { match self { - Self::Open(_) => 0, - Self::Options(_) => 1, - Self::Status(_) => 2, - Self::Have(_) => 3, - Self::Unhave(_) => 4, + Self::Synchronize(_) => 0, + Self::Request(_) => 1, + Self::Cancel(_) => 2, + Self::Data(_) => 3, + Self::NoData(_) => 4, Self::Want(_) => 5, Self::Unwant(_) => 6, - Self::Request(_) => 7, - Self::Cancel(_) => 8, - Self::Data(_) => 9, - Self::Close(_) => 10, - Self::Extension(_) => 15, + Self::Bitfield(_) => 7, + Self::Range(_) => 8, + Self::Extension(_) => 9, + value => unimplemented!("{} does not have a type", value), } } -} -impl Encoder for Message { - fn encoded_len(&self) -> usize { - match self { - Self::Open(ref message) => message.encoded_len(), - Self::Options(ref message) => message.encoded_len(), - Self::Status(ref message) => message.encoded_len(), - Self::Have(ref message) => message.encoded_len(), - Self::Unhave(ref message) => message.encoded_len(), - Self::Want(ref message) => message.encoded_len(), - Self::Unwant(ref message) => message.encoded_len(), - Self::Request(ref message) => message.encoded_len(), - Self::Cancel(ref message) => message.encoded_len(), - Self::Data(ref message) => message.encoded_len(), - Self::Close(ref message) => message.encoded_len(), - Self::Extension(ref message) => message.encoded_len(), - } + /// Decode a message from a buffer based on type. + pub(crate) fn decode(buf: &[u8], typ: u64) -> Result<(Self, usize), EncodingError> { + let mut state = HypercoreState::from_buffer(buf); + let message = match typ { + 0 => Ok(Self::Synchronize((*state).decode(buf)?)), + 1 => Ok(Self::Request(state.decode(buf)?)), + 2 => Ok(Self::Cancel((*state).decode(buf)?)), + 3 => Ok(Self::Data(state.decode(buf)?)), + 4 => Ok(Self::NoData((*state).decode(buf)?)), + 5 => Ok(Self::Want((*state).decode(buf)?)), + 6 => Ok(Self::Unwant((*state).decode(buf)?)), + 7 => Ok(Self::Bitfield((*state).decode(buf)?)), + 8 => Ok(Self::Range((*state).decode(buf)?)), + 9 => Ok(Self::Extension((*state).decode(buf)?)), + _ => Err(EncodingError::new( + EncodingErrorKind::InvalidData, + &format!("Invalid message type to decode: {typ}"), + )), + }?; + Ok((message, state.start())) } - fn encode(&self, buf: &mut [u8]) -> Result { + /// Pre-encodes a message to state, returns length + pub(crate) fn preencode(&mut self, state: &mut HypercoreState) -> Result { match self { - Self::Open(ref message) => encode_prost_message(message, buf), - Self::Options(ref message) => encode_prost_message(message, buf), - Self::Status(ref message) => encode_prost_message(message, buf), - Self::Have(ref message) => encode_prost_message(message, buf), - Self::Unhave(ref message) => encode_prost_message(message, buf), - Self::Want(ref message) => encode_prost_message(message, buf), - Self::Unwant(ref message) => encode_prost_message(message, buf), - Self::Request(ref message) => encode_prost_message(message, buf), - Self::Cancel(ref message) => encode_prost_message(message, buf), - Self::Data(ref message) => encode_prost_message(message, buf), - Self::Close(ref message) => encode_prost_message(message, buf), - Self::Extension(ref message) => message.encode(buf), - } + Self::Open(ref message) => state.0.preencode(message)?, + Self::Close(ref message) => state.0.preencode(message)?, + Self::Synchronize(ref message) => state.0.preencode(message)?, + Self::Request(ref message) => state.preencode(message)?, + Self::Cancel(ref message) => state.0.preencode(message)?, + Self::Data(ref message) => state.preencode(message)?, + Self::NoData(ref message) => state.0.preencode(message)?, + Self::Want(ref message) => state.0.preencode(message)?, + Self::Unwant(ref message) => state.0.preencode(message)?, + Self::Bitfield(ref message) => state.0.preencode(message)?, + Self::Range(ref message) => state.0.preencode(message)?, + Self::Extension(ref message) => state.0.preencode(message)?, + Self::LocalSignal(_) => 0, + }; + Ok(state.end()) } -} -fn encode_prost_message( - msg: &impl prost::Message, - mut buf: &mut [u8], -) -> Result { - let len = msg.encoded_len(); - msg.encode(&mut buf)?; - Ok(len) + /// Encodes a message to a given buffer, using preencoded state, results size + pub(crate) fn encode( + &mut self, + state: &mut HypercoreState, + buf: &mut [u8], + ) -> Result { + match self { + Self::Open(ref message) => state.0.encode(message, buf)?, + Self::Close(ref message) => state.0.encode(message, buf)?, + Self::Synchronize(ref message) => state.0.encode(message, buf)?, + Self::Request(ref message) => state.encode(message, buf)?, + Self::Cancel(ref message) => state.0.encode(message, buf)?, + Self::Data(ref message) => state.encode(message, buf)?, + Self::NoData(ref message) => state.0.encode(message, buf)?, + Self::Want(ref message) => state.0.encode(message, buf)?, + Self::Unwant(ref message) => state.0.encode(message, buf)?, + Self::Bitfield(ref message) => state.0.encode(message, buf)?, + Self::Range(ref message) => state.0.encode(message, buf)?, + Self::Extension(ref message) => state.0.encode(message, buf)?, + Self::LocalSignal(_) => 0, + }; + Ok(state.start()) + } } impl fmt::Display for Message { @@ -257,11 +459,13 @@ impl fmt::Display for Message { ), Self::Data(msg) => write!( f, - "Data(index {}, value: <{}>, nodes: {}, signature <{}>)", - msg.index, - msg.value.as_ref().map_or(0, |d| d.len()), - msg.nodes.len(), - msg.signature.as_ref().map_or(0, |d| d.len()), + "Data(request: {}, fork: {}, block: {}, hash: {}, seek: {}, upgrade: {})", + msg.request, + msg.fork, + msg.block.is_some(), + msg.hash.is_some(), + msg.seek.is_some(), + msg.upgrade.is_some(), ), _ => write!(f, "{:?}", &self), } @@ -269,10 +473,17 @@ impl fmt::Display for Message { } /// A message on a channel. -#[derive(Clone, PartialEq)] -pub struct ChannelMessage { - pub channel: u64, - pub message: Message, +#[derive(Clone)] +pub(crate) struct ChannelMessage { + pub(crate) channel: u64, + pub(crate) message: Message, + state: Option, +} + +impl PartialEq for ChannelMessage { + fn eq(&self, other: &Self) -> bool { + self.channel == other.channel && self.message == other.message + } } impl fmt::Debug for ChannelMessage { @@ -283,127 +494,161 @@ impl fmt::Debug for ChannelMessage { impl ChannelMessage { /// Create a new message. - pub fn new(channel: u64, message: Message) -> Self { - Self { channel, message } + pub(crate) fn new(channel: u64, message: Message) -> Self { + Self { + channel, + message, + state: None, + } } /// Consume self and return (channel, Message). - pub fn into_split(self) -> (u64, Message) { + pub(crate) fn into_split(self) -> (u64, Message) { (self.channel, self.message) } - /// Decode a channel message from a buffer. + /// Decodes an open message for a channel message from a buffer. /// - /// Note: `buf` has to have a valid length, and the length - /// prefix has to be removed already. - pub fn decode(buf: &[u8]) -> io::Result { - if buf.is_empty() { + /// Note: `buf` has to have a valid length, and without the 3 LE + /// bytes in it + pub(crate) fn decode_open_message(buf: &[u8]) -> io::Result<(Self, usize)> { + if buf.len() <= 5 { return Err(io::Error::new( io::ErrorKind::UnexpectedEof, - "received empty message", + "received too short Open message", )); } - let mut header = 0u64; - let headerlen = varinteger::decode(&buf, &mut header); - // let body = buf.split_off(headerlen); - let channel = header >> 4; - let typ = header & 0b1111; - let message = Message::decode(&buf[headerlen..], typ)?; - - let channel_message = Self { channel, message }; - - Ok(channel_message) - } - fn header(&self) -> u64 { - let typ = self.message.typ(); - self.channel << 4 | typ - } -} - -impl Encoder for ChannelMessage { - fn encoded_len(&self) -> usize { - let header_len = varinteger::length(self.header()); - let body_len = self.message.encoded_len(); - header_len + body_len + let mut state = State::new_with_start_and_end(0, buf.len()); + let open_msg: Open = state.decode(buf)?; + Ok(( + Self { + channel: open_msg.channel, + message: Message::Open(open_msg), + state: None, + }, + state.start(), + )) } - fn encode(&self, buf: &mut [u8]) -> Result { - let header = self.header(); - let header_len = varinteger::length(header); - let body_len = self.message.encoded_len(); - let len = header_len + body_len; - if buf.len() < len || len > MAX_MESSAGE_SIZE as usize { - return Err(EncodeError::new(len)); + /// Decodes a close message for a channel message from a buffer. + /// + /// Note: `buf` has to have a valid length, and without the 3 LE + /// bytes in it + pub(crate) fn decode_close_message(buf: &[u8]) -> io::Result<(Self, usize)> { + if buf.is_empty() { + return Err(io::Error::new( + io::ErrorKind::UnexpectedEof, + "received too short Close message", + )); } - varinteger::encode(header, &mut buf[..header_len]); - self.message.encode(&mut buf[header_len..len])?; - Ok(len) + let mut state = State::new_with_start_and_end(0, buf.len()); + let close_msg: Close = state.decode(buf)?; + Ok(( + Self { + channel: close_msg.channel, + message: Message::Close(close_msg), + state: None, + }, + state.start(), + )) } -} -/// A extension message. -#[derive(Debug, Clone, PartialEq)] -pub struct ExtensionMessage { - /// ID of this extension - pub id: u64, - /// Message content - pub message: Vec, -} - -impl ExtensionMessage { - /// Create a new extension message. - pub fn new(id: u64, message: Vec) -> Self { - Self { id, message } - } - - /// Decode an extension message from a buffer. - fn decode(buf: &[u8]) -> io::Result { - if buf.is_empty() { + /// Decode a normal channel message from a buffer. + /// + /// Note: `buf` has to have a valid length, and without the 3 LE + /// bytes in it + pub(crate) fn decode(buf: &[u8], channel: u64) -> io::Result<(Self, usize)> { + if buf.len() <= 1 { return Err(io::Error::new( - io::ErrorKind::InvalidData, - "Extension message may not be empty", + io::ErrorKind::UnexpectedEof, + "received empty message", )); } - let mut id: u64 = 0; - let id_len = varinteger::decode(&buf, &mut id); - Ok(Self { - id, - message: buf[id_len..].to_vec(), - }) + let mut state = State::from_buffer(buf); + let typ: u64 = state.decode(buf)?; + let (message, length) = Message::decode(&buf[state.start()..], typ)?; + Ok(( + Self { + channel, + message, + state: None, + }, + state.start() + length, + )) + } + + /// Performance optimization for letting calling encoded_len() already do + /// the preencode phase of compact_encoding. + fn prepare_state(&mut self) -> Result<(), EncodingError> { + if self.state.is_none() { + let state = if let Message::Open(_) = self.message { + // Open message doesn't have a type + // https://github.com/mafintosh/protomux/blob/43d5192f31e7a7907db44c11afef3195b7797508/index.js#L41 + let mut state = HypercoreState::new(); + self.message.preencode(&mut state)?; + state + } else if let Message::Close(_) = self.message { + // Close message doesn't have a type + // https://github.com/mafintosh/protomux/blob/43d5192f31e7a7907db44c11afef3195b7797508/index.js#L162 + let mut state = HypercoreState::new(); + self.message.preencode(&mut state)?; + state + } else { + // The header is the channel id uint followed by message type uint + // https://github.com/mafintosh/protomux/blob/43d5192f31e7a7907db44c11afef3195b7797508/index.js#L179 + let mut state = HypercoreState::new(); + let typ = self.message.typ(); + (*state).preencode(&typ)?; + self.message.preencode(&mut state)?; + state + }; + self.state = Some(state); + } + Ok(()) } } -impl Encoder for ExtensionMessage { - fn encoded_len(&self) -> usize { - let id_len = varinteger::length(self.id); - id_len + self.message.len() +impl Encoder for ChannelMessage { + fn encoded_len(&mut self) -> Result { + self.prepare_state()?; + Ok(self.state.as_ref().unwrap().end()) } - fn encode(&self, buf: &mut [u8]) -> Result { - let id_len = varinteger::length(self.id); - let len = self.message.len() + id_len; - if buf.len() < len { - return Err(EncodeError::new(len)); + fn encode(&mut self, buf: &mut [u8]) -> Result { + self.prepare_state()?; + let state = self.state.as_mut().unwrap(); + if let Message::Open(_) = self.message { + // Open message is different in that the type byte is missing + self.message.encode(state, buf)?; + } else if let Message::Close(_) = self.message { + // Close message is different in that the type byte is missing + self.message.encode(state, buf)?; + } else { + let typ = self.message.typ(); + state.0.encode(&typ, buf)?; + self.message.encode(state, buf)?; } - varinteger::encode(self.id, &mut buf[..id_len]); - buf[id_len..len].copy_from_slice(&self.message[..]); - Ok(len) + Ok(state.start()) } } #[cfg(test)] mod tests { use super::*; + use hypercore::{ + DataBlock, DataHash, DataSeek, DataUpgrade, Node, RequestBlock, RequestSeek, RequestUpgrade, + }; macro_rules! message_enc_dec { ($( $msg:expr ),*) => { $( let channel = rand::random::() as u64; - let channel_message = ChannelMessage::new(channel, $msg); - let mut buf = vec![0u8; channel_message.encoded_len()]; + let mut channel_message = ChannelMessage::new(channel, $msg); + let encoded_len = channel_message.encoded_len().expect("Failed to get encoded length"); + let mut buf = vec![0u8; encoded_len]; let n = channel_message.encode(&mut buf[..]).expect("Failed to encode message"); - let decoded = ChannelMessage::decode(&buf[..n]).expect("Failed to decode message").into_split(); + let decoded = ChannelMessage::decode(&buf[..n], channel).expect("Failed to decode message").0.into_split(); assert_eq!(channel, decoded.0); assert_eq!($msg, decoded.1); )* @@ -411,53 +656,85 @@ mod tests { } #[test] - fn encode_decode() { + fn message_encode_decode() { message_enc_dec! { - Message::Open(Open{ - discovery_key: vec![2u8; 20], - capability: None + Message::Synchronize(Synchronize{ + fork: 0, + can_upgrade: true, + downloading: true, + uploading: true, + length: 5, + remote_length: 0, }), - Message::Options(Options { - extensions: vec!["test ext".to_string()], - ack: None + Message::Request(Request { + id: 1, + fork: 1, + block: Some(RequestBlock { + index: 5, + nodes: 10, + }), + hash: Some(RequestBlock { + index: 20, + nodes: 0 + }), + seek: Some(RequestSeek { + bytes: 10 + }), + upgrade: Some(RequestUpgrade { + start: 0, + length: 10 + }) }), - Message::Status(Status { - uploading: Some(true), - downloading: Some(false) + Message::Cancel(Cancel { + request: 1, }), - Message::Have(Have { - start: 0, - length: Some(100), - bitfield: None, - ack: Some(true) + Message::Data(Data{ + request: 1, + fork: 5, + block: Some(DataBlock { + index: 5, + nodes: vec![Node::new(1, vec![0x01; 32], 100)], + value: vec![0xFF; 10] + }), + hash: Some(DataHash { + index: 20, + nodes: vec![Node::new(2, vec![0x02; 32], 200)], + }), + seek: Some(DataSeek { + bytes: 10, + nodes: vec![Node::new(3, vec![0x03; 32], 300)], + }), + upgrade: Some(DataUpgrade { + start: 0, + length: 10, + nodes: vec![Node::new(4, vec![0x04; 32], 400)], + additional_nodes: vec![Node::new(5, vec![0x05; 32], 500)], + signature: vec![0xAB; 32] + }) }), - Message::Unhave(Unhave { - start: 0, - length: Some(100), + Message::NoData(NoData { + request: 2, }), Message::Want(Want { start: 0, - length: Some(100), + length: 100, }), - Message::Request(Request { - index: 0, - bytes: None, - hash: Some(true), - nodes: None + Message::Unwant(Unwant { + start: 10, + length: 2, }), - Message::Cancel(Cancel{ - index: 10, - bytes: Some(10), - hash: Some(true) + Message::Bitfield(Bitfield { + start: 20, + bitfield: vec![0x89ABCDEF, 0x00, 0xFFFFFFFF], }), - Message::Data(Data { - index: 1, - value: None, - nodes: vec![], - signature: None + Message::Range(Range { + drop: true, + start: 12345, + length: 100000 }), - Message::Close(Close { - discovery_key: Some(vec![1u8; 10]) + Message::Extension(Extension { + name: "custom_extension/v1/open".to_string(), + message: vec![0x44, 20] }) }; } diff --git a/src/noise/cipher.rs b/src/noise/cipher.rs deleted file mode 100644 index f0cb532..0000000 --- a/src/noise/cipher.rs +++ /dev/null @@ -1,50 +0,0 @@ -use crate::noise::HandshakeResult; -use salsa20::stream_cipher::{NewStreamCipher, SyncStreamCipher}; -use salsa20::XSalsa20; -use std::io::{Error, ErrorKind, Result}; - -// TODO: Don't define here but use the values from the XSalsa20 impl. -const KEY_SIZE: usize = 32; -const NONCE_SIZE: usize = 24; - -pub struct Cipher(XSalsa20); - -impl std::fmt::Debug for Cipher { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "Cipher(XSalsa20)") - } -} - -impl Cipher { - pub fn from_handshake_rx(handshake: &HandshakeResult) -> Result { - let cipher = XSalsa20::new_var( - &handshake.split_rx[..KEY_SIZE], - &handshake.remote_nonce[..NONCE_SIZE], - ) - .map_err(|e| { - Error::new( - ErrorKind::PermissionDenied, - format!("Cannot initialize cipher: {}", e), - ) - })?; - Ok(Self(cipher)) - } - - pub fn from_handshake_tx(handshake: &HandshakeResult) -> Result { - let cipher = XSalsa20::new_var( - &handshake.split_tx[..KEY_SIZE], - &handshake.local_nonce[..NONCE_SIZE], - ) - .map_err(|e| { - Error::new( - ErrorKind::PermissionDenied, - format!("Cannot initialize cipher: {}", e), - ) - })?; - Ok(Self(cipher)) - } - - pub fn apply(&mut self, buffer: &mut [u8]) { - self.0.apply_keystream(buffer); - } -} diff --git a/src/noise/handshake.rs b/src/noise/handshake.rs deleted file mode 100644 index ccb19ab..0000000 --- a/src/noise/handshake.rs +++ /dev/null @@ -1,218 +0,0 @@ -// use async_std::io::{BufReader, BufWriter}; -use blake2_rfc::blake2b::Blake2b; -use prost::Message; -use rand::Rng; -pub use snow::Keypair; -use snow::{Builder, Error as SnowError, HandshakeState}; -use std::io::{Error, ErrorKind, Result}; - -use crate::constants::CAP_NS_BUF; -use crate::schema::NoisePayload; - -const CIPHERKEYLEN: usize = 32; -const HANDSHAKE_PATTERN: &str = "Noise_XX_25519_ChaChaPoly_BLAKE2b"; - -#[derive(Debug, Clone, Default)] -pub struct HandshakeResult { - pub is_initiator: bool, - pub local_pubkey: Vec, - pub local_seckey: Vec, - pub remote_pubkey: Vec, - pub local_nonce: Vec, - pub remote_nonce: Vec, - pub split_tx: [u8; CIPHERKEYLEN], - pub split_rx: [u8; CIPHERKEYLEN], -} - -impl HandshakeResult { - pub fn capability(&self, key: &[u8]) -> Option> { - let mut context = Blake2b::with_key(32, &self.split_rx[..32]); - context.update(CAP_NS_BUF); - context.update(&self.split_tx[..32]); - context.update(key); - let hash = context.finalize(); - Some(hash.as_bytes().to_vec()) - } - - pub fn remote_capability(&self, key: &[u8]) -> Option> { - let mut context = Blake2b::with_key(32, &self.split_tx[..32]); - context.update(CAP_NS_BUF); - context.update(&self.split_rx[..32]); - context.update(key); - let hash = context.finalize(); - Some(hash.as_bytes().to_vec()) - } - - pub fn verify_remote_capability(&self, capability: Option>, key: &[u8]) -> Result<()> { - let expected_capability = self.remote_capability(key); - match (capability, expected_capability) { - (Some(c1), Some(c2)) if c1 == c2 => Ok(()), - (None, None) => Err(Error::new( - ErrorKind::PermissionDenied, - "Missing capabilities for verification", - )), - _ => Err(Error::new( - ErrorKind::PermissionDenied, - "Invalid remote channel capability", - )), - } - } -} - -pub fn build_handshake_state( - is_initiator: bool, -) -> std::result::Result<(HandshakeState, Keypair), SnowError> { - let builder: Builder<'_> = Builder::new(HANDSHAKE_PATTERN.parse()?); - let key_pair = builder.generate_keypair().unwrap(); - let builder = builder.local_private_key(&key_pair.private); - // log::trace!("hs local pubkey: {:x?}", &key_pair.public); - let handshake_state = if is_initiator { - builder.build_initiator()? - } else { - builder.build_responder()? - }; - Ok((handshake_state, key_pair)) -} - -pub struct Handshake { - result: HandshakeResult, - state: HandshakeState, - payload: Vec, - tx_buf: Vec, - rx_buf: Vec, - complete: bool, - did_receive: bool, -} - -impl Handshake { - pub fn new(is_initiator: bool) -> Result { - let (state, local_keypair) = build_handshake_state(is_initiator).map_err(map_err)?; - - let local_nonce = generate_nonce(); - let payload = encode_nonce(local_nonce.clone()); - - let result = HandshakeResult { - is_initiator, - local_pubkey: local_keypair.public, - local_seckey: local_keypair.private, - // local_keypair, - local_nonce, - ..Default::default() - }; - Ok(Self { - state, - result, - payload, - tx_buf: vec![0u8; 512], - rx_buf: vec![0u8; 512], - complete: false, - did_receive: false, - }) - } - - pub fn start(&mut self) -> Result> { - if self.is_initiator() { - let tx_len = self.send()?; - Ok(Some(&self.tx_buf[..tx_len])) - } else { - Ok(None) - } - } - - pub fn complete(&self) -> bool { - self.complete - } - - pub fn is_initiator(&self) -> bool { - self.result.is_initiator - } - - fn recv(&mut self, msg: &[u8]) -> Result { - self.state - .read_message(&msg, &mut self.rx_buf) - .map_err(map_err) - } - fn send(&mut self) -> Result { - self.state - .write_message(&self.payload, &mut self.tx_buf) - .map_err(map_err) - } - - pub fn read(&mut self, msg: &[u8]) -> Result> { - // eprintln!("hs read len {}", msg.len()); - if self.complete() { - return Err(Error::new(ErrorKind::Other, "Handshake read after finish")); - } - - // eprintln!( - // "[{}] HANDSHAKE recv len {} {:?}", - // self.is_initiator(), - // msg.len(), - // msg - // ); - let rx_len = self.recv(&msg)?; - // eprintln!("[{}] HANDSHAKE recv post", self.is_initiator()); - - if !self.is_initiator() && !self.did_receive { - self.did_receive = true; - let tx_len = self.send()?; - return Ok(Some(&self.tx_buf[..tx_len])); - } - - let tx_buf = if self.is_initiator() { - let tx_len = self.send()?; - Some(&self.tx_buf[..tx_len]) - } else { - None - }; - - let split = self.state.dangerously_get_raw_split(); - if self.is_initiator() { - self.result.split_tx = split.0; - self.result.split_rx = split.1; - } else { - self.result.split_tx = split.1; - self.result.split_rx = split.0; - } - self.result.remote_nonce = decode_nonce(&self.rx_buf[..rx_len])?; - self.result.remote_pubkey = self.state.get_remote_static().unwrap().to_vec(); - self.complete = true; - - Ok(tx_buf) - } - - pub fn into_result(self) -> Result { - if !self.complete() { - Err(Error::new(ErrorKind::Other, "Handshake is not complete")) - } else { - Ok(self.result) - } - } -} - -fn map_err(e: SnowError) -> Error { - Error::new( - ErrorKind::PermissionDenied, - format!("handshake error: {}", e), - ) -} - -#[inline] -fn generate_nonce() -> Vec { - let random_bytes = rand::thread_rng().gen::<[u8; 24]>(); - random_bytes.to_vec() -} - -#[inline] -fn encode_nonce(nonce: Vec) -> Vec { - let nonce_msg = NoisePayload { nonce }; - let mut buf = vec![0u8; 0]; - nonce_msg.encode(&mut buf).unwrap(); - buf -} - -#[inline] -fn decode_nonce(msg: &[u8]) -> Result> { - let decoded = NoisePayload::decode(msg)?; - Ok(decoded.nonce) -} diff --git a/src/noise/mod.rs b/src/noise/mod.rs deleted file mode 100644 index 15f48d5..0000000 --- a/src/noise/mod.rs +++ /dev/null @@ -1,4 +0,0 @@ -mod cipher; -mod handshake; -pub use cipher::Cipher; -pub use handshake::{Handshake, HandshakeResult}; diff --git a/src/protocol.rs b/src/protocol.rs index f4e0590..83aa21c 100644 --- a/src/protocol.rs +++ b/src/protocol.rs @@ -2,7 +2,6 @@ use async_channel::{Receiver, Sender}; use futures_lite::io::{AsyncRead, AsyncWrite}; use futures_lite::stream::Stream; use futures_timer::Delay; -use log::*; use std::collections::VecDeque; use std::convert::TryInto; use std::fmt; @@ -12,16 +11,13 @@ use std::pin::Pin; use std::task::{Context, Poll}; use std::time::Duration; -use crate::builder::{Builder, Options}; use crate::channels::{Channel, ChannelMap}; -use crate::constants::DEFAULT_KEEPALIVE; -use crate::extension::{Extension, Extensions}; -use crate::message::{ChannelMessage, EncodeError, Frame, FrameType, Message}; -use crate::noise::{Handshake, HandshakeResult}; +use crate::constants::{DEFAULT_KEEPALIVE, PROTOCOL_NAME}; +use crate::crypto::{DecryptCipher, EncryptCipher, Handshake, HandshakeResult}; +use crate::message::{ChannelMessage, Frame, FrameType, Message}; use crate::reader::ReadState; use crate::schema::*; -use crate::util::map_channel_err; -use crate::util::pretty_hash; +use crate::util::{map_channel_err, pretty_hash}; use crate::writer::WriteState; macro_rules! return_error { @@ -35,8 +31,32 @@ macro_rules! return_error { const CHANNEL_CAP: usize = 1000; const KEEPALIVE_DURATION: Duration = Duration::from_secs(DEFAULT_KEEPALIVE as u64); +/// Options for a Protocol instance. +#[derive(Debug)] +pub(crate) struct Options { + /// Whether this peer initiated the IO connection for this protoccol + pub(crate) is_initiator: bool, + /// Enable or disable the handshake. + /// Disabling the handshake will also disable capabilitity verification. + /// Don't disable this if you're not 100% sure you want this. + pub(crate) noise: bool, + /// Enable or disable transport encryption. + pub(crate) encrypted: bool, +} + +impl Options { + /// Create with default options. + pub(crate) fn new(is_initiator: bool) -> Self { + Self { + is_initiator, + noise: true, + encrypted: true, + } + } +} + /// Remote public key (32 bytes). -pub type RemotePublicKey = [u8; 32]; +pub(crate) type RemotePublicKey = [u8; 32]; /// Discovery key (32 bytes). pub type DiscoveryKey = [u8; 32]; /// Key (32 bytes). @@ -55,6 +75,9 @@ pub enum Event { Channel(Channel), /// Emitted when a channel is closed. Close(DiscoveryKey), + /// Convenience event to make it possible to signal the protocol from a channel. + /// See channel.signal_local(). + LocalSignal((String, Vec)), } /// A protocol command. @@ -77,17 +100,21 @@ impl fmt::Debug for Event { write!(f, "Channel({})", &pretty_hash(channel.discovery_key())) } Event::Close(discovery_key) => write!(f, "Close({})", &pretty_hash(discovery_key)), + Event::LocalSignal((name, data)) => { + write!(f, "LocalSignal(name={},len={})", name, data.len()) + } } } } /// Protocol state #[allow(clippy::large_enum_variant)] -pub enum State { +pub(crate) enum State { NotInitialized, // The Handshake struct sits behind an option only so that we can .take() // it out, it's never actually empty when in State::Handshake. Handshake(Option), + SecretStream(Option), Established, } @@ -96,6 +123,7 @@ impl fmt::Debug for State { match self { State::NotInitialized => write!(f, "NotInitialized"), State::Handshake(_) => write!(f, "Handshaking"), + State::SecretStream(_) => write!(f, "SecretStream"), State::Established => write!(f, "Established"), } } @@ -114,11 +142,10 @@ pub struct Protocol { channels: ChannelMap, command_rx: Receiver, command_tx: CommandTx, - outbound_rx: Receiver, - outbound_tx: Sender, + outbound_rx: Receiver>, + outbound_tx: Sender>, keepalive: Delay, queued_events: VecDeque, - extensions: Extensions, } impl Protocol @@ -126,9 +153,12 @@ where IO: AsyncWrite + AsyncRead + Send + Unpin + 'static, { /// Create a new protocol instance. - pub fn new(io: IO, options: Options) -> Self { + pub(crate) fn new(io: IO, options: Options) -> Self { let (command_tx, command_rx) = async_channel::bounded(CHANNEL_CAP); - let (outbound_tx, outbound_rx) = async_channel::bounded(1); + let (outbound_tx, outbound_rx): ( + Sender>, + Receiver>, + ) = async_channel::bounded(1); Protocol { io, read_state: ReadState::new(), @@ -137,7 +167,6 @@ where state: State::NotInitialized, channels: ChannelMap::new(), handshake: None, - extensions: Extensions::new(outbound_tx.clone(), 0), command_rx, command_tx: CommandTx(command_tx), outbound_tx, @@ -147,17 +176,6 @@ where } } - /// Create a protocol instance with the default options. - pub fn with_defaults(io: IO, is_initiator: bool) -> Self { - let options = Options::new(is_initiator); - Protocol::new(io, options) - } - - /// Create a protocol builder that allows to set additional options. - pub fn builder(is_initiator: bool) -> Builder { - Builder::new(is_initiator) - } - /// Whether this protocol stream initiated the underlying IO connection. pub fn is_initiator(&self) -> bool { self.options.is_initiator @@ -193,11 +211,6 @@ where self.command_tx.send(command).await } - /// Register a protocol extension on the stream. - pub async fn register_extension(&mut self, name: impl ToString) -> Extension { - self.extensions.register(name.to_string()).await - } - /// Open a new protocol channel. /// /// Once the other side proofed that it also knows the `key`, the channel is emitted as @@ -251,9 +264,10 @@ where } fn init(&mut self) -> Result<()> { - debug!( + tracing::debug!( "protocol init, state {:?}, options {:?}", - self.state, self.options + self.state, + self.options ); match self.state { State::NotInitialized => {} @@ -287,20 +301,35 @@ where /// Poll the keepalive timer and queue a ping message if needed. fn poll_keepalive(&mut self, cx: &mut Context<'_>) { if Pin::new(&mut self.keepalive).poll(cx).is_ready() { - self.write_state.queue_frame(Frame::Raw(vec![0u8; 0])); + if let State::Established = self.state { + // 24 bit header for the empty message, hence the 3 + self.write_state + .queue_frame(Frame::RawBatch(vec![vec![0u8; 3]])); + } self.keepalive.reset(KEEPALIVE_DURATION); } } - fn on_outbound_message(&mut self, message: &ChannelMessage) { + fn on_outbound_message(&mut self, message: &ChannelMessage) -> bool { // If message is close, close the local channel. if let ChannelMessage { channel, message: Message::Close(_), + .. } = message { self.close_local(*channel); + // If message is a LocalSignal, emit an event and return false to indicate + // this message should be filtered out. + } else if let ChannelMessage { + message: Message::LocalSignal((name, data)), + .. + } = message + { + self.queue_event(Event::LocalSignal((name.to_string(), data.to_vec()))); + return false; } + true } /// Poll for inbound messages and processs them. @@ -328,10 +357,14 @@ where } match Pin::new(&mut self.outbound_rx).poll_next(cx) { - Poll::Ready(Some(message)) => { - self.on_outbound_message(&message); - let frame = Frame::Message(message); - self.write_state.park_frame(frame); + Poll::Ready(Some(mut messages)) => { + if !messages.is_empty() { + messages.retain(|message| self.on_outbound_message(message)); + if !messages.is_empty() { + let frame = Frame::MessageBatch(messages); + self.write_state.park_frame(frame); + } + } } Poll::Ready(None) => unreachable!("Channel closed before end"), Poll::Pending => return Ok(()), @@ -341,13 +374,54 @@ where fn on_inbound_frame(&mut self, frame: Frame) -> Result<()> { match frame { - Frame::Raw(buf) => match self.state { - State::Handshake(_) => self.on_handshake_message(buf), - _ => unreachable!("May not receive raw frames outside of handshake state"), - }, - Frame::Message(channel_message) => match self.state { - State::Established => self.on_inbound_message(channel_message), - _ => unreachable!("May not receive message frames when not established"), + Frame::RawBatch(raw_batch) => { + let mut processed_state: Option = None; + for buf in raw_batch { + let state_name: String = format!("{:?}", self.state); + match self.state { + State::Handshake(_) => self.on_handshake_message(buf)?, + State::SecretStream(_) => self.on_secret_stream_message(buf)?, + State::Established => { + if let Some(processed_state) = processed_state.as_ref() { + let previous_state = if self.options.encrypted { + State::SecretStream(None) + } else { + State::Handshake(None) + }; + if processed_state == &format!("{previous_state:?}") { + // This is the unlucky case where the batch had two or more messages where + // the first one was correctly identified as Raw but everything + // after that should have been (decrypted and) a MessageBatch. Correct the mistake + // here post-hoc. + let buf = self.read_state.decrypt_buf(&buf)?; + let frame = Frame::decode(&buf, &FrameType::Message)?; + self.on_inbound_frame(frame)?; + continue; + } + } + unreachable!( + "May not receive raw frames in Established state" + ) + } + _ => unreachable!( + "May not receive raw frames outside of handshake or secretstream state, was {:?}", + self.state + ), + }; + if processed_state.is_none() { + processed_state = Some(state_name) + } + } + Ok(()) + } + Frame::MessageBatch(channel_messages) => match self.state { + State::Established => { + for channel_message in channel_messages { + self.on_inbound_message(channel_message)? + } + Ok(()) + } + _ => unreachable!("May not receive message batch frames when not established"), }, } } @@ -365,43 +439,61 @@ where if !handshake.complete() { self.state = State::Handshake(Some(handshake)); } else { - let result = handshake.into_result()?; + let handshake_result = handshake.into_result()?; + if self.options.encrypted { - self.read_state.upgrade_with_handshake(&result)?; - self.write_state.upgrade_with_handshake(&result)?; + // The cipher will be put to use to the writer only after the peer's answer has come + let (cipher, init_msg) = EncryptCipher::from_handshake_tx(&handshake_result)?; + self.state = State::SecretStream(Some(cipher)); + + // Send the secret stream init message header to the other side + self.queue_frame_direct(init_msg).unwrap(); + } else { + // Skip secret stream and go straight to Established, then notify about + // handshake + self.read_state.set_frame_type(FrameType::Message); + let remote_public_key = parse_key(&handshake_result.remote_pubkey)?; + self.queue_event(Event::Handshake(remote_public_key)); + self.state = State::Established; } - self.read_state.set_frame_type(FrameType::Message); - let remote_public_key = parse_key(&result.remote_pubkey)?; - log::debug!( - "handshake complete, remote_key {}", - pretty_hash(&remote_public_key) - ); - self.handshake = Some(result); - self.state = State::Established; - self.queue_event(Event::Handshake(remote_public_key)); + // Store handshake result + self.handshake = Some(handshake_result); } Ok(()) } + fn on_secret_stream_message(&mut self, buf: Vec) -> Result<()> { + let encrypt_cipher = match &mut self.state { + State::SecretStream(encrypt_cipher) => encrypt_cipher.take().unwrap(), + _ => { + unreachable!("May not call on_secret_stream_message when not in SecretStream state") + } + }; + let handshake_result = &self + .handshake + .as_ref() + .expect("Handshake result must be set before secret stream"); + let decrypt_cipher = DecryptCipher::from_handshake_rx_and_init_msg(handshake_result, &buf)?; + self.read_state.upgrade_with_decrypt_cipher(decrypt_cipher); + self.write_state.upgrade_with_encrypt_cipher(encrypt_cipher); + self.read_state.set_frame_type(FrameType::Message); + + // Lastly notify that handshake is ready and set state to established + let remote_public_key = parse_key(&handshake_result.remote_pubkey)?; + self.queue_event(Event::Handshake(remote_public_key)); + self.state = State::Established; + Ok(()) + } + fn on_inbound_message(&mut self, channel_message: ChannelMessage) -> Result<()> { // let channel_message = ChannelMessage::decode(buf)?; - log::debug!("[{}] recv {:?}", self.is_initiator(), channel_message); let (remote_id, message) = channel_message.into_split(); - match remote_id { - // Id 0 means stream-level, where only extension and options messages are supported. - 0 => match message { - Message::Options(msg) => self.extensions.on_remote_update(msg.extensions), - Message::Extension(msg) => self.extensions.on_message(msg), - _ => {} - }, - // Any other Id is a regular channel message. - _ => match message { - Message::Open(msg) => self.on_open(remote_id, msg)?, - Message::Close(msg) => self.on_close(remote_id, msg)?, - _ => self - .channels - .forward_inbound_message(remote_id as usize, message)?, - }, + match message { + Message::Open(msg) => self.on_open(remote_id, msg)?, + Message::Close(msg) => self.on_close(remote_id, msg)?, + _ => self + .channels + .forward_inbound_message(remote_id as usize, message)?, } Ok(()) } @@ -428,13 +520,16 @@ where // Tell the remote end about the new channel. let capability = self.capability(&key); + let channel = local_id as u64; let message = Message::Open(Open { + channel, + protocol: PROTOCOL_NAME.to_string(), discovery_key: discovery_key.to_vec(), capability, }); - let channel_message = ChannelMessage::new(local_id as u64, message); + let channel_message = ChannelMessage::new(channel, message); self.write_state - .queue_frame(Frame::Message(channel_message)); + .queue_frame(Frame::MessageBatch(vec![channel_message])); Ok(()) } @@ -458,9 +553,9 @@ where self.queued_events.push_back(event); } - fn queue_frame_direct(&mut self, body: Vec) -> std::result::Result { - let frame = Frame::Raw(body); - self.write_state.try_queue_direct(&frame) + fn queue_frame_direct(&mut self, body: Vec) -> Result { + let mut frame = Frame::RawBatch(vec![body]); + self.write_state.try_queue_direct(&mut frame) } fn accept_channel(&mut self, local_id: usize) -> Result<()> { @@ -482,8 +577,10 @@ where fn on_close(&mut self, remote_id: u64, msg: Close) -> Result<()> { if let Some(channel_handle) = self.channels.get_remote(remote_id as usize) { let discovery_key = *channel_handle.discovery_key(); + // There is a possibility both sides will close at the same time, so + // the channel could be closed already, let's tolerate that. self.channels - .forward_inbound_message(remote_id as usize, Message::Close(msg))?; + .forward_inbound_message_tolerate_closed(remote_id as usize, Message::Close(msg))?; self.channels.remove(&discovery_key); self.queue_event(Event::Close(discovery_key)); } diff --git a/src/reader.rs b/src/reader.rs index 50b9931..51b370b 100644 --- a/src/reader.rs +++ b/src/reader.rs @@ -1,4 +1,4 @@ -use crate::noise::{Cipher, HandshakeResult}; +use crate::crypto::DecryptCipher; use futures_lite::io::AsyncRead; use futures_timer::Delay; use std::future::Future; @@ -8,13 +8,14 @@ use std::task::{Context, Poll}; use crate::constants::{DEFAULT_TIMEOUT, MAX_MESSAGE_SIZE}; use crate::message::{Frame, FrameType}; +use crate::util::stat_uint24_le; use std::time::Duration; const TIMEOUT: Duration = Duration::from_secs(DEFAULT_TIMEOUT as u64); const READ_BUF_INITIAL_SIZE: usize = 1024 * 128; #[derive(Debug)] -pub struct ReadState { +pub(crate) struct ReadState { /// The read buffer. buf: Vec, /// The start of the not-yet-processed byte range in the read buffer. @@ -25,16 +26,16 @@ pub struct ReadState { step: Step, /// The timeout after which the connection is closed. timeout: Delay, - /// Optional encryption cipher. - cipher: Option, + /// Optional decryption cipher. + cipher: Option, /// The frame type to be passed to the decoder. frame_type: FrameType, } impl ReadState { - pub fn new() -> ReadState { + pub(crate) fn new() -> ReadState { ReadState { - buf: vec![0u8; READ_BUF_INITIAL_SIZE as usize], + buf: vec![0u8; READ_BUF_INITIAL_SIZE], start: 0, end: 0, step: Step::Header, @@ -48,22 +49,36 @@ impl ReadState { #[derive(Debug)] enum Step { Header, - Body { header_len: usize, body_len: usize }, + Body { + header_len: usize, + body_len: usize, + }, + /// Multiple messages one after another + Batch, } impl ReadState { - pub fn upgrade_with_handshake(&mut self, handshake: &HandshakeResult) -> Result<()> { - let mut cipher = Cipher::from_handshake_rx(handshake)?; - cipher.apply(&mut self.buf[self.start..self.end]); - self.cipher = Some(cipher); - Ok(()) + pub(crate) fn upgrade_with_decrypt_cipher(&mut self, decrypt_cipher: DecryptCipher) { + self.cipher = Some(decrypt_cipher); } - pub fn set_frame_type(&mut self, frame_type: FrameType) { + /// Decrypts a given buf with stored cipher, if present. Used to correct + /// the rare mistake that more than two messages came in where the first + /// one created the cipher, and the next one should have been decrypted + /// but wasn't. + pub(crate) fn decrypt_buf(&mut self, buf: &[u8]) -> Result> { + if let Some(cipher) = self.cipher.as_mut() { + Ok(cipher.decrypt_buf(buf)?.0) + } else { + Ok(buf.to_vec()) + } + } + + pub(crate) fn set_frame_type(&mut self, frame_type: FrameType) { self.frame_type = frame_type; } - pub fn poll_reader( + pub(crate) fn poll_reader( &mut self, cx: &mut Context<'_>, mut reader: &mut R, @@ -71,11 +86,15 @@ impl ReadState { where R: AsyncRead + Unpin, { + let mut incomplete = true; loop { - if let Some(result) = self.process() { - return Poll::Ready(result); + if !incomplete { + if let Some(result) = self.process() { + return Poll::Ready(result); + } + } else { + incomplete = false; } - let n = match Pin::new(&mut reader).poll_read(cx, &mut self.buf[self.end..]) { Poll::Ready(Ok(n)) if n > 0 => n, Poll::Ready(Err(e)) => return Poll::Ready(Err(e)), @@ -90,89 +109,123 @@ impl ReadState { }; let end = self.end + n; - if let Some(ref mut cipher) = self.cipher { - cipher.apply(&mut self.buf[self.end..end]); + let (success, segments) = create_segments(&self.buf[self.start..end])?; + if success { + if let Some(ref mut cipher) = self.cipher { + let mut dec_end = self.start; + for (index, header_len, body_len) in segments { + let de = cipher.decrypt( + &mut self.buf[self.start + index..end], + header_len, + body_len, + )?; + dec_end = self.start + index + de; + } + self.end = dec_end; + } else { + self.end = end; + } + } else { + // Could not segment due to buffer being full, need to cycle the buffer + // and possibly resize it too if the message is too big. + self.cycle_buf_and_resize_if_needed(segments[segments.len() - 1]); + + // Set incomplete flag to skip processing and instead poll more data + incomplete = true; } - self.end = end; self.timeout.reset(TIMEOUT); } } - fn cycle_buf_if_needed(&mut self) { - // TODO: It would be great if we wouldn't have to allocate here. - if self.end == self.buf.len() { - let temp = self.buf[self.start..self.end].to_vec(); - let len = temp.len(); - self.buf[..len].copy_from_slice(&temp[..]); - self.end = len; - self.start = 0; + fn cycle_buf_and_resize_if_needed(&mut self, last_segment: (usize, usize, usize)) { + let (last_index, last_header_len, last_body_len) = last_segment; + let total_incoming_length = last_index + last_header_len + last_body_len; + if self.buf.len() < total_incoming_length { + // The incoming segments will not fit into the buffer, need to resize it + self.buf.resize(total_incoming_length, 0u8); } + let temp = self.buf[self.start..].to_vec(); + let len = temp.len(); + self.buf[..len].copy_from_slice(&temp[..]); + self.end = len; + self.start = 0; } fn process(&mut self) -> Option> { - if self.start == self.end { - return None; - } loop { match self.step { Step::Header => { - let varint = varint_decode(&self.buf[self.start..self.end]); - if let Some((header_len, body_len)) = varint { - let body_len = body_len as usize; - if body_len > MAX_MESSAGE_SIZE as usize { - return Some(Err(Error::new( - ErrorKind::InvalidData, - "Message length above max allowed size", - ))); + let stat = stat_uint24_le(&self.buf[self.start..self.end]); + if let Some((header_len, body_len)) = stat { + if body_len == 0 { + // This is a keepalive message, just remain in Step::Header + self.start += header_len; + return None; + } else if (self.start + header_len + body_len as usize) < self.end { + // There are more than one message here, create a batch from all of + // then + self.step = Step::Batch; + } else { + let body_len = body_len as usize; + if body_len > MAX_MESSAGE_SIZE as usize { + return Some(Err(Error::new( + ErrorKind::InvalidData, + "Message length above max allowed size", + ))); + } + self.step = Step::Body { + header_len, + body_len, + }; } - self.step = Step::Body { - header_len, - body_len, - }; } else { - self.cycle_buf_if_needed(); - return None; + return Some(Err(Error::new(ErrorKind::InvalidData, "Invalid header"))); } } + Step::Body { header_len, body_len, } => { let message_len = header_len + body_len; - if message_len > self.buf.len() { - self.buf.resize(message_len, 0u8); - } - if (self.end - self.start) < message_len { - self.cycle_buf_if_needed(); - return None; - } else { - let range = self.start + header_len..self.start + message_len; - let frame = Frame::decode(&self.buf[range], &self.frame_type); - self.start += message_len; - self.step = Step::Header; - return Some(frame); - } + let range = self.start + header_len..self.start + message_len; + let frame = Frame::decode(&self.buf[range], &self.frame_type); + self.start += message_len; + self.step = Step::Header; + return Some(frame); + } + Step::Batch => { + let frame = + Frame::decode_multiple(&self.buf[self.start..self.end], &self.frame_type); + self.start = self.end; + self.step = Step::Header; + return Some(frame); } } } } } -fn varint_decode(buf: &[u8]) -> Option<(usize, u64)> { - let mut value = 0u64; - let mut m = 1u64; - let mut offset = 0usize; - for _i in 0..8 { - if offset >= buf.len() { - return None; - } - let byte = buf[offset]; - offset += 1; - value += m * u64::from(byte & 127); - m *= 128; - if byte & 128 == 0 { - break; +#[allow(clippy::type_complexity)] +fn create_segments(buf: &[u8]) -> Result<(bool, Vec<(usize, usize, usize)>)> { + let mut index: usize = 0; + let len = buf.len(); + let mut segments: Vec<(usize, usize, usize)> = vec![]; + while index < len { + if let Some((header_len, body_len)) = stat_uint24_le(&buf[index..]) { + let body_len = body_len as usize; + segments.push((index, header_len, body_len)); + if len < index + header_len + body_len { + // The segments will not fit, return false to indicate that more needs to be read + return Ok((false, segments)); + } + index += header_len + body_len; + } else { + return Err(Error::new( + ErrorKind::InvalidData, + "Could not read header while decrypting", + )); } } - Some((offset, value)) + Ok((true, segments)) } diff --git a/src/schema.proto b/src/schema.proto deleted file mode 100644 index 4db5996..0000000 --- a/src/schema.proto +++ /dev/null @@ -1,90 +0,0 @@ -syntax = "proto2"; - -package hypercore.schema; - -// Sent as part of the noise protocol. -message NoisePayload { - required bytes nonce = 1; -} - -// type=0 -message Open { - required bytes discoveryKey = 1; - optional bytes capability = 2; -} - -// type=1, overall feed options. can be sent multiple times -message Options { - repeated string extensions = 1; // Should be sorted lexicographically - optional bool ack = 2; // Should all blocks be explicitly acknowledged? -} - -// type=2, message indicating state changes etc. -// initial state for uploading/downloading is true -message Status { - optional bool uploading = 1; - optional bool downloading = 2; -} - -// type=3, what do we have? -message Have { - required uint64 start = 1; - optional uint64 length = 2 [default = 1]; // defaults to 1 - optional bytes bitfield = 3; - optional bool ack = 4; // when true, this Have message is an acknowledgement -} - -// type=4, what did we lose? -message Unhave { - required uint64 start = 1; - optional uint64 length = 2 [default = 1]; // defaults to 1 -} - -// type=5, what do we want? remote should start sending have messages in this range -message Want { - required uint64 start = 1; - optional uint64 length = 2; // defaults to Infinity or feed.length (if not live) -} - -// type=6, what don't we want anymore? -message Unwant { - required uint64 start = 1; - optional uint64 length = 2; // defaults to Infinity or feed.length (if not live) -} - -// type=7, ask for data -message Request { - required uint64 index = 1; - optional uint64 bytes = 2; - optional bool hash = 3; - optional uint64 nodes = 4; -} - -// type=8, cancel a request -message Cancel { - required uint64 index = 1; - optional uint64 bytes = 2; - optional bool hash = 3; -} - -// type=9, get some data -message Data { - message Node { - required uint64 index = 1; - required bytes hash = 2; - required uint64 size = 3; - } - - required uint64 index = 1; - optional bytes value = 2; - repeated Node nodes = 3; - optional bytes signature = 4; -} - -// type=10, explicitly close a channel. -message Close { - optional bytes discoveryKey = 1; // only send this if you did not do an open -} - -// type=15, extension message - diff --git a/src/schema.rs b/src/schema.rs new file mode 100644 index 0000000..cf4653a --- /dev/null +++ b/src/schema.rs @@ -0,0 +1,537 @@ +use hypercore::encoding::{CompactEncoding, EncodingError, HypercoreState, State}; +use hypercore::{ + DataBlock, DataHash, DataSeek, DataUpgrade, Proof, RequestBlock, RequestSeek, RequestUpgrade, +}; + +/// Open message +#[derive(Debug, Clone, PartialEq)] +pub struct Open { + /// Channel id to open + pub channel: u64, + /// Protocol name + pub protocol: String, + /// Hypercore discovery key + pub discovery_key: Vec, + /// Capability hash + pub capability: Option>, +} + +impl CompactEncoding for State { + fn preencode(&mut self, value: &Open) -> Result { + self.preencode(&value.channel)?; + self.preencode(&value.protocol)?; + self.preencode(&value.discovery_key)?; + if value.capability.is_some() { + self.add_end(1)?; // flags for future use + self.preencode_fixed_32()?; + } + Ok(self.end()) + } + + fn encode(&mut self, value: &Open, buffer: &mut [u8]) -> Result { + self.encode(&value.channel, buffer)?; + self.encode(&value.protocol, buffer)?; + self.encode(&value.discovery_key, buffer)?; + if let Some(capability) = &value.capability { + self.add_start(1)?; // flags for future use + self.encode_fixed_32(capability, buffer)?; + } + Ok(self.start()) + } + + fn decode(&mut self, buffer: &[u8]) -> Result { + let channel: u64 = self.decode(buffer)?; + let protocol: String = self.decode(buffer)?; + let discovery_key: Vec = self.decode(buffer)?; + let capability: Option> = if self.start() < self.end() { + self.add_start(1)?; // flags for future use + let capability: Vec = self.decode_fixed_32(buffer)?.to_vec(); + Some(capability) + } else { + None + }; + Ok(Open { + channel, + protocol, + discovery_key, + capability, + }) + } +} + +/// Close message +#[derive(Debug, Clone, PartialEq)] +pub struct Close { + /// Channel id to close + pub channel: u64, +} + +impl CompactEncoding for State { + fn preencode(&mut self, value: &Close) -> Result { + self.preencode(&value.channel) + } + + fn encode(&mut self, value: &Close, buffer: &mut [u8]) -> Result { + self.encode(&value.channel, buffer) + } + + fn decode(&mut self, buffer: &[u8]) -> Result { + let channel: u64 = self.decode(buffer)?; + Ok(Close { channel }) + } +} + +/// Synchronize message. Type 0. +#[derive(Debug, Clone, PartialEq)] +pub struct Synchronize { + /// Fork id, set to 0 for an un-forked hypercore. + pub fork: u64, + /// Length of hypercore + pub length: u64, + /// Known length of the remote party, 0 for unknown. + pub remote_length: u64, + /// Downloading allowed + pub downloading: bool, + /// Uploading allowed + pub uploading: bool, + /// Upgrade possible + pub can_upgrade: bool, +} + +impl CompactEncoding for State { + fn preencode(&mut self, value: &Synchronize) -> Result { + self.add_end(1)?; // flags + self.preencode(&value.fork)?; + self.preencode(&value.length)?; + self.preencode(&value.remote_length) + } + + fn encode(&mut self, value: &Synchronize, buffer: &mut [u8]) -> Result { + let mut flags: u8 = if value.can_upgrade { 1 } else { 0 }; + flags |= if value.uploading { 2 } else { 0 }; + flags |= if value.downloading { 4 } else { 0 }; + self.encode(&flags, buffer)?; + self.encode(&value.fork, buffer)?; + self.encode(&value.length, buffer)?; + self.encode(&value.remote_length, buffer) + } + + fn decode(&mut self, buffer: &[u8]) -> Result { + let flags: u8 = self.decode(buffer)?; + let fork: u64 = self.decode(buffer)?; + let length: u64 = self.decode(buffer)?; + let remote_length: u64 = self.decode(buffer)?; + let can_upgrade = flags & 1 != 0; + let uploading = flags & 2 != 0; + let downloading = flags & 4 != 0; + Ok(Synchronize { + fork, + length, + remote_length, + can_upgrade, + uploading, + downloading, + }) + } +} + +/// Request message. Type 1. +#[derive(Debug, Clone, PartialEq)] +pub struct Request { + /// Request id, will be returned with corresponding [Data] + pub id: u64, + /// Current fork, set to 0 for un-forked hypercore + pub fork: u64, + /// Request for data + pub block: Option, + /// Request hash + pub hash: Option, + /// Request seek + pub seek: Option, + /// Request upgrade + pub upgrade: Option, +} + +impl CompactEncoding for HypercoreState { + fn preencode(&mut self, value: &Request) -> Result { + self.add_end(1)?; // flags + self.0.preencode(&value.id)?; + self.0.preencode(&value.fork)?; + if let Some(block) = &value.block { + self.preencode(block)?; + } + if let Some(hash) = &value.hash { + self.preencode(hash)?; + } + if let Some(seek) = &value.seek { + self.preencode(seek)?; + } + if let Some(upgrade) = &value.upgrade { + self.preencode(upgrade)?; + } + Ok(self.end()) + } + + fn encode(&mut self, value: &Request, buffer: &mut [u8]) -> Result { + let mut flags: u8 = if value.block.is_some() { 1 } else { 0 }; + flags |= if value.hash.is_some() { 2 } else { 0 }; + flags |= if value.seek.is_some() { 4 } else { 0 }; + flags |= if value.upgrade.is_some() { 8 } else { 0 }; + self.0.encode(&flags, buffer)?; + self.0.encode(&value.id, buffer)?; + self.0.encode(&value.fork, buffer)?; + if let Some(block) = &value.block { + self.encode(block, buffer)?; + } + if let Some(hash) = &value.hash { + self.encode(hash, buffer)?; + } + if let Some(seek) = &value.seek { + self.encode(seek, buffer)?; + } + if let Some(upgrade) = &value.upgrade { + self.encode(upgrade, buffer)?; + } + Ok(self.start()) + } + + fn decode(&mut self, buffer: &[u8]) -> Result { + let flags: u8 = self.0.decode(buffer)?; + let id: u64 = self.0.decode(buffer)?; + let fork: u64 = self.0.decode(buffer)?; + let block: Option = if flags & 1 != 0 { + Some(self.decode(buffer)?) + } else { + None + }; + let hash: Option = if flags & 2 != 0 { + Some(self.decode(buffer)?) + } else { + None + }; + let seek: Option = if flags & 4 != 0 { + Some(self.decode(buffer)?) + } else { + None + }; + let upgrade: Option = if flags & 8 != 0 { + Some(self.decode(buffer)?) + } else { + None + }; + Ok(Request { + id, + fork, + block, + hash, + seek, + upgrade, + }) + } +} + +/// Cancel message for a [Request]. Type 2 +#[derive(Debug, Clone, PartialEq)] +pub struct Cancel { + /// Request to cancel, see field `id` in [Request] + pub request: u64, +} + +impl CompactEncoding for State { + fn preencode(&mut self, value: &Cancel) -> Result { + self.preencode(&value.request) + } + + fn encode(&mut self, value: &Cancel, buffer: &mut [u8]) -> Result { + self.encode(&value.request, buffer) + } + + fn decode(&mut self, buffer: &[u8]) -> Result { + let request: u64 = self.decode(buffer)?; + Ok(Cancel { request }) + } +} + +/// Data message responding to received [Request]. Type 3. +#[derive(Debug, Clone, PartialEq)] +pub struct Data { + /// Request this data is for, see field `id` in [Request] + pub request: u64, + /// Fork id, set to 0 for un-forked hypercore + pub fork: u64, + /// Response for block request + pub block: Option, + /// Response for hash request + pub hash: Option, + /// Response for seek request + pub seek: Option, + /// Response for upgrade request + pub upgrade: Option, +} + +impl CompactEncoding for HypercoreState { + fn preencode(&mut self, value: &Data) -> Result { + self.add_end(1)?; // flags + self.0.preencode(&value.request)?; + self.0.preencode(&value.fork)?; + if let Some(block) = &value.block { + self.preencode(block)?; + } + if let Some(hash) = &value.hash { + self.preencode(hash)?; + } + if let Some(seek) = &value.seek { + self.preencode(seek)?; + } + if let Some(upgrade) = &value.upgrade { + self.preencode(upgrade)?; + } + Ok(self.end()) + } + + fn encode(&mut self, value: &Data, buffer: &mut [u8]) -> Result { + let mut flags: u8 = if value.block.is_some() { 1 } else { 0 }; + flags |= if value.hash.is_some() { 2 } else { 0 }; + flags |= if value.seek.is_some() { 4 } else { 0 }; + flags |= if value.upgrade.is_some() { 8 } else { 0 }; + self.0.encode(&flags, buffer)?; + self.0.encode(&value.request, buffer)?; + self.0.encode(&value.fork, buffer)?; + if let Some(block) = &value.block { + self.encode(block, buffer)?; + } + if let Some(hash) = &value.hash { + self.encode(hash, buffer)?; + } + if let Some(seek) = &value.seek { + self.encode(seek, buffer)?; + } + if let Some(upgrade) = &value.upgrade { + self.encode(upgrade, buffer)?; + } + Ok(self.start()) + } + + fn decode(&mut self, buffer: &[u8]) -> Result { + let flags: u8 = self.0.decode(buffer)?; + let request: u64 = self.0.decode(buffer)?; + let fork: u64 = self.0.decode(buffer)?; + let block: Option = if flags & 1 != 0 { + Some(self.decode(buffer)?) + } else { + None + }; + let hash: Option = if flags & 2 != 0 { + Some(self.decode(buffer)?) + } else { + None + }; + let seek: Option = if flags & 4 != 0 { + Some(self.decode(buffer)?) + } else { + None + }; + let upgrade: Option = if flags & 8 != 0 { + Some(self.decode(buffer)?) + } else { + None + }; + Ok(Data { + request, + fork, + block, + hash, + seek, + upgrade, + }) + } +} + +impl Data { + /// Transform Data message into a Proof emptying fields + pub fn into_proof(&mut self) -> Proof { + Proof { + fork: self.fork, + block: self.block.take(), + hash: self.hash.take(), + seek: self.seek.take(), + upgrade: self.upgrade.take(), + } + } +} + +/// No data message. Type 4. +#[derive(Debug, Clone, PartialEq)] +pub struct NoData { + /// Request this message is for, see field `id` in [Request] + pub request: u64, +} + +impl CompactEncoding for State { + fn preencode(&mut self, value: &NoData) -> Result { + self.preencode(&value.request) + } + + fn encode(&mut self, value: &NoData, buffer: &mut [u8]) -> Result { + self.encode(&value.request, buffer) + } + + fn decode(&mut self, buffer: &[u8]) -> Result { + let request: u64 = self.decode(buffer)?; + Ok(NoData { request }) + } +} + +/// Want message. Type 5. +#[derive(Debug, Clone, PartialEq)] +pub struct Want { + /// Start index + pub start: u64, + /// Length + pub length: u64, +} +impl CompactEncoding for State { + fn preencode(&mut self, value: &Want) -> Result { + self.preencode(&value.start)?; + self.preencode(&value.length) + } + + fn encode(&mut self, value: &Want, buffer: &mut [u8]) -> Result { + self.encode(&value.start, buffer)?; + self.encode(&value.length, buffer) + } + + fn decode(&mut self, buffer: &[u8]) -> Result { + let start: u64 = self.decode(buffer)?; + let length: u64 = self.decode(buffer)?; + Ok(Want { start, length }) + } +} + +/// Un-want message. Type 6. +#[derive(Debug, Clone, PartialEq)] +pub struct Unwant { + /// Start index + pub start: u64, + /// Length + pub length: u64, +} +impl CompactEncoding for State { + fn preencode(&mut self, value: &Unwant) -> Result { + self.preencode(&value.start)?; + self.preencode(&value.length) + } + + fn encode(&mut self, value: &Unwant, buffer: &mut [u8]) -> Result { + self.encode(&value.start, buffer)?; + self.encode(&value.length, buffer) + } + + fn decode(&mut self, buffer: &[u8]) -> Result { + let start: u64 = self.decode(buffer)?; + let length: u64 = self.decode(buffer)?; + Ok(Unwant { start, length }) + } +} + +/// Bitfield message. Type 7. +#[derive(Debug, Clone, PartialEq)] +pub struct Bitfield { + /// Start index of bitfield + pub start: u64, + /// Bitfield in 32 bit chunks beginning from `start` + pub bitfield: Vec, +} +impl CompactEncoding for State { + fn preencode(&mut self, value: &Bitfield) -> Result { + self.preencode(&value.start)?; + self.preencode(&value.bitfield) + } + + fn encode(&mut self, value: &Bitfield, buffer: &mut [u8]) -> Result { + self.encode(&value.start, buffer)?; + self.encode(&value.bitfield, buffer) + } + + fn decode(&mut self, buffer: &[u8]) -> Result { + let start: u64 = self.decode(buffer)?; + let bitfield: Vec = self.decode(buffer)?; + Ok(Bitfield { start, bitfield }) + } +} + +/// Range message. Type 8. +#[derive(Debug, Clone, PartialEq)] +pub struct Range { + /// If true, notifies that data has been cleared from this range. + /// If false, notifies existing data range. + pub drop: bool, + /// Start index + pub start: u64, + /// Length + pub length: u64, +} + +impl CompactEncoding for State { + fn preencode(&mut self, value: &Range) -> Result { + self.add_end(1)?; // flags + self.preencode(&value.start)?; + if value.length != 1 { + self.preencode(&value.length)?; + } + Ok(self.end()) + } + + fn encode(&mut self, value: &Range, buffer: &mut [u8]) -> Result { + let mut flags: u8 = if value.drop { 1 } else { 0 }; + flags |= if value.length == 1 { 2 } else { 0 }; + self.encode(&flags, buffer)?; + self.encode(&value.start, buffer)?; + if value.length != 1 { + self.encode(&value.length, buffer)?; + } + Ok(self.end()) + } + + fn decode(&mut self, buffer: &[u8]) -> Result { + let flags: u8 = self.decode(buffer)?; + let start: u64 = self.decode(buffer)?; + let drop = flags & 1 != 0; + let length: u64 = if flags & 2 != 0 { + 1 + } else { + self.decode(buffer)? + }; + Ok(Range { + drop, + length, + start, + }) + } +} + +/// Extension message. Type 9. Use this for custom messages in your application. +#[derive(Debug, Clone, PartialEq)] +pub struct Extension { + /// Name of the custom message + pub name: String, + /// Message content, use empty vector for no data. + pub message: Vec, +} +impl CompactEncoding for State { + fn preencode(&mut self, value: &Extension) -> Result { + self.preencode(&value.name)?; + self.preencode_raw_buffer(&value.message) + } + + fn encode(&mut self, value: &Extension, buffer: &mut [u8]) -> Result { + self.encode(&value.name, buffer)?; + self.encode_raw_buffer(&value.message, buffer) + } + + fn decode(&mut self, buffer: &[u8]) -> Result { + let name: String = self.decode(buffer)?; + let message: Vec = self.decode_raw_buffer(buffer)?; + Ok(Extension { name, message }) + } +} diff --git a/src/util.rs b/src/util.rs index b1a5ee2..c99ff9c 100644 --- a/src/util.rs +++ b/src/util.rs @@ -1,4 +1,7 @@ -use blake2_rfc::blake2b::Blake2b; +use blake2::{ + digest::{typenum::U32, FixedOutput, Update}, + Blake2bMac, +}; use std::convert::TryInto; use std::io::{Error, ErrorKind}; @@ -9,18 +12,47 @@ use crate::DiscoveryKey; /// /// The discovery key is a 32 byte namespaced hash of the key. pub fn discovery_key(key: &[u8]) -> DiscoveryKey { - let mut hasher = Blake2b::with_key(32, key); - hasher.update(&DISCOVERY_NS_BUF); - hasher.finalize().as_bytes().try_into().unwrap() + let mut hasher = Blake2bMac::::new_with_salt_and_personal(key, &[], &[]).unwrap(); + hasher.update(DISCOVERY_NS_BUF); + hasher.finalize_fixed().as_slice().try_into().unwrap() } -pub fn pretty_hash(key: &[u8]) -> String { +pub(crate) fn pretty_hash(key: &[u8]) -> String { pretty_hash::fmt(key).unwrap_or_else(|_| "".into()) } -pub fn map_channel_err(err: async_channel::SendError) -> Error { +pub(crate) fn map_channel_err(err: async_channel::SendError) -> Error { Error::new( ErrorKind::BrokenPipe, - format!("Cannot forward on channel: {}", err), + format!("Cannot forward on channel: {err}"), ) } + +pub(crate) const UINT_24_LENGTH: usize = 3; + +#[inline] +pub(crate) fn wrap_uint24_le(data: &Vec) -> Vec { + let mut buf: Vec = vec![0; 3]; + let n = data.len(); + write_uint24_le(n, &mut buf); + buf.extend(data); + buf +} + +#[inline] +pub(crate) fn write_uint24_le(n: usize, buf: &mut [u8]) { + buf[0] = (n & 255) as u8; + buf[1] = ((n >> 8) & 255) as u8; + buf[2] = ((n >> 16) & 255) as u8; +} + +#[inline] +pub(crate) fn stat_uint24_le(buffer: &[u8]) -> Option<(usize, u64)> { + if buffer.len() >= 3 { + let len = + ((buffer[0] as u32) | ((buffer[1] as u32) << 8) | ((buffer[2] as u32) << 16)) as u64; + Some((UINT_24_LENGTH, len)) + } else { + None + } +} diff --git a/src/writer.rs b/src/writer.rs index 65741b0..e3cc5da 100644 --- a/src/writer.rs +++ b/src/writer.rs @@ -1,5 +1,6 @@ -use crate::message::{EncodeError, Encoder, Frame}; -use crate::noise::{Cipher, HandshakeResult}; +use crate::crypto::EncryptCipher; +use crate::message::{Encoder, Frame}; + use futures_lite::{ready, AsyncWrite}; use std::collections::VecDeque; use std::fmt; @@ -10,19 +11,19 @@ use std::task::{Context, Poll}; const BUF_SIZE: usize = 1024 * 64; #[derive(Debug)] -pub enum Step { +pub(crate) enum Step { Flushing, Writing, Processing, } -pub struct WriteState { +pub(crate) struct WriteState { queue: VecDeque, buf: Vec, current_frame: Option, start: usize, end: usize, - cipher: Option, + cipher: Option, step: Step, } @@ -41,7 +42,7 @@ impl fmt::Debug for WriteState { } impl WriteState { - pub fn new() -> Self { + pub(crate) fn new() -> Self { Self { queue: VecDeque::new(), buf: vec![0u8; BUF_SIZE], @@ -53,34 +54,37 @@ impl WriteState { } } - pub fn queue_frame(&mut self, frame: F) + pub(crate) fn queue_frame(&mut self, frame: F) where F: Into, { self.queue.push_back(frame.into()) } - pub fn try_queue_direct( - &mut self, - frame: &T, - ) -> std::result::Result { - let len = frame.encoded_len(); - if self.buf.len() < len { - self.buf.resize(len, 0u8); + pub(crate) fn try_queue_direct(&mut self, frame: &mut T) -> Result { + let promised_len = frame.encoded_len()?; + let padded_promised_len = self.safe_encrypted_len(promised_len); + if self.buf.len() < padded_promised_len { + self.buf.resize(padded_promised_len, 0u8); } - if len > self.remaining() { + if padded_promised_len > self.remaining() { return Ok(false); } - let len = frame.encode(&mut self.buf[self.end..])?; - self.advance(len); + let actual_len = frame.encode(&mut self.buf[self.end..])?; + if actual_len != promised_len { + panic!( + "encoded_len() did not return that right size, expected={promised_len}, actual={actual_len}" + ); + } + self.advance(padded_promised_len)?; Ok(true) } - pub fn can_park_frame(&self) -> bool { + pub(crate) fn can_park_frame(&self) -> bool { self.current_frame.is_none() } - pub fn park_frame(&mut self, frame: F) + pub(crate) fn park_frame(&mut self, frame: F) where F: Into, { @@ -89,19 +93,23 @@ impl WriteState { } } - fn advance(&mut self, n: usize) { + fn advance(&mut self, n: usize) -> Result<()> { let end = self.end + n; - if let Some(ref mut cipher) = self.cipher { - cipher.apply(&mut self.buf[self.end..end]); - } - self.end = end; - } - pub fn upgrade_with_handshake(&mut self, handshake: &HandshakeResult) -> Result<()> { - let cipher = Cipher::from_handshake_tx(handshake)?; - self.cipher = Some(cipher); + let encrypted_end = if let Some(ref mut cipher) = self.cipher { + self.end + cipher.encrypt(&mut self.buf[self.end..end])? + } else { + end + }; + + self.end = encrypted_end; Ok(()) } + + pub(crate) fn upgrade_with_encrypt_cipher(&mut self, encrypt_cipher: EncryptCipher) { + self.cipher = Some(encrypt_cipher); + } + fn remaining(&self) -> usize { self.buf.len() - self.end } @@ -110,7 +118,11 @@ impl WriteState { self.end - self.start } - pub fn poll_send(&mut self, cx: &mut Context<'_>, mut writer: &mut W) -> Poll> + pub(crate) fn poll_send( + &mut self, + cx: &mut Context<'_>, + mut writer: &mut W, + ) -> Poll> where W: AsyncWrite + Unpin, { @@ -121,11 +133,12 @@ impl WriteState { self.current_frame = self.queue.pop_front(); } - if let Some(frame) = self.current_frame.take() { - if !self.try_queue_direct(&frame)? { + if let Some(mut frame) = self.current_frame.take() { + if !self.try_queue_direct(&mut frame)? { self.current_frame = Some(frame); } } + if self.pending() == 0 { return Poll::Ready(Ok(())); } @@ -149,4 +162,12 @@ impl WriteState { } } } + + fn safe_encrypted_len(&self, encoded_len: usize) -> usize { + if let Some(cipher) = &self.cipher { + cipher.safe_encrypted_len(encoded_len) + } else { + encoded_len + } + } } diff --git a/tests/_util.rs b/tests/_util.rs index 2ceeaf5..9d0f9bf 100644 --- a/tests/_util.rs +++ b/tests/_util.rs @@ -3,6 +3,7 @@ use async_std::prelude::*; use async_std::task::{self, JoinHandle}; use futures_lite::io::{AsyncRead, AsyncWrite}; use hypercore_protocol::{Channel, DiscoveryKey, Duplex, Event, Protocol, ProtocolBuilder}; +use instant::Duration; use std::io; pub type MemoryProtocol = Protocol>; @@ -31,12 +32,11 @@ pub fn next_event( where IO: AsyncRead + AsyncWrite + Send + Unpin + 'static, { - let task = task::spawn(async move { + task::spawn(async move { let e1 = proto.next().await; let e1 = e1.unwrap(); (proto, e1) - }); - task + }) } pub fn event_discovery_key(event: Event) -> DiscoveryKey { @@ -65,9 +65,8 @@ where task::spawn(async move { while let Some(event) = proto.next().await { let event = event?; - match event { - Event::Channel(channel) => return Ok((proto, channel)), - _ => {} + if let Event::Channel(channel) = event { + return Ok((proto, channel)); } } Err(io::Error::new( @@ -97,3 +96,23 @@ pub mod tcp { Ok((server_stream, client_stream)) } } + +const RETRY_TIMEOUT: u64 = 100_u64; +const NO_RESPONSE_TIMEOUT: u64 = 1000_u64; +pub async fn wait_for_localhost_port(port: u32) { + loop { + let timeout = async_std::future::timeout( + Duration::from_millis(NO_RESPONSE_TIMEOUT), + TcpStream::connect(format!("localhost:{}", port)), + ) + .await; + if timeout.is_err() { + continue; + } + if timeout.unwrap().is_err() { + async_std::task::sleep(Duration::from_millis(RETRY_TIMEOUT)).await; + } else { + break; + } + } +} diff --git a/tests/basic.rs b/tests/basic.rs index 9ab24d9..8a99c7e 100644 --- a/tests/basic.rs +++ b/tests/basic.rs @@ -7,11 +7,12 @@ use futures_lite::io::{AsyncRead, AsyncWrite}; use hypercore_protocol::{discovery_key, Channel, Event, Message, Protocol, ProtocolBuilder}; use hypercore_protocol::{schema::*, DiscoveryKey}; use std::io; +use test_log::test; mod _util; use _util::*; -#[async_std::test] +#[test(async_std::test)] async fn basic_protocol() -> anyhow::Result<()> { // env_logger::init(); let (proto_a, proto_b) = create_pair_memory().await?; @@ -29,7 +30,7 @@ async fn basic_protocol() -> anyhow::Result<()> { let key = [3u8; 32]; - proto_a.open(key.clone()).await?; + proto_a.open(key).await?; let next_a = next_event(proto_a); let next_b = next_event(proto_b); @@ -38,7 +39,7 @@ async fn basic_protocol() -> anyhow::Result<()> { assert!(matches!(event_b, Ok(Event::DiscoveryKey(_)))); assert_eq!(event_discovery_key(event_b.unwrap()), discovery_key(&key)); - proto_b.open(key.clone()).await?; + proto_b.open(key).await?; let next_b = next_event(proto_b); let (proto_b, event_b) = next_b.await; @@ -51,54 +52,33 @@ async fn basic_protocol() -> anyhow::Result<()> { assert_eq!(channel_a.discovery_key(), channel_b.discovery_key()); - channel_a - .want(Want { - start: 0, - length: Some(10), - }) - .await?; + channel_a.send(want(0, 10)).await?; - channel_b - .want(Want { - start: 10, - length: Some(5), - }) - .await?; + channel_b.send(want(10, 5)).await?; let next_a = next_event(proto_a); let next_b = next_event(proto_b); let channel_event_b = channel_b.next().await; - assert_eq!( - channel_event_b, - Some(Message::Want(Want { - start: 0, - length: Some(10) - })) - ); + assert_eq!(channel_event_b, Some(want(0, 10))); // eprintln!("channel_event_b: {:?}", channel_event_b); let channel_event_a = channel_a.next().await; - assert_eq!( - channel_event_a, - Some(Message::Want(Want { - start: 10, - length: Some(5) - })) - ); + assert_eq!(channel_event_a, Some(want(10, 5))); channel_a.close().await?; - channel_b.close().await?; let (_, event_a) = next_a.await; let (_, event_b) = next_b.await; assert!(matches!(event_a, Ok(Event::Close(_)))); assert!(matches!(event_b, Ok(Event::Close(_)))); - return Ok(()); + assert!(channel_a.closed()); + assert!(channel_b.closed()); + Ok(()) } -#[async_std::test] +#[test(async_std::test)] async fn open_close_channels() -> anyhow::Result<()> { let (mut proto_a, mut proto_b) = create_pair_memory().await?; @@ -143,8 +123,8 @@ async fn open_close_channels() -> anyhow::Result<()> { let (mut proto_b, ev_b) = next_b.await; let ev_a = ev_a?; let ev_b = ev_b?; - eprintln!("next a: {:?}", ev_a); - eprintln!("next b: {:?}", ev_b); + eprintln!("next a: {ev_a:?}"); + eprintln!("next b: {ev_b:?}"); let channels_a: Vec<&DiscoveryKey> = proto_a.channels().collect(); let channels_b: Vec<&DiscoveryKey> = proto_b.channels().collect(); @@ -153,17 +133,17 @@ async fn open_close_channels() -> anyhow::Result<()> { assert_eq!(channels_a.len(), 1); assert_eq!(channels_b.len(), 1); - let res = channel_a1.want(want(1)).await; + let res = channel_a1.send(want(0, 1)).await; assert!(matches!(res, Err(ref e) if e.kind() == io::ErrorKind::ConnectionAborted)); - let res = channel_b1.want(want(2)).await; + let res = channel_b1.send(want(0, 2)).await; assert!(matches!(res, Err(ref e) if e.kind() == io::ErrorKind::ConnectionAborted)); // Test that channel 2 still works - let res = channel_a2.want(want(10)).await; + let res = channel_a2.send(want(0, 10)).await; assert!(matches!(res, Ok(()))); - let res = channel_b2.want(want(20)).await; + let res = channel_b2.send(want(0, 20)).await; assert!(matches!(res, Ok(()))); // Check that the message arrives. @@ -181,17 +161,14 @@ async fn open_close_channels() -> anyhow::Result<()> { let msg_a = channel_a2.next().await; let msg_b = channel_b2.next().await; - assert_eq!(msg_a, Some(Message::Want(want(20)))); - assert_eq!(msg_b, Some(Message::Want(want(10)))); + assert_eq!(msg_a, Some(want(0, 20))); + assert_eq!(msg_b, Some(want(0, 10))); eprintln!("all good!"); Ok(()) } -fn want(len: u64) -> Want { - Want { - start: 0, - length: Some(len), - } +fn want(start: u64, length: u64) -> Message { + Message::Want(Want { start, length }) } diff --git a/tests/extension.rs b/tests/extension.rs deleted file mode 100644 index ff1ca17..0000000 --- a/tests/extension.rs +++ /dev/null @@ -1,175 +0,0 @@ -#![allow(dead_code, unused_imports)] - -use async_std::net::TcpStream; -use async_std::prelude::*; -use async_std::task::{self, JoinHandle}; -use futures_lite::io::{AsyncRead, AsyncWrite}; -// use futures_lite::{AsyncReadExt, AsyncWriteExt}; -use hypercore_protocol::schema::*; -use hypercore_protocol::{discovery_key, Channel, Event, Message, Protocol, ProtocolBuilder}; -use std::io; - -mod _util; -use _util::*; - -// Drive a stream to completion in a task. -fn drive(mut proto: S) -> JoinHandle<()> -where - S: Stream + Send + Unpin + 'static, -{ - task::spawn(async move { while let Some(_event) = proto.next().await {} }) -} - -// Drive a number of streams to completion. -// fn drive_all(streams: Vec) -> JoinHandle<()> -// where -// S: Stream + Send + Unpin + 'static, -// { -// let join_handles = streams.into_iter().map(drive); -// task::spawn(async move { -// for join_handle in join_handles { -// join_handle.await; -// } -// }) -// } - -// Drive a protocol stream until the first channel arrives. -fn drive_until_channel( - mut proto: Protocol, -) -> JoinHandle, Channel)>> -where - IO: AsyncRead + AsyncWrite + Send + Unpin + 'static, -{ - task::spawn(async move { - while let Some(event) = proto.next().await { - let event = event?; - match event { - Event::Channel(channel) => return Ok((proto, channel)), - _ => {} - } - } - Err(io::Error::new( - io::ErrorKind::Interrupted, - "Protocol closed before a channel was opened", - )) - }) -} - -#[async_std::test] -async fn stream_extension() -> anyhow::Result<()> { - // env_logger::init(); - let (mut proto_a, mut proto_b) = create_pair_memory().await?; - - let mut ext_a = proto_a.register_extension("ext").await; - let mut ext_b = proto_b.register_extension("ext").await; - - drive(proto_a); - drive(proto_b); - - task::spawn(async move { - while let Some(message) = ext_b.next().await { - assert_eq!(message, b"hello".to_vec()); - // eprintln!("B received: {:?}", String::from_utf8(message)); - ext_b.send(b"ack".to_vec()).await; - } - }); - - ext_a.send(b"hello".to_vec()).await; - let response = ext_a.next().await; - assert_eq!(response, Some(b"ack".to_vec())); - // eprintln!("A received: {:?}", response.map(String::from_utf8)); - Ok(()) -} - -#[async_std::test] -async fn channel_extension() -> anyhow::Result<()> { - // env_logger::init(); - let (mut proto_a, mut proto_b) = create_pair_memory().await?; - let key = [1u8; 32]; - - proto_a.open(key).await?; - proto_b.open(key).await?; - - let next_a = drive_until_channel(proto_a); - let next_b = drive_until_channel(proto_b); - let (proto_a, mut channel_a) = next_a.await?; - let (proto_b, mut channel_b) = next_b.await?; - - let mut ext_a = channel_a.register_extension("ext").await; - let mut ext_b = channel_b.register_extension("ext").await; - - drive(proto_a); - drive(proto_b); - drive(channel_a); - drive(channel_b); - - task::spawn(async move { - while let Some(message) = ext_b.next().await { - // eprintln!("B received: {:?}", String::from_utf8(message)); - assert_eq!(message, b"hello".to_vec()); - ext_b.send(b"ack".to_vec()).await; - } - }); - - ext_a.send(b"hello".to_vec()).await; - let response = ext_a.next().await; - assert_eq!(response, Some(b"ack".to_vec())); - // eprintln!("A received: {:?}", response.map(String::from_utf8)); - Ok(()) -} - -#[async_std::test] -async fn channel_extension_async_read_write() -> anyhow::Result<()> { - // env_logger::init(); - let (mut proto_a, mut proto_b) = create_pair_memory().await?; - let key = [1u8; 32]; - - proto_a.open(key).await?; - proto_b.open(key).await?; - - let next_a = drive_until_channel(proto_a); - let next_b = drive_until_channel(proto_b); - let (proto_a, mut channel_a) = next_a.await?; - let (proto_b, mut channel_b) = next_b.await?; - - let mut ext_a = channel_a.register_extension("ext").await; - let mut ext_b = channel_b.register_extension("ext").await; - - drive(proto_a); - drive(proto_b); - drive(channel_a); - drive(channel_b); - - task::spawn(async move { - let mut read_buf = vec![0u8; 3]; - // let mut total = 0; - let mut res = vec![]; - while res.len() < 10 { - let n = ext_b.read(&mut read_buf).await.unwrap(); - // eprintln!( - // "B read: n {} buf {}", - // n, - // std::str::from_utf8(&read_buf[..n]).unwrap() - // ); - res.extend_from_slice(&read_buf[..n]); - } - assert_eq!(res, b"helloworld".to_vec()); - - let write = b"ack".to_vec(); - ext_b.write_all(&write).await.unwrap(); - }); - - ext_a.write_all(b"hello").await.unwrap(); - ext_a.write_all(b"world").await.unwrap(); - - let mut read_buf = vec![0u8; 5]; - let n = ext_a.read(&mut read_buf).await.unwrap(); - assert_eq!(n, 3); - assert_eq!(&read_buf[..n], b"ack"); - // eprintln!( - // "A read: n {} buf {}", - // n, - // std::str::from_utf8(&read_buf[..n]).unwrap() - // ); - Ok(()) -} diff --git a/tests/js/interop.js b/tests/js/interop.js new file mode 100644 index 0000000..ad7d5ad --- /dev/null +++ b/tests/js/interop.js @@ -0,0 +1,161 @@ +const Hypercore = require('hypercore'); +const net = require('net'); +const fs = require('fs').promises; + +// Static test key pair obtained with: +// +// const crypto = require('hypercore-crypto'); +// const keyPair = crypto.keyPair(); +// console.log("public key", keyPair.publicKey.toString('hex').match(/../g).join(' ')); +// console.log("secret key", keyPair.secretKey.toString('hex').match(/../g).join(' ')); +const testKeyPair = { + publicKey: Buffer.from([ + 0x97, 0x60, 0x6c, 0xaa, 0xd2, 0xb0, 0x8c, 0x1d, 0x5f, 0xe1, 0x64, 0x2e, 0xee, 0xa5, 0x62, 0xcb, + 0x91, 0xd6, 0x55, 0xe2, 0x00, 0xc8, 0xd4, 0x3a, 0x32, 0x09, 0x1d, 0x06, 0x4a, 0x33, 0x1e, 0xe3]), + secretKey: Buffer.from([ + 0x27, 0xe6, 0x74, 0x25, 0xc1, 0xff, 0xd1, 0xd9, 0xee, 0x62, 0x5c, 0x96, 0x2b, 0x57, 0x13, 0xc3, + 0x51, 0x0b, 0x71, 0x14, 0x15, 0xf3, 0x31, 0xf6, 0xfa, 0x9e, 0xf2, 0xbf, 0x23, 0x5f, 0x2f, 0xfe, + 0x97, 0x60, 0x6c, 0xaa, 0xd2, 0xb0, 0x8c, 0x1d, 0x5f, 0xe1, 0x64, 0x2e, 0xee, 0xa5, 0x62, 0xcb, + 0x91, 0xd6, 0x55, 0xe2, 0x00, 0xc8, 0xd4, 0x3a, 0x32, 0x09, 0x1d, 0x06, 0x4a, 0x33, 0x1e, 0xe3]), +} +const hostname = 'localhost' + +if (process.argv.length !== 9 || process.argv[7].length != 1) { + console.error("Usage: node interop.js [server/client] [writer/reader] [port] [count of items to replicate] [size in bytes of items] [character to repeat in item data] [test set]") + process.exit(1); +} + +const isWriter = process.argv[3] === 'writer'; +const port = parseInt(process.argv[4]); +const itemCount = parseInt(process.argv[5]); +const itemSize = parseInt(process.argv[6]); +const itemChar = process.argv[7]; +const testSet = process.argv[8]; +const resultFile = `work/${testSet}/result.txt`; + +if (process.argv[2] === 'server') { + runServer(isWriter, itemCount, itemSize, itemChar, testSet).then(_ => { + // console.log("NODE: Server created"); + }); +} else if (process.argv[2] === 'client') { + runClient(isWriter, itemCount, itemSize, itemChar, testSet).then(_ => { + // console.log("NODE: client run"); + }); +} else { + console.error(`Invalid mode {}, only server/client supported`, process.argv[2]); + process.exit(2); +} + +async function runServer(isWriter, itemCount, itemSize, itemChar, testSet) { + const isInitiator = false; + const hypercore = isWriter ? await createWriteHypercore(itemCount, itemSize, itemChar, testSet) : await createReadHypercore(testSet); + const server = net.createServer(async socket => onconnection({ isInitiator, hypercore, socket, itemCount })) + try { + server.listen(port, hostname, async () => { const { address, port } = server.address() + // console.log(`NODE: server listening on ${address}:${port}`) + }); + } catch (error) { + console.error(`NODE: ${isInitiator} server listen got error`, error); + } +} + +async function runClient(isWriter, itemCount, itemSize, itemChar, testSet) { + const isInitiator = true; + const hypercore = isWriter ? await createWriteHypercore(itemCount, itemSize, itemChar, testSet) : await createReadHypercore(testSet); + const socket = await net.connect(port, hostname); + await onconnection({ isInitiator, hypercore, socket, itemCount }); +} + +class Mutex { + constructor () { + this.locked = false + this.destroyed = false + + this._destroying = null + this._destroyError = null + this._queue = [] + this._enqueue = (resolve, reject) => this._queue.push([resolve, reject]) + } + + lock () { + if (this.destroyed) return Promise.reject(this._destroyError) + if (this.locked) return new Promise(this._enqueue) + this.locked = true + return Promise.resolve() + } + + unlock () { + if (!this._queue.length) { + this.locked = false + return + } + this._queue.shift()[0]() + } + + destroy (err) { + if (!this._destroying) this._destroying = this.locked ? this.lock().catch(() => {}) : Promise.resolve() + + this.destroyed = true + this._destroyError = err || new Error('Mutex has been destroyed') + + if (err) { + while (this._queue.length) this._queue.shift()[1](err) + } + + return this._destroying + } +} + +let mutex = new Mutex() +async function onconnection (opts) { + const { isInitiator, hypercore, socket, itemCount } = opts + const { remoteAddress, remotePort } = socket + if (!isInitiator) { + // console.log(`NODE: new connection from ${remoteAddress}:${remotePort}`) + } + socket.on('close', () => { + if (!isInitiator) { + // console.log(`NODE: connection closed from ${remoteAddress}:${remotePort}`) + } else { + // console.log('NODE: connection closed from server') + } + }) + + hypercore.on('append', async _ => { + await mutex.lock() + // console.log(`NODE: ${isInitiator} got append, new length ${hypercore.length} and byte length ${hypercore.byteLength}, count match=${hypercore.length === itemCount}`) + if (hypercore.length === itemCount) { + let fileContent = ""; + for (let i = 0; i < hypercore.length; i++) { + // console.log(`${isInitiator} Getting value for index ${i}`); + let value = await hypercore.get(i); + fileContent += `${i} ${value}\n`; + } + try { + // console.log(`NODE: ${isInitiator} Writing file`); + await fs.writeFile(resultFile, fileContent); + } catch (error) { + // console.log(`NODE: ${isInitiator} got error`, error); + process.exit(3); + } + + // console.log(`NODE: ${isInitiator} Wrote content exiting`); + process.exit(0); + } + mutex.unlock() + }) + socket.pipe(hypercore.replicate(isInitiator)).pipe(socket) +} + +async function createWriteHypercore(itemCount, itemSize, itemChar, testSet){ + const core = new Hypercore(`work/${testSet}/writer`, testKeyPair.publicKey, {keyPair: testKeyPair}); + let data = Buffer.alloc(itemSize, itemChar); + for (let i=0; i (String, String, String) { + let path_result = format!("tests/js/work/{}/result.txt", test_set); + let path_writer = format!("tests/js/work/{}/writer", test_set); + let path_reader = format!("tests/js/work/{}/reader", test_set); + create_dir_all(&path_writer).expect("Unable to create work writer directory"); + create_dir_all(&path_reader).expect("Unable to create work reader directory"); + (path_result, path_writer, path_reader) +} + +pub struct JavascriptServer { + handle: Option>, +} + +impl JavascriptServer { + pub fn new() -> JavascriptServer { + JavascriptServer { handle: None } + } + + pub async fn run( + &mut self, + is_writer: bool, + port: u32, + data_count: usize, + data_size: usize, + data_char: char, + test_set: String, + ) { + self.handle = Some(task::spawn(async move { + // This sometimes fails on OSX immediately with unix signal 4, let's retry a few times + let mut retries = 3; + let mut code: Option = None; + while code.is_none() && retries > 0 { + let status = process::Command::new("node") + .current_dir("tests/js") + .args([ + "interop.js", + "server", + if is_writer { "writer" } else { "reader" }, + &port.to_string(), + &data_count.to_string(), + &data_size.to_string(), + &data_char.to_string(), + &test_set, + ]) + .kill_on_drop(true) + .status() + .await + .expect("Unable to execute node"); + code = status.code(); + if code.is_none() { + sleep(Duration::from_millis(100)).await; + retries -= 1; + } + } + + assert_eq!( + Some(0), + code, + "node server did not exit successfully, is_writer={}, port={}, data_count={}, data_size={}, data_char={}, test_set={}", + is_writer, + port, + data_count, + data_size, + data_char, + test_set, + ); + })); + wait_for_localhost_port(port).await; + } +} + +impl Drop for JavascriptServer { + fn drop(&mut self) { + #[cfg(feature = "async-std")] + if let Some(handle) = self.handle.take() { + async_std::task::block_on(handle.cancel()); + } + } +} + +pub async fn js_start_server( + is_writer: bool, + port: u32, + data_count: usize, + data_size: usize, + data_char: char, + test_set: String, +) -> Result { + let mut server = JavascriptServer::new(); + server + .run(is_writer, port, data_count, data_size, data_char, test_set) + .await; + Ok(server) +} + +pub async fn js_run_client( + is_writer: bool, + port: u32, + data_count: usize, + data_size: usize, + data_char: char, + test_set: &str, +) { + let status = process::Command::new("node") + .current_dir("tests/js") + .args([ + "interop.js", + "client", + if is_writer { "writer" } else { "reader" }, + &port.to_string(), + &data_count.to_string(), + &data_size.to_string(), + &data_char.to_string(), + test_set, + ]) + .kill_on_drop(true) + .status() + .await + .expect("Unable to execute node"); + assert_eq!( + Some(0), + status.code(), + "node client did not run successfully, \ + is_writer={is_writer}, \ + port={port}, \ + data_count={data_count}, \ + data_size={data_size}, \ + data_char={data_char}, \ + test_set={test_set}" + ); +} diff --git a/tests/js/package.json b/tests/js/package.json new file mode 100644 index 0000000..721537e --- /dev/null +++ b/tests/js/package.json @@ -0,0 +1,7 @@ +{ + "name": "hypercore-protocol-rs-js-interop-tests", + "version": "0.0.1", + "dependencies": { + "hypercore": "^10" + } +} diff --git a/tests/js_interop.rs b/tests/js_interop.rs new file mode 100644 index 0000000..fb0e4ea --- /dev/null +++ b/tests/js_interop.rs @@ -0,0 +1,937 @@ +use _util::wait_for_localhost_port; +use anyhow::Result; +use futures::Future; +use futures_lite::stream::StreamExt; +use hypercore::SigningKey; +use hypercore::{ + Hypercore, HypercoreBuilder, PartialKeypair, RequestBlock, RequestUpgrade, Storage, + VerifyingKey, PUBLIC_KEY_LENGTH, SECRET_KEY_LENGTH, +}; +use instant::Duration; +use random_access_disk::RandomAccessDisk; +use random_access_storage::RandomAccess; +use std::fmt::Debug; +use std::path::Path; +use std::sync::Arc; +use std::sync::Once; + +#[cfg(feature = "tokio")] +use async_compat::CompatExt; +#[cfg(feature = "async-std")] +use async_std::{ + fs::{metadata, File}, + io::{prelude::BufReadExt, BufReader, BufWriter, WriteExt}, + net::{TcpListener, TcpStream}, + sync::Mutex, + task::{self, sleep}, + test as async_test, +}; +use test_log::test; +#[cfg(feature = "tokio")] +use tokio::{ + fs::{metadata, File}, + io::{AsyncBufReadExt, AsyncWriteExt, BufReader, BufWriter}, + net::{TcpListener, TcpStream}, + sync::Mutex, + task, test as async_test, + time::sleep, +}; + +use hypercore_protocol::schema::*; +use hypercore_protocol::{discovery_key, Channel, Event, Message, ProtocolBuilder}; + +pub mod _util; +mod js; +use js::{cleanup, install, js_run_client, js_start_server, prepare_test_set}; + +static INIT: Once = Once::new(); +fn init() { + INIT.call_once(|| { + // run initialization here + cleanup(); + install(); + }); +} + +const TEST_SET_NODE_CLIENT_NODE_SERVER: &str = "ncns"; +const TEST_SET_RUST_CLIENT_NODE_SERVER: &str = "rcns"; +const TEST_SET_NODE_CLIENT_RUST_SERVER: &str = "ncrs"; +const TEST_SET_RUST_CLIENT_RUST_SERVER: &str = "rcrs"; +const TEST_SET_SERVER_WRITER: &str = "sw"; +const TEST_SET_CLIENT_WRITER: &str = "cw"; +const TEST_SET_SIMPLE: &str = "simple"; + +#[test(async_test)] +#[cfg_attr(not(feature = "js_interop_tests"), ignore)] +async fn js_interop_ncns_simple_server_writer() -> Result<()> { + js_interop_ncns_simple(true, 8101).await?; + Ok(()) +} + +#[test(async_test)] +#[cfg_attr(not(feature = "js_interop_tests"), ignore)] +async fn js_interop_ncns_simple_client_writer() -> Result<()> { + js_interop_ncns_simple(false, 8102).await?; + Ok(()) +} + +#[test(async_test)] +#[cfg_attr(not(feature = "js_interop_tests"), ignore)] +async fn js_interop_rcns_simple_server_writer() -> Result<()> { + js_interop_rcns_simple(true, 8103).await?; + Ok(()) +} + +#[test(async_test)] +#[cfg_attr(not(feature = "js_interop_tests"), ignore)] +async fn js_interop_rcns_simple_client_writer() -> Result<()> { + js_interop_rcns_simple(false, 8104).await?; + Ok(()) +} + +#[test(async_test)] +#[cfg_attr(not(feature = "js_interop_tests"), ignore)] +async fn js_interop_ncrs_simple_server_writer() -> Result<()> { + js_interop_ncrs_simple(true, 8105).await?; + Ok(()) +} + +#[test(async_test)] +#[cfg_attr(not(feature = "js_interop_tests"), ignore)] +async fn js_interop_ncrs_simple_client_writer() -> Result<()> { + js_interop_ncrs_simple(false, 8106).await?; + Ok(()) +} + +#[test(async_test)] +#[cfg_attr(not(feature = "js_interop_tests"), ignore)] +async fn js_interop_rcrs_simple_server_writer() -> Result<()> { + js_interop_rcrs_simple(true, 8107).await?; + Ok(()) +} + +#[test(async_test)] +#[cfg_attr(not(feature = "js_interop_tests"), ignore)] +async fn js_interop_rcrs_simple_client_writer() -> Result<()> { + js_interop_rcrs_simple(false, 8108).await?; + Ok(()) +} + +async fn js_interop_ncns_simple(server_writer: bool, port: u32) -> Result<()> { + init(); + let test_set = format!( + "{}_{}_{}", + TEST_SET_NODE_CLIENT_NODE_SERVER, + if server_writer { + TEST_SET_SERVER_WRITER + } else { + TEST_SET_CLIENT_WRITER + }, + TEST_SET_SIMPLE + ); + let (result_path, _writer_path, _reader_path) = prepare_test_set(&test_set); + let item_count = 4; + let item_size = 4; + let data_char = '1'; + let _server = js_start_server( + server_writer, + port, + item_count, + item_size, + data_char, + test_set.clone(), + ) + .await?; + js_run_client( + !server_writer, + port, + item_count, + item_size, + data_char, + &test_set, + ) + .await; + assert_result(result_path, item_count, item_size, data_char).await?; + + Ok(()) +} + +async fn js_interop_rcns_simple(server_writer: bool, port: u32) -> Result<()> { + init(); + let test_set = format!( + "{}_{}_{}", + TEST_SET_RUST_CLIENT_NODE_SERVER, + if server_writer { + TEST_SET_SERVER_WRITER + } else { + TEST_SET_CLIENT_WRITER + }, + TEST_SET_SIMPLE + ); + let (result_path, writer_path, reader_path) = prepare_test_set(&test_set); + let item_count = 4; + let item_size = 4; + let data_char = '1'; + let server = js_start_server( + server_writer, + port, + item_count, + item_size, + data_char, + test_set.clone(), + ) + .await?; + run_client( + !server_writer, + port, + item_count, + item_size, + data_char, + if server_writer { + &reader_path + } else { + &writer_path + }, + &result_path, + ) + .await?; + assert_result(result_path, item_count, item_size, data_char).await?; + drop(server); + + Ok(()) +} + +async fn js_interop_ncrs_simple(server_writer: bool, port: u32) -> Result<()> { + init(); + let test_set = format!( + "{}_{}_{}", + TEST_SET_NODE_CLIENT_RUST_SERVER, + if server_writer { + TEST_SET_SERVER_WRITER + } else { + TEST_SET_CLIENT_WRITER + }, + TEST_SET_SIMPLE + ); + let (result_path, writer_path, reader_path) = prepare_test_set(&test_set); + let item_count = 4; + let item_size = 4; + let data_char = '1'; + + let _server = start_server( + server_writer, + port, + item_count, + item_size, + data_char, + if server_writer { + &writer_path + } else { + &reader_path + }, + &result_path, + ) + .await?; + js_run_client( + !server_writer, + port, + item_count, + item_size, + data_char, + &test_set.clone(), + ) + .await; + + assert_result(result_path, item_count, item_size, data_char).await?; + + Ok(()) +} + +async fn js_interop_rcrs_simple(server_writer: bool, port: u32) -> Result<()> { + init(); + let test_set = format!( + "{}_{}_{}", + TEST_SET_RUST_CLIENT_RUST_SERVER, + if server_writer { + TEST_SET_SERVER_WRITER + } else { + TEST_SET_CLIENT_WRITER + }, + TEST_SET_SIMPLE + ); + let (result_path, writer_path, reader_path) = prepare_test_set(&test_set); + let item_count = 4; + let item_size = 4; + let data_char = '1'; + + let _server = start_server( + server_writer, + port, + item_count, + item_size, + data_char, + if server_writer { + &writer_path + } else { + &reader_path + }, + &result_path, + ) + .await?; + run_client( + !server_writer, + port, + item_count, + item_size, + data_char, + if server_writer { + &reader_path + } else { + &writer_path + }, + &result_path, + ) + .await?; + + assert_result(result_path, item_count, item_size, data_char).await?; + + Ok(()) +} + +async fn assert_result( + result_path: String, + item_count: usize, + item_size: usize, + data_char: char, +) -> Result<()> { + // First we need to wait for the file to be ready + loop { + let path = Path::new(&result_path); + if path.exists() { + let metadata = metadata(path).await?; + // There's a index + space + line feed + if metadata.len() >= (item_count * (3 + item_size)) as u64 { + break; + } + } + sleep(Duration::from_millis(100)).await; + } + + let mut reader = BufReader::new(File::open(result_path).await?); + let mut i: usize = 0; + let expected_value = data_char.to_string().repeat(item_size); + let mut line = String::new(); + while reader.read_line(&mut line).await? != 0 { + assert_eq!(line, format!("{} {}\n", i, expected_value)); + i += 1; + line = String::new(); + } + assert_eq!(i, item_count); + Ok(()) +} + +async fn run_client( + is_writer: bool, + port: u32, + data_count: usize, + data_size: usize, + data_char: char, + data_path: &str, + result_path: &str, +) -> Result<()> { + let hypercore = if is_writer { + create_writer_hypercore(data_count, data_size, data_char, data_path).await? + } else { + create_reader_hypercore(data_path).await? + }; + let hypercore_wrapper = HypercoreWrapper::from_disk_hypercore( + hypercore, + if is_writer { + None + } else { + Some(result_path.to_string()) + }, + ); + tcp_client(port, on_replication_connection, Arc::new(hypercore_wrapper)).await?; + Ok(()) +} + +async fn start_server( + is_writer: bool, + port: u32, + item_count: usize, + item_size: usize, + data_char: char, + data_path: &str, + result_path: &str, +) -> Result { + let hypercore = if is_writer { + create_writer_hypercore(item_count, item_size, data_char, data_path).await? + } else { + create_reader_hypercore(data_path).await? + }; + let hypercore_wrapper = HypercoreWrapper::from_disk_hypercore( + hypercore, + if is_writer { + None + } else { + Some(result_path.to_string()) + }, + ); + let mut server = RustServer::new(); + server.run(Arc::new(hypercore_wrapper), port).await; + Ok(server) +} + +async fn create_writer_hypercore( + data_count: usize, + data_size: usize, + data_char: char, + path: &str, +) -> Result> { + let path = Path::new(path).to_owned(); + let key_pair = get_test_key_pair(true); + let storage = Storage::new_disk(&path, false).await?; + let mut hypercore = HypercoreBuilder::new(storage) + .key_pair(key_pair) + .build() + .await?; + for _ in 0..data_count { + let value = vec![data_char as u8; data_size]; + hypercore.append(&value).await?; + } + Ok(hypercore) +} + +async fn create_reader_hypercore(path: &str) -> Result> { + let path = Path::new(path).to_owned(); + let key_pair = get_test_key_pair(false); + let storage = Storage::new_disk(&path, false).await?; + Ok(HypercoreBuilder::new(storage) + .key_pair(key_pair) + .build() + .await?) +} + +const TEST_PUBLIC_KEY_BYTES: [u8; PUBLIC_KEY_LENGTH] = [ + 0x97, 0x60, 0x6c, 0xaa, 0xd2, 0xb0, 0x8c, 0x1d, 0x5f, 0xe1, 0x64, 0x2e, 0xee, 0xa5, 0x62, 0xcb, + 0x91, 0xd6, 0x55, 0xe2, 0x00, 0xc8, 0xd4, 0x3a, 0x32, 0x09, 0x1d, 0x06, 0x4a, 0x33, 0x1e, 0xe3, +]; +// NB: In the javascript version this is 64 bytes, but that's because sodium appends the the public +// key after the secret key for some reason. Only the first 32 bytes are actually used in +// javascript side too for signing. +const TEST_SECRET_KEY_BYTES: [u8; SECRET_KEY_LENGTH] = [ + 0x27, 0xe6, 0x74, 0x25, 0xc1, 0xff, 0xd1, 0xd9, 0xee, 0x62, 0x5c, 0x96, 0x2b, 0x57, 0x13, 0xc3, + 0x51, 0x0b, 0x71, 0x14, 0x15, 0xf3, 0x31, 0xf6, 0xfa, 0x9e, 0xf2, 0xbf, 0x23, 0x5f, 0x2f, 0xfe, +]; + +pub fn get_test_key_pair(include_secret: bool) -> PartialKeypair { + let public = VerifyingKey::from_bytes(&TEST_PUBLIC_KEY_BYTES).unwrap(); + let secret = if include_secret { + let signing_key = SigningKey::from_bytes(&TEST_SECRET_KEY_BYTES); + assert_eq!( + TEST_PUBLIC_KEY_BYTES, + signing_key.verifying_key().to_bytes() + ); + Some(signing_key) + } else { + None + }; + + PartialKeypair { public, secret } +} + +#[cfg(feature = "async-std")] +async fn on_replication_connection( + stream: TcpStream, + is_initiator: bool, + hypercore: Arc>, +) -> Result<()> +where + T: RandomAccess + Debug + Send, +{ + let mut protocol = ProtocolBuilder::new(is_initiator).connect(stream); + while let Some(event) = protocol.next().await { + let event = event?; + match event { + Event::Handshake(_) => { + if is_initiator { + protocol.open(*hypercore.key()).await?; + } + } + Event::DiscoveryKey(dkey) => { + if hypercore.discovery_key == dkey { + protocol.open(*hypercore.key()).await?; + } else { + panic!("Invalid discovery key"); + } + } + Event::Channel(channel) => { + hypercore.on_replication_peer(channel); + } + Event::Close(_dkey) => { + break; + } + _ => {} + } + } + Ok(()) +} + +#[cfg(feature = "tokio")] +async fn on_replication_connection( + stream: TcpStream, + is_initiator: bool, + hypercore: Arc>, +) -> Result<()> +where + T: RandomAccess + Debug + Send, +{ + let mut protocol = ProtocolBuilder::new(is_initiator).connect(stream.compat()); + while let Some(event) = protocol.next().await { + let event = event?; + match event { + Event::Handshake(_) => { + if is_initiator { + protocol.open(*hypercore.key()).await?; + } + } + Event::DiscoveryKey(dkey) => { + if hypercore.discovery_key == dkey { + protocol.open(*hypercore.key()).await?; + } else { + panic!("Invalid discovery key"); + } + } + Event::Channel(channel) => { + hypercore.on_replication_peer(channel); + } + Event::Close(_dkey) => { + break; + } + _ => {} + } + } + Ok(()) +} + +#[derive(Debug, Clone)] +pub struct HypercoreWrapper +where + T: RandomAccess + Debug + Send, +{ + discovery_key: [u8; 32], + key: [u8; 32], + hypercore: Arc>>, + result_path: Option, +} + +impl HypercoreWrapper { + pub fn from_disk_hypercore( + hypercore: Hypercore, + result_path: Option, + ) -> Self { + let key = hypercore.key_pair().public.to_bytes(); + HypercoreWrapper { + key, + discovery_key: discovery_key(&key), + hypercore: Arc::new(Mutex::new(hypercore)), + result_path, + } + } +} + +impl HypercoreWrapper +where + T: RandomAccess + Debug + Send + 'static, +{ + pub fn key(&self) -> &[u8; 32] { + &self.key + } + + pub fn on_replication_peer(&self, mut channel: Channel) { + let mut peer_state = PeerState::default(); + let mut hypercore = self.hypercore.clone(); + let result_path = self.result_path.clone(); + task::spawn(async move { + let info = { + let hypercore = hypercore.lock().await; + hypercore.info() + }; + + if info.fork != peer_state.remote_fork { + peer_state.can_upgrade = false; + } + let remote_length = if info.fork == peer_state.remote_fork { + peer_state.remote_length + } else { + 0 + }; + + let sync_msg = Synchronize { + fork: info.fork, + length: info.length, + remote_length, + can_upgrade: peer_state.can_upgrade, + uploading: true, + downloading: true, + }; + + if info.contiguous_length > 0 { + let range_msg = Range { + drop: false, + start: 0, + length: info.contiguous_length, + }; + channel + .send_batch(&[Message::Synchronize(sync_msg), Message::Range(range_msg)]) + .await + .unwrap(); + } else { + channel.send(Message::Synchronize(sync_msg)).await.unwrap(); + } + while let Some(message) = channel.next().await { + let ready = on_replication_message( + &mut hypercore, + &mut peer_state, + result_path.clone(), + &mut channel, + message, + ) + .await + .expect("on_replication_message should return Ok"); + if ready { + channel.close().await.expect("Should be able to close"); + break; + } + } + }); + } +} + +async fn on_replication_message( + hypercore: &mut Arc>>, + peer_state: &mut PeerState, + result_path: Option, + channel: &mut Channel, + message: Message, +) -> Result +where + T: RandomAccess + Debug + Send, +{ + match message { + Message::Synchronize(message) => { + let length_changed = message.length != peer_state.remote_length; + let first_sync = !peer_state.remote_synced; + let info = { + let hypercore = hypercore.lock().await; + hypercore.info() + }; + let same_fork = message.fork == info.fork; + + peer_state.remote_fork = message.fork; + peer_state.remote_length = message.length; + peer_state.remote_can_upgrade = message.can_upgrade; + peer_state.remote_uploading = message.uploading; + peer_state.remote_downloading = message.downloading; + peer_state.remote_synced = true; + + peer_state.length_acked = if same_fork { message.remote_length } else { 0 }; + + let mut messages = vec![]; + + if first_sync { + // Need to send another sync back that acknowledges the received sync + let msg = Synchronize { + fork: info.fork, + length: info.length, + remote_length: peer_state.remote_length, + can_upgrade: peer_state.can_upgrade, + uploading: true, + downloading: true, + }; + messages.push(Message::Synchronize(msg)); + } + + if peer_state.remote_length > info.length + && peer_state.length_acked == info.length + && length_changed + { + let msg = Request { + id: 1, + fork: info.fork, + hash: None, + block: None, + seek: None, + upgrade: Some(RequestUpgrade { + start: info.length, + length: peer_state.remote_length - info.length, + }), + }; + messages.push(Message::Request(msg)); + } + + channel.send_batch(&messages).await?; + } + Message::Request(message) => { + let (info, proof) = { + let mut hypercore = hypercore.lock().await; + let proof = hypercore + .create_proof(message.block, message.hash, message.seek, message.upgrade) + .await?; + (hypercore.info(), proof) + }; + if let Some(proof) = proof { + let msg = Data { + request: message.id, + fork: info.fork, + hash: proof.hash, + block: proof.block, + seek: proof.seek, + upgrade: proof.upgrade, + }; + channel.send(Message::Data(msg)).await?; + } else { + panic!("Could not create proof from {:?}", message.id); + } + } + Message::Data(message) => { + let (old_info, applied, new_info, request_block, synced) = { + let mut hypercore = hypercore.lock().await; + let old_info = hypercore.info(); + let proof = message.clone().into_proof(); + let applied = hypercore.verify_and_apply_proof(&proof).await?; + let new_info = hypercore.info(); + let request_block: Option = if let Some(upgrade) = &message.upgrade { + // When getting the initial upgrade, send a request for the first missing block + if old_info.length < upgrade.length { + let request_index = old_info.length; + let nodes = hypercore.missing_nodes(request_index).await?; + Some(RequestBlock { + index: request_index, + nodes, + }) + } else { + None + } + } else if let Some(block) = &message.block { + // When receiving a block, ask for the next, if there are still some missing + if block.index < peer_state.remote_length - 1 { + let request_index = block.index + 1; + let nodes = hypercore.missing_nodes(request_index).await?; + Some(RequestBlock { + index: request_index, + nodes, + }) + } else { + None + } + } else { + None + }; + let synced = new_info.contiguous_length == new_info.length; + (old_info, applied, new_info, request_block, synced) + }; + assert!(applied, "Could not apply proof"); + let mut messages: Vec = vec![]; + if let Some(upgrade) = &message.upgrade { + let new_length = upgrade.length; + + let remote_length = if new_info.fork == peer_state.remote_fork { + peer_state.remote_length + } else { + 0 + }; + + messages.push(Message::Synchronize(Synchronize { + fork: new_info.fork, + length: new_length, + remote_length, + can_upgrade: false, + uploading: true, + downloading: true, + })); + } + if let Some(block) = &message.block { + // Send Range if the number of items changed, both for the single and + // for the contiguous length + if old_info.length < new_info.length { + messages.push(Message::Range(Range { + drop: false, + start: block.index, + length: 1, + })); + } + if old_info.contiguous_length < new_info.contiguous_length { + messages.push(Message::Range(Range { + drop: false, + start: 0, + length: new_info.contiguous_length, + })); + } + } + if let Some(request_block) = request_block { + messages.push(Message::Request(Request { + id: request_block.index + 1, + fork: new_info.fork, + hash: None, + block: Some(request_block), + seek: None, + upgrade: None, + })); + } + let exit = if synced { + if let Some(result_path) = result_path.as_ref() { + let mut hypercore = hypercore.lock().await; + let mut writer = BufWriter::new(File::create(result_path).await?); + for i in 0..new_info.contiguous_length { + let value = String::from_utf8(hypercore.get(i).await?.unwrap()).unwrap(); + let line = format!("{} {}\n", i, value); + writer.write(line.as_bytes()).await?; + } + writer.flush().await?; + true + } else { + false + } + } else { + false + }; + channel.send_batch(&messages).await.unwrap(); + if exit { + return Ok(true); + } + } + Message::Range(message) => { + if result_path.is_none() { + let info = { + let hypercore = hypercore.lock().await; + hypercore.info() + }; + if message.start == 0 && message.length == info.contiguous_length { + // Let's sleep here for a while so that close messages can pass + sleep(Duration::from_millis(100)).await; + return Ok(true); + } + } + } + _ => { + panic!("Received unexpected message {:?}", message); + } + }; + Ok(false) +} + +#[derive(Debug)] +struct PeerState { + can_upgrade: bool, + remote_fork: u64, + remote_length: u64, + remote_can_upgrade: bool, + remote_uploading: bool, + remote_downloading: bool, + remote_synced: bool, + length_acked: u64, +} +impl Default for PeerState { + fn default() -> Self { + PeerState { + can_upgrade: true, + remote_fork: 0, + remote_length: 0, + remote_can_upgrade: false, + remote_uploading: true, + remote_downloading: true, + remote_synced: false, + length_acked: 0, + } + } +} + +pub(crate) struct RustServer { + handle: Option>, +} + +impl RustServer { + pub fn new() -> RustServer { + RustServer { handle: None } + } + + pub async fn run(&mut self, hypercore: Arc>, port: u32) { + self.handle = Some(task::spawn(async move { + tcp_server(port, on_replication_connection, hypercore) + .await + .expect("Server return ok"); + })); + wait_for_localhost_port(port).await; + } +} + +impl Drop for RustServer { + fn drop(&mut self) { + #[cfg(feature = "async-std")] + if let Some(handle) = self.handle.take() { + task::block_on(handle.cancel()); + } + } +} + +#[cfg(feature = "async-std")] +pub async fn tcp_server( + port: u32, + onconnection: impl Fn(TcpStream, bool, C) -> F + Send + Sync + Copy + 'static, + context: C, +) -> Result<()> +where + F: Future> + Send, + C: Clone + Send + 'static, +{ + let listener = TcpListener::bind(&format!("localhost:{}", port)).await?; + let mut incoming = listener.incoming(); + while let Some(Ok(stream)) = incoming.next().await { + let context = context.clone(); + let _peer_addr = stream.peer_addr().unwrap(); + task::spawn(async move { + onconnection(stream, false, context) + .await + .expect("Should return ok"); + }); + } + Ok(()) +} + +#[cfg(feature = "tokio")] +pub async fn tcp_server( + port: u32, + onconnection: impl Fn(TcpStream, bool, C) -> F + Send + Sync + Copy + 'static, + context: C, +) -> Result<()> +where + F: Future> + Send, + C: Clone + Send + 'static, +{ + let listener = TcpListener::bind(&format!("localhost:{}", port)).await?; + + while let Ok((stream, _peer_address)) = listener.accept().await { + let context = context.clone(); + task::spawn(async move { + onconnection(stream, false, context) + .await + .expect("Should return ok"); + }); + } + Ok(()) +} + +pub async fn tcp_client( + port: u32, + onconnection: impl Fn(TcpStream, bool, C) -> F + Send + Sync + Copy + 'static, + context: C, +) -> Result<()> +where + F: Future> + Send, + C: Clone + Send + 'static, +{ + let stream = TcpStream::connect(&format!("localhost:{}", port)).await?; + onconnection(stream, true, context).await +}