Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: update middleware skipper #887

Merged
merged 2 commits into from
Jan 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 0 additions & 20 deletions .github/workflows/backend-tests-default.yml

This file was deleted.

2 changes: 0 additions & 2 deletions .github/workflows/backend-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@ on:
branches:
- main
- "release/*.*.*"
paths-ignore:
- "web/**"

jobs:
go-static-checks:
Expand Down
25 changes: 0 additions & 25 deletions .github/workflows/frontend-tests-default.yml

This file was deleted.

2 changes: 0 additions & 2 deletions .github/workflows/frontend-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@ on:
branches:
- main
- "release/*.*.*"
paths:
- "web/**"

jobs:
eslint-checks:
Expand Down
4 changes: 2 additions & 2 deletions api/auth.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
package api

type Signin struct {
type SignIn struct {
Username string `json:"username"`
Password string `json:"password"`
}

type Signup struct {
type SignUp struct {
Username string `json:"username"`
Password string `json:"password"`
Role Role `json:"role"`
Expand Down
65 changes: 19 additions & 46 deletions server/acl.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,15 @@ import (

var (
userIDContextKey = "user-id"
sessionName = "memos_session"
)

func getUserIDContextKey() string {
return userIDContextKey
}

func setUserSession(ctx echo.Context, user *api.User) error {
sess, _ := session.Get("memos_session", ctx)
sess, _ := session.Get(sessionName, ctx)
sess.Options = &sessions.Options{
Path: "/",
MaxAge: 3600 * 24 * 30,
Expand All @@ -38,7 +39,7 @@ func setUserSession(ctx echo.Context, user *api.User) error {
}

func removeUserSession(ctx echo.Context) error {
sess, _ := session.Get("memos_session", ctx)
sess, _ := session.Get(sessionName, ctx)
sess.Options = &sessions.Options{
Path: "/",
MaxAge: 0,
Expand All @@ -57,61 +58,33 @@ func aclMiddleware(s *Server, next echo.HandlerFunc) echo.HandlerFunc {
ctx := c.Request().Context()
path := c.Path()

// Skip auth.
if common.HasPrefixes(path, "/api/auth") {
if s.DefaultAuthSkipper(c) {
return next(c)
}

{
// If there is openId in query string and related user is found, then skip auth.
openID := c.QueryParam("openId")
if openID != "" {
userFind := &api.UserFind{
OpenID: &openID,
}
user, err := s.Store.FindUser(ctx, userFind)
if err != nil && common.ErrorCode(err) != common.NotFound {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find user by open_id").SetInternal(err)
}
if user != nil {
// Stores userID into context.
c.Set(getUserIDContextKey(), user.ID)
return next(c)
}
sess, _ := session.Get(sessionName, c)
userIDValue := sess.Values[userIDContextKey]
if userIDValue != nil {
userID, _ := strconv.Atoi(fmt.Sprintf("%v", userIDValue))
userFind := &api.UserFind{
ID: &userID,
}
}

{
sess, _ := session.Get("memos_session", c)
userIDValue := sess.Values[userIDContextKey]
if userIDValue != nil {
userID, _ := strconv.Atoi(fmt.Sprintf("%v", userIDValue))
userFind := &api.UserFind{
ID: &userID,
}
user, err := s.Store.FindUser(ctx, userFind)
if err != nil && common.ErrorCode(err) != common.NotFound {
return echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("Failed to find user by ID: %d", userID)).SetInternal(err)
}
if user != nil {
if user.RowStatus == api.Archived {
return echo.NewHTTPError(http.StatusForbidden, fmt.Sprintf("User has been archived with username %s", user.Username))
}
c.Set(getUserIDContextKey(), userID)
user, err := s.Store.FindUser(ctx, userFind)
if err != nil && common.ErrorCode(err) != common.NotFound {
return echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("Failed to find user by ID: %d", userID)).SetInternal(err)
}
if user != nil {
if user.RowStatus == api.Archived {
return echo.NewHTTPError(http.StatusForbidden, fmt.Sprintf("User has been archived with username %s", user.Username))
}
c.Set(getUserIDContextKey(), userID)
}
}

if common.HasPrefixes(path, "/api/ping", "/api/status", "/api/user/:id", "/api/memo/all", "/api/memo/:memoId", "/api/memo/amount") && c.Request().Method == http.MethodGet {
if common.HasPrefixes(path, "/api/ping", "/api/status", "/api/user/:id", "/api/memo") && c.Request().Method == http.MethodGet {
return next(c)
}

if common.HasPrefixes(path, "/api/memo", "/api/tag", "/api/shortcut") && c.Request().Method == http.MethodGet {
if _, err := strconv.Atoi(c.QueryParam("creatorId")); err == nil {
return next(c)
}
}

userID := c.Get(getUserIDContextKey())
if userID == nil {
return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session")
Expand Down
10 changes: 5 additions & 5 deletions server/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ import (
func (s *Server) registerAuthRoutes(g *echo.Group) {
g.POST("/auth/signin", func(c echo.Context) error {
ctx := c.Request().Context()
signin := &api.Signin{}
signin := &api.SignIn{}
if err := json.NewDecoder(c.Request().Body).Decode(signin); err != nil {
return echo.NewHTTPError(http.StatusBadRequest, "Malformatted signin request").SetInternal(err)
}
Expand Down Expand Up @@ -56,7 +56,7 @@ func (s *Server) registerAuthRoutes(g *echo.Group) {

g.POST("/auth/signup", func(c echo.Context) error {
ctx := c.Request().Context()
signup := &api.Signup{}
signup := &api.SignUp{}
if err := json.NewDecoder(c.Request().Body).Decode(signup); err != nil {
return echo.NewHTTPError(http.StatusBadRequest, "Malformatted signup request").SetInternal(err)
}
Expand Down Expand Up @@ -130,14 +130,14 @@ func (s *Server) registerAuthRoutes(g *echo.Group) {
return nil
})

g.POST("/auth/logout", func(c echo.Context) error {
g.POST("/auth/signout", func(c echo.Context) error {
ctx := c.Request().Context()
err := removeUserSession(c)
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to set logout session").SetInternal(err)
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to set sign out session").SetInternal(err)
}
s.Collector.Collect(ctx, &metric.Metric{
Name: "user logout",
Name: "user signout",
})

return c.JSON(http.StatusOK, true)
Expand Down
45 changes: 40 additions & 5 deletions server/common.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,46 @@
package server

func composeResponse(data interface{}) interface{} {
type R struct {
Data interface{} `json:"data"`
}
import (
"github.com/labstack/echo/v4"
"github.com/usememos/memos/api"
"github.com/usememos/memos/common"
)

type response struct {
Data interface{} `json:"data"`
}

return R{
func composeResponse(data interface{}) response {
return response{
Data: data,
}
}

func (server *Server) DefaultAuthSkipper(c echo.Context) bool {
ctx := c.Request().Context()
path := c.Path()

// Skip auth.
if common.HasPrefixes(path, "/api/auth") {
return true
}

// If there is openId in query string and related user is found, then skip auth.
openID := c.QueryParam("openId")
if openID != "" {
userFind := &api.UserFind{
OpenID: &openID,
}
user, err := server.Store.FindUser(ctx, userFind)
if err != nil && common.ErrorCode(err) != common.NotFound {
return false
}
if user != nil {
// Stores userID into context.
c.Set(getUserIDContextKey(), user.ID)
return true
}
}

return false
}
41 changes: 7 additions & 34 deletions server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@ import (
"fmt"
"time"

"github.com/usememos/memos/api"
"github.com/usememos/memos/common"
"github.com/usememos/memos/server/profile"
"github.com/usememos/memos/store"

Expand Down Expand Up @@ -43,8 +41,12 @@ func NewServer(profile *profile.Profile) *Server {
`"status":${status},"error":"${error}"}` + "\n",
}))

e.Use(middleware.Gzip())

e.Use(middleware.CSRFWithConfig(middleware.CSRFConfig{
Skipper: s.OpenAPISkipper,
Skipper: func(c echo.Context) bool {
return s.DefaultAuthSkipper(c)
},
TokenLookup: "cookie:_csrf",
}))

Expand Down Expand Up @@ -92,35 +94,6 @@ func NewServer(profile *profile.Profile) *Server {
return s
}

func (server *Server) Run() error {
return server.e.Start(fmt.Sprintf(":%d", server.Profile.Port))
}

func (server *Server) OpenAPISkipper(c echo.Context) bool {
ctx := c.Request().Context()
path := c.Path()

// Skip auth.
if common.HasPrefixes(path, "/api/auth") {
return true
}

// If there is openId in query string and related user is found, then skip auth.
openID := c.QueryParam("openId")
if openID != "" {
userFind := &api.UserFind{
OpenID: &openID,
}
user, err := server.Store.FindUser(ctx, userFind)
if err != nil && common.ErrorCode(err) != common.NotFound {
return false
}
if user != nil {
// Stores userID into context.
c.Set(getUserIDContextKey(), user.ID)
return true
}
}

return false
func (s *Server) Run() error {
return s.e.Start(fmt.Sprintf(":%d", s.Profile.Port))
}
2 changes: 1 addition & 1 deletion server/shortcut.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,10 +91,10 @@ func (s *Server) registerShortcutRoutes(g *echo.Group) {
if !ok {
return echo.NewHTTPError(http.StatusBadRequest, "Missing user id to find shortcut")
}

shortcutFind := &api.ShortcutFind{
CreatorID: &userID,
}

list, err := s.Store.FindShortcutList(ctx, shortcutFind)
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to fetch shortcut list").SetInternal(err)
Expand Down
16 changes: 5 additions & 11 deletions server/tag.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import (
"net/http"
"regexp"
"sort"
"strconv"

"github.com/usememos/memos/api"
"github.com/usememos/memos/common"
Expand Down Expand Up @@ -49,19 +48,14 @@ func (s *Server) registerTagRoutes(g *echo.Group) {

g.GET("/tag", func(c echo.Context) error {
ctx := c.Request().Context()
tagFind := &api.TagFind{}
if userID, err := strconv.Atoi(c.QueryParam("creatorId")); err == nil {
tagFind.CreatorID = userID
userID, ok := c.Get(getUserIDContextKey()).(int)
if !ok {
return echo.NewHTTPError(http.StatusBadRequest, "Missing user id to find tag")
}

if tagFind.CreatorID == 0 {
currentUserID, ok := c.Get(getUserIDContextKey()).(int)
if !ok {
return echo.NewHTTPError(http.StatusBadRequest, "Missing user id to find tag")
}
tagFind.CreatorID = currentUserID
tagFind := &api.TagFind{
CreatorID: userID,
}

tagList, err := s.Store.FindTagList(ctx, tagFind)
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find tag list").SetInternal(err)
Expand Down
2 changes: 1 addition & 1 deletion web/src/helpers/api.ts
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ export function signup(username: string, password: string, role: UserRole) {
}

export function signout() {
return axios.post("/api/auth/logout");
return axios.post("/api/auth/signout");
}

export function createUser(userCreate: UserCreate) {
Expand Down
Loading