/
upgrade.go
110 lines (94 loc) · 2.88 KB
/
upgrade.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
package spdy
import (
"bufio"
"fmt"
"io"
"net"
"net/http"
"strings"
"sync/atomic"
)
type connWrapper struct {
net.Conn
closed int32
bufReader *bufio.Reader
}
func (w *connWrapper) Read(b []byte) (n int, err error) {
if atomic.LoadInt32(&w.closed) == 1 {
return 0, io.EOF
}
return w.bufReader.Read(b)
}
func (w *connWrapper) Close() error {
err := w.Conn.Close()
atomic.StoreInt32(&w.closed, 1)
return err
}
// UpgradeResponse upgrades an HTTP response to a SPDY multi-stream connection
func UpgradeResponse(w http.ResponseWriter, req *http.Request) (Connection, error) {
if !isValidUpgradeHeader(req.Header) {
return nil, fmt.Errorf("missing upgrade headers in request: %#v", req.Header)
}
hijacker, ok := w.(http.Hijacker)
if !ok {
return nil, fmt.Errorf("unable to hijack response")
}
w.Header().Add(headerConnection, headerUpgrade)
w.Header().Add(headerUpgrade, headerSpdy31)
w.WriteHeader(http.StatusSwitchingProtocols)
conn, bufrw, err := hijacker.Hijack()
if err != nil {
return nil, fmt.Errorf("error hijacking response: %v", err)
}
connWithBuf := &connWrapper{
Conn: conn,
bufReader: bufrw.Reader,
}
spdyConn, err := Server(connWithBuf)
if err != nil {
return nil, fmt.Errorf("error creating SPDY server connection: %v", err)
}
return spdyConn, nil
}
func negotiateProtocol(clientProtocols, serverProtocols []string) string {
for i := range clientProtocols {
for j := range serverProtocols {
if clientProtocols[i] == serverProtocols[j] {
return clientProtocols[i]
}
}
}
return ""
}
func commaSeparatedHeaderValues(header []string) []string {
var parsedClientProtocols []string
for i := range header {
for _, clientProtocol := range strings.Split(header[i], ",") {
if proto := strings.Trim(clientProtocol, " "); len(proto) > 0 {
parsedClientProtocols = append(parsedClientProtocols, proto)
}
}
}
return parsedClientProtocols
}
// Handshake performs a subprotocol negotiation.
func Handshake(req *http.Request, w http.ResponseWriter, serverProtocols []string) (string, error) {
clientProtocols := commaSeparatedHeaderValues(req.Header[http.CanonicalHeaderKey(headerProtocolVersion)])
if len(clientProtocols) == 0 {
return "", fmt.Errorf("unable to upgrade: %s is required", headerProtocolVersion)
}
if len(serverProtocols) == 0 {
panic(fmt.Errorf("unable to upgrade: serverProtocols is required"))
}
negotiatedProtocol := negotiateProtocol(clientProtocols, serverProtocols)
if len(negotiatedProtocol) == 0 {
for i := range serverProtocols {
w.Header().Add(headerAcceptedProtocolVersions, serverProtocols[i])
}
err := fmt.Errorf("unable to upgrade: unable to negotiate protocol: client supports %v, server accepts %v", clientProtocols, serverProtocols)
http.Error(w, err.Error(), http.StatusForbidden)
return "", err
}
w.Header().Add(headerProtocolVersion, negotiatedProtocol)
return negotiatedProtocol, nil
}