Skip to content

Commit

Permalink
chore: update middleware skipper (#887)
Browse files Browse the repository at this point in the history
* chore: update middleware skipper

* chore: update
  • Loading branch information
boojack committed Jan 1, 2023
1 parent 293f88e commit a797280
Show file tree
Hide file tree
Showing 13 changed files with 82 additions and 156 deletions.
20 changes: 0 additions & 20 deletions .github/workflows/backend-tests-default.yml

This file was deleted.

2 changes: 0 additions & 2 deletions .github/workflows/backend-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@ on:
branches:
- main
- "release/*.*.*"
paths-ignore:
- "web/**"

jobs:
go-static-checks:
Expand Down
25 changes: 0 additions & 25 deletions .github/workflows/frontend-tests-default.yml

This file was deleted.

2 changes: 0 additions & 2 deletions .github/workflows/frontend-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@ on:
branches:
- main
- "release/*.*.*"
paths:
- "web/**"

jobs:
eslint-checks:
Expand Down
4 changes: 2 additions & 2 deletions api/auth.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
package api

type Signin struct {
type SignIn struct {
Username string `json:"username"`
Password string `json:"password"`
}

type Signup struct {
type SignUp struct {
Username string `json:"username"`
Password string `json:"password"`
Role Role `json:"role"`
Expand Down
65 changes: 19 additions & 46 deletions server/acl.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,15 @@ import (

var (
userIDContextKey = "user-id"
sessionName = "memos_session"
)

func getUserIDContextKey() string {
return userIDContextKey
}

func setUserSession(ctx echo.Context, user *api.User) error {
sess, _ := session.Get("memos_session", ctx)
sess, _ := session.Get(sessionName, ctx)
sess.Options = &sessions.Options{
Path: "/",
MaxAge: 3600 * 24 * 30,
Expand All @@ -38,7 +39,7 @@ func setUserSession(ctx echo.Context, user *api.User) error {
}

func removeUserSession(ctx echo.Context) error {
sess, _ := session.Get("memos_session", ctx)
sess, _ := session.Get(sessionName, ctx)
sess.Options = &sessions.Options{
Path: "/",
MaxAge: 0,
Expand All @@ -57,61 +58,33 @@ func aclMiddleware(s *Server, next echo.HandlerFunc) echo.HandlerFunc {
ctx := c.Request().Context()
path := c.Path()

// Skip auth.
if common.HasPrefixes(path, "/api/auth") {
if s.DefaultAuthSkipper(c) {
return next(c)
}

{
// If there is openId in query string and related user is found, then skip auth.
openID := c.QueryParam("openId")
if openID != "" {
userFind := &api.UserFind{
OpenID: &openID,
}
user, err := s.Store.FindUser(ctx, userFind)
if err != nil && common.ErrorCode(err) != common.NotFound {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find user by open_id").SetInternal(err)
}
if user != nil {
// Stores userID into context.
c.Set(getUserIDContextKey(), user.ID)
return next(c)
}
sess, _ := session.Get(sessionName, c)
userIDValue := sess.Values[userIDContextKey]
if userIDValue != nil {
userID, _ := strconv.Atoi(fmt.Sprintf("%v", userIDValue))
userFind := &api.UserFind{
ID: &userID,
}
}

{
sess, _ := session.Get("memos_session", c)
userIDValue := sess.Values[userIDContextKey]
if userIDValue != nil {
userID, _ := strconv.Atoi(fmt.Sprintf("%v", userIDValue))
userFind := &api.UserFind{
ID: &userID,
}
user, err := s.Store.FindUser(ctx, userFind)
if err != nil && common.ErrorCode(err) != common.NotFound {
return echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("Failed to find user by ID: %d", userID)).SetInternal(err)
}
if user != nil {
if user.RowStatus == api.Archived {
return echo.NewHTTPError(http.StatusForbidden, fmt.Sprintf("User has been archived with username %s", user.Username))
}
c.Set(getUserIDContextKey(), userID)
user, err := s.Store.FindUser(ctx, userFind)
if err != nil && common.ErrorCode(err) != common.NotFound {
return echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("Failed to find user by ID: %d", userID)).SetInternal(err)
}
if user != nil {
if user.RowStatus == api.Archived {
return echo.NewHTTPError(http.StatusForbidden, fmt.Sprintf("User has been archived with username %s", user.Username))
}
c.Set(getUserIDContextKey(), userID)
}
}

if common.HasPrefixes(path, "/api/ping", "/api/status", "/api/user/:id", "/api/memo/all", "/api/memo/:memoId", "/api/memo/amount") && c.Request().Method == http.MethodGet {
if common.HasPrefixes(path, "/api/ping", "/api/status", "/api/user/:id", "/api/memo") && c.Request().Method == http.MethodGet {
return next(c)
}

if common.HasPrefixes(path, "/api/memo", "/api/tag", "/api/shortcut") && c.Request().Method == http.MethodGet {
if _, err := strconv.Atoi(c.QueryParam("creatorId")); err == nil {
return next(c)
}
}

userID := c.Get(getUserIDContextKey())
if userID == nil {
return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session")
Expand Down
10 changes: 5 additions & 5 deletions server/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ import (
func (s *Server) registerAuthRoutes(g *echo.Group) {
g.POST("/auth/signin", func(c echo.Context) error {
ctx := c.Request().Context()
signin := &api.Signin{}
signin := &api.SignIn{}
if err := json.NewDecoder(c.Request().Body).Decode(signin); err != nil {
return echo.NewHTTPError(http.StatusBadRequest, "Malformatted signin request").SetInternal(err)
}
Expand Down Expand Up @@ -56,7 +56,7 @@ func (s *Server) registerAuthRoutes(g *echo.Group) {

g.POST("/auth/signup", func(c echo.Context) error {
ctx := c.Request().Context()
signup := &api.Signup{}
signup := &api.SignUp{}
if err := json.NewDecoder(c.Request().Body).Decode(signup); err != nil {
return echo.NewHTTPError(http.StatusBadRequest, "Malformatted signup request").SetInternal(err)
}
Expand Down Expand Up @@ -130,14 +130,14 @@ func (s *Server) registerAuthRoutes(g *echo.Group) {
return nil
})

g.POST("/auth/logout", func(c echo.Context) error {
g.POST("/auth/signout", func(c echo.Context) error {
ctx := c.Request().Context()
err := removeUserSession(c)
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to set logout session").SetInternal(err)
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to set sign out session").SetInternal(err)
}
s.Collector.Collect(ctx, &metric.Metric{
Name: "user logout",
Name: "user signout",
})

return c.JSON(http.StatusOK, true)
Expand Down
45 changes: 40 additions & 5 deletions server/common.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,46 @@
package server

func composeResponse(data interface{}) interface{} {
type R struct {
Data interface{} `json:"data"`
}
import (
"github.com/labstack/echo/v4"
"github.com/usememos/memos/api"
"github.com/usememos/memos/common"
)

type response struct {
Data interface{} `json:"data"`
}

return R{
func composeResponse(data interface{}) response {
return response{
Data: data,
}
}

func (server *Server) DefaultAuthSkipper(c echo.Context) bool {
ctx := c.Request().Context()
path := c.Path()

// Skip auth.
if common.HasPrefixes(path, "/api/auth") {
return true
}

// If there is openId in query string and related user is found, then skip auth.
openID := c.QueryParam("openId")
if openID != "" {
userFind := &api.UserFind{
OpenID: &openID,
}
user, err := server.Store.FindUser(ctx, userFind)
if err != nil && common.ErrorCode(err) != common.NotFound {
return false
}
if user != nil {
// Stores userID into context.
c.Set(getUserIDContextKey(), user.ID)
return true
}
}

return false
}
41 changes: 7 additions & 34 deletions server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@ import (
"fmt"
"time"

"github.com/usememos/memos/api"
"github.com/usememos/memos/common"
"github.com/usememos/memos/server/profile"
"github.com/usememos/memos/store"

Expand Down Expand Up @@ -43,8 +41,12 @@ func NewServer(profile *profile.Profile) *Server {
`"status":${status},"error":"${error}"}` + "\n",
}))

e.Use(middleware.Gzip())

e.Use(middleware.CSRFWithConfig(middleware.CSRFConfig{
Skipper: s.OpenAPISkipper,
Skipper: func(c echo.Context) bool {
return s.DefaultAuthSkipper(c)
},
TokenLookup: "cookie:_csrf",
}))

Expand Down Expand Up @@ -92,35 +94,6 @@ func NewServer(profile *profile.Profile) *Server {
return s
}

func (server *Server) Run() error {
return server.e.Start(fmt.Sprintf(":%d", server.Profile.Port))
}

func (server *Server) OpenAPISkipper(c echo.Context) bool {
ctx := c.Request().Context()
path := c.Path()

// Skip auth.
if common.HasPrefixes(path, "/api/auth") {
return true
}

// If there is openId in query string and related user is found, then skip auth.
openID := c.QueryParam("openId")
if openID != "" {
userFind := &api.UserFind{
OpenID: &openID,
}
user, err := server.Store.FindUser(ctx, userFind)
if err != nil && common.ErrorCode(err) != common.NotFound {
return false
}
if user != nil {
// Stores userID into context.
c.Set(getUserIDContextKey(), user.ID)
return true
}
}

return false
func (s *Server) Run() error {
return s.e.Start(fmt.Sprintf(":%d", s.Profile.Port))
}
2 changes: 1 addition & 1 deletion server/shortcut.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,10 +91,10 @@ func (s *Server) registerShortcutRoutes(g *echo.Group) {
if !ok {
return echo.NewHTTPError(http.StatusBadRequest, "Missing user id to find shortcut")
}

shortcutFind := &api.ShortcutFind{
CreatorID: &userID,
}

list, err := s.Store.FindShortcutList(ctx, shortcutFind)
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to fetch shortcut list").SetInternal(err)
Expand Down
16 changes: 5 additions & 11 deletions server/tag.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import (
"net/http"
"regexp"
"sort"
"strconv"

"github.com/usememos/memos/api"
"github.com/usememos/memos/common"
Expand Down Expand Up @@ -49,19 +48,14 @@ func (s *Server) registerTagRoutes(g *echo.Group) {

g.GET("/tag", func(c echo.Context) error {
ctx := c.Request().Context()
tagFind := &api.TagFind{}
if userID, err := strconv.Atoi(c.QueryParam("creatorId")); err == nil {
tagFind.CreatorID = userID
userID, ok := c.Get(getUserIDContextKey()).(int)
if !ok {
return echo.NewHTTPError(http.StatusBadRequest, "Missing user id to find tag")
}

if tagFind.CreatorID == 0 {
currentUserID, ok := c.Get(getUserIDContextKey()).(int)
if !ok {
return echo.NewHTTPError(http.StatusBadRequest, "Missing user id to find tag")
}
tagFind.CreatorID = currentUserID
tagFind := &api.TagFind{
CreatorID: userID,
}

tagList, err := s.Store.FindTagList(ctx, tagFind)
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find tag list").SetInternal(err)
Expand Down
2 changes: 1 addition & 1 deletion web/src/helpers/api.ts
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ export function signup(username: string, password: string, role: UserRole) {
}

export function signout() {
return axios.post("/api/auth/logout");
return axios.post("/api/auth/signout");
}

export function createUser(userCreate: UserCreate) {
Expand Down
Loading

0 comments on commit a797280

Please sign in to comment.