Skip to content

Commit

Permalink
Merge pull request #1 from idada/master
Browse files Browse the repository at this point in the history
优化了一下代码
  • Loading branch information
tomasen committed Dec 2, 2015
2 parents 7f0462b + 74b63b7 commit 7577022
Showing 1 changed file with 114 additions and 107 deletions.
221 changes: 114 additions & 107 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,21 @@ import (

const (
// max open file should at least be
_MaxOpenfile uint64 = 1024 * 1024 * 1024
_MaxBackendAddrCacheCount int = 1024 * 1024
_DefaultPort string = "4043"
_MTU = 1500
_MaxOpenfile = uint64(1024 * 1024 * 1024)
_MaxBackendAddrCacheCount = 1024 * 1024
_DefaultPort = "4043"
_MTU = 1500
)

var (
_SecretPassphase string
_OpenSSL = openssl.New()
)

var (
_BackendAddrCacheMutex = new(sync.Mutex)
_BackendAddrCacheMutex sync.Mutex
_BackendAddrCache atomic.Value
_BufioReaderPool sync.Pool
)

type backendAddrMap map[string]string
Expand All @@ -38,146 +40,151 @@ func init() {
_BackendAddrCache.Store(make(backendAddrMap))
}

func readBackendAddrCache(key string) (string, bool) {
func decryptBackendAddr(line []byte) (string, error) {
// Try to check cache
m1 := _BackendAddrCache.Load().(backendAddrMap)

val, ok := m1[key]
return val, ok
addr, ok := m1[string(line)]
if ok {
return addr, nil
}
// Try to decrypt it (AES)
plaintext, err := _OpenSSL.DecryptString(_SecretPassphase, string(line))
if err != nil {
return "", err
}
addr = string(plaintext)
cacheBackendAddr(string(line), addr)
return addr, nil
}

func writeBackendAddrCache(key, val string) {
func cacheBackendAddr(key, val string) {
_BackendAddrCacheMutex.Lock()
defer _BackendAddrCacheMutex.Unlock()

m1 := _BackendAddrCache.Load().(backendAddrMap)
m2 := make(backendAddrMap) // create a new value
// double check
if _, ok := m1[key]; ok {
return
}

m2 := make(backendAddrMap)
// flush cache if there is way too many
if len(m1) < _MaxBackendAddrCacheCount {
// copy-on-write
for k, v := range m1 {
m2[k] = v // copy all data from the current object to the new one
}
}

m2[key] = val
_BackendAddrCache.Store(m2) // atomically replace the current object with the new one
}

// pipe upstream and downstream
func pipe(dst io.Writer, src io.Reader) {
defer func() {
if r := recover(); r != nil {
log.Println("Recovered in", r, ":", string(debug.Stack()))
}
}()

_, err := io.Copy(dst, src)
func main() {
runtime.GOMAXPROCS(runtime.NumCPU())
os.Setenv("GOTRACEBACK", "crash")

switch err {
case io.EOF:
err = nil
return
case nil:
return
var lim syscall.Rlimit
syscall.Getrlimit(syscall.RLIMIT_NOFILE, &lim)
if lim.Cur < _MaxOpenfile || lim.Max < _MaxOpenfile {
lim.Cur = _MaxOpenfile
lim.Max = _MaxOpenfile
syscall.Setrlimit(syscall.RLIMIT_NOFILE, &lim)
}
// log.Println("pipe:", n, err)

_SecretPassphase = os.Getenv("SECRET")

ListenAndServe()
}

// TCPServer is handler for all tcp queries
func TCPServer(l net.Listener) {
func ListenAndServe() {
l, err := net.Listen("tcp", ":"+_DefaultPort)
if err != nil {
log.Fatal(err)
}
defer l.Close()
for {
// Wait for a connection.
conn, err := l.Accept()
if err != nil {
log.Fatal(err)
}
// Handle the connection in a new goroutine.
// The loop then returns to accepting, so that
// multiple connections may be served concurrently.
go func(c net.Conn) {
defer func() {
if r := recover(); r != nil {
log.Println("Recovered in", r, ":", string(debug.Stack()))
}
}()
defer c.Close()

// TODO: get rid of bufio.Reader
// TODO: use binary protocol if first byte is 0x00

// Read first line
rdr := bufio.NewReader(c)
line, isPrefix, err := rdr.ReadLine()
if err != nil || isPrefix {
// handle error
log.Println(err)
c.Write([]byte{0x04})
return
}

// Try to check cache
addr, ok := readBackendAddrCache(string(line))
if !ok {
// Try to decrypt it (AES)
o := openssl.New()
plaintext, err := o.DecryptString(string(_SecretPassphase), string(line))
if err != nil {
c.Write([]byte{0x06})
return
}
addr = string(plaintext)
// Write to cache
writeBackendAddrCache(string(line), string(addr))
}
go handleConn(conn)
}
}

// TODO: check if addr is allowed

// Build tunnel
backend, err := net.Dial("tcp", addr)
if err != nil {
// handle error
switch err := err.(type) {
case net.Error:
if err.Timeout() {
c.Write([]byte{0x01})
log.Println(err)
return
}
}
log.Println(err)
c.Write([]byte{0x02})
return
}
defer backend.Close()
func handleConn(c net.Conn) {
defer func() {
c.Close()
if r := recover(); r != nil {
log.Println("Recovered in", r, ":", string(debug.Stack()))
}
}()

// Start transfering data
go pipe(c, backend)
pipe(backend, rdr)
// TODO: get rid of bufio.Reader
// TODO: use binary protocol if first byte is 0x00

}(conn)
// Read first line
rdr, ok := _BufioReaderPool.Get().(*bufio.Reader)
if ok {
rdr.Reset(c)
} else {
rdr = bufio.NewReader(c)
defer _BufioReaderPool.Put(rdr)
}
line, isPrefix, err := rdr.ReadLine()
if err != nil || isPrefix {
log.Println(err)
c.Write([]byte{0x04})
return
}
}

func main() {
runtime.GOMAXPROCS(runtime.NumCPU())
os.Setenv("GOTRACEBACK", "crash")

lim := syscall.Rlimit{}
syscall.Getrlimit(syscall.RLIMIT_NOFILE, &lim)
if lim.Cur < _MaxOpenfile || lim.Max < _MaxOpenfile {
lim.Cur = _MaxOpenfile
lim.Max = _MaxOpenfile
syscall.Setrlimit(syscall.RLIMIT_NOFILE, &lim)
// Try to check cache
addr, err := decryptBackendAddr(line)
if err != nil {
c.Write([]byte{0x06})
return
}

_SecretPassphase = os.Getenv("SECRET")
// TODO: check if addr is allowed

ln, err := net.Listen("tcp", ":"+_DefaultPort)
// Build tunnel
backend, err := net.Dial("tcp", addr)
if err != nil {
log.Fatal(err)
// handle error
switch err := err.(type) {
case net.Error:
if err.Timeout() {
c.Write([]byte{0x01})
log.Println(err)
return
}
}
log.Println(err)
c.Write([]byte{0x02})
return
}
defer backend.Close()

TCPServer(ln)
// Start transfering data
go pipe(c, backend)
pipe(backend, rdr)
}

// pipe upstream and downstream
func pipe(dst io.Writer, src io.Reader) {
defer func() {
if r := recover(); r != nil {
log.Println("Recovered in", r, ":", string(debug.Stack()))
}
}()

_, err := io.Copy(dst, src)

switch err {
case io.EOF:
err = nil
return
case nil:
return
}
// log.Println("pipe:", n, err)
}

0 comments on commit 7577022

Please sign in to comment.