diff --git a/.github/workflows/lint_and_test.yml b/.github/workflows/lint_test_and_build.yml similarity index 94% rename from .github/workflows/lint_and_test.yml rename to .github/workflows/lint_test_and_build.yml index 5341ca3..db52d40 100644 --- a/.github/workflows/lint_and_test.yml +++ b/.github/workflows/lint_test_and_build.yml @@ -25,3 +25,4 @@ jobs: - run: task deps - run: task lint - run: task test + - run: task e2e:tun diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..03756b8 --- /dev/null +++ b/.gitignore @@ -0,0 +1,5 @@ +/.task/ +/build/ + +# e2e working directory created by Task e2e:tun +/.e2e/ diff --git a/README.md b/README.md index b08d2ec..2577f7a 100644 --- a/README.md +++ b/README.md @@ -44,7 +44,7 @@ Notes: rawClient, rawServer := net.Pipe() defer rawClient.Close(); defer rawServer.Close() -client := netx.NewFramedConn(rawClient) // default max frame size 16KiB +client := netx.NewFramedConn(rawClient) // default max frame size 32KiB server := netx.NewFramedConn(rawServer, netx.WithMaxFrameSize(64<<10)) msg := []byte("hello frame") @@ -153,10 +153,84 @@ If `Logger` is nil, the server/tunnel use `slog.Default()`. - Unhandled connections are dropped immediately after all routes decline. - `Shutdown(ctx)` will close listeners, then wait for tracked connections until `ctx` is done, after which remaining connections are force‑closed. -## Testing +## CLI -The repository includes unit and end‑to‑end tests (UDP over TCP, TLS routing, graceful shutdown). Run: +An extendable CLI is available at `cmd/netx` with an initial `tun` subcommand to relay between chainable endpoints. + +Build: ```bash -go test ./... +task build ``` + +Install and use: + +```bash +go install github.com/pedramktb/go-netx/cmd/netx@latest + +# Show help +netx tun -h + +# Example: TCP TLS server to TCP TLS+buffered+framed+aesgcm client +netx tun \ + --from tcp+tls[cert=server.crt,key=server.key]://:9000 \ + --to tcp+tls[cert=client.crt]+buffered[size=8192]+framed[maxsize=4096]+aesgcm[key=00112233445566778899aabbccddeeff]://example.com:9443 + +# Example: UDP DTLS server to UDP aesgcm client +netx tun \ + --from udp+dtls[cert=server.crt,key=server.key]://:4444 \ + --to udp+aesgcm[key=00112233445566778899aabbccddeeff]://10.0.0.10:5555 +``` + +Options: + +- `--from ://listenAddr` - Incoming side chain URI (required) +- `--to ://connectAddr` - Peer side chain URI (required) +- `--log ` - Log level: debug|info|warn|error (default: info) +- `-h` - Show help + +Chain syntax: + +Chains use the form `://host:port` where `` is a `+`-separated list starting with a base transport (`tcp` or `udp`), optionally followed by wrappers with parameters in brackets. + +**Supported base transports:** + +- `tcp` - TCP listener or dialer +- `udp` - UDP listener or dialer + +**Supported wrappers:** + +- `tls` - Transport Layer Security + - Server params: `cert`, `key` + - Client params: `cert` (optional, for SPKI pinning), `servername` (required if cert not provided) + +- `utls` - TLS with client fingerprint camouflage via uTLS + - Client-side only + - Params: `cert` (optional, for SPKI pinning), `servername` (required if cert not provided), `hello` (optional: chrome, firefox, ios, android, safari, edge, randomized, randomizednoalpn; default: chrome) + +- `dtls` - Datagram Transport Layer Security + - Server params: `cert`, `key` + - Client params: `cert` (optional, for SPKI pinning), `servername` (required if cert not provided) + +- `tlspsk` - TLS with pre-shared key (TLS 1.2, cipher: TLS_PSK_WITH_AES_256_CBC_SHA) + - Params: `key`, `identity` + +- `dtlspsk` - DTLS with pre-shared key (cipher: TLS_PSK_WITH_AES_128_GCM_SHA256) + - Params: `key`, `identity` + +- `aesgcm` - AES-GCM encryption with passive IV exchange + - Params: `key`, `maxpacket` (optional, default: 32768) + +- `buffered` - Buffered read/write for better performance + - Params: `size` (optional, default: 4096) + +- `framed` - Length-prefixed frames for packet semantics over streams + - Params: `maxsize` (optional, default: 32768) + +- `ssh` - SSH tunneling via "direct-tcpip" channels + - Server params: `key` (optional, required with pass), `pass` (optional), `pubkey` (optional, required if no pass) + - Client params: `pubkey`, `pass` (optional), `key` (optional, required if no pass) + +**Notes:** +- All passwords, keys and certificates must be provided as hex-encoded strings. +- When using `cert` for client-side `tls`/`utls`/`dtls`, default validation is disabled and a manual SPKI (SubjectPublicKeyInfo) hash comparison is performed against the provided certificate. This is certificate pinning and will fail if the server presents a different key. diff --git a/Taskfile.yml b/Taskfile.yml index ee195f0..d36c743 100644 --- a/Taskfile.yml +++ b/Taskfile.yml @@ -2,7 +2,7 @@ version: '3' tasks: default: - cmds: + cmds: - task --list deps: @@ -20,3 +20,154 @@ tasks: desc: Run tests cmds: - go test ./... + + build: + desc: Build binaries and libraries + cmds: + # Linux binaries and shared libraries + - env GOOS=linux GOARCH=amd64 go build -o build/netx_linux_x64 cmd/netx/*.go + - env GOOS=linux GOARCH=arm64 go build -o build/netx_linux_arm64 cmd/netx/*.go + # - env GOOS=linux GOARCH=amd64 CGO_ENABLED=1 CC=x86_64-linux-gnu-gcc go build -buildmode=c-shared -o build/libnetx_linux_x64.so cmd/netx/lib/main.go + # - env GOOS=linux GOARCH=arm64 CGO_ENABLED=1 CC=aarch64-linux-gnu-gcc go build -buildmode=c-shared -o build/libnetx_linux_arm64.so cmd/netx/lib/main.go + # # Windows binaries and shared libraries + - env GOOS=windows GOARCH=amd64 go build -o build/netx_windows_x64.exe cmd/netx/*.go + - env GOOS=windows GOARCH=arm64 go build -o build/netx_windows_arm64.exe cmd/netx/*.go + # - env GOOS=windows GOARCH=amd64 CGO_ENABLED=1 CC=x86_64-w64-mingw32-gcc go build -buildmode=c-shared -o build/libnetx_windows_x64.dll cmd/netx/lib/main.go + # # aarch64-w64-mingw32-gcc is experimental and not available + # # - env GOOS=windows GOARCH=arm64 CGO_ENABLED=1 CC=aarch64-w64-mingw32-gcc go build -buildmode=c-shared -o build/libnetx_windows_arm64.dll cmd/lib/main.go + # # macOS binaries + - env GOOS=darwin GOARCH=amd64 go build -o build/netx_macos_x64 cmd/netx/*.go + - env GOOS=darwin GOARCH=arm64 go build -o build/netx_macos_arm64 cmd/netx/*.go + # # Android shared libraries + # - env GOOS=android GOARCH=amd64 CGO_ENABLED=1 CC=$ANDROID_NDK_HOME/toolchains/llvm/prebuilt/linux-x86_64/bin/x86_64-linux-android26-clang go build -buildmode=c-shared -o build/libnetx_android_x64.so cmd/netx/lib/main.go + # - env GOOS=android GOARCH=arm64 CGO_ENABLED=1 CC=$ANDROID_NDK_HOME/toolchains/llvm/prebuilt/linux-x86_64/bin/aarch64-linux-android26-clang go build -buildmode=c-shared -o build/libnetx_android_arm64.so cmd/netx/lib/main.go + sources: + - '**/*.go' + - go.mod + - go.sum + generates: + - build/netx_linux_x64 + - build/netx_linux_arm64 + - build/libnetx_linux_x64.so + - build/libnetx_linux_arm64.so + - build/netx_windows_x64.exe + - build/netx_windows_arm64.exe + - build/libnetx_windows_x64.dll + - build/netx_macos_x64 + - build/netx_macos_arm64 + - build/libnetx_android_x64.so + - build/libnetx_android_arm64.so + + build-apple-libs: + desc: Build Apple libraries (optional, requires mac toolchains) + cmds: + - env GOOS=darwin GOARCH=amd64 CGO_ENABLED=1 go build -buildmode=c-shared -o build/libnetx_macos_x64.dylib cmd/netx/lib/main.go + - env GOOS=darwin GOARCH=arm64 CGO_ENABLED=1 go build -buildmode=c-shared -o build/libnetx_macos_arm64.dylib cmd/netx/lib/main.go + - | + mkdir -p build + export CC=$(xcrun -find -sdk iphonesimulator clang) + export CXX=$(xcrun -find -sdk iphonesimulator clang++) + export SDKROOT=$(xcrun --sdk iphonesimulator --show-sdk-path) + export CFLAGS="-arch x86_64 -isysroot $SDKROOT -mios-simulator-version-min=10.0" + export LDFLAGS="-arch x86_64 -isysroot $SDKROOT -mios-simulator-version-min=10.0" + env GOOS=darwin GOARCH=amd64 CGO_ENABLED=1 go build -buildmode=c-archive -o build/libnetx_ios_x64.a cmd/netx/lib/main.go + - | + mkdir -p build + export CC=$(xcrun -find -sdk iphoneos clang) + export CXX=$(xcrun -find -sdk iphoneos clang++) + export SDKROOT=$(xcrun --sdk iphoneos --show-sdk-path) + export CFLAGS="-arch arm64 -isysroot $SDKROOT -mios-version-min=10.0" + export LDFLAGS="-arch arm64 -isysroot $SDKROOT -mios-version-min=10.0" + env GOOS=darwin GOARCH=arm64 CGO_ENABLED=1 go build -buildmode=c-archive -o build/libnetx_ios_arm64.a cmd/netx/lib/main.go + sources: + - '**/*.go' + - go.mod + - go.sum + generates: + - build/libnetx_macos_x64.dylib + - build/libnetx_macos_arm64.dylib + - build/libnetx_ios_x64.a + - build/libnetx_ios_arm64.a + + e2e:tun: + desc: Run CLI end-to-end tun tests locally (uses .e2e working dir). + cmds: + - | + set -euo pipefail + ROOT=$(pwd) + WORK=.e2e + mkdir -p "$WORK" + + echo "Building netx binary..." + go build -o "$WORK/netx" ./cmd/netx + chmod +x "$WORK/netx" + cd "$WORK" + + echo "Generating certs and keys..." + openssl req -x509 -newkey rsa:2048 -keyout server.key -out server.crt -days 1 -nodes -subj "/CN=localhost" >/dev/null 2>&1 + openssl rand -hex 32 > psk.hex + cp psk.hex aes.hex + echo y | ssh-keygen -t ed25519 -f ssh_server_key -N "" -C "e2e-server" >/dev/null 2>&1 + echo y | ssh-keygen -t ed25519 -f ssh_client_key -N "" -C "e2e-client" >/dev/null 2>&1 + + echo "Building echo servers and clients..." + go build -tags e2e -o tcp_echo "$ROOT/internal/tools/e2e/tcp_echo" + go build -tags e2e -o udp_echo "$ROOT/internal/tools/e2e/udp_echo" + go build -tags e2e -o tcp_client "$ROOT/internal/tools/e2e/tcp_client" + go build -tags e2e -o udp_client "$ROOT/internal/tools/e2e/udp_client" + + # Use isolated ports to avoid conflicts with external demos + TE=48080; UE=48081 + STLS=49000; SDTLS=49100; SDTLSP=49300; SAESCT=49400; SAESCU=49500; SFR=49600; STLSPSK=49200; SSSH=49700; SUTLS=49800 + CTLS=50000; CDTLS=50010; CDTLSP=50011; CAESCT=50002; CAESCU=50012; CFR=50003; CTLSPSK=50001; CSSH=50004; CUTLS=50005 + + echo "Starting echo servers..." + ./tcp_echo 127.0.0.1:${TE} > tcp_echo.log 2>&1 & + ./udp_echo 127.0.0.1:${UE} > udp_echo.log 2>&1 & + + echo "Starting server tunnels..." + ./netx tun --from "tcp+tls[cert=$(xxd -p server.crt | tr -d '\n'),key=$(xxd -p server.key | tr -d '\n')]://127.0.0.1:${STLS}" --to "tcp://127.0.0.1:${TE}" --log info > tls_server.log 2>&1 & + ./netx tun --from "udp+dtls[cert=$(xxd -p server.crt | tr -d '\n'),key=$(xxd -p server.key | tr -d '\n')]://127.0.0.1:${SDTLS}" --to "udp://127.0.0.1:${UE}" --log info > dtls_server.log 2>&1 & + ./netx tun --from "udp+dtlspsk[identity=i,key=$(cat psk.hex)]://127.0.0.1:${SDTLSP}" --to "udp://127.0.0.1:${UE}" --log info > dtlspsk_server.log 2>&1 & + ./netx tun --from "tcp+buffered[size=8192]+framed[maxsize=4096]+aesgcm[key=$(cat aes.hex)]://127.0.0.1:${SAESCT}" --to "tcp://127.0.0.1:${TE}" --log info > aesgcm_tcp_server.log 2>&1 & + ./netx tun --from "udp+aesgcm[key=$(cat aes.hex)]://127.0.0.1:${SAESCU}" --to "udp://127.0.0.1:${UE}" --log info > aesgcm_udp_server.log 2>&1 & + ./netx tun --from "tcp+framed[maxsize=4096]://127.0.0.1:${SFR}" --to "udp://127.0.0.1:${UE}" --log info > framed_tcp_server.log 2>&1 & + ./netx tun --from "tcp+tlspsk[identity=i,key=$(cat psk.hex)]://127.0.0.1:${STLSPSK}" --to "tcp://127.0.0.1:${TE}" --log info > tlspsk_server.log 2>&1 & + ./netx tun --from "tcp+ssh[key=$(xxd -p ssh_server_key | tr -d '\n'),pubkey=$(xxd -p ssh_client_key.pub | tr -d '\n')]://127.0.0.1:${SSSH}" --to "tcp://127.0.0.1:${TE}" --log info > ssh_server.log 2>&1 & + ./netx tun --from "tcp+tls[cert=$(xxd -p server.crt | tr -d '\n'),key=$(xxd -p server.key | tr -d '\n')]://127.0.0.1:${SUTLS}" --to "tcp://127.0.0.1:${TE}" --log info > utls_server.log 2>&1 & + + echo "Starting client tunnels..." + ./netx tun --from "tcp://127.0.0.1:${CTLS}" --to "tcp+tls[cert=$(xxd -p server.crt | tr -d '\n')]://127.0.0.1:${STLS}" --log info > tls_client.log 2>&1 & + ./netx tun --from "udp://127.0.0.1:${CDTLS}" --to "udp+dtls[cert=$(xxd -p server.crt | tr -d '\n')]://127.0.0.1:${SDTLS}" --log info > dtls_client.log 2>&1 & + ./netx tun --from "udp://127.0.0.1:${CDTLSP}" --to "udp+dtlspsk[identity=i,key=$(cat psk.hex)]://127.0.0.1:${SDTLSP}" --log info > dtlspsk_client.log 2>&1 & + ./netx tun --from "tcp://127.0.0.1:${CAESCT}" --to "tcp+buffered[size=8192]+framed[maxsize=4096]+aesgcm[key=$(cat aes.hex)]://127.0.0.1:${SAESCT}" --log info > aesgcm_tcp_client.log 2>&1 & + ./netx tun --from "udp://127.0.0.1:${CAESCU}" --to "udp+aesgcm[key=$(cat aes.hex)]://127.0.0.1:${SAESCU}" --log info > aesgcm_udp_client.log 2>&1 & + ./netx tun --from "udp://127.0.0.1:${CFR}" --to "tcp+framed[maxsize=4096]://127.0.0.1:${SFR}" --log info > framed_tcp_client.log 2>&1 & + ./netx tun --from "tcp://127.0.0.1:${CTLSPSK}" --to "tcp+tlspsk[identity=i,key=$(cat psk.hex)]://127.0.0.1:${STLSPSK}" --log info > tlspsk_client.log 2>&1 & + ./netx tun --from "tcp://127.0.0.1:${CSSH}" --to "tcp+ssh[pubkey=$(xxd -p ssh_server_key.pub | tr -d '\n'),key=$(xxd -p ssh_client_key | tr -d '\n')]://127.0.0.1:${SSSH}" --log info > ssh_client.log 2>&1 & + ./netx tun --from "tcp://127.0.0.1:${CUTLS}" --to "tcp+utls[cert=$(xxd -p server.crt | tr -d '\n'),hello=chrome]://127.0.0.1:${SUTLS}" --log info > utls_client.log 2>&1 & + + sleep 2 + + echo "Running tests..." + pass=0; fail=0 + run_tcp(){ name=$1; addr=$2; msg=$3; out=$(./tcp_client "$addr" "$msg" || true); if [ "$out" = "$msg" ]; then echo "PASS $name"; pass=$((pass+1)); else echo "FAIL $name -> got: $out"; fail=$((fail+1)); fi } + run_udp(){ name=$1; addr=$2; msg=$3; out=$(./udp_client "$addr" "$msg" || true); if [ "$out" = "$msg" ]; then echo "PASS $name"; pass=$((pass+1)); else echo "FAIL $name -> got: $out"; fail=$((fail+1)); fi } + run_tcp TLS 127.0.0.1:${CTLS} hello_tls + run_udp DTLS 127.0.0.1:${CDTLS} hello_dtls + run_udp DTLSPSK 127.0.0.1:${CDTLSP} hello_dtlspsk + run_tcp AESGCM_TCP 127.0.0.1:${CAESCT} hello_aesgcm_tcp + run_udp AESGCM_UDP 127.0.0.1:${CAESCU} hello_aesgcm_udp + run_udp FRAMED_TCP_BR 127.0.0.1:${CFR} hello_udp_over_tcp + run_tcp SSH 127.0.0.1:${CSSH} hello_ssh + run_tcp UTLS 127.0.0.1:${CUTLS} hello_utls + run_tcp TLSPSK 127.0.0.1:${CTLSPSK} hello_tlspsk + echo "RESULTS: pass=$pass fail=$fail" + + echo "Cleaning up..." + pkill netx || true + pkill tcp_echo || true + pkill udp_echo || true + + echo "Done." + [ "$fail" -eq 0 ] diff --git a/aesgcm_conn.go b/aesgcm_conn.go new file mode 100644 index 0000000..dc30881 --- /dev/null +++ b/aesgcm_conn.go @@ -0,0 +1,172 @@ +package netx + +import ( + "crypto/aes" + "crypto/cipher" + "crypto/rand" + "encoding/binary" + "errors" + "io" + "net" + "sync/atomic" + "time" +) + +type aesgcmConn struct { + net.Conn + aead cipher.AEAD + wiv [12]byte + riv [12]byte + + // sequence number for nonce derivation, incremented atomically + seq atomic.Uint64 + + maxPacketSize int +} + +type AESGCMOption func(*aesgcmConn) + +// WithMaxPacket sets the maximum ciphertext packet size accepted on Read. +// Default is 32KB. This should be >= 8 (seq) + plaintext + aead.Overhead(). +func WithAESGCMMaxPacket(size uint32) AESGCMOption { + return func(c *aesgcmConn) { + c.maxPacketSize = int(size) + } +} + +// NewAESGCMConn constructs a new AES-GCM wrapper around a packet-based net.Conn. +// Key must be 16, 24, or 32 bytes (AES-128/192/256). +// It encrypts each packet using AES-GCM. It assumes the underlying conn +// preserves packet boundaries; it does not perform additional framing. +// +// Packet layout (single datagram): +// +// [8-byte seq big-endian][GCM(ciphertext||tag)] +// +// Nonce derivation: 12-byte IV is required. For a packet with sequence S, +// nonce = IV with its last 8 bytes XORed with S (big-endian). This ensures +// per-packet unique nonces without transmitting the full nonce. +// Write IV is randomly generated on creation and sent to the peer in the +// passive handshake that is performed on creation to exchange random IVs. +func NewAESGCMConn(c net.Conn, key []byte, opts ...AESGCMOption) (net.Conn, error) { + block, err := aes.NewCipher(key) + if err != nil { + return nil, err + } + a, err := cipher.NewGCM(block) + if err != nil { + return nil, err + } + agc := &aesgcmConn{ + Conn: c, + aead: a, + maxPacketSize: 32 * 1024} + for _, opt := range opts { + opt(agc) + } + if agc.maxPacketSize < 8+a.Overhead() { + return nil, errors.New("aesgcmConn: maxPacketSize too small") + } + if _, err := io.ReadFull(rand.Reader, agc.wiv[:]); err != nil { + return nil, err + } + + // Passive handshake (duplex): concurrently read peer IV while writing ours + handshakeDeadline := time.Now().Add(5 * time.Second) + _ = c.SetDeadline(handshakeDeadline) + defer func() { _ = c.SetDeadline(time.Time{}) }() // clear deadline after handshake + + // Start read of peer's 12-byte IV + readErrCh := make(chan error, 1) + go func() { + // io.ReadFull returns err if not enough bytes + _, err := io.ReadFull(c, agc.riv[:]) + readErrCh <- err + }() + + // Write our 12-byte IV + o := 0 + for o < len(agc.wiv) { + n, err := c.Write(agc.wiv[o:]) + if err != nil { + return nil, err + } + o += n + } + if o != len(agc.wiv) { + return nil, io.ErrShortWrite + } + + // Wait for read to complete + if err := <-readErrCh; err != nil { + return nil, err + } + + return agc, nil +} + +// Read reads and decrypts a single datagram from the underlying conn. +// If p is too small for the decrypted payload, io.ErrShortBuffer is returned. +func (c *aesgcmConn) Read(p []byte) (int, error) { + buf := make([]byte, c.maxPacketSize) + n, err := c.Conn.Read(buf) + if err != nil { + return 0, err + } + if n == c.maxPacketSize { + return 0, errors.New("aesgcmConn: packet may be truncated; increase maxPacketSize") + } + if n < 8+c.aead.Overhead() { + return 0, errors.New("aesgcmConn: packet too small") + } + + nonce := [12]byte{} + copy(nonce[:], c.riv[:]) + for i := range 8 { + nonce[4+i] ^= buf[i] + } + + buf, err = c.aead.Open(buf[8:8], nonce[:], buf[8:n], buf[:8]) + if err != nil { + return 0, err + } + + if len(buf) > len(p) { + return 0, io.ErrShortBuffer + } + + copy(p, buf) + return len(buf), nil +} + +// Write encrypts p as a single datagram and writes it to the underlying conn. +// It prepends an 8-byte sequence number used for nonce derivation. +func (c *aesgcmConn) Write(p []byte) (int, error) { + if len(p)+8+c.aead.Overhead() > c.maxPacketSize { + return 0, errors.New("aesgcmConn: packet may be too large; increase maxPacketSize") + } + buf := make([]byte, c.maxPacketSize) + + seq := c.seq.Add(1) - 1 + binary.BigEndian.PutUint64(buf[:8], seq) + + nonce := [12]byte{} + copy(nonce[:], c.wiv[:]) + for i := range 8 { + nonce[4+i] ^= buf[i] + } + + ct := c.aead.Seal(buf[8:8], nonce[:], p, buf[:8]) + buf = buf[:8+len(ct)] + + n, err := c.Conn.Write(buf) + if err != nil { + return 0, err + } + if n != len(buf) { + return 0, io.ErrShortWrite + } + + // Satisfy io.Writer contract: on success, return len(p) bytes written. + return len(p), nil +} diff --git a/aesgcm_conn_test.go b/aesgcm_conn_test.go new file mode 100644 index 0000000..2d8b610 --- /dev/null +++ b/aesgcm_conn_test.go @@ -0,0 +1,207 @@ +package netx_test + +import ( + "bytes" + "io" + "net" + "testing" + "time" + + netx "github.com/pedramktb/go-netx" +) + +// helper to create an AES-GCM protected pair over a framed connection +func newAESPair(t *testing.T) (client net.Conn, server net.Conn) { + t.Helper() + cr, sr := net.Pipe() + t.Cleanup(func() { _ = cr.Close(); _ = sr.Close() }) + + fc := netx.NewFramedConn(cr) + fs := netx.NewFramedConn(sr) + + key := bytes.Repeat([]byte{0x42}, 32) + + var ( + c net.Conn + s net.Conn + ec error + es error + done = make(chan struct{}, 2) + ) + go func() { c, ec = netx.NewAESGCMConn(fc, key); done <- struct{}{} }() + go func() { s, es = netx.NewAESGCMConn(fs, key); done <- struct{}{} }() + <-done + <-done + if ec != nil { + t.Fatalf("client aesgcm: %v", ec) + } + if es != nil { + t.Fatalf("server aesgcm: %v", es) + } + return c, s +} + +func TestAESGCM_Roundtrip(t *testing.T) { + c, s := newAESPair(t) + + msg := []byte("hello secret world") + + got := make([]byte, len(msg)) + done := make(chan error, 1) + go func() { + _, err := io.ReadFull(s, got) + done <- err + }() + time.Sleep(10 * time.Millisecond) + if _, err := c.Write(msg); err != nil { + t.Fatalf("write: %v", err) + } + select { + case err := <-done: + if err != nil { + t.Fatalf("readfull: %v", err) + } + case <-time.After(2 * time.Second): + t.Fatalf("timeout") + } + if !bytes.Equal(got, msg) { + t.Fatalf("mismatch") + } + + // multiple sequential messages should also work + data2 := bytes.Repeat([]byte("x"), 1024) + go func() { _, _ = c.Write(data2) }() + buf := make([]byte, len(data2)) + if _, err := io.ReadFull(s, buf); err != nil { + t.Fatalf("readfull2: %v", err) + } + if !bytes.Equal(buf, data2) { + t.Fatalf("mismatch2") + } +} + +func TestAESGCM_EmptyPayload(t *testing.T) { + c, s := newAESPair(t) + // write an empty datagram concurrently to avoid net.Pipe blocking + doneW := make(chan error, 1) + go func() { + _, err := c.Write(nil) + doneW <- err + }() + // should deliver a zero-length read (keep-alive style) + buf := make([]byte, 8) + n, err := s.Read(buf) + if err != nil { + t.Fatalf("read empty: %v", err) + } + if n != 0 { + t.Fatalf("expected zero-length read, got %d", n) + } + select { + case err := <-doneW: + if err != nil { + t.Fatalf("write empty err: %v", err) + } + case <-time.After(2 * time.Second): + t.Fatalf("write timeout") + } +} + +func TestAESGCM_ShortBufferDropsPacket(t *testing.T) { + c, s := newAESPair(t) + + first := bytes.Repeat([]byte("a"), 128) + second := []byte("ok") + + // send two packets back-to-back + go func() { _, _ = c.Write(first); _, _ = c.Write(second) }() + + // too-small buffer for the first packet should return io.ErrShortBuffer + small := make([]byte, 10) + if n, err := s.Read(small); err != io.ErrShortBuffer || n != 0 { + t.Fatalf("want io.ErrShortBuffer, got n=%d err=%v", n, err) + } + + // the next read should yield the second packet + buf := make([]byte, 16) + n, err := s.Read(buf) + if err != nil { + t.Fatalf("read second: %v", err) + } + if n != len(second) || !bytes.Equal(buf[:n], second) { + t.Fatalf("unexpected second packet: %q", buf[:n]) + } +} + +func TestAESGCM_MaxPacketWrite(t *testing.T) { + // choose a small max packet size so writes exceed it + // Overhead is 8 (seq) + 16 (GCM) + len(plaintext) + // set max to 48; max plaintext allowed ~= 24 + cr, sr := net.Pipe() + t.Cleanup(func() { _ = cr.Close(); _ = sr.Close() }) + fc := netx.NewFramedConn(cr) + fs := netx.NewFramedConn(sr) + key := bytes.Repeat([]byte{0x42}, 32) + var ( + c net.Conn + ec error + es error + done = make(chan struct{}, 2) + ) + go func() { c, ec = netx.NewAESGCMConn(fc, key, netx.WithAESGCMMaxPacket(48)); done <- struct{}{} }() + go func() { _, es = netx.NewAESGCMConn(fs, key, netx.WithAESGCMMaxPacket(48)); done <- struct{}{} }() + <-done + <-done + if ec != nil { + t.Fatalf("client: %v", ec) + } + if es != nil { + t.Fatalf("server: %v", es) + } + + big := bytes.Repeat([]byte("b"), 64) + if _, err := c.Write(big); err == nil { + t.Fatalf("expected write error due to max packet size") + } +} + +func TestAESGCM_DecryptErrorWrongKey(t *testing.T) { + cr, sr := net.Pipe() + t.Cleanup(func() { _ = cr.Close(); _ = sr.Close() }) + fc := netx.NewFramedConn(cr) + fs := netx.NewFramedConn(sr) + + keyA := bytes.Repeat([]byte{0x11}, 32) + keyB := bytes.Repeat([]byte{0x22}, 32) + + var ( + c net.Conn + s net.Conn + ec error + es error + done = make(chan struct{}, 2) + ) + go func() { c, ec = netx.NewAESGCMConn(fc, keyA); done <- struct{}{} }() + go func() { s, es = netx.NewAESGCMConn(fs, keyB); done <- struct{}{} }() + <-done + <-done + if ec != nil { + t.Fatalf("client: %v", ec) + } + if es != nil { + t.Fatalf("server: %v", es) + } + + // write a packet and expect read to fail + writeDone := make(chan error, 1) + go func() { + _, err := c.Write([]byte("test")) + writeDone <- err + }() + + buf := make([]byte, 16) + if _, err := s.Read(buf); err == nil { + t.Fatalf("expected decrypt error") + } + <-writeDone +} diff --git a/buffered_conn.go b/buffered_conn.go index 413a0ae..75bfdde 100644 --- a/buffered_conn.go +++ b/buffered_conn.go @@ -4,7 +4,6 @@ import ( "bufio" "errors" "net" - "time" ) type BufConn interface { @@ -13,32 +12,39 @@ type BufConn interface { } type bufConn struct { - bc net.Conn + net.Conn br *bufio.Reader bw *bufio.Writer } -type bufConnOption func(*bufConn) +type BufConnOption func(*bufConn) + +func WithBufSize(size uint32) BufConnOption { + return func(bc *bufConn) { + bc.br = bufio.NewReaderSize(bc.Conn, int(size)) + bc.bw = bufio.NewWriterSize(bc.Conn, int(size)) + } +} -func WithBufWriterSize(size int) bufConnOption { +func WithBufWriterSize(size uint32) BufConnOption { return func(bc *bufConn) { - bc.bw = bufio.NewWriterSize(bc.bc, size) + bc.bw = bufio.NewWriterSize(bc.Conn, int(size)) } } -func WithBufReaderSize(size int) bufConnOption { +func WithBufReaderSize(size uint32) BufConnOption { return func(bc *bufConn) { - bc.br = bufio.NewReaderSize(bc.bc, size) + bc.br = bufio.NewReaderSize(bc.Conn, int(size)) } } // NewBufConn wraps a net.Conn with buffered reader and writer. // By default, the buffer size is 4KB. Use WithBufWriterSize and WithBufReaderSize to customize the sizes. -func NewBufConn(c net.Conn, opts ...bufConnOption) BufConn { +func NewBufConn(c net.Conn, opts ...BufConnOption) BufConn { bc := &bufConn{ - bc: c, - br: bufio.NewReader(c), - bw: bufio.NewWriter(c), + Conn: c, + br: bufio.NewReader(c), + bw: bufio.NewWriter(c), } for _, opt := range opts { opt(bc) @@ -57,17 +63,12 @@ func (c *bufConn) Close() error { err = errors.Join(err, fErr) } } - if c.bc != nil { - if cErr := c.bc.Close(); cErr != nil { + if c.Conn != nil { + if cErr := c.Conn.Close(); cErr != nil { err = errors.Join(err, cErr) } } return err } -func (c *bufConn) LocalAddr() net.Addr { return c.bc.LocalAddr() } -func (c *bufConn) RemoteAddr() net.Addr { return c.bc.RemoteAddr() } -func (c *bufConn) SetDeadline(t time.Time) error { return c.bc.SetDeadline(t) } -func (c *bufConn) SetReadDeadline(t time.Time) error { return c.bc.SetReadDeadline(t) } -func (c *bufConn) SetWriteDeadline(t time.Time) error { return c.bc.SetWriteDeadline(t) } func (c *bufConn) Flush() error { return c.bw.Flush() } diff --git a/cmd/netx/main.go b/cmd/netx/main.go new file mode 100644 index 0000000..8e6eddb --- /dev/null +++ b/cmd/netx/main.go @@ -0,0 +1,70 @@ +package main + +import ( + "context" + "fmt" + "log/slog" + "os" + "os/signal" + "strings" + "syscall" + + "github.com/spf13/cobra" +) + +func main() { + ctx, cancel := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM) + defer cancel() + + var logLevel string + + cmd := &cobra.Command{ + Use: "netx [command]", + Short: "Small networking toolbox", + Long: "netx is a small networking toolbox.", + SilenceUsage: true, + SilenceErrors: true, + RunE: func(cmd *cobra.Command, args []string) error { + return cmd.Help() + }, + PersistentPreRunE: func(cmd *cobra.Command, args []string) error { + lvl, err := parseLogLevel(logLevel) + if err != nil { + return err + } + slog.SetDefault(slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: lvl}))) + return nil + }, + } + + defaultHelp := cmd.HelpFunc() + cmd.SetHelpFunc(func(cmd *cobra.Command, args []string) { + defaultHelp(cmd, args) + fmt.Fprintln(cmd.OutOrStdout()) + fmt.Fprint(cmd.OutOrStdout(), uriFormat) + }) + + cmd.PersistentFlags().StringVar(&logLevel, "log", "info", "log level: debug|info|warn|error") + + cmd.AddCommand(tun(cancel)) + + if err := cmd.ExecuteContext(ctx); err != nil { + fmt.Fprintln(os.Stderr, err) + os.Exit(1) + } +} + +func parseLogLevel(level string) (slog.Level, error) { + switch strings.ToLower(strings.TrimSpace(level)) { + case "", "info": + return slog.LevelInfo, nil + case "debug": + return slog.LevelDebug, nil + case "warn", "warning": + return slog.LevelWarn, nil + case "error": + return slog.LevelError, nil + default: + return 0, fmt.Errorf("invalid log level %q", level) + } +} diff --git a/cmd/netx/tun.go b/cmd/netx/tun.go new file mode 100644 index 0000000..7fe1663 --- /dev/null +++ b/cmd/netx/tun.go @@ -0,0 +1,102 @@ +package main + +import ( + "context" + "errors" + "fmt" + "log/slog" + "net" + "time" + + netx "github.com/pedramktb/go-netx" + "github.com/pedramktb/go-netx/uri" + "github.com/spf13/cobra" +) + +const tunExample = ` netx tun \ + --from "tcp+tls[cert=$(cat server.crt | xxd -p),key=$(cat server.key | xxd -p)]://:9000" \ + --to "udp+aesgcm[key=00112233445566778899aabbccddeeff]://127.0.0.1:5555" +` + +func tun(cancel context.CancelFunc) *cobra.Command { + var from string + var to string + + if cancel == nil { + cancel = func() {} + } + + cmd := &cobra.Command{ + Use: "tun", + Short: "Relay between two endpoints with chainable transforms.", + Long: "tun relays between two endpoints with chainable transforms, this can be used for obfuscation tunnels, proxies, reverse proxies, etc.", + Example: tunExample, + SilenceUsage: true, + SilenceErrors: true, + RunE: func(cmd *cobra.Command, args []string) error { + ctx := cmd.Context() + if ctx == nil { + ctx = context.Background() + } + err := runTun(ctx, cancel, from, to) + if err != nil { + return errors.Join(err, cmd.Help()) + } + return nil + }, + } + + cmd.Flags().StringVar(&from, "from", "", "") + cmd.Flags().StringVar(&to, "to", "", "") + + _ = cmd.MarkFlagRequired("from") + _ = cmd.MarkFlagRequired("to") + + return cmd +} + +func runTun(ctx context.Context, cancel context.CancelFunc, from, to string) error { + var fromURI, toURI uri.URI + fromURI.Listener = true + if err := fromURI.UnmarshalText([]byte(from)); err != nil { + return fmt.Errorf("parse --from: %w", err) + } + if err := toURI.UnmarshalText([]byte(to)); err != nil { + return fmt.Errorf("parse --to: %w", err) + } + + ln, err := fromURI.Listen(ctx) + if err != nil { + return err + } + defer ln.Close() + + tm := netx.TunMaster[struct{}]{} + + tm.SetRoute(struct{}{}, func(ctx context.Context, conn net.Conn) (bool, context.Context, netx.Tun) { + pconn, err := toURI.Dial(ctx) + if err != nil { + slog.Error("dial peer", "err", err) + _ = conn.Close() + return false, ctx, netx.Tun{} + } + + return true, ctx, netx.Tun{Conn: conn, Peer: pconn} + }) + + go func() { + if err := tm.Serve(ctx, ln); err != nil && !errors.Is(err, netx.ErrServerClosed) { + slog.Error("serve error", "err", err) + cancel() + } + }() + + slog.Info("netx tun started", "listen", ln.Addr().String(), "from", from, "to", to) + + <-ctx.Done() + shutdownCtx, stop := context.WithTimeout(context.Background(), 3*time.Second) + defer stop() + _ = tm.Shutdown(shutdownCtx) + + return nil +} diff --git a/cmd/netx/uri.go b/cmd/netx/uri.go new file mode 100644 index 0000000..87e2d24 --- /dev/null +++ b/cmd/netx/uri.go @@ -0,0 +1,42 @@ +package main + +const uriFormat = `URI Format: + +[layer1param1key=layer1param1value,layer1param2key=layer1param2value,...]++...://
+ + Examples: + tcp+tls[cert=$(cat server.crt | xxd -p),key=$(cat server.key | xxd -p)]://:9000 + tcp+tls[cert=$(cat client.crt | xxd -p)]+buffered[size=8192]+framed[maxsize=4096]+aesgcm[key=00112233445566778899aabbccddeeff]://example.com:9443 + + Supported transports: + - tcp: TCP listener or dialer + - udp: UDP listener or dialer + + Supported layers: + - framed: length-prefixed frames for transports or layers that need packet semantics over streams. + params: maxsize (optional, defaults to 32768) + - buffered: buffered read/write for better performance when using framing. + params: size (optional, defaults to 4096) + - aesgcm: AES-GCM encryption. A passive 12-byte handshake exchanges IVs. + params: key, maxpacket (optional, defaults to 32768) + - ssh: SSH tunneling via "direct-tcpip" channels. + server params: key, pass (optional), pubkey (optional, required if no pass) + client options: pubkey, pass (optional), key (optional, required if no pass) + - tls: Transport Layer Security + server params: key, cert + client params: cert (optional, for SPKI pinning), servername (required if cert not provided) + - utls: TLS with client fingerprint camouflage via uTLS (github.com/refraction-networking/utls) + client params: cert (optional, for SPKI pinning), servername (required if cert not provided), hello (optional, e.g. chrome, firefox, ios, android, safari, edge, randomized) + - dtls: Datagram Transport Layer Security + server params: key, cert + client params: cert (optional, for SPKI pinning), servername (required if cert not provided) + - tlspsk: TLS with pre-shared key. Cipher is TLS_DHE_PSK_WITH_AES_256_CBC_SHA. + params: key + - dtlspsk: DTLS with pre-shared key. Cipher is TLS_PSK_WITH_AES_128_GCM_SHA256. + params: key + + Notes: + - All passwords, keys and certificates must be provided as hex-encoded strings. + - When using 'cert' for client-side TLS/uTLS/DTLS, default validation is disabled and a manual SPKI (SubjectPublicKeyInfo) hash comparison is performed + against the provided certificate. This is certificate pinning and will fail if the server presents a different key. + - SSH server must accept "direct-tcpip" channels (most do by default). +` diff --git a/dial.go b/dial.go new file mode 100644 index 0000000..5f0c7d5 --- /dev/null +++ b/dial.go @@ -0,0 +1,77 @@ +package netx + +import ( + "context" + "net" + "strings" + + pudp "github.com/pion/transport/v3/udp" +) + +type listenCfg struct { + net.ListenConfig + packet pudp.ListenConfig +} + +type ListenOption func(*listenCfg) + +func WithListenConfig(cfg net.ListenConfig) ListenOption { + return func(lc *listenCfg) { + lc.ListenConfig = cfg + } +} + +func WithPacketListenConfig(cfg pudp.ListenConfig) ListenOption { + return func(lc *listenCfg) { + lc.packet = cfg + } +} + +func Listen(ctx context.Context, network, addr string, opts ...ListenOption) (net.Listener, error) { + cfg := &listenCfg{} + for _, o := range opts { + o(cfg) + } + switch strings.Split(network, ":")[0] { + case "udp", "udp4", "udp6": + uaddr, err := net.ResolveUDPAddr(network, addr) + if err != nil { + return nil, err + } + return cfg.packet.Listen(network, uaddr) + // case "ip", "ip4", "ip6": + // iaddr, err := net.ResolveIPAddr(network, addr) + // if err != nil { + // return nil, err + // } + // return (&ip.ListenConfig{ + // Backlog: cfg.packet.Backlog, + // AcceptFilter: cfg.packet.AcceptFilter, + // ReadBufferSize: cfg.packet.ReadBufferSize, + // WriteBufferSize: cfg.packet.WriteBufferSize, + // Batch: cfg.packet.Batch, + // }).Listen(network, iaddr) + default: + return cfg.Listen(ctx, network, addr) + } +} + +type dialCfg struct { + net.Dialer +} + +type DialOption func(*dialCfg) + +func WithDialConfig(cfg net.Dialer) DialOption { + return func(dc *dialCfg) { + dc.Dialer = cfg + } +} + +func Dial(ctx context.Context, network, addr string, opts ...DialOption) (net.Conn, error) { + cfg := &dialCfg{} + for _, o := range opts { + o(cfg) + } + return cfg.DialContext(ctx, network, addr) +} diff --git a/framed_conn.go b/framed_conn.go index d60d3ad..b0181a2 100644 --- a/framed_conn.go +++ b/framed_conn.go @@ -6,34 +6,33 @@ import ( "io" "net" "sync" - "time" ) var ErrFrameTooLarge = errors.New("framedConn: frame too large") type framedConn struct { - bc net.Conn + net.Conn maxFrameSize int pending []byte rmu, wmu sync.Mutex } -type framedConnOption func(*framedConn) +type FramedConnOption func(*framedConn) -func WithMaxFrameSize(size int) framedConnOption { +func WithMaxFrameSize(size uint32) FramedConnOption { return func(c *framedConn) { - c.maxFrameSize = size + c.maxFrameSize = int(size) } } // NewFramedConn wraps a net.Conn with a simple length-prefixed framing protocol. // Each frame is prefixed with a 4-byte big-endian unsigned integer indicating the length of the frame. // If the frame size exceeds maxFrameSize, Read will return ErrFrameTooLarge. -// The default maxFrameSize is 16KB. -func NewFramedConn(c net.Conn, opts ...framedConnOption) net.Conn { +// The default maxFrameSize is 32KB. +func NewFramedConn(c net.Conn, opts ...FramedConnOption) net.Conn { fc := &framedConn{ - bc: c, - maxFrameSize: 1 << 14, // 16KB default max frame size + Conn: c, + maxFrameSize: 32 * 1024, // 32KB default max frame size } for _, opt := range opts { opt(fc) @@ -53,7 +52,7 @@ func (c *framedConn) Read(p []byte) (int, error) { } var hdr [4]byte - if _, err := io.ReadFull(c.bc, hdr[:]); err != nil { + if _, err := io.ReadFull(c.Conn, hdr[:]); err != nil { return 0, err } n := int(binary.BigEndian.Uint32(hdr[:])) @@ -66,12 +65,12 @@ func (c *framedConn) Read(p []byte) (int, error) { } if len(p) >= n { - _, err := io.ReadFull(c.bc, p[:n]) + _, err := io.ReadFull(c.Conn, p[:n]) return n, err } buf := make([]byte, n) - if _, err := io.ReadFull(c.bc, buf); err != nil { + if _, err := io.ReadFull(c.Conn, buf); err != nil { return 0, err } w := copy(p, buf) @@ -86,27 +85,20 @@ func (c *framedConn) Write(p []byte) (int, error) { var hdr [4]byte binary.BigEndian.PutUint32(hdr[:], uint32(len(p))) - if _, err := c.bc.Write(hdr[:]); err != nil { + if _, err := c.Conn.Write(hdr[:]); err != nil { return 0, err } if len(p) == 0 { return 0, nil } - if _, err := c.bc.Write(p); err != nil { + if _, err := c.Conn.Write(p); err != nil { return 0, err } // If the underlying layer is buffered and implements Flush, flush now to coalesce header+payload. - if fw, ok := c.bc.(BufConn); ok { + if fw, ok := c.Conn.(BufConn); ok { if err := fw.Flush(); err != nil { return 0, err } } return len(p), nil } - -func (c *framedConn) Close() error { return c.bc.Close() } -func (c *framedConn) LocalAddr() net.Addr { return c.bc.LocalAddr() } -func (c *framedConn) RemoteAddr() net.Addr { return c.bc.RemoteAddr() } -func (c *framedConn) SetDeadline(t time.Time) error { return c.bc.SetDeadline(t) } -func (c *framedConn) SetReadDeadline(t time.Time) error { return c.bc.SetReadDeadline(t) } -func (c *framedConn) SetWriteDeadline(t time.Time) error { return c.bc.SetWriteDeadline(t) } diff --git a/go.mod b/go.mod index 3482f60..299e482 100644 --- a/go.mod +++ b/go.mod @@ -1,3 +1,27 @@ module github.com/pedramktb/go-netx go 1.25.1 + +require ( + github.com/pion/dtls/v3 v3.0.7 + github.com/pion/transport/v3 v3.0.8 + github.com/raff/tls-ext v1.0.0 + github.com/raff/tls-psk v1.0.0 + github.com/refraction-networking/utls v1.8.0 + github.com/spf13/cobra v1.8.1 + github.com/stretchr/testify v1.11.1 + golang.org/x/crypto v0.43.0 + golang.org/x/net v0.46.0 +) + +require ( + github.com/andybalholm/brotli v1.2.0 // indirect + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/inconshreveable/mousetrap v1.1.0 // indirect + github.com/klauspost/compress v1.18.0 // indirect + github.com/pion/logging v0.2.4 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/spf13/pflag v1.0.5 // indirect + golang.org/x/sys v0.37.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) diff --git a/go.sum b/go.sum index e69de29..8917f5c 100644 --- a/go.sum +++ b/go.sum @@ -0,0 +1,50 @@ +github.com/andybalholm/brotli v1.2.0 h1:ukwgCxwYrmACq68yiUqwIWnGY0cTPox/M94sVwToPjQ= +github.com/andybalholm/brotli v1.2.0/go.mod h1:rzTDkvFWvIrjDXZHkuS16NPggd91W3kUSvPlQ1pLaKY= +github.com/cpuguy83/go-md2man/v2 v2.0.4/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= +github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= +github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo= +github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ= +github.com/pion/dtls/v3 v3.0.7 h1:bItXtTYYhZwkPFk4t1n3Kkf5TDrfj6+4wG+CZR8uI9Q= +github.com/pion/dtls/v3 v3.0.7/go.mod h1:uDlH5VPrgOQIw59irKYkMudSFprY9IEFCqz/eTz16f8= +github.com/pion/logging v0.2.4 h1:tTew+7cmQ+Mc1pTBLKH2puKsOvhm32dROumOZ655zB8= +github.com/pion/logging v0.2.4/go.mod h1:DffhXTKYdNZU+KtJ5pyQDjvOAh/GsNSyv1lbkFbe3so= +github.com/pion/transport/v3 v3.0.8 h1:oI3myyYnTKUSTthu/NZZ8eu2I5sHbxbUNNFW62olaYc= +github.com/pion/transport/v3 v3.0.8/go.mod h1:+c2eewC5WJQHiAA46fkMMzoYZSuGzA/7E2FPrOYHctQ= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/raff/tls-ext v1.0.0 h1:72EP1QYiXxTpTt3zWLi6YefLDnXFHTvnxog/H6COwj4= +github.com/raff/tls-ext v1.0.0/go.mod h1:HEICLTE9Cp+MmIiJ9iZnNj4VYxkUKjdpEml9ersDBbs= +github.com/raff/tls-psk v1.0.0 h1:cLGFfZCxtkBpsie1TzACuYHJHEj0VYRN1dCv+lPRPxo= +github.com/raff/tls-psk v1.0.0/go.mod h1:SUNKszL9dnQq9lkqg7P34Qrg9FuCiHcTKRVqdIyHbF0= +github.com/refraction-networking/utls v1.8.0 h1:L38krhiTAyj9EeiQQa2sg+hYb4qwLCqdMcpZrRfbONE= +github.com/refraction-networking/utls v1.8.0/go.mod h1:jkSOEkLqn+S/jtpEHPOsVv/4V4EVnelwbMQl4vCWXAM= +github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= +github.com/spf13/cobra v1.8.1 h1:e5/vxKd/rZsfSJMUX1agtjeTDf+qv1/JdBF8gg5k9ZM= +github.com/spf13/cobra v1.8.1/go.mod h1:wHxEcudfqmLYa8iTfL+OuZPbBZkmvliBWKIezN3kD9Y= +github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= +github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= +github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= +github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= +github.com/xyproto/randomstring v1.0.5 h1:YtlWPoRdgMu3NZtP45drfy1GKoojuR7hmRcnhZqKjWU= +github.com/xyproto/randomstring v1.0.5/go.mod h1:rgmS5DeNXLivK7YprL0pY+lTuhNQW3iGxZ18UQApw/E= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20191206172530-e9b2fee46413/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= +golang.org/x/crypto v0.43.0 h1:dduJYIi3A3KOfdGOHX8AVZ/jGiyPa3IbBozJ5kNuE04= +golang.org/x/crypto v0.43.0/go.mod h1:BFbav4mRNlXJL4wNeejLpWxB7wMbc79PdRGhWKncxR0= +golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.46.0 h1:giFlY12I07fugqwPuWJi68oOnpfqFnJIJzaIIm2JVV4= +golang.org/x/net v0.46.0/go.mod h1:Q9BGdFy1y4nkUwiLvT5qtyhAnEHgnQ/zd8PfU6nc210= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.37.0 h1:fdNQudmxPjkdUTPnLn5mdQv7Zwvbvpaxqs831goi9kQ= +golang.org/x/sys v0.37.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= +golang.org/x/term v0.36.0 h1:zMPR+aF8gfksFprF/Nc/rd1wRS1EI6nDBGyWAvDzx2Q= +golang.org/x/term v0.36.0/go.mod h1:Qu394IJq6V6dCBRgwqshf3mPF85AqzYEzofzRdZkWss= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/internal/tools/e2e/tcp_client/main.go b/internal/tools/e2e/tcp_client/main.go new file mode 100644 index 0000000..9ee9179 --- /dev/null +++ b/internal/tools/e2e/tcp_client/main.go @@ -0,0 +1,38 @@ +//go:build e2e +// +build e2e + +package main + +import ( + "bufio" + "fmt" + "net" + "os" + "time" +) + +func main() { + if len(os.Args) < 3 { + fmt.Println("usage: tcp_client host:port message") + os.Exit(2) + } + addr, msg := os.Args[1], os.Args[2] + c, err := net.DialTimeout("tcp", addr, 2*time.Second) + if err != nil { + fmt.Println("dial error:", err) + os.Exit(1) + } + defer c.Close() + _ = c.SetDeadline(time.Now().Add(3 * time.Second)) + if _, err := c.Write([]byte(msg)); err != nil { + fmt.Println("write error:", err) + os.Exit(1) + } + r := bufio.NewReader(c) + buf := make([]byte, len(msg)) + if _, err := r.Read(buf); err != nil { + fmt.Println("read error:", err) + os.Exit(1) + } + fmt.Print(string(buf)) +} diff --git a/internal/tools/e2e/tcp_echo/main.go b/internal/tools/e2e/tcp_echo/main.go new file mode 100644 index 0000000..824c5f2 --- /dev/null +++ b/internal/tools/e2e/tcp_echo/main.go @@ -0,0 +1,47 @@ +//go:build e2e +// +build e2e + +package main + +import ( + "bufio" + "io" + "log" + "net" + "os" +) + +func main() { + addr := "127.0.0.1:28080" + if len(os.Args) > 1 { + addr = os.Args[1] + } + ln, err := net.Listen("tcp", addr) + if err != nil { + log.Fatal(err) + } + log.Printf("tcp echo listening on %s", ln.Addr()) + for { + c, err := ln.Accept() + if err != nil { + log.Fatal(err) + } + go func(conn net.Conn) { + defer conn.Close() + r := bufio.NewReader(conn) + buf := make([]byte, 4096) + for { + n, err := r.Read(buf) + if n > 0 { + _, _ = conn.Write(buf[:n]) + } + if err != nil { + if err != io.EOF { + log.Println("read err:", err) + } + return + } + } + }(c) + } +} diff --git a/internal/tools/e2e/udp_client/main.go b/internal/tools/e2e/udp_client/main.go new file mode 100644 index 0000000..126a56f --- /dev/null +++ b/internal/tools/e2e/udp_client/main.go @@ -0,0 +1,42 @@ +//go:build e2e +// +build e2e + +package main + +import ( + "fmt" + "net" + "os" + "time" +) + +func main() { + if len(os.Args) < 3 { + fmt.Println("usage: udp_client host:port message") + os.Exit(2) + } + addr, msg := os.Args[1], os.Args[2] + raddr, err := net.ResolveUDPAddr("udp", addr) + if err != nil { + fmt.Println("resolve error:", err) + os.Exit(1) + } + c, err := net.DialUDP("udp", nil, raddr) + if err != nil { + fmt.Println("dial error:", err) + os.Exit(1) + } + defer c.Close() + _ = c.SetDeadline(time.Now().Add(3 * time.Second)) + if _, err := c.Write([]byte(msg)); err != nil { + fmt.Println("write error:", err) + os.Exit(1) + } + buf := make([]byte, 65535) + n, _, err := c.ReadFrom(buf) + if err != nil { + fmt.Println("read error:", err) + os.Exit(1) + } + fmt.Print(string(buf[:n])) +} diff --git a/internal/tools/e2e/udp_echo/main.go b/internal/tools/e2e/udp_echo/main.go new file mode 100644 index 0000000..0a5a61e --- /dev/null +++ b/internal/tools/e2e/udp_echo/main.go @@ -0,0 +1,36 @@ +//go:build e2e +// +build e2e + +package main + +import ( + "log" + "net" + "os" + "time" +) + +func main() { + addr := "127.0.0.1:28081" + if len(os.Args) > 1 { + addr = os.Args[1] + } + a, err := net.ResolveUDPAddr("udp", addr) + if err != nil { + log.Fatal(err) + } + conn, err := net.ListenUDP("udp", a) + if err != nil { + log.Fatal(err) + } + log.Printf("udp echo listening on %s", conn.LocalAddr()) + buf := make([]byte, 65535) + for { + n, ra, err := conn.ReadFromUDP(buf) + if err != nil { + log.Fatal(err) + } + conn.SetWriteDeadline(time.Now().Add(5 * time.Second)) + _, _ = conn.WriteToUDP(buf[:n], ra) + } +} diff --git a/ssh_conn.go b/ssh_conn.go new file mode 100644 index 0000000..d8e2c8f --- /dev/null +++ b/ssh_conn.go @@ -0,0 +1,72 @@ +package netx + +import ( + "errors" + "net" + "time" + + ssh "golang.org/x/crypto/ssh" +) + +type sshConn struct { + ssh.Channel + sshConn ssh.Conn + bc net.Conn +} + +func (s *sshConn) LocalAddr() net.Addr { return s.sshConn.LocalAddr() } +func (s *sshConn) RemoteAddr() net.Addr { return s.sshConn.RemoteAddr() } +func (s *sshConn) SetDeadline(t time.Time) error { return s.bc.SetDeadline(t) } +func (s *sshConn) SetReadDeadline(t time.Time) error { return s.bc.SetReadDeadline(t) } +func (s *sshConn) SetWriteDeadline(t time.Time) error { return s.bc.SetWriteDeadline(t) } +func (s *sshConn) Close() error { + return errors.Join(s.Channel.Close(), s.sshConn.Close()) +} +func (s *sshConn) CloseWrite() error { + err := s.Channel.CloseWrite() + if bcCloseWrite, ok := s.bc.(interface{ CloseWrite() error }); ok { + err = errors.Join(err, bcCloseWrite.CloseWrite()) + } + return err +} + +func NewSSHServerConn(bc net.Conn, cfg *ssh.ServerConfig) (net.Conn, error) { + svConn, sshChans, sshReqs, err := ssh.NewServerConn(bc, cfg) + if err != nil { + return nil, err + } + go ssh.DiscardRequests(sshReqs) + for newCh := range sshChans { + switch newCh.ChannelType() { + case "direct-tcpip": + ch, reqs, err := newCh.Accept() + if err != nil { + _ = svConn.Close() + return nil, err + } + go ssh.DiscardRequests(reqs) + return &sshConn{Channel: ch, sshConn: svConn, bc: bc}, nil + default: + _ = newCh.Reject(ssh.UnknownChannelType, "unsupported channel type") + return nil, errors.New("no supported ssh channel opened by client") + } + + } + _ = svConn.Close() + return nil, errors.New("no ssh channel opened by client") +} + +func NewSSHClientConn(bc net.Conn, cfg *ssh.ClientConfig) (net.Conn, error) { + clConn, _, sshReqs, err := ssh.NewClientConn(bc, "", cfg) + if err != nil { + return nil, err + } + go ssh.DiscardRequests(sshReqs) + ch, reqs, err := clConn.OpenChannel("direct-tcpip", nil) + if err != nil { + _ = clConn.Close() + return nil, err + } + go ssh.DiscardRequests(reqs) + return &sshConn{Channel: ch, sshConn: clConn, bc: bc}, nil +} diff --git a/tun_e2e_udp_tcp_test.go b/tun_udp_tcp_int_test.go similarity index 99% rename from tun_e2e_udp_tcp_test.go rename to tun_udp_tcp_int_test.go index 423df52..0fe1195 100644 --- a/tun_e2e_udp_tcp_test.go +++ b/tun_udp_tcp_int_test.go @@ -77,7 +77,7 @@ func newUDPPair(t *testing.T) (*net.UDPConn, *net.UDPConn) { return a, b } -func TestE2E_UDP_over_TCP_TunMasters(t *testing.T) { +func TestInt_UDP_over_TCP_TunMasters(t *testing.T) { t.Parallel() ctx := context.Background() logger := &memLogger{} @@ -182,7 +182,7 @@ func TestE2E_UDP_over_TCP_TunMasters(t *testing.T) { // - TLS route: expects a tls.Conn and forwards to server UDPTLS peer // - Plain route: handles non-TLS and forwards to server UDPPlain peer // Two client tunnels are created: one plain over framed stream, one framed over TLS. -func TestE2E_TunMasterRouting_PlainAndTLS(t *testing.T) { +func TestInt_TunMasterRouting_PlainAndTLS(t *testing.T) { t.Parallel() ctx := context.Background() logger := &memLogger{} diff --git a/uri/dial.go b/uri/dial.go new file mode 100644 index 0000000..d3c5111 --- /dev/null +++ b/uri/dial.go @@ -0,0 +1,44 @@ +package uri + +import ( + "context" + "fmt" + "net" + + "github.com/pedramktb/go-netx" +) + +type listener struct { + net.Listener + uri *URI +} + +func (l *listener) Accept() (net.Conn, error) { + c, err := l.Listener.Accept() + if err != nil { + return nil, err + } + return l.uri.Layers.Wrap(c) +} + +func (u URI) Listen(ctx context.Context, opts ...netx.ListenOption) (net.Listener, error) { + if !u.Listener { + return nil, fmt.Errorf("uri: cannot listen on a non-listener URI") + } + l, err := netx.Listen(ctx, u.Scheme.Transport.String(), u.Addr, opts...) + if err != nil { + return nil, fmt.Errorf("error listening on %s://%s: %w", u.Scheme.Transport.String(), u.Addr, err) + } + return &listener{l, &u}, nil +} + +func (u URI) Dial(ctx context.Context, opts ...netx.DialOption) (net.Conn, error) { + if u.Listener { + return nil, fmt.Errorf("uri: cannot dial on a listener URI") + } + c, err := netx.Dial(ctx, u.Scheme.Transport.String(), u.Addr, opts...) + if err != nil { + return nil, fmt.Errorf("error dialing %s://%s: %w", u.Scheme.Transport.String(), u.Addr, err) + } + return u.Layers.Wrap(c) +} diff --git a/uri/layer.go b/uri/layer.go new file mode 100644 index 0000000..ca8ec17 --- /dev/null +++ b/uri/layer.go @@ -0,0 +1,581 @@ +package uri + +import ( + "bytes" + "crypto/sha256" + "crypto/tls" + "crypto/x509" + "encoding/hex" + "encoding/json" + "encoding/pem" + "errors" + "fmt" + "net" + "strconv" + "strings" + + "github.com/pedramktb/go-netx" + "github.com/pion/dtls/v3" + dtlsnet "github.com/pion/dtls/v3/pkg/net" + tlswithpks "github.com/raff/tls-ext" + tlspks "github.com/raff/tls-psk" + utls "github.com/refraction-networking/utls" + "golang.org/x/crypto/ssh" +) + +type Layers struct { + Listener bool + Layers []Layer +} + +func (ls Layers) Wrap(conn net.Conn) (net.Conn, error) { + var err error + for _, l := range ls.Layers { + conn, err = l.Wrap(conn) + if err != nil { + return nil, fmt.Errorf("wrap %q: %w", l.String(), err) + } + } + return conn, nil +} + +func (ls Layers) String() string { + strs := make([]string, len(ls.Layers)) + for i, l := range ls.Layers { + strs[i] = l.String() + } + return strings.Join(strs, "+") +} + +func (ls Layers) MarshalText() ([]byte, error) { + return []byte(ls.String()), nil +} + +func (ls *Layers) UnmarshalText(text []byte) error { + parts := strings.Split(string(text), "+") + ls.Layers = make([]Layer, len(parts)) + for i := range parts { + ls.Layers[i].Listener = ls.Listener + if err := ls.Layers[i].UnmarshalText([]byte(parts[i])); err != nil { + return err + } + } + + return nil +} + +func (ls Layers) MarshalJSON() ([]byte, error) { + return json.Marshal(ls.Layers) +} + +func (ls *Layers) UnmarshalJSON(data []byte) error { + return json.Unmarshal(data, &ls.Layers) +} + +type Layer struct { + Listener bool + Prot string + Params map[string]string + wrap func(net.Conn) (net.Conn, error) +} + +func (l Layer) Wrap(conn net.Conn) (net.Conn, error) { + if l.wrap == nil { + return conn, nil + } + return l.wrap(conn) +} + +func (l Layer) String() string { + pairs := make([]string, 0, len(l.Params)) + for k, v := range l.Params { + pairs = append(pairs, k+"="+v) + } + if len(pairs) > 0 { + return fmt.Sprintf("%s[%s]", l.Prot, strings.Join(pairs, ",")) + } + return l.Prot +} + +func (l Layer) MarshalText() ([]byte, error) { + return []byte(l.String()), nil +} + +func (l *Layer) UnmarshalText(text []byte) error { + str := string(text) + + l.Prot = strings.ToLower(strings.TrimSpace(str)) + l.Params = map[string]string{} + if idx := strings.Index(str, "["); idx != -1 { + if !strings.HasSuffix(str, "]") { + return fmt.Errorf("uri: missing ']' in layer %q", str) + } + l.Prot = strings.ToLower(strings.TrimSpace(str[:idx])) + for pair := range strings.SplitSeq(str[idx+1:len(str)-1], ",") { + kv := strings.SplitN(pair, "=", 2) + if len(kv) != 2 { + return fmt.Errorf("uri: invalid parameter %q", pair) + } + key := strings.ToLower(strings.TrimSpace(kv[0])) + value := strings.TrimSpace(kv[1]) + if key == "" { + return fmt.Errorf("uri: empty parameter key") + } + l.Params[key] = value + } + } + + switch l.Prot { + case "framed": + opts := []netx.FramedConnOption{} + for key, value := range l.Params { + switch key { + case "maxsize": + maxSize, err := strconv.ParseUint(value, 10, 31) + if err != nil { + return fmt.Errorf("uri: invalid framed maxsize parameter %q: %w", value, err) + } + opts = append(opts, netx.WithMaxFrameSize(uint32(maxSize))) + default: + return fmt.Errorf("uri: unknown framed parameter %q", key) + } + } + l.wrap = func(c net.Conn) (net.Conn, error) { + return netx.NewFramedConn(c, opts...), nil + } + case "buffered": + opts := []netx.BufConnOption{} + for key, value := range l.Params { + switch key { + case "size": + size, err := strconv.ParseUint(value, 10, 31) + if err != nil { + return fmt.Errorf("uri: invalid buffered size parameter %q: %w", value, err) + } + opts = append(opts, netx.WithBufSize(uint32(size))) + default: + return fmt.Errorf("uri: unknown buffered parameter %q", key) + } + } + l.wrap = func(c net.Conn) (net.Conn, error) { + return netx.NewBufConn(c, opts...), nil + } + case "aesgcm": + aeskey := []byte{} + opts := []netx.AESGCMOption{} + for key, value := range l.Params { + switch key { + case "key": + var err error + aeskey, err = hex.DecodeString(value) + if err != nil { + return fmt.Errorf("uri: invalid aesgcm key parameter: %w", err) + } + if len(aeskey) != 16 && len(aeskey) != 24 && len(aeskey) != 32 { + return fmt.Errorf("uri: invalid aesgcm key size %d", len(aeskey)) + } + case "maxpacket": + maxPacket, err := strconv.ParseUint(value, 10, 31) + if err != nil { + return fmt.Errorf("uri: invalid aesgcm maxpacket parameter %q: %w", value, err) + } + opts = append(opts, netx.WithAESGCMMaxPacket(uint32(maxPacket))) + default: + return fmt.Errorf("uri: unknown aesgcm parameter %q", key) + } + } + if len(aeskey) == 0 { + return fmt.Errorf("uri: missing aesgcm key parameter") + } + l.wrap = func(c net.Conn) (net.Conn, error) { + return netx.NewAESGCMConn(c, aeskey, opts...) + } + case "ssh": + var pass string + var sshkey ssh.Signer // Host key for server, private key for client + var pubkey ssh.PublicKey + for key, value := range l.Params { + switch key { + case "pass": + pass = value + case "key": + pemkey, err := hex.DecodeString(value) + if err != nil { + return fmt.Errorf("uri: invalid ssh key parameter: %w", err) + } + sshkey, err = ssh.ParsePrivateKey(pemkey) + if err != nil { + return fmt.Errorf("uri: invalid ssh private key: %w", err) + } + case "pubkey": + azkey, err := hex.DecodeString(value) + if err != nil { + return fmt.Errorf("uri: invalid ssh pubkey parameter: %w", err) + } + pubkey, _, _, _, err = ssh.ParseAuthorizedKey(azkey) + if err != nil { + return fmt.Errorf("uri: invalid ssh public key: %w", err) + } + default: + return fmt.Errorf("uri: unknown ssh parameter %q", key) + } + } + if l.Listener { + cfg := &ssh.ServerConfig{} + if sshkey == nil { + return fmt.Errorf("uri: ssh server requires key parameter") + } + cfg.AddHostKey(sshkey) + if pubkey != nil { + cfg.PublicKeyCallback = func(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) { + if bytes.Equal(key.Marshal(), pubkey.Marshal()) { + return nil, nil + } + return nil, fmt.Errorf("uri: ssh public key mismatch") + } + } + if pass != "" { + cfg.PasswordCallback = func(conn ssh.ConnMetadata, password []byte) (*ssh.Permissions, error) { + if pass == string(password) { + return nil, nil + } + return nil, fmt.Errorf("uri: ssh password mismatch") + } + } + if cfg.PublicKeyCallback == nil && cfg.PasswordCallback == nil { + return fmt.Errorf("uri: ssh server requires pubkey or pass parameter") + } + l.wrap = func(c net.Conn) (net.Conn, error) { + return netx.NewSSHServerConn(c, cfg) + } + } else { + cfg := &ssh.ClientConfig{} + if pubkey == nil { + return fmt.Errorf("uri: ssh client requires pubkey parameter") + } + cfg.HostKeyCallback = func(hostname string, remote net.Addr, key ssh.PublicKey) error { + if bytes.Equal(key.Marshal(), pubkey.Marshal()) { + return nil + } + return fmt.Errorf("uri: ssh host key mismatch") + } + if sshkey != nil { + cfg.Auth = append(cfg.Auth, ssh.PublicKeys(sshkey)) + } + if pass != "" { + cfg.Auth = append(cfg.Auth, ssh.Password(pass)) + } + if len(cfg.Auth) == 0 { + return fmt.Errorf("uri: ssh client requires key or pass parameter") + } + l.wrap = func(c net.Conn) (net.Conn, error) { + return netx.NewSSHClientConn(c, cfg) + } + } + case "tls": + var certKey, cert []byte + cfg := &tls.Config{ + MinVersion: tls.VersionTLS13, + MaxVersion: tls.VersionTLS13, + } + for key, value := range l.Params { + switch key { + case "key": + var err error + certKey, err = hex.DecodeString(value) + if err != nil { + return fmt.Errorf("uri: invalid tls key parameter: %w", err) + } + case "cert": + var err error + cert, err = hex.DecodeString(value) + if err != nil { + return fmt.Errorf("uri: invalid tls cert parameter: %w", err) + } + case "servername": + cfg.ServerName = value + default: + return fmt.Errorf("uri: unknown tls parameter %q", key) + } + } + if l.Listener { + if cert == nil || certKey == nil { + return fmt.Errorf("uri: tls server requires cert and key parameters") + } + certificate, err := tls.X509KeyPair(cert, certKey) + if err != nil { + return fmt.Errorf("uri: invalid tls certificate: %w", err) + } + cfg.Certificates = []tls.Certificate{certificate} + l.wrap = func(c net.Conn) (net.Conn, error) { + return tls.Server(c, cfg), nil + } + } else { + if certKey != nil { + return fmt.Errorf("uri: tls client does not support key parameter") + } + if cert != nil { + var err error + cfg.InsecureSkipVerify = true + cfg.VerifyPeerCertificate, err = spkiVerifier(cert) + if err != nil { + return fmt.Errorf("uri: invalid tls cert parameter: %w", err) + } + } + if cfg.ServerName == "" && cert == nil { + return fmt.Errorf("uri: tls client requires servername or cert parameter") + } + l.wrap = func(c net.Conn) (net.Conn, error) { + return tls.Client(c, cfg), nil + } + } + case "utls": + if l.Listener { + return errors.New("uri: utls is exclusive to clients, use tls for servers instead") + } + var cert []byte + cfg := &utls.Config{ + MinVersion: tls.VersionTLS13, + MaxVersion: tls.VersionTLS13, + } + id := utls.HelloChrome_Auto + for key, value := range l.Params { + switch key { + case "cert": + var err error + cert, err = hex.DecodeString(value) + if err != nil { + return fmt.Errorf("uri: invalid utls cert parameter: %w", err) + } + case "servername": + cfg.ServerName = value + case "hello": + switch strings.ToLower(value) { + case "chrome": + id = utls.HelloChrome_Auto + case "firefox": + id = utls.HelloFirefox_Auto + case "ios": + id = utls.HelloIOS_Auto + case "android": + id = utls.HelloAndroid_11_OkHttp + case "safari": + id = utls.HelloSafari_Auto + case "edge": + id = utls.HelloEdge_Auto + case "randomized": + id = utls.HelloRandomizedALPN + case "randomizednoalpn": + id = utls.HelloRandomized + default: + return fmt.Errorf("unknown utls hello profile %q", value) + } + default: + return fmt.Errorf("uri: unknown utls parameter %q", key) + } + } + if cert != nil { + var err error + cfg.InsecureSkipVerify = true + cfg.VerifyPeerCertificate, err = spkiVerifier(cert) + if err != nil { + return fmt.Errorf("uri: invalid utls cert parameter: %w", err) + } + } + if cfg.ServerName == "" && cert == nil { + return fmt.Errorf("uri: utls client requires servername or cert parameter") + } + l.wrap = func(c net.Conn) (net.Conn, error) { + uc := utls.UClient(c, cfg, id) + return uc, uc.Handshake() + } + case "dtls": + var certKey, cert []byte + cfg := &dtls.Config{} + for key, value := range l.Params { + switch key { + case "key": + var err error + certKey, err = hex.DecodeString(value) + if err != nil { + return fmt.Errorf("uri: invalid dtls key parameter: %w", err) + } + case "cert": + var err error + cert, err = hex.DecodeString(value) + if err != nil { + return fmt.Errorf("uri: invalid dtls cert parameter: %w", err) + } + case "servername": + cfg.ServerName = value + default: + return fmt.Errorf("uri: unknown dtls parameter %q", key) + } + } + if l.Listener { + if cert == nil || certKey == nil { + return fmt.Errorf("uri: dtls server requires cert and key parameters") + } + certificate, err := tls.X509KeyPair(cert, certKey) + if err != nil { + return fmt.Errorf("uri: invalid dtls certificate: %w", err) + } + cfg.Certificates = []tls.Certificate{certificate} + l.wrap = func(c net.Conn) (net.Conn, error) { + return dtls.Server(dtlsnet.PacketConnFromConn(c), c.RemoteAddr(), cfg) + } + } else { + if certKey != nil { + return fmt.Errorf("uri: dtls client does not support key parameter") + } + if cert != nil { + var err error + cfg.InsecureSkipVerify = true + cfg.VerifyPeerCertificate, err = spkiVerifier(cert) + if err != nil { + return fmt.Errorf("uri: invalid dtls cert parameter: %w", err) + } + } + if cfg.ServerName == "" && cert == nil { + return fmt.Errorf("uri: dtls client requires servername or cert parameter") + } + l.wrap = func(c net.Conn) (net.Conn, error) { + return dtls.Client(dtlsnet.PacketConnFromConn(c), c.RemoteAddr(), cfg) + } + } + case "tlspsk": + var identity string + var psk []byte + for key, value := range l.Params { + switch key { + case "key": + var err error + psk, err = hex.DecodeString(value) + if err != nil { + return fmt.Errorf("uri: invalid tlspsk key parameter: %w", err) + } + case "identity": + identity = value + default: + return fmt.Errorf("uri: unknown tlspsk parameter %q", key) + } + } + if len(psk) == 0 { + return fmt.Errorf("uri: missing tlspsk key parameter") + } + if !l.Listener && identity == "" { + return fmt.Errorf("uri: tlspsk client requires identity parameter") + } + cfg := &tlswithpks.Config{ + MinVersion: tls.VersionTLS12, + MaxVersion: tls.VersionTLS12, + Extra: tlspks.PSKConfig{ + GetIdentity: func() string { return identity }, + GetKey: func(identity string) ([]byte, error) { return psk, nil }, + }, + CipherSuites: []uint16{tlspks.TLS_PSK_WITH_AES_256_CBC_SHA}, + InsecureSkipVerify: true, + } + if l.Listener { + // Provide dummy Certificates to make tlspsk happy on server side + cfg.Certificates = dummyCert() + l.wrap = func(c net.Conn) (net.Conn, error) { + return tlswithpks.Server(c, cfg), nil + } + } else { + l.wrap = func(c net.Conn) (net.Conn, error) { + return tlswithpks.Client(c, cfg), nil + } + } + case "dtlspsk": + var identity string + var psk []byte + for key, value := range l.Params { + switch key { + case "key": + var err error + psk, err = hex.DecodeString(value) + if err != nil { + return fmt.Errorf("uri: invalid dtlspsk key parameter: %w", err) + } + case "identity": + identity = value + default: + return fmt.Errorf("uri: unknown dtlspsk parameter %q", key) + } + } + if len(psk) == 0 { + return fmt.Errorf("uri: missing dtlspsk key parameter") + } + if !l.Listener && identity == "" { + return fmt.Errorf("uri: dtlspsk client requires identity parameter") + } + cfg := &dtls.Config{ + PSK: func(hint []byte) ([]byte, error) { + return psk, nil + }, + PSKIdentityHint: []byte(identity), + CipherSuites: []dtls.CipherSuiteID{dtls.TLS_PSK_WITH_AES_128_GCM_SHA256}, + InsecureSkipVerify: true, + } + if l.Listener { + l.wrap = func(c net.Conn) (net.Conn, error) { + return dtls.Server(dtlsnet.PacketConnFromConn(c), c.RemoteAddr(), cfg) + } + } else { + l.wrap = func(c net.Conn) (net.Conn, error) { + return dtls.Client(dtlsnet.PacketConnFromConn(c), c.RemoteAddr(), cfg) + } + } + } + return nil +} + +func spkiVerifier(certPEM []byte) (func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error, error) { + block, _ := pem.Decode(certPEM) + if block == nil || block.Type != "CERTIFICATE" { + return nil, fmt.Errorf("uri: invalid PEM certificate") + } + cert, err := x509.ParseCertificate(block.Bytes) + if err != nil { + return nil, fmt.Errorf("uri: parse x509 certificate: %w", err) + } + spkiHash := sha256.New().Sum(cert.RawSubjectPublicKeyInfo) + return func(rawCerts [][]byte, _ [][]*x509.Certificate) error { + for _, rawCert := range rawCerts { + c, err := x509.ParseCertificate(rawCert) + if err != nil { + return fmt.Errorf("parse peer cert: %w", err) + } + if bytes.Equal(sha256.New().Sum(c.RawSubjectPublicKeyInfo), spkiHash) { + return nil + } + } + return fmt.Errorf("no matching SPKI found") + }, nil +} + +// dummyCert returns a self-signed certificate for use in tls-psk server mode. (ed25519) +func dummyCert() []tlswithpks.Certificate { + // Generated with: + // openssl req -x509 -newkey ed25519 -keyout key.pem -out cert.pem -days 100000 -nodes -subj "/CN=dummy" + certPEM := `-----BEGIN CERTIFICATE----- +MIIBNjCB6aADAgECAhRX020iAjrT4wTjwRdAJ+PPjpe33DAFBgMrZXAwEDEOMAwG +A1UEAwwFZHVtbXkwIBcNMjUwOTIxMTUxNzMwWhgPMjI5OTA3MDcxNTE3MzBaMBAx +DjAMBgNVBAMMBWR1bW15MCowBQYDK2VwAyEA/8RGhnpLT8uPAm8Ah0vEYWCskGrk +R3lqdOjspIidVmKjUzBRMB0GA1UdDgQWBBRMUX8P7I1KV1UxMjcJlIT42a72ozAf +BgNVHSMEGDAWgBRMUX8P7I1KV1UxMjcJlIT42a72ozAPBgNVHRMBAf8EBTADAQH/ +MAUGAytlcANBAEFf17f1XhfLek4D203mGz8BihBfXfeL6kADMMV+G2qpkqZPcnTI +NXPuT9B/6+hM7nD/vh7JKXTfSAEFo22rzwA= +-----END CERTIFICATE----- +` + keyPEM := `-----BEGIN PRIVATE KEY----- +MC4CAQAwBQYDK2VwBCIEIEsb9X3HHGBFSe5jKvqNmua6ZFplNaiBROtJ7ZZAJlRz +-----END PRIVATE KEY----- +` + cert, err := tlswithpks.X509KeyPair([]byte(certPEM), []byte(keyPEM)) + if err != nil { + panic("dummyCert: " + err.Error()) + } + return []tlswithpks.Certificate{cert} +} diff --git a/uri/scheme.go b/uri/scheme.go new file mode 100644 index 0000000..85dfb03 --- /dev/null +++ b/uri/scheme.go @@ -0,0 +1,43 @@ +package uri + +import ( + "fmt" + "strings" +) + +type Scheme struct { + Listener bool + Transport `json:"transport"` + Layers Layers `json:"layers"` +} + +func (s Scheme) String() string { + str := s.Transport.String() + for _, l := range s.Layers.Layers { + str += "+" + l.String() + } + return str +} + +func (s Scheme) MarshalText() ([]byte, error) { + return []byte(s.String()), nil +} + +func (s *Scheme) UnmarshalText(text []byte) error { + parts := strings.SplitN(string(text), "+", 2) + if len(parts) == 0 { + return fmt.Errorf("uri: empty scheme") + } + + if err := s.Transport.UnmarshalText([]byte(parts[0])); err != nil { + return err + } + + if len(parts) == 1 { + return nil + } + + s.Layers.Listener = s.Listener + + return s.Layers.UnmarshalText([]byte(parts[1])) +} diff --git a/uri/transport.go b/uri/transport.go new file mode 100644 index 0000000..a4b889c --- /dev/null +++ b/uri/transport.go @@ -0,0 +1,32 @@ +package uri + +import ( + "fmt" + "strings" +) + +type Transport string + +const ( + TransportTCP Transport = "tcp" + TransportUDP Transport = "udp" +) + +func (t Transport) String() string { + return string(t) +} + +func (t Transport) MarshalText() ([]byte, error) { + return []byte(t.String()), nil +} + +func (t *Transport) UnmarshalText(text []byte) error { + str := strings.ToLower(strings.TrimSpace(string(text))) + switch Transport(str) { + case TransportTCP, TransportUDP: + *t = Transport(str) + return nil + default: + return fmt.Errorf("uri: unknown transport %q", str) + } +} diff --git a/uri/uri.go b/uri/uri.go new file mode 100644 index 0000000..3b54992 --- /dev/null +++ b/uri/uri.go @@ -0,0 +1,39 @@ +package uri + +import ( + "fmt" + "strings" +) + +type URI struct { + // This flag must be set if the URI is being applied to a listener (server side) + // The parser takes this into account when validating parameters + Listener bool + Scheme `json:"scheme"` + Addr string `json:"addr"` +} + +func (u URI) String() string { + return u.Scheme.String() + "://" + u.Addr +} + +func (u URI) MarshalText() ([]byte, error) { + return []byte(u.String()), nil +} + +func (u *URI) UnmarshalText(text []byte) error { + str := string(text) + parts := strings.SplitN(str, "://", 2) + if len(parts) < 2 { + return fmt.Errorf("uri: missing scheme delimiter in %q", str) + } + + u.Addr = strings.TrimSpace(parts[1]) + if u.Addr == "" { + return fmt.Errorf("uri: empty address in %q", str) + } + + u.Scheme.Listener = u.Listener + + return u.Scheme.UnmarshalText([]byte(parts[0])) +}