Skip to content

Commit

Permalink
Fix potential memory leak problem for gorilla. gorilla/context#32
Browse files Browse the repository at this point in the history
  • Loading branch information
jinzhu committed Jun 29, 2017
1 parent a4e94c3 commit 95a405f
Showing 1 changed file with 17 additions and 8 deletions.
25 changes: 17 additions & 8 deletions gorilla/gorilla.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,20 @@ type Gorilla struct {
Store sessions.Store
}

var writer utils.ContextKey = "writer"
var (
writer utils.ContextKey = "gorilla_writer"
reader utils.ContextKey = "gorilla_reader"
)

func (gorilla Gorilla) getSession(req *http.Request) (*sessions.Session, error) {
if r, ok := req.Context().Value(reader).(*http.Request); ok {
return gorilla.Store.Get(r, gorilla.SessionName)
}
return gorilla.Store.Get(req, gorilla.SessionName)
}

func (gorilla Gorilla) saveSession(req *http.Request) {
if session, err := gorilla.Store.Get(req, gorilla.SessionName); err == nil {
if session, err := gorilla.getSession(req); err == nil {
if w, ok := req.Context().Value(writer).(http.ResponseWriter); ok {
session.Save(req, w)
}
Expand All @@ -37,8 +47,7 @@ func (gorilla Gorilla) saveSession(req *http.Request) {
func (gorilla Gorilla) Add(req *http.Request, key string, value interface{}) error {
defer gorilla.saveSession(req)

session, err := gorilla.Store.Get(req, gorilla.SessionName)

session, err := gorilla.getSession(req)
if err != nil {
return err
}
Expand All @@ -57,7 +66,7 @@ func (gorilla Gorilla) Add(req *http.Request, key string, value interface{}) err
func (gorilla Gorilla) Pop(req *http.Request, key string) string {
defer gorilla.saveSession(req)

if session, err := gorilla.Store.Get(req, gorilla.SessionName); err == nil {
if session, err := gorilla.getSession(req); err == nil {
if value, ok := session.Values[key]; ok {
delete(session.Values, key)
return fmt.Sprint(value)
Expand All @@ -68,7 +77,7 @@ func (gorilla Gorilla) Pop(req *http.Request, key string) string {

// Get value from session data
func (gorilla Gorilla) Get(req *http.Request, key string) string {
if session, err := gorilla.Store.Get(req, gorilla.SessionName); err == nil {
if session, err := gorilla.getSession(req); err == nil {
if value, ok := session.Values[key]; ok {
return fmt.Sprint(value)
}
Expand Down Expand Up @@ -115,7 +124,7 @@ func (gorilla Gorilla) PopLoad(req *http.Request, key string, result interface{}
func (gorilla Gorilla) Middleware(handler http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
defer gorillaContext.Clear(req)

handler.ServeHTTP(w, req.WithContext(context.WithValue(req.Context(), writer, w)))
ctx := context.WithValue(context.WithValue(req.Context(), writer, w), reader, req)
handler.ServeHTTP(w, req.WithContext(ctx))
})
}

0 comments on commit 95a405f

Please sign in to comment.