-
Notifications
You must be signed in to change notification settings - Fork 0
/
conn.go
145 lines (122 loc) · 2.87 KB
/
conn.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
package http
import (
"bufio"
"bytes"
"crypto/sha1"
"encoding/base64"
"errors"
"io"
"net"
"net/http"
"net/http/httputil"
"sync"
"time"
"github.com/go-gost/core/logger"
)
type obfsHTTPConn struct {
net.Conn
rbuf bytes.Buffer
wbuf bytes.Buffer
handshaked bool
handshakeMutex sync.Mutex
header http.Header
logger logger.Logger
}
func (c *obfsHTTPConn) Handshake() (err error) {
c.handshakeMutex.Lock()
defer c.handshakeMutex.Unlock()
if c.handshaked {
return nil
}
if err = c.handshake(); err != nil {
return
}
c.handshaked = true
return nil
}
func (c *obfsHTTPConn) handshake() (err error) {
br := bufio.NewReader(c.Conn)
r, err := http.ReadRequest(br)
if err != nil {
return
}
if c.logger.IsLevelEnabled(logger.TraceLevel) {
dump, _ := httputil.DumpRequest(r, false)
c.logger.Trace(string(dump))
}
if r.ContentLength > 0 {
_, err = io.Copy(&c.rbuf, r.Body)
} else {
var b []byte
b, err = br.Peek(br.Buffered())
if len(b) > 0 {
_, err = c.rbuf.Write(b)
}
}
if err != nil {
c.logger.Error(err)
return
}
resp := http.Response{
StatusCode: http.StatusOK,
ProtoMajor: 1,
ProtoMinor: 1,
Header: c.header,
}
if resp.Header == nil {
resp.Header = http.Header{}
}
resp.Header.Set("Date", time.Now().Format(time.RFC1123))
if r.Method != http.MethodGet || r.Header.Get("Upgrade") != "websocket" {
resp.StatusCode = http.StatusBadRequest
if c.logger.IsLevelEnabled(logger.TraceLevel) {
dump, _ := httputil.DumpResponse(&resp, false)
c.logger.Trace(string(dump))
}
resp.Write(c.Conn)
return errors.New("bad request")
}
resp.StatusCode = http.StatusSwitchingProtocols
resp.Header.Set("Connection", "Upgrade")
resp.Header.Set("Upgrade", "websocket")
resp.Header.Set("Sec-WebSocket-Accept", c.computeAcceptKey(r.Header.Get("Sec-WebSocket-Key")))
if c.logger.IsLevelEnabled(logger.TraceLevel) {
dump, _ := httputil.DumpResponse(&resp, false)
c.logger.Trace(string(dump))
}
if c.rbuf.Len() > 0 {
// cache the response header if there are extra data in the request body.
resp.Write(&c.wbuf)
return
}
err = resp.Write(c.Conn)
return
}
func (c *obfsHTTPConn) Read(b []byte) (n int, err error) {
if err = c.Handshake(); err != nil {
return
}
if c.rbuf.Len() > 0 {
return c.rbuf.Read(b)
}
return c.Conn.Read(b)
}
func (c *obfsHTTPConn) Write(b []byte) (n int, err error) {
if err = c.Handshake(); err != nil {
return
}
if c.wbuf.Len() > 0 {
c.wbuf.Write(b) // append the data to the cached header
_, err = c.wbuf.WriteTo(c.Conn)
n = len(b) // exclude the header length
return
}
return c.Conn.Write(b)
}
var keyGUID = []byte("258EAFA5-E914-47DA-95CA-C5AB0DC85B11")
func (c *obfsHTTPConn) computeAcceptKey(challengeKey string) string {
h := sha1.New()
h.Write([]byte(challengeKey))
h.Write(keyGUID)
return base64.StdEncoding.EncodeToString(h.Sum(nil))
}