-
Notifications
You must be signed in to change notification settings - Fork 281
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
httputil: add cookie chunker (#3775)
- Loading branch information
1 parent
472370e
commit 457fca0
Showing
2 changed files
with
218 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,122 @@ | ||
package httputil | ||
|
||
import ( | ||
"errors" | ||
"net/http" | ||
"strconv" | ||
"strings" | ||
) | ||
|
||
// ErrCookieTooLarge indicates that a cookie is too large. | ||
var ErrCookieTooLarge = errors.New("cookie too large") | ||
|
||
const ( | ||
defaultCookieChunkerChunkSize = 3800 | ||
defaultCookieChunkerMaxChunks = 16 | ||
) | ||
|
||
type cookieChunkerConfig struct { | ||
chunkSize int | ||
maxChunks int | ||
} | ||
|
||
// A CookieChunkerOption customizes the cookie chunker. | ||
type CookieChunkerOption func(cfg *cookieChunkerConfig) | ||
|
||
// WithCookieChunkerChunkSize sets the chunk size for the cookie chunker. | ||
func WithCookieChunkerChunkSize(chunkSize int) CookieChunkerOption { | ||
return func(cfg *cookieChunkerConfig) { | ||
cfg.chunkSize = chunkSize | ||
} | ||
} | ||
|
||
// WithCookieChunkerMaxChunks sets the maximum number of chunks for the cookie chunker. | ||
func WithCookieChunkerMaxChunks(maxChunks int) CookieChunkerOption { | ||
return func(cfg *cookieChunkerConfig) { | ||
cfg.maxChunks = maxChunks | ||
} | ||
} | ||
|
||
func getCookieChunkerConfig(options ...CookieChunkerOption) *cookieChunkerConfig { | ||
cfg := new(cookieChunkerConfig) | ||
WithCookieChunkerChunkSize(defaultCookieChunkerChunkSize)(cfg) | ||
WithCookieChunkerMaxChunks(defaultCookieChunkerMaxChunks)(cfg) | ||
for _, option := range options { | ||
option(cfg) | ||
} | ||
return cfg | ||
} | ||
|
||
// A CookieChunker breaks up a large cookie into multiple pieces. | ||
type CookieChunker struct { | ||
cfg *cookieChunkerConfig | ||
} | ||
|
||
// NewCookieChunker creates a new CookieChunker. | ||
func NewCookieChunker(options ...CookieChunkerOption) *CookieChunker { | ||
return &CookieChunker{ | ||
cfg: getCookieChunkerConfig(options...), | ||
} | ||
} | ||
|
||
// SetCookie sets a chunked cookie. | ||
func (cc *CookieChunker) SetCookie(w http.ResponseWriter, cookie *http.Cookie) error { | ||
chunks := chunk(cookie.Value, cc.cfg.chunkSize) | ||
if len(chunks) > cc.cfg.maxChunks { | ||
return ErrCookieTooLarge | ||
} | ||
|
||
sizeCookie := *cookie | ||
sizeCookie.Value = strconv.Itoa(len(chunks)) | ||
http.SetCookie(w, &sizeCookie) | ||
for i, chunk := range chunks { | ||
chunkCookie := *cookie | ||
chunkCookie.Name += strconv.Itoa(i) | ||
chunkCookie.Value = chunk | ||
http.SetCookie(w, &chunkCookie) | ||
} | ||
return nil | ||
} | ||
|
||
// LoadCookie loads a chunked cookie. | ||
func (cc *CookieChunker) LoadCookie(r *http.Request, name string) (*http.Cookie, error) { | ||
sizeCookie, err := r.Cookie(name) | ||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
size, err := strconv.Atoi(sizeCookie.Value) | ||
if err != nil { | ||
return nil, err | ||
} | ||
if size > cc.cfg.maxChunks { | ||
return nil, ErrCookieTooLarge | ||
} | ||
|
||
var b strings.Builder | ||
for i := 0; i < size; i++ { | ||
chunkCookie, err := r.Cookie(name + strconv.Itoa(i)) | ||
if err != nil { | ||
return nil, err | ||
} | ||
_, err = b.WriteString(chunkCookie.Value) | ||
if err != nil { | ||
return nil, err | ||
} | ||
} | ||
|
||
cookie := *sizeCookie | ||
cookie.Value = b.String() | ||
return &cookie, nil | ||
} | ||
|
||
func chunk(s string, size int) []string { | ||
ss := make([]string, 0, len(s)/size+1) | ||
for len(s) > 0 { | ||
if len(s) < size { | ||
size = len(s) | ||
} | ||
ss, s = append(ss, s[:size]), s[size:] | ||
} | ||
return ss | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
package httputil | ||
|
||
import ( | ||
"net/http" | ||
"net/http/cookiejar" | ||
"net/http/httptest" | ||
"strings" | ||
"testing" | ||
|
||
"github.com/stretchr/testify/assert" | ||
"github.com/stretchr/testify/require" | ||
) | ||
|
||
func TestCookieChunker(t *testing.T) { | ||
t.Parallel() | ||
|
||
t.Run("chunk", func(t *testing.T) { | ||
t.Parallel() | ||
|
||
cc := NewCookieChunker(WithCookieChunkerChunkSize(16)) | ||
srv1 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||
assert.NoError(t, cc.SetCookie(w, &http.Cookie{ | ||
Name: "example", | ||
Value: strings.Repeat("x", 77), | ||
})) | ||
})) | ||
defer srv1.Close() | ||
srv2 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||
cookie, err := cc.LoadCookie(r, "example") | ||
if assert.NoError(t, err) { | ||
assert.Equal(t, &http.Cookie{ | ||
Name: "example", | ||
Value: strings.Repeat("x", 77), | ||
}, cookie) | ||
} | ||
})) | ||
defer srv2.Close() | ||
|
||
jar, err := cookiejar.New(&cookiejar.Options{}) | ||
client := &http.Client{Jar: jar} | ||
require.NoError(t, err) | ||
res, err := client.Get(srv1.URL) | ||
if assert.NoError(t, err) { | ||
assert.Equal(t, []string{ | ||
"example=5", | ||
"example0=xxxxxxxxxxxxxxxx", | ||
"example1=xxxxxxxxxxxxxxxx", | ||
"example2=xxxxxxxxxxxxxxxx", | ||
"example3=xxxxxxxxxxxxxxxx", | ||
"example4=xxxxxxxxxxxxx", | ||
}, res.Header.Values("Set-Cookie")) | ||
} | ||
client.Get(srv2.URL) | ||
}) | ||
|
||
t.Run("set max error", func(t *testing.T) { | ||
t.Parallel() | ||
|
||
cc := NewCookieChunker(WithCookieChunkerChunkSize(2), WithCookieChunkerMaxChunks(2)) | ||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||
assert.Error(t, cc.SetCookie(w, &http.Cookie{ | ||
Name: "example", | ||
Value: strings.Repeat("x", 1024), | ||
})) | ||
})) | ||
defer srv.Close() | ||
http.Get(srv.URL) | ||
}) | ||
|
||
t.Run("load max error", func(t *testing.T) { | ||
t.Parallel() | ||
|
||
cc1 := NewCookieChunker(WithCookieChunkerChunkSize(64)) | ||
srv1 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||
assert.NoError(t, cc1.SetCookie(w, &http.Cookie{ | ||
Name: "example", | ||
Value: strings.Repeat("x", 1024), | ||
})) | ||
})) | ||
defer srv1.Close() | ||
|
||
cc2 := NewCookieChunker(WithCookieChunkerChunkSize(64), WithCookieChunkerMaxChunks(2)) | ||
srv2 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||
cookie, err := cc2.LoadCookie(r, "example") | ||
assert.Error(t, err) | ||
assert.Nil(t, cookie) | ||
})) | ||
defer srv2.Close() | ||
|
||
jar, err := cookiejar.New(&cookiejar.Options{}) | ||
require.NoError(t, err) | ||
client := &http.Client{Jar: jar} | ||
client.Get(srv1.URL) | ||
client.Get(srv2.URL) | ||
}) | ||
} |