Skip to content

Commit

Permalink
Refactor AddrSpec and add AddressRewriter interface
Browse files Browse the repository at this point in the history
Changed AddrSpec in request_test.go and request.go to be a pointer rather than a value. Implemented a new interface, AddressRewriter, to make adjustments to the request address prior to initiating a dialing request. An error handling case was additionally added to the Dial function in socks5_test.go. Tests were modified to reflect these changes.
  • Loading branch information
ryanbekhen committed Dec 7, 2023
1 parent 24c9a8f commit 567bc2b
Show file tree
Hide file tree
Showing 4 changed files with 56 additions and 19 deletions.
50 changes: 36 additions & 14 deletions pkg/socks5/request.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,15 @@ import (
"fmt"
"io"
"net"
"strconv"
"strings"
"time"
)

type AddressRewriter interface {
Rewrite(request *Request) *AddrSpec
}

// AddrSpec is a SOCKS5 address specification
type AddrSpec struct {
FQDN string
Expand All @@ -27,13 +32,23 @@ func (a *AddrSpec) String() string {
return fmt.Sprintf("%s:%d", a.IP, a.Port)
}

// Address returns a string suitable to dial; prefer returning IP-based
// address, fallback to FQDN
func (a *AddrSpec) Address() string {
if 0 != len(a.IP) {
return net.JoinHostPort(a.IP.String(), strconv.Itoa(a.Port))
}
return net.JoinHostPort(a.FQDN, strconv.Itoa(a.Port))
}

// Request is a SOCKS5 request message
type Request struct {
Version uint8
Command Command
AuthContext *AuthContext
RemoteAddr *AddrSpec
DestAddr *AddrSpec
realAddr *AddrSpec
BufferConn io.Reader
Latency time.Duration
}
Expand Down Expand Up @@ -64,45 +79,52 @@ func NewRequest(bufferConn io.Reader) (*Request, error) {
}

func readAddressSpec(r io.Reader) (*AddrSpec, error) {
d := &AddrSpec{}

addrType := []byte{0}
if _, err := io.ReadAtLeast(r, addrType, 1); err != nil {
if _, err := r.Read(addrType); err != nil {
return nil, err
}

var addr AddrSpec
// Handle on a per-type basis
switch AddrType(addrType[0]) {
case AddressTypeIPv4:
addr.IP = make(net.IP, net.IPv4len)
if _, err := io.ReadAtLeast(r, addr.IP, net.IPv4len); err != nil {
addr := make([]byte, 4)
if _, err := io.ReadAtLeast(r, addr, len(addr)); err != nil {
return nil, err
}
d.IP = addr

case AddressTypeIPv6:
addr.IP = make(net.IP, net.IPv6len)
if _, err := io.ReadAtLeast(r, addr.IP, net.IPv6len); err != nil {
addr := make([]byte, 16)
if _, err := io.ReadAtLeast(r, addr, len(addr)); err != nil {
return nil, err
}
d.IP = addr

case AddressTypeDomain:
domainLen := []byte{0}
if _, err := io.ReadAtLeast(r, domainLen, 1); err != nil {
if _, err := r.Read(addrType); err != nil {
return nil, err
}

domain := make([]byte, domainLen[0])
if _, err := io.ReadAtLeast(r, domain, int(domainLen[0])); err != nil {
addrLen := int(addrType[0])
fqdn := make([]byte, addrLen)
if _, err := io.ReadAtLeast(r, fqdn, addrLen); err != nil {
return nil, err
}
addr.FQDN = string(domain)
d.FQDN = string(fqdn)

default:
return nil, fmt.Errorf("unrecognized address type: %d", addrType[0])
}

// Read the port
port := []byte{0, 0}
if _, err := io.ReadAtLeast(r, port, 2); err != nil {
return nil, err
}
addr.Port = int(port[0])<<8 | int(port[1])
d.Port = (int(port[0]) << 8) | int(port[1])

return &addr, nil
return d, nil
}

func sendReply(conn io.Writer, reply uint8, addr *AddrSpec) error {
Expand Down
6 changes: 3 additions & 3 deletions pkg/socks5/request_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,16 +48,16 @@ func (m *MockConn) SetWriteDeadline(t time.Time) error {
}

func Test_AddrSpec_String(t *testing.T) {
var a AddrSpec
a = AddrSpec{
var a *AddrSpec
a = &AddrSpec{
FQDN: "www.google.com",
IP: net.ParseIP("192.168.1.1"),
Port: 8080,
}

assert.Equal(t, "www.google.com (192.168.1.1):8080", a.String())

a = AddrSpec{
a = &AddrSpec{
FQDN: "",
IP: net.ParseIP("192.168.1.1"),
Port: 8080,
Expand Down
8 changes: 7 additions & 1 deletion pkg/socks5/socks5.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ type Config struct {
Dial func(network, addr string) (net.Conn, error)
AfterRequest func(req *Request, conn net.Conn)
Resolver Resolver
Rewriter AddressRewriter
}

type Server struct {
Expand Down Expand Up @@ -197,6 +198,11 @@ func (s *Server) handleRequest(req *Request, conn net.Conn) error {
dest.IP = addr
}

req.realAddr = req.DestAddr
if s.config.Rewriter != nil {
req.realAddr = s.config.Rewriter.Rewrite(req)
}

switch req.Command {
case CommandConnect:
return s.handleConnect(conn, req)
Expand All @@ -222,7 +228,7 @@ func (s *Server) handleConnect(conn net.Conn, req *Request) error {
}

processStartTimestamp := time.Now()
dest, err := dial("tcp", req.DestAddr.String())
dest, err := dial("tcp", req.realAddr.Address())
req.Latency = time.Since(processStartTimestamp)

if err != nil {
Expand Down
11 changes: 10 additions & 1 deletion pkg/socks5/socks5_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"bytes"
"encoding/binary"
"errors"
"fmt"
"github.com/stretchr/testify/assert"
"io"
"net"
Expand Down Expand Up @@ -257,7 +258,12 @@ func TestRequest_Unreachable(t *testing.T) {
Dial: func(network, addr string) (net.Conn, error) {
// timeout
timeout := time.Duration(1) * time.Millisecond
return net.DialTimeout(network, addr, timeout)
conn, err := net.DialTimeout(network, addr, timeout)
if err != nil {
// Handle error here. For example, return an error or print a log.
return nil, fmt.Errorf("failed to dial: %w", err)
}
return conn, nil
},
},
}
Expand All @@ -279,6 +285,7 @@ func TestRequest_Unreachable(t *testing.T) {
req, err := NewRequest(buf)
assert.NoError(t, err)

req.realAddr = req.DestAddr
err = s.handleConnect(resp, req)
assert.Error(t, err)

Expand Down Expand Up @@ -317,6 +324,7 @@ func TestRequest_Refused(t *testing.T) {
req, err := NewRequest(buf)
assert.NoError(t, err)

req.realAddr = req.DestAddr
err = s.handleConnect(resp, req)
assert.Error(t, err)

Expand Down Expand Up @@ -359,6 +367,7 @@ func TestRequest_NetworkUnreachable(t *testing.T) {
req, err := NewRequest(buf)
assert.NoError(t, err)

req.realAddr = req.DestAddr
err = s.handleConnect(resp, req)
assert.Error(t, err)

Expand Down

0 comments on commit 567bc2b

Please sign in to comment.