/
middleware.go
106 lines (93 loc) · 2.88 KB
/
middleware.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
package casbinmw
import (
"net/http"
"net/url"
"github.com/casbin/casbin/v2"
"github.com/labstack/echo/v4"
"github.com/labstack/echo/v4/middleware"
)
type (
// Config defines the config for this middleware.
Config struct {
// Skipper defines a function to skip middleware.
Skipper middleware.Skipper
// BeforeFunc defines a function which is executed just before the middleware.
BeforeFunc middleware.BeforeFunc
// GetURLPathFunc defines a function which return the requested URL.
GetURLPathFunc func(ctx echo.Context) string
// SuccessHandler defines a function which is executed for a granted access.
SuccessHandler func(echo.Context)
// ErrorHandler defines a function which is executed for a rejected access.
// It may be used to define a custom error.
ErrorHandler func(error, echo.Context) error
// Enforcer instance.
Enforcer *casbin.Enforcer
// DataSource is the interface that extract a subject from echo.Context.
DataSource DataSource
}
)
var (
DefaultConfig = Config{
Skipper: middleware.DefaultSkipper,
}
)
// DataSource is the interface that extract a subject from echo.Context.
type DataSource interface {
GetSubject(c echo.Context) string
}
// Middleware returns a Echo middleware.
func Middleware(ce *casbin.Enforcer, ds DataSource) echo.MiddlewareFunc {
c := DefaultConfig
c.Enforcer = ce
c.DataSource = ds
return MiddlewareWithConfig(c)
}
// MiddlewareWithConfig returns an Echo middleware with config.
func MiddlewareWithConfig(config Config) echo.MiddlewareFunc {
if config.Skipper == nil {
config.Skipper = DefaultConfig.Skipper
}
if config.GetURLPathFunc == nil {
config.GetURLPathFunc = func(c echo.Context) string {
u, err := url.Parse(c.Request().URL.Path)
if err != nil {
return c.Request().URL.Path
}
return u.Path
}
}
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
if config.Skipper(c) {
return next(c)
}
if config.BeforeFunc != nil {
config.BeforeFunc(c)
}
urlPath := config.GetURLPathFunc(c)
ok, err := config.HasPermission(c, urlPath)
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, err.Error())
}
if ok {
if config.SuccessHandler != nil {
config.SuccessHandler(c)
}
return next(c)
}
if config.ErrorHandler != nil {
return config.ErrorHandler(echo.ErrForbidden, c)
}
return echo.ErrForbidden
}
}
}
// GetSubject extract a subject from the request.
func (a *Config) GetSubject(c echo.Context) string {
return a.DataSource.GetSubject(c)
}
// HasPermission checks a resource access permission against casbin with the subject/method/path combination from the request.
// Returns true (permission granted) or false (permission forbidden).
func (a *Config) HasPermission(c echo.Context, urlPath string) (bool, error) {
return a.Enforcer.Enforce(a.GetSubject(c), urlPath, c.Request().Method)
}