-
Notifications
You must be signed in to change notification settings - Fork 8
/
context.go
132 lines (115 loc) · 3.31 KB
/
context.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
package provider
import (
"context"
"errors"
"net/http"
"net/url"
"strings"
)
type valueKey int
var (
issuer valueKey = 1
)
type IssuerInterceptor struct {
issuerFromRequest IssuerFromRequest
}
//NewIssuerInterceptor will set the issuer into the context
//by the provided IssuerFromRequest (e.g. returned from StaticIssuer or IssuerFromHost)
func NewIssuerInterceptor(issuerFromRequest IssuerFromRequest) *IssuerInterceptor {
return &IssuerInterceptor{
issuerFromRequest: issuerFromRequest,
}
}
func (i *IssuerInterceptor) Handler(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
i.setIssuerCtx(w, r, next)
})
}
func (i *IssuerInterceptor) HandlerFunc(next http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
i.setIssuerCtx(w, r, next)
}
}
//IssuerFromContext reads the issuer from the context (set by an IssuerInterceptor)
//it will return an empty string if not found
func IssuerFromContext(ctx context.Context) string {
ctxIssuer, _ := ctx.Value(issuer).(string)
return ctxIssuer
}
func (i *IssuerInterceptor) setIssuerCtx(w http.ResponseWriter, r *http.Request, next http.Handler) {
ctx := context.WithValue(r.Context(), issuer, i.issuerFromRequest(r))
r = r.WithContext(ctx)
next.ServeHTTP(w, r)
}
var (
ErrInvalidIssuerPath = errors.New("no fragments or query allowed for issuer")
ErrInvalidIssuerNoIssuer = errors.New("missing issuer")
ErrInvalidIssuerURL = errors.New("invalid url for issuer")
ErrInvalidIssuerMissingHost = errors.New("host for issuer missing")
ErrInvalidIssuerHTTPS = errors.New("scheme for issuer must be `https`")
)
type IssuerFromRequest func(r *http.Request) string
func IssuerFromHost(path string) func(bool) (IssuerFromRequest, error) {
return func(allowInsecure bool) (IssuerFromRequest, error) {
issuerPath, err := url.Parse(path)
if err != nil {
return nil, ErrInvalidIssuerURL
}
if err := ValidateIssuerPath(issuerPath); err != nil {
return nil, err
}
return func(r *http.Request) string {
return dynamicIssuer(r.Host, path, allowInsecure)
}, nil
}
}
func StaticIssuer(issuer string) func(bool) (IssuerFromRequest, error) {
return func(allowInsecure bool) (IssuerFromRequest, error) {
if err := ValidateIssuer(issuer, allowInsecure); err != nil {
return nil, err
}
return func(_ *http.Request) string {
return issuer
}, nil
}
}
func ValidateIssuer(issuer string, allowInsecure bool) error {
if issuer == "" {
return ErrInvalidIssuerNoIssuer
}
u, err := url.Parse(issuer)
if err != nil {
return ErrInvalidIssuerURL
}
if u.Host == "" {
return ErrInvalidIssuerMissingHost
}
if u.Scheme != "https" {
if !devLocalAllowed(u, allowInsecure) {
return ErrInvalidIssuerHTTPS
}
}
return ValidateIssuerPath(u)
}
func ValidateIssuerPath(issuer *url.URL) error {
if issuer.Fragment != "" || len(issuer.Query()) > 0 {
return ErrInvalidIssuerPath
}
return nil
}
func devLocalAllowed(url *url.URL, allowInsecure bool) bool {
if !allowInsecure {
return false
}
return url.Scheme == "http"
}
func dynamicIssuer(issuer, path string, allowInsecure bool) string {
schema := "https"
if allowInsecure {
schema = "http"
}
if len(path) > 0 && !strings.HasPrefix(path, "/") {
path = "/" + path
}
return schema + "://" + issuer + path
}