From a69682c6822b48ba68f9c6f89cb0e13512405996 Mon Sep 17 00:00:00 2001 From: yk Date: Fri, 26 Nov 2021 17:34:56 +0300 Subject: [PATCH 1/6] - Model - Repo --- go.mod | 43 +- internal/app/appauth/request.go | 109 -- internal/app/appauth/storage.go | 113 +- internal/app/auth_server.go | 4 +- internal/infrastructure/mongo/auth_request.go | 75 ++ internal/infrastructure/mongo/container.go | 1 + .../mongo/mongodoc/auth_request.go | 116 ++ internal/usecase/repo/auth_request.go | 16 + internal/usecase/repo/container.go | 1 + pkg/auth/builder.go | 102 ++ {internal/app/appauth => pkg/auth}/client.go | 4 +- pkg/auth/request.go | 138 +++ pkg/id/auth_request_gen.go | 297 +++++ pkg/id/auth_request_gen_test.go | 1011 +++++++++++++++++ pkg/id/gen.go | 2 + 15 files changed, 1821 insertions(+), 211 deletions(-) delete mode 100644 internal/app/appauth/request.go create mode 100644 internal/infrastructure/mongo/auth_request.go create mode 100644 internal/infrastructure/mongo/mongodoc/auth_request.go create mode 100644 internal/usecase/repo/auth_request.go create mode 100644 pkg/auth/builder.go rename {internal/app/appauth => pkg/auth}/client.go (97%) create mode 100644 pkg/auth/request.go create mode 100644 pkg/id/auth_request_gen.go create mode 100644 pkg/id/auth_request_gen_test.go diff --git a/go.mod b/go.mod index 00f51985..bf2d714c 100644 --- a/go.mod +++ b/go.mod @@ -69,47 +69,6 @@ require ( gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b // indirect ) -require ( - github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751 // indirect - github.com/aws/aws-sdk-go v1.34.28 // indirect - github.com/caos/logging v0.0.2 // indirect - github.com/davecgh/go-spew v1.1.1 // indirect - github.com/felixge/httpsnoop v1.0.1 // indirect - github.com/go-stack/stack v1.8.0 // indirect - github.com/golang/groupcache v0.0.0-20200121045136-8c9f03a8e57e // indirect - github.com/golang/protobuf v1.5.2 // indirect - github.com/golang/snappy v0.0.3 // indirect - github.com/google/go-cmp v0.5.6 // indirect - github.com/google/pprof v0.0.0-20210609004039-a478d1d731e9 // indirect - github.com/googleapis/gax-go/v2 v2.0.5 // indirect - github.com/gorilla/handlers v1.5.1 // indirect - github.com/gorilla/schema v1.2.0 // indirect - github.com/gorilla/securecookie v1.1.1 // indirect - github.com/gorilla/websocket v1.4.2 // indirect - github.com/h2non/parth v0.0.0-20190131123155-b4df798d6542 // indirect - github.com/jmespath/go-jmespath v0.4.0 // indirect - github.com/jstemmer/go-junit-report v0.9.1 // indirect - github.com/mattn/go-colorable v0.1.8 // indirect - github.com/pmezard/go-difflib v1.0.0 // indirect - github.com/russross/blackfriday/v2 v2.1.0 // indirect - github.com/sendgrid/rest v2.6.5+incompatible // indirect - github.com/valyala/bytebufferpool v1.0.0 // indirect - github.com/valyala/fasttemplate v1.2.1 // indirect - github.com/xdg-go/pbkdf2 v1.0.0 // indirect - github.com/xdg-go/scram v1.0.2 // indirect - github.com/xdg-go/stringprep v1.0.2 // indirect - github.com/youmark/pkcs8 v0.0.0-20181117223130-1be2e3e5546d // indirect - go.opencensus.io v0.23.0 // indirect - golang.org/x/lint v0.0.0-20210508222113-6edffad5e616 // indirect - golang.org/x/net v0.0.0-20210503060351-7fd8e65b6420 // indirect - golang.org/x/oauth2 v0.0.0-20210628180205-a41e5a781914 // indirect - golang.org/x/sync v0.0.0-20210220032951-036812b2e83c // indirect - golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 // indirect - google.golang.org/appengine v1.6.7 // indirect - google.golang.org/genproto v0.0.0-20210716133855-ce7ef5c701ea // indirect - google.golang.org/grpc v1.39.0 // indirect - google.golang.org/protobuf v1.27.1 // indirect - gopkg.in/alecthomas/kingpin.v2 v2.2.6 // indirect -) +require github.com/sendgrid/rest v2.6.5+incompatible // indirect go 1.17 diff --git a/internal/app/appauth/request.go b/internal/app/appauth/request.go deleted file mode 100644 index 261eb137..00000000 --- a/internal/app/appauth/request.go +++ /dev/null @@ -1,109 +0,0 @@ -package appauth - -import ( - "time" - - "github.com/caos/oidc/pkg/oidc" -) - -var essentialScopes = []string{"openid", "profile", "email"} - -type AuthRequest struct { - ID string - ClientID string - subject string - code string - state string - ResponseType oidc.ResponseType - scopes []string - audiences []string - RedirectURI string - Nonce string - CodeChallenge *oidc.CodeChallenge - createdAt time.Time - authorizedAt *time.Time -} - -func (a *AuthRequest) GetID() string { - return a.ID -} - -func (a *AuthRequest) GetACR() string { - return "" -} - -func (a *AuthRequest) GetAMR() []string { - return []string{ - "password", - } -} - -func (a *AuthRequest) GetAudience() []string { - return a.audiences -} - -func (a *AuthRequest) GetAuthTime() time.Time { - return a.createdAt -} - -func (a *AuthRequest) GetClientID() string { - return a.ClientID -} - -func (a *AuthRequest) GetCode() string { - return a.code -} - -func (a *AuthRequest) GetState() string { - return a.state -} - -func (a *AuthRequest) GetCodeChallenge() *oidc.CodeChallenge { - return a.CodeChallenge -} - -func (a *AuthRequest) GetNonce() string { - return a.Nonce -} - -func (a *AuthRequest) GetRedirectURI() string { - return a.RedirectURI -} - -func (a *AuthRequest) GetResponseType() oidc.ResponseType { - return a.ResponseType -} - -func (a *AuthRequest) GetScopes() []string { - return unique(append(a.scopes, essentialScopes...)) -} - -func (a *AuthRequest) SetCurrentScopes(scopes []string) { - a.scopes = unique(append(scopes, essentialScopes...)) -} - -func (a *AuthRequest) GetSubject() string { - return a.subject // return "auth0|60acc23af5de37006a5d8229" -} - -func (a *AuthRequest) Done() bool { - return a.authorizedAt != nil -} - -func (a *AuthRequest) Complete(sub string) { - a.subject = sub - now := time.Now() - a.authorizedAt = &now -} - -func unique(list []string) []string { - allKeys := make(map[string]struct{}) - var uniqueList []string - for _, item := range list { - if _, ok := allKeys[item]; !ok { - allKeys[item] = struct{}{} - uniqueList = append(uniqueList, item) - } - } - return uniqueList -} diff --git a/internal/app/appauth/storage.go b/internal/app/appauth/storage.go index b0cc12c7..a4f6d747 100644 --- a/internal/app/appauth/storage.go +++ b/internal/app/appauth/storage.go @@ -8,13 +8,14 @@ import ( "crypto/x509/pkix" "errors" "math/big" - mrand "math/rand" "sync" "time" "github.com/caos/oidc/pkg/oidc" "github.com/caos/oidc/pkg/op" - "github.com/oklog/ulid" + "github.com/reearth/reearth-backend/internal/usecase/repo" + "github.com/reearth/reearth-backend/pkg/auth" + "github.com/reearth/reearth-backend/pkg/id" "github.com/reearth/reearth-backend/pkg/user" "gopkg.in/square/go-jose.v2" ) @@ -24,7 +25,7 @@ type Storage struct { appConfig *StorageConfig getUserBySubject func(context.Context, string) (*user.User, error) clients map[string]op.Client - requests map[string]AuthRequest + requests repo.AuthRequest keySet jose.JSONWebKeySet key *rsa.PrivateKey sigKey jose.SigningKey @@ -58,9 +59,9 @@ var dummyName = pkix.Name{ PostalCode: []string{"1"}, } -func NewAuthStorage(cfg *StorageConfig, getUserBySubject func(context.Context, string) (*user.User, error)) op.Storage { +func NewAuthStorage(cfg *StorageConfig, request repo.AuthRequest, getUserBySubject func(context.Context, string) (*user.User, error)) op.Storage { - client := initLocalClient(cfg.Debug) + client := auth.NewLocalClient(cfg.Debug) name := dummyName if cfg.DN != nil { @@ -81,7 +82,7 @@ func NewAuthStorage(cfg *StorageConfig, getUserBySubject func(context.Context, s return &Storage{ appConfig: cfg, getUserBySubject: getUserBySubject, - requests: make(map[string]AuthRequest), + requests: request, key: key, sigKey: sigKey, keySet: keySet, @@ -134,13 +135,10 @@ func (s *Storage) Health(_ context.Context) error { return nil } -func (s *Storage) CreateAuthRequest(_ context.Context, authReq *oidc.AuthRequest, _ string) (op.AuthRequest, error) { +func (s *Storage) CreateAuthRequest(ctx context.Context, authReq *oidc.AuthRequest, _ string) (op.AuthRequest, error) { s.lock.Lock() defer s.lock.Unlock() - ti := time.Now() - entropy := ulid.Monotonic(mrand.New(mrand.NewSource(ti.UnixNano())), 0) - audiences := []string{ s.appConfig.Domain, } @@ -148,73 +146,70 @@ func (s *Storage) CreateAuthRequest(_ context.Context, authReq *oidc.AuthRequest audiences = append(audiences, "http://localhost:8080") } - request := &AuthRequest{ - ID: ulid.MustNew(ulid.Timestamp(ti), entropy).String(), - ClientID: authReq.ClientID, - subject: "", - code: "", // Will be set after /authorize/callback success - state: authReq.State, - scopes: authReq.Scopes, - audiences: audiences, - ResponseType: authReq.ResponseType, - Nonce: authReq.Nonce, - RedirectURI: authReq.RedirectURI, - createdAt: time.Now().UTC(), - authorizedAt: nil, - } + var cc *oidc.CodeChallenge if authReq.CodeChallenge != "" { - request.CodeChallenge = &oidc.CodeChallenge{ + cc = &oidc.CodeChallenge{ Challenge: authReq.CodeChallenge, Method: authReq.CodeChallengeMethod, } } - - s.requests[request.ID] = *request + var request = auth.New(). + NewID(). + ClientID(authReq.ClientID). + State(authReq.State). + ResponseType(authReq.ResponseType). + Scopes(authReq.Scopes). + Audiences(audiences). + RedirectURI(authReq.RedirectURI). + Nonce(authReq.Nonce). + CodeChallenge(cc). + CreatedAt(time.Now().UTC()). + AuthorizedAt(nil). + MustBuild() + + if err := s.requests.Save(ctx, request); err != nil { + return nil, err + } return request, nil } -func (s *Storage) AuthRequestByID(_ context.Context, requestID string) (op.AuthRequest, error) { +func (s *Storage) AuthRequestByID(ctx context.Context, requestID string) (op.AuthRequest, error) { s.lock.Lock() defer s.lock.Unlock() if requestID == "" { return nil, errors.New("invalid id") } - request, exists := s.requests[requestID] - if !exists { - return nil, errors.New("not found") + reqId, err := id.AuthRequestIDFrom(requestID) + if err != nil { + return nil, err + } + request, err := s.requests.FindByID(ctx, reqId) + if err != nil { + return nil, err } - return &request, nil + return request, nil } -func (s *Storage) AuthRequestByCode(_ context.Context, code string) (op.AuthRequest, error) { +func (s *Storage) AuthRequestByCode(ctx context.Context, code string) (op.AuthRequest, error) { s.lock.Lock() defer s.lock.Unlock() if code == "" { return nil, errors.New("invalid code") } - for _, request := range s.requests { - if request.GetCode() == code { - return &request, nil - } - } - return nil, errors.New("invalid code") + return s.requests.FindByCode(ctx, code) } -func (s *Storage) AuthRequestBySubject(_ context.Context, subject string) (op.AuthRequest, error) { +func (s *Storage) AuthRequestBySubject(ctx context.Context, subject string) (op.AuthRequest, error) { s.lock.Lock() defer s.lock.Unlock() if subject == "" { return nil, errors.New("invalid subject") } - for _, request := range s.requests { - if request.GetSubject() == subject { - return &request, nil - } - } - return nil, errors.New("invalid subject") + + return s.requests.FindBySubject(ctx, subject) } func (s *Storage) SaveAuthCode(ctx context.Context, requestID, code string) error { @@ -223,8 +218,8 @@ func (s *Storage) SaveAuthCode(ctx context.Context, requestID, code string) erro if err != nil { return err } - request2 := request.(*AuthRequest) - request2.code = code + request2 := request.(*auth.Request) + request2.SetCode(code) err = s.updateRequest(ctx, requestID, *request2) return err } @@ -242,8 +237,8 @@ func (s *Storage) CreateAccessToken(_ context.Context, _ op.TokenRequest) (strin } func (s *Storage) CreateAccessAndRefreshTokens(_ context.Context, request op.TokenRequest, _ string) (accessTokenID string, newRefreshToken string, expiration time.Time, err error) { - authReq := request.(*AuthRequest) - return "id", authReq.ID, time.Now().UTC().Add(5 * time.Minute), nil + authReq := request.(*auth.Request) + return "id", authReq.GetID(), time.Now().UTC().Add(5 * time.Minute), nil } func (s *Storage) TokenRequestByRefreshToken(ctx context.Context, refreshToken string) (op.RefreshTokenRequest, error) { @@ -338,25 +333,31 @@ func (s *Storage) CompleteAuthRequest(ctx context.Context, requestId, sub string if err != nil { return err } - req := request.(*AuthRequest) + req := request.(*auth.Request) req.Complete(sub) err = s.updateRequest(ctx, requestId, *req) return err } -func (s *Storage) updateRequest(_ context.Context, requestID string, req AuthRequest) error { +func (s *Storage) updateRequest(ctx context.Context, requestID string, req auth.Request) error { s.lock.Lock() defer s.lock.Unlock() if requestID == "" { return errors.New("invalid id") } - _, exists := s.requests[requestID] - if !exists { - return errors.New("not found") + reqId, err := id.AuthRequestIDFrom(requestID) + if err != nil { + return err + } + + if _, err := s.requests.FindByID(ctx, reqId); err != nil { + return err } - s.requests[requestID] = req + if err := s.requests.Save(ctx, &req); err != nil { + return err + } return nil } diff --git a/internal/app/auth_server.go b/internal/app/auth_server.go index 5231c708..2a926dbc 100644 --- a/internal/app/auth_server.go +++ b/internal/app/auth_server.go @@ -9,13 +9,12 @@ import ( "strings" "github.com/caos/oidc/pkg/op" + "github.com/golang/gddo/httputil/header" "github.com/gorilla/mux" "github.com/labstack/echo/v4" "github.com/reearth/reearth-backend/internal/app/appauth" "github.com/reearth/reearth-backend/internal/usecase/interactor" "github.com/reearth/reearth-backend/internal/usecase/interfaces" - - "github.com/golang/gddo/httputil/header" ) var ( @@ -60,6 +59,7 @@ func authEndPoints(ctx context.Context, e *echo.Echo, r *echo.Group, cfg *Server Debug: cfg.Debug, DN: dn, }, + cfg.Repos.AuthRequest, userUsecase.GetUserBySubject, ) handler, err := op.NewOpenIDProvider( diff --git a/internal/infrastructure/mongo/auth_request.go b/internal/infrastructure/mongo/auth_request.go new file mode 100644 index 00000000..3e1a36e4 --- /dev/null +++ b/internal/infrastructure/mongo/auth_request.go @@ -0,0 +1,75 @@ +package mongo + +import ( + "context" + + "github.com/reearth/reearth-backend/pkg/auth" + "go.mongodb.org/mongo-driver/bson" + + "github.com/reearth/reearth-backend/internal/infrastructure/mongo/mongodoc" + "github.com/reearth/reearth-backend/internal/usecase/repo" + "github.com/reearth/reearth-backend/pkg/id" + "github.com/reearth/reearth-backend/pkg/log" +) + +type authRequestRepo struct { + client *mongodoc.ClientCollection +} + +func NewAuthRequest(client *mongodoc.Client) repo.AuthRequest { + r := &authRequestRepo{client: client.WithCollection("authRequest")} + r.init() + return r +} + +func (r *authRequestRepo) init() { + i := r.client.CreateIndex(context.Background(), []string{"code", "subject"}) + if len(i) > 0 { + log.Infof("mongo: %s: index created: %s", "authRequest", i) + } +} + +func (r *authRequestRepo) FindByID(ctx context.Context, id2 id.AuthRequestID) (*auth.Request, error) { + filter := bson.D{{Key: "id", Value: id.ID(id2).String()}} + return r.findOne(ctx, filter) +} + +func (r *authRequestRepo) FindByCode(ctx context.Context, s string) (*auth.Request, error) { + filter := bson.D{{Key: "code", Value: s}} + return r.findOne(ctx, filter) +} + +func (r *authRequestRepo) FindBySubject(ctx context.Context, s string) (*auth.Request, error) { + filter := bson.D{{Key: "subject", Value: s}} + return r.findOne(ctx, filter) +} + +func (r *authRequestRepo) Save(ctx context.Context, request *auth.Request) error { + doc, id1 := mongodoc.NewAuthRequest(request) + return r.client.SaveOne(ctx, id1, doc) +} + +func (r *authRequestRepo) Remove(ctx context.Context, requestID id.AuthRequestID) error { + return r.client.RemoveOne(ctx, requestID.String()) +} + +func (r *authRequestRepo) find(ctx context.Context, dst []*auth.Request, filter bson.D) ([]*auth.Request, error) { + c := mongodoc.AuthRequestConsumer{ + Rows: dst, + } + if err := r.client.Find(ctx, filter, &c); err != nil { + return nil, err + } + return c.Rows, nil +} + +func (r *authRequestRepo) findOne(ctx context.Context, filter bson.D) (*auth.Request, error) { + dst := make([]*auth.Request, 0, 1) + c := mongodoc.AuthRequestConsumer{ + Rows: dst, + } + if err := r.client.FindOne(ctx, filter, &c); err != nil { + return nil, err + } + return c.Rows[0], nil +} diff --git a/internal/infrastructure/mongo/container.go b/internal/infrastructure/mongo/container.go index a112483d..f7d73eaf 100644 --- a/internal/infrastructure/mongo/container.go +++ b/internal/infrastructure/mongo/container.go @@ -16,6 +16,7 @@ func InitRepos(ctx context.Context, c *repo.Container, mc *mongo.Client, databas client := mongodoc.NewClient(databaseName, mc) c.Asset = NewAsset(client) + c.AuthRequest = NewAuthRequest(client) c.Config = NewConfig(client) c.DatasetSchema = NewDatasetSchema(client) c.Dataset = NewDataset(client) diff --git a/internal/infrastructure/mongo/mongodoc/auth_request.go b/internal/infrastructure/mongo/mongodoc/auth_request.go new file mode 100644 index 00000000..589e9ceb --- /dev/null +++ b/internal/infrastructure/mongo/mongodoc/auth_request.go @@ -0,0 +1,116 @@ +package mongodoc + +import ( + "time" + + "github.com/caos/oidc/pkg/oidc" + "github.com/reearth/reearth-backend/pkg/auth" + "github.com/reearth/reearth-backend/pkg/id" + "go.mongodb.org/mongo-driver/bson" +) + +type AuthRequestDocument struct { + ID string + ClientID string + Subject string + Code string + State string + ResponseType string + Scopes []string + Audiences []string + RedirectURI string + Nonce string + CodeChallenge *CodeChallengeDocument + CreatedAt time.Time + AuthorizedAt *time.Time +} + +type CodeChallengeDocument struct { + Challenge string + Method string +} + +type AuthRequestConsumer struct { + Rows []*auth.Request +} + +func (a *AuthRequestConsumer) Consume(raw bson.Raw) error { + if raw == nil { + return nil + } + + var doc AuthRequestDocument + if err := bson.Unmarshal(raw, &doc); err != nil { + return err + } + request, err := doc.Model() + if err != nil { + return err + } + a.Rows = append(a.Rows, request) + return nil +} + +func NewAuthRequest(req *auth.Request) (*AuthRequestDocument, string) { + if req == nil { + return nil, "" + } + reqID := req.GetID() + var cc *CodeChallengeDocument + if req.GetCodeChallenge() != nil { + cc = &CodeChallengeDocument{ + Challenge: req.GetCodeChallenge().Challenge, + Method: string(req.GetCodeChallenge().Method), + } + } + return &AuthRequestDocument{ + ID: reqID, + ClientID: req.GetClientID(), + Subject: req.GetSubject(), + Code: req.GetCode(), + State: req.GetState(), + ResponseType: string(req.GetResponseType()), + Scopes: req.GetScopes(), + Audiences: req.GetAudience(), + RedirectURI: req.GetRedirectURI(), + Nonce: req.GetNonce(), + CodeChallenge: cc, + CreatedAt: req.CreatedAt(), + AuthorizedAt: req.AuthorizedAt(), + }, reqID +} + +func (d *AuthRequestDocument) Model() (*auth.Request, error) { + if d == nil { + return nil, nil + } + + uuid, err := id.AuthRequestIDFrom(d.ID) + if err != nil { + return nil, err + } + + var cc *oidc.CodeChallenge + if d.CodeChallenge != nil { + cc = &oidc.CodeChallenge{ + Challenge: d.CodeChallenge.Challenge, + Method: oidc.CodeChallengeMethod(d.CodeChallenge.Method), + } + } + var req = auth.New(). + ID(uuid). + ClientID(d.ClientID). + Subject(d.Subject). + Code(d.Code). + State(d.State). + ResponseType(oidc.ResponseType(d.ResponseType)). + Scopes(d.Scopes). + Audiences(d.Audiences). + RedirectURI(d.RedirectURI). + Nonce(d.Nonce). + CodeChallenge(cc). + CreatedAt(d.CreatedAt). + AuthorizedAt(d.AuthorizedAt). + MustBuild() + return req, nil +} diff --git a/internal/usecase/repo/auth_request.go b/internal/usecase/repo/auth_request.go new file mode 100644 index 00000000..378926bb --- /dev/null +++ b/internal/usecase/repo/auth_request.go @@ -0,0 +1,16 @@ +package repo + +import ( + "context" + + "github.com/reearth/reearth-backend/pkg/auth" + "github.com/reearth/reearth-backend/pkg/id" +) + +type AuthRequest interface { + FindByID(context.Context, id.AuthRequestID) (*auth.Request, error) + FindByCode(context.Context, string) (*auth.Request, error) + FindBySubject(context.Context, string) (*auth.Request, error) + Save(context.Context, *auth.Request) error + Remove(context.Context, id.AuthRequestID) error +} diff --git a/internal/usecase/repo/container.go b/internal/usecase/repo/container.go index f329feb8..40c6c1f9 100644 --- a/internal/usecase/repo/container.go +++ b/internal/usecase/repo/container.go @@ -2,6 +2,7 @@ package repo type Container struct { Asset Asset + AuthRequest AuthRequest Config Config DatasetSchema DatasetSchema Dataset Dataset diff --git a/pkg/auth/builder.go b/pkg/auth/builder.go new file mode 100644 index 00000000..ed91727a --- /dev/null +++ b/pkg/auth/builder.go @@ -0,0 +1,102 @@ +package auth + +import ( + "time" + + "github.com/caos/oidc/pkg/oidc" + "github.com/reearth/reearth-backend/pkg/id" +) + +type Builder struct { + r *Request +} + +func New() *Builder { + return &Builder{r: &Request{}} +} + +func (b *Builder) Build() (*Request, error) { + if id.ID(b.r.id).IsNil() { + return nil, id.ErrInvalidID + } + b.r.createdAt = time.Now() + return b.r, nil +} + +func (b *Builder) MustBuild() *Request { + r, err := b.Build() + if err != nil { + panic(err) + } + return r +} + +func (b *Builder) ID(id id.AuthRequestID) *Builder { + b.r.id = id + return b +} + +func (b *Builder) NewID() *Builder { + b.r.id = id.AuthRequestID(id.New()) + return b +} + +func (b *Builder) ClientID(id string) *Builder { + b.r.clientID = id + return b +} + +func (b *Builder) Subject(subject string) *Builder { + b.r.subject = subject + return b +} + +func (b *Builder) Code(code string) *Builder { + b.r.code = code + return b +} + +func (b *Builder) State(state string) *Builder { + b.r.state = state + return b +} + +func (b *Builder) ResponseType(rt oidc.ResponseType) *Builder { + b.r.responseType = rt + return b +} + +func (b *Builder) Scopes(scopes []string) *Builder { + b.r.scopes = scopes + return b +} + +func (b *Builder) Audiences(audiences []string) *Builder { + b.r.audiences = audiences + return b +} + +func (b *Builder) RedirectURI(redirectURI string) *Builder { + b.r.redirectURI = redirectURI + return b +} + +func (b *Builder) Nonce(nonce string) *Builder { + b.r.nonce = nonce + return b +} + +func (b *Builder) CodeChallenge(CodeChallenge *oidc.CodeChallenge) *Builder { + b.r.codeChallenge = CodeChallenge + return b +} + +func (b *Builder) CreatedAt(createdAt time.Time) *Builder { + b.r.createdAt = createdAt + return b +} + +func (b *Builder) AuthorizedAt(authorizedAt *time.Time) *Builder { + b.r.authorizedAt = authorizedAt + return b +} diff --git a/internal/app/appauth/client.go b/pkg/auth/client.go similarity index 97% rename from internal/app/appauth/client.go rename to pkg/auth/client.go index 7c6a05a0..2c4b2bcb 100644 --- a/internal/app/appauth/client.go +++ b/pkg/auth/client.go @@ -1,4 +1,4 @@ -package appauth +package auth import ( "fmt" @@ -24,7 +24,7 @@ type ConfClient struct { devMode bool } -func initLocalClient(devMode bool) op.Client { +func NewLocalClient(devMode bool) op.Client { return &ConfClient{ ID: "01FH69GFQ4DFCXS5XD91JK4HZ1", applicationType: op.ApplicationTypeWeb, diff --git a/pkg/auth/request.go b/pkg/auth/request.go new file mode 100644 index 00000000..19047fa1 --- /dev/null +++ b/pkg/auth/request.go @@ -0,0 +1,138 @@ +package auth + +import ( + "time" + + "github.com/caos/oidc/pkg/oidc" + "github.com/reearth/reearth-backend/pkg/id" +) + +var essentialScopes = []string{"openid", "profile", "email"} + +type Request struct { + id id.AuthRequestID + clientID string + subject string + code string + state string + responseType oidc.ResponseType + scopes []string + audiences []string + redirectURI string + nonce string + codeChallenge *oidc.CodeChallenge + createdAt time.Time + authorizedAt *time.Time +} + +func (a *Request) ID() id.AuthRequestID { + return a.id +} + +func (a *Request) GetID() string { + return a.id.String() +} + +func (a *Request) GetACR() string { + return "" +} + +func (a *Request) GetAMR() []string { + return []string{ + "password", + } +} + +func (a *Request) GetAudience() []string { + if a.audiences == nil { + return make([]string, 0) + } + + return a.audiences +} + +func (a *Request) GetAuthTime() time.Time { + return a.createdAt +} + +func (a *Request) GetClientID() string { + return a.clientID +} + +func (a *Request) GetCode() string { + return a.code +} + +func (a *Request) GetState() string { + return a.state +} + +func (a *Request) GetCodeChallenge() *oidc.CodeChallenge { + return a.codeChallenge +} + +func (a *Request) GetNonce() string { + return a.nonce +} + +func (a *Request) GetRedirectURI() string { + return a.redirectURI +} + +func (a *Request) GetResponseType() oidc.ResponseType { + return a.responseType +} + +func (a *Request) GetScopes() []string { + return unique(append(a.scopes, essentialScopes...)) +} + +func (a *Request) SetCurrentScopes(scopes []string) { + a.scopes = unique(append(scopes, essentialScopes...)) +} + +func (a *Request) GetSubject() string { + return a.subject +} + +func (a *Request) CreatedAt() time.Time { + return a.createdAt +} + +func (a *Request) SetCreatedAt(createdAt time.Time) { + a.createdAt = createdAt +} + +func (a *Request) AuthorizedAt() *time.Time { + return a.authorizedAt +} + +func (a *Request) SetAuthorizedAt(authorizedAt *time.Time) { + a.authorizedAt = authorizedAt +} + +func (a *Request) Done() bool { + return a.authorizedAt != nil +} + +func (a *Request) Complete(sub string) { + a.subject = sub + now := time.Now() + a.authorizedAt = &now +} + +func (a *Request) SetCode(code string) { + a.code = code +} + +func unique(list []string) []string { + allKeys := make(map[string]struct{}) + var uniqueList []string + for _, item := range list { + if _, ok := allKeys[item]; !ok { + allKeys[item] = struct{}{} + uniqueList = append(uniqueList, item) + } + } + return uniqueList +} diff --git a/pkg/id/auth_request_gen.go b/pkg/id/auth_request_gen.go new file mode 100644 index 00000000..76a36140 --- /dev/null +++ b/pkg/id/auth_request_gen.go @@ -0,0 +1,297 @@ +// Code generated by gen, DO NOT EDIT. + +package id + +import "encoding/json" + +// AuthRequestID is an ID for AuthRequest. +type AuthRequestID ID + +// NewAuthRequestID generates a new AuthRequestId. +func NewAuthRequestID() AuthRequestID { + return AuthRequestID(New()) +} + +// AuthRequestIDFrom generates a new AuthRequestID from a string. +func AuthRequestIDFrom(i string) (nid AuthRequestID, err error) { + var did ID + did, err = FromID(i) + if err != nil { + return + } + nid = AuthRequestID(did) + return +} + +// MustAuthRequestID generates a new AuthRequestID from a string, but panics if the string cannot be parsed. +func MustAuthRequestID(i string) AuthRequestID { + did, err := FromID(i) + if err != nil { + panic(err) + } + return AuthRequestID(did) +} + +// AuthRequestIDFromRef generates a new AuthRequestID from a string ref. +func AuthRequestIDFromRef(i *string) *AuthRequestID { + did := FromIDRef(i) + if did == nil { + return nil + } + nid := AuthRequestID(*did) + return &nid +} + +// AuthRequestIDFromRefID generates a new AuthRequestID from a ref of a generic ID. +func AuthRequestIDFromRefID(i *ID) *AuthRequestID { + if i == nil { + return nil + } + nid := AuthRequestID(*i) + return &nid +} + +// ID returns a domain ID. +func (d AuthRequestID) ID() ID { + return ID(d) +} + +// String returns a string representation. +func (d AuthRequestID) String() string { + return ID(d).String() +} + +// GoString implements fmt.GoStringer interface. +func (d AuthRequestID) GoString() string { + return "id.AuthRequestID(" + d.String() + ")" +} + +// RefString returns a reference of string representation. +func (d AuthRequestID) RefString() *string { + id := ID(d).String() + return &id +} + +// Ref returns a reference. +func (d AuthRequestID) Ref() *AuthRequestID { + d2 := d + return &d2 +} + +// Contains returns whether the id is contained in the slice. +func (d AuthRequestID) Contains(ids []AuthRequestID) bool { + for _, i := range ids { + if d.ID().Equal(i.ID()) { + return true + } + } + return false +} + +// CopyRef returns a copy of a reference. +func (d *AuthRequestID) CopyRef() *AuthRequestID { + if d == nil { + return nil + } + d2 := *d + return &d2 +} + +// IDRef returns a reference of a domain id. +func (d *AuthRequestID) IDRef() *ID { + if d == nil { + return nil + } + id := ID(*d) + return &id +} + +// StringRef returns a reference of a string representation. +func (d *AuthRequestID) StringRef() *string { + if d == nil { + return nil + } + id := ID(*d).String() + return &id +} + +// MarhsalJSON implements json.Marhsaler interface +func (d *AuthRequestID) MarhsalJSON() ([]byte, error) { + return json.Marshal(d.String()) +} + +// UnmarhsalJSON implements json.Unmarshaler interface +func (d *AuthRequestID) UnmarhsalJSON(bs []byte) (err error) { + var idstr string + if err = json.Unmarshal(bs, &idstr); err != nil { + return + } + *d, err = AuthRequestIDFrom(idstr) + return +} + +// MarshalText implements encoding.TextMarshaler interface +func (d *AuthRequestID) MarshalText() ([]byte, error) { + if d == nil { + return nil, nil + } + return []byte(d.String()), nil +} + +// UnmarshalText implements encoding.TextUnmarshaler interface +func (d *AuthRequestID) UnmarshalText(text []byte) (err error) { + *d, err = AuthRequestIDFrom(string(text)) + return +} + +// Ref returns true if a ID is nil or zero-value +func (d AuthRequestID) IsNil() bool { + return ID(d).IsNil() +} + +// AuthRequestIDToKeys converts IDs into a string slice. +func AuthRequestIDToKeys(ids []AuthRequestID) []string { + keys := make([]string, 0, len(ids)) + for _, i := range ids { + keys = append(keys, i.String()) + } + return keys +} + +// AuthRequestIDsFrom converts a string slice into a ID slice. +func AuthRequestIDsFrom(ids []string) ([]AuthRequestID, error) { + dids := make([]AuthRequestID, 0, len(ids)) + for _, i := range ids { + did, err := AuthRequestIDFrom(i) + if err != nil { + return nil, err + } + dids = append(dids, did) + } + return dids, nil +} + +// AuthRequestIDsFromID converts a generic ID slice into a ID slice. +func AuthRequestIDsFromID(ids []ID) []AuthRequestID { + dids := make([]AuthRequestID, 0, len(ids)) + for _, i := range ids { + dids = append(dids, AuthRequestID(i)) + } + return dids +} + +// AuthRequestIDsFromIDRef converts a ref of a generic ID slice into a ID slice. +func AuthRequestIDsFromIDRef(ids []*ID) []AuthRequestID { + dids := make([]AuthRequestID, 0, len(ids)) + for _, i := range ids { + if i != nil { + dids = append(dids, AuthRequestID(*i)) + } + } + return dids +} + +// AuthRequestIDsToID converts a ID slice into a generic ID slice. +func AuthRequestIDsToID(ids []AuthRequestID) []ID { + dids := make([]ID, 0, len(ids)) + for _, i := range ids { + dids = append(dids, i.ID()) + } + return dids +} + +// AuthRequestIDsToIDRef converts a ID ref slice into a generic ID ref slice. +func AuthRequestIDsToIDRef(ids []*AuthRequestID) []*ID { + dids := make([]*ID, 0, len(ids)) + for _, i := range ids { + dids = append(dids, i.IDRef()) + } + return dids +} + +// AuthRequestIDSet represents a set of AuthRequestIDs +type AuthRequestIDSet struct { + m map[AuthRequestID]struct{} + s []AuthRequestID +} + +// NewAuthRequestIDSet creates a new AuthRequestIDSet +func NewAuthRequestIDSet() *AuthRequestIDSet { + return &AuthRequestIDSet{} +} + +// Add adds a new ID if it does not exists in the set +func (s *AuthRequestIDSet) Add(p ...AuthRequestID) { + if s == nil || p == nil { + return + } + if s.m == nil { + s.m = map[AuthRequestID]struct{}{} + } + for _, i := range p { + if _, ok := s.m[i]; !ok { + if s.s == nil { + s.s = []AuthRequestID{} + } + s.m[i] = struct{}{} + s.s = append(s.s, i) + } + } +} + +// AddRef adds a new ID ref if it does not exists in the set +func (s *AuthRequestIDSet) AddRef(p *AuthRequestID) { + if s == nil || p == nil { + return + } + s.Add(*p) +} + +// Has checks if the ID exists in the set +func (s *AuthRequestIDSet) Has(p AuthRequestID) bool { + if s == nil || s.m == nil { + return false + } + _, ok := s.m[p] + return ok +} + +// Clear clears all stored IDs +func (s *AuthRequestIDSet) Clear() { + if s == nil { + return + } + s.m = nil + s.s = nil +} + +// All returns stored all IDs as a slice +func (s *AuthRequestIDSet) All() []AuthRequestID { + if s == nil { + return nil + } + return append([]AuthRequestID{}, s.s...) +} + +// Clone returns a cloned set +func (s *AuthRequestIDSet) Clone() *AuthRequestIDSet { + if s == nil { + return NewAuthRequestIDSet() + } + s2 := NewAuthRequestIDSet() + s2.Add(s.s...) + return s2 +} + +// Merge returns a merged set +func (s *AuthRequestIDSet) Merge(s2 *AuthRequestIDSet) *AuthRequestIDSet { + if s == nil { + return nil + } + s3 := s.Clone() + if s2 == nil { + return s3 + } + s3.Add(s2.s...) + return s3 +} diff --git a/pkg/id/auth_request_gen_test.go b/pkg/id/auth_request_gen_test.go new file mode 100644 index 00000000..5f84e759 --- /dev/null +++ b/pkg/id/auth_request_gen_test.go @@ -0,0 +1,1011 @@ +// Code generated by gen, DO NOT EDIT. + +package id + +import ( + "encoding/json" + "errors" + "testing" + + "github.com/oklog/ulid" + "github.com/stretchr/testify/assert" +) + +func TestNewAuthRequestID(t *testing.T) { + id := NewAuthRequestID() + assert.NotNil(t, id) + ulID, err := ulid.Parse(id.String()) + + assert.NotNil(t, ulID) + assert.Nil(t, err) +} + +func TestAuthRequestIDFrom(t *testing.T) { + t.Parallel() + testCases := []struct { + name string + input string + expected struct { + result AuthRequestID + err error + } + }{ + { + name: "Fail:Not valid string", + input: "testMustFail", + expected: struct { + result AuthRequestID + err error + }{ + AuthRequestID{}, + ErrInvalidID, + }, + }, + { + name: "Fail:Not valid string", + input: "", + expected: struct { + result AuthRequestID + err error + }{ + AuthRequestID{}, + ErrInvalidID, + }, + }, + { + name: "success:valid string", + input: "01f2r7kg1fvvffp0gmexgy5hxy", + expected: struct { + result AuthRequestID + err error + }{ + AuthRequestID{ulid.MustParse("01f2r7kg1fvvffp0gmexgy5hxy")}, + nil, + }, + }, + } + for _, tc := range testCases { + tc := tc + t.Run(tc.name, func(tt *testing.T) { + tt.Parallel() + result, err := AuthRequestIDFrom(tc.input) + assert.Equal(tt, tc.expected.result, result) + if err != nil { + assert.True(tt, errors.As(tc.expected.err, &err)) + } + }) + } +} + +func TestMustAuthRequestID(t *testing.T) { + t.Parallel() + testCases := []struct { + name string + input string + shouldPanic bool + expected AuthRequestID + }{ + { + name: "Fail:Not valid string", + input: "testMustFail", + shouldPanic: true, + }, + { + name: "Fail:Not valid string", + input: "", + shouldPanic: true, + }, + { + name: "success:valid string", + input: "01f2r7kg1fvvffp0gmexgy5hxy", + shouldPanic: false, + expected: AuthRequestID{ulid.MustParse("01f2r7kg1fvvffp0gmexgy5hxy")}, + }, + } + for _, tc := range testCases { + tc := tc + t.Run(tc.name, func(tt *testing.T) { + tt.Parallel() + + if tc.shouldPanic { + assert.Panics(tt, func() { MustBeID(tc.input) }) + return + } + result := MustAuthRequestID(tc.input) + assert.Equal(tt, tc.expected, result) + }) + } +} + +func TestAuthRequestIDFromRef(t *testing.T) { + testCases := []struct { + name string + input string + expected *AuthRequestID + }{ + { + name: "Fail:Not valid string", + input: "testMustFail", + expected: nil, + }, + { + name: "Fail:Not valid string", + input: "", + expected: nil, + }, + { + name: "success:valid string", + input: "01f2r7kg1fvvffp0gmexgy5hxy", + expected: &AuthRequestID{ulid.MustParse("01f2r7kg1fvvffp0gmexgy5hxy")}, + }, + } + for _, tc := range testCases { + tc := tc + t.Run(tc.name, func(tt *testing.T) { + tt.Parallel() + result := AuthRequestIDFromRef(&tc.input) + assert.Equal(tt, tc.expected, result) + if tc.expected != nil { + assert.Equal(tt, *tc.expected, *result) + } + }) + } +} + +func TestAuthRequestIDFromRefID(t *testing.T) { + id := New() + + subId := AuthRequestIDFromRefID(&id) + + assert.NotNil(t, subId) + assert.Equal(t, subId.id, id.id) +} + +func TestAuthRequestID_ID(t *testing.T) { + id := New() + subId := AuthRequestIDFromRefID(&id) + + idOrg := subId.ID() + + assert.Equal(t, id, idOrg) +} + +func TestAuthRequestID_String(t *testing.T) { + id := New() + subId := AuthRequestIDFromRefID(&id) + + assert.Equal(t, subId.String(), id.String()) +} + +func TestAuthRequestID_GoString(t *testing.T) { + id := New() + subId := AuthRequestIDFromRefID(&id) + + assert.Equal(t, subId.GoString(), "id.AuthRequestID("+id.String()+")") +} + +func TestAuthRequestID_RefString(t *testing.T) { + id := New() + subId := AuthRequestIDFromRefID(&id) + + refString := subId.StringRef() + + assert.NotNil(t, refString) + assert.Equal(t, *refString, id.String()) +} + +func TestAuthRequestID_Ref(t *testing.T) { + id := New() + subId := AuthRequestIDFromRefID(&id) + + subIdRef := subId.Ref() + + assert.Equal(t, *subId, *subIdRef) +} + +func TestAuthRequestID_Contains(t *testing.T) { + id := NewAuthRequestID() + id2 := NewAuthRequestID() + assert.True(t, id.Contains([]AuthRequestID{id, id2})) + assert.False(t, id.Contains([]AuthRequestID{id2})) +} + +func TestAuthRequestID_CopyRef(t *testing.T) { + id := New() + subId := AuthRequestIDFromRefID(&id) + + subIdCopyRef := subId.CopyRef() + + assert.Equal(t, *subId, *subIdCopyRef) + assert.NotSame(t, subId, subIdCopyRef) +} + +func TestAuthRequestID_IDRef(t *testing.T) { + id := New() + subId := AuthRequestIDFromRefID(&id) + + assert.Equal(t, id, *subId.IDRef()) +} + +func TestAuthRequestID_StringRef(t *testing.T) { + id := New() + subId := AuthRequestIDFromRefID(&id) + + assert.Equal(t, *subId.StringRef(), id.String()) +} + +func TestAuthRequestID_MarhsalJSON(t *testing.T) { + id := New() + subId := AuthRequestIDFromRefID(&id) + + res, err := subId.MarhsalJSON() + exp, _ := json.Marshal(subId.String()) + + assert.Nil(t, err) + assert.Equal(t, exp, res) +} + +func TestAuthRequestID_UnmarhsalJSON(t *testing.T) { + jsonString := "\"01f3zhkysvcxsnzepyyqtq21fb\"" + + subId := &AuthRequestID{} + + err := subId.UnmarhsalJSON([]byte(jsonString)) + + assert.Nil(t, err) + assert.Equal(t, "01f3zhkysvcxsnzepyyqtq21fb", subId.String()) +} + +func TestAuthRequestID_MarshalText(t *testing.T) { + id := New() + subId := AuthRequestIDFromRefID(&id) + + res, err := subId.MarshalText() + + assert.Nil(t, err) + assert.Equal(t, []byte(id.String()), res) +} + +func TestAuthRequestID_UnmarshalText(t *testing.T) { + text := []byte("01f3zhcaq35403zdjnd6dcm0t2") + + subId := &AuthRequestID{} + + err := subId.UnmarshalText(text) + + assert.Nil(t, err) + assert.Equal(t, "01f3zhcaq35403zdjnd6dcm0t2", subId.String()) + +} + +func TestAuthRequestID_IsNil(t *testing.T) { + subId := AuthRequestID{} + + assert.True(t, subId.IsNil()) + + id := New() + subId = *AuthRequestIDFromRefID(&id) + + assert.False(t, subId.IsNil()) +} + +func TestAuthRequestIDToKeys(t *testing.T) { + t.Parallel() + testCases := []struct { + name string + input []AuthRequestID + expected []string + }{ + { + name: "Empty slice", + input: make([]AuthRequestID, 0), + expected: make([]string, 0), + }, + { + name: "1 element", + input: []AuthRequestID{MustAuthRequestID("01f3zhcaq35403zdjnd6dcm0t2")}, + expected: []string{"01f3zhcaq35403zdjnd6dcm0t2"}, + }, + { + name: "multiple elements", + input: []AuthRequestID{ + MustAuthRequestID("01f3zhcaq35403zdjnd6dcm0t1"), + MustAuthRequestID("01f3zhcaq35403zdjnd6dcm0t2"), + MustAuthRequestID("01f3zhcaq35403zdjnd6dcm0t3"), + }, + expected: []string{ + "01f3zhcaq35403zdjnd6dcm0t1", + "01f3zhcaq35403zdjnd6dcm0t2", + "01f3zhcaq35403zdjnd6dcm0t3", + }, + }, + } + + for _, tc := range testCases { + tc := tc + t.Run(tc.name, func(tt *testing.T) { + tt.Parallel() + assert.Equal(tt, tc.expected, AuthRequestIDToKeys(tc.input)) + }) + } + +} + +func TestAuthRequestIDsFrom(t *testing.T) { + t.Parallel() + testCases := []struct { + name string + input []string + expected struct { + res []AuthRequestID + err error + } + }{ + { + name: "Empty slice", + input: make([]string, 0), + expected: struct { + res []AuthRequestID + err error + }{ + res: make([]AuthRequestID, 0), + err: nil, + }, + }, + { + name: "1 element", + input: []string{"01f3zhcaq35403zdjnd6dcm0t2"}, + expected: struct { + res []AuthRequestID + err error + }{ + res: []AuthRequestID{MustAuthRequestID("01f3zhcaq35403zdjnd6dcm0t2")}, + err: nil, + }, + }, + { + name: "multiple elements", + input: []string{ + "01f3zhcaq35403zdjnd6dcm0t1", + "01f3zhcaq35403zdjnd6dcm0t2", + "01f3zhcaq35403zdjnd6dcm0t3", + }, + expected: struct { + res []AuthRequestID + err error + }{ + res: []AuthRequestID{ + MustAuthRequestID("01f3zhcaq35403zdjnd6dcm0t1"), + MustAuthRequestID("01f3zhcaq35403zdjnd6dcm0t2"), + MustAuthRequestID("01f3zhcaq35403zdjnd6dcm0t3"), + }, + err: nil, + }, + }, + { + name: "multiple elements", + input: []string{ + "01f3zhcaq35403zdjnd6dcm0t1", + "01f3zhcaq35403zdjnd6dcm0t2", + "01f3zhcaq35403zdjnd6dcm0t3", + }, + expected: struct { + res []AuthRequestID + err error + }{ + res: nil, + err: ErrInvalidID, + }, + }, + } + + for _, tc := range testCases { + tc := tc + t.Run(tc.name, func(tt *testing.T) { + tt.Parallel() + + if tc.expected.err != nil { + _, err := AuthRequestIDsFrom(tc.input) + assert.True(tt, errors.As(ErrInvalidID, &err)) + } else { + res, err := AuthRequestIDsFrom(tc.input) + assert.Equal(tt, tc.expected.res, res) + assert.Nil(tt, err) + } + + }) + } +} + +func TestAuthRequestIDsFromID(t *testing.T) { + t.Parallel() + testCases := []struct { + name string + input []ID + expected []AuthRequestID + }{ + { + name: "Empty slice", + input: make([]ID, 0), + expected: make([]AuthRequestID, 0), + }, + { + name: "1 element", + input: []ID{MustBeID("01f3zhcaq35403zdjnd6dcm0t2")}, + expected: []AuthRequestID{MustAuthRequestID("01f3zhcaq35403zdjnd6dcm0t2")}, + }, + { + name: "multiple elements", + input: []ID{ + MustBeID("01f3zhcaq35403zdjnd6dcm0t1"), + MustBeID("01f3zhcaq35403zdjnd6dcm0t2"), + MustBeID("01f3zhcaq35403zdjnd6dcm0t3"), + }, + expected: []AuthRequestID{ + MustAuthRequestID("01f3zhcaq35403zdjnd6dcm0t1"), + MustAuthRequestID("01f3zhcaq35403zdjnd6dcm0t2"), + MustAuthRequestID("01f3zhcaq35403zdjnd6dcm0t3"), + }, + }, + } + + for _, tc := range testCases { + tc := tc + t.Run(tc.name, func(tt *testing.T) { + tt.Parallel() + + res := AuthRequestIDsFromID(tc.input) + assert.Equal(tt, tc.expected, res) + }) + } +} + +func TestAuthRequestIDsFromIDRef(t *testing.T) { + t.Parallel() + + id1 := MustBeID("01f3zhcaq35403zdjnd6dcm0t1") + id2 := MustBeID("01f3zhcaq35403zdjnd6dcm0t2") + id3 := MustBeID("01f3zhcaq35403zdjnd6dcm0t3") + + testCases := []struct { + name string + input []*ID + expected []AuthRequestID + }{ + { + name: "Empty slice", + input: make([]*ID, 0), + expected: make([]AuthRequestID, 0), + }, + { + name: "1 element", + input: []*ID{&id1}, + expected: []AuthRequestID{MustAuthRequestID(id1.String())}, + }, + { + name: "multiple elements", + input: []*ID{&id1, &id2, &id3}, + expected: []AuthRequestID{ + MustAuthRequestID(id1.String()), + MustAuthRequestID(id2.String()), + MustAuthRequestID(id3.String()), + }, + }, + } + + for _, tc := range testCases { + tc := tc + t.Run(tc.name, func(tt *testing.T) { + tt.Parallel() + + res := AuthRequestIDsFromIDRef(tc.input) + assert.Equal(tt, tc.expected, res) + }) + } +} + +func TestAuthRequestIDsToID(t *testing.T) { + t.Parallel() + + testCases := []struct { + name string + input []AuthRequestID + expected []ID + }{ + { + name: "Empty slice", + input: make([]AuthRequestID, 0), + expected: make([]ID, 0), + }, + { + name: "1 element", + input: []AuthRequestID{MustAuthRequestID("01f3zhcaq35403zdjnd6dcm0t2")}, + expected: []ID{MustBeID("01f3zhcaq35403zdjnd6dcm0t2")}, + }, + { + name: "multiple elements", + input: []AuthRequestID{ + MustAuthRequestID("01f3zhcaq35403zdjnd6dcm0t1"), + MustAuthRequestID("01f3zhcaq35403zdjnd6dcm0t2"), + MustAuthRequestID("01f3zhcaq35403zdjnd6dcm0t3"), + }, + expected: []ID{ + MustBeID("01f3zhcaq35403zdjnd6dcm0t1"), + MustBeID("01f3zhcaq35403zdjnd6dcm0t2"), + MustBeID("01f3zhcaq35403zdjnd6dcm0t3"), + }, + }, + } + + for _, tc := range testCases { + tc := tc + t.Run(tc.name, func(tt *testing.T) { + tt.Parallel() + + res := AuthRequestIDsToID(tc.input) + assert.Equal(tt, tc.expected, res) + }) + } +} + +func TestAuthRequestIDsToIDRef(t *testing.T) { + t.Parallel() + + id1 := MustBeID("01f3zhcaq35403zdjnd6dcm0t1") + subId1 := MustAuthRequestID(id1.String()) + id2 := MustBeID("01f3zhcaq35403zdjnd6dcm0t2") + subId2 := MustAuthRequestID(id2.String()) + id3 := MustBeID("01f3zhcaq35403zdjnd6dcm0t3") + subId3 := MustAuthRequestID(id3.String()) + + testCases := []struct { + name string + input []*AuthRequestID + expected []*ID + }{ + { + name: "Empty slice", + input: make([]*AuthRequestID, 0), + expected: make([]*ID, 0), + }, + { + name: "1 element", + input: []*AuthRequestID{&subId1}, + expected: []*ID{&id1}, + }, + { + name: "multiple elements", + input: []*AuthRequestID{&subId1, &subId2, &subId3}, + expected: []*ID{&id1, &id2, &id3}, + }, + } + + for _, tc := range testCases { + tc := tc + t.Run(tc.name, func(tt *testing.T) { + tt.Parallel() + + res := AuthRequestIDsToIDRef(tc.input) + assert.Equal(tt, tc.expected, res) + }) + } +} + +func TestNewAuthRequestIDSet(t *testing.T) { + AuthRequestIdSet := NewAuthRequestIDSet() + + assert.NotNil(t, AuthRequestIdSet) + assert.Empty(t, AuthRequestIdSet.m) + assert.Empty(t, AuthRequestIdSet.s) +} + +func TestAuthRequestIDSet_Add(t *testing.T) { + t.Parallel() + + testCases := []struct { + name string + input []AuthRequestID + expected *AuthRequestIDSet + }{ + { + name: "Empty slice", + input: make([]AuthRequestID, 0), + expected: &AuthRequestIDSet{ + m: map[AuthRequestID]struct{}{}, + s: nil, + }, + }, + { + name: "1 element", + input: []AuthRequestID{MustAuthRequestID("01f3zhcaq35403zdjnd6dcm0t1")}, + expected: &AuthRequestIDSet{ + m: map[AuthRequestID]struct{}{MustAuthRequestID("01f3zhcaq35403zdjnd6dcm0t1"): {}}, + s: []AuthRequestID{MustAuthRequestID("01f3zhcaq35403zdjnd6dcm0t1")}, + }, + }, + { + name: "multiple elements", + input: []AuthRequestID{ + MustAuthRequestID("01f3zhcaq35403zdjnd6dcm0t1"), + MustAuthRequestID("01f3zhcaq35403zdjnd6dcm0t2"), + MustAuthRequestID("01f3zhcaq35403zdjnd6dcm0t3"), + }, + expected: &AuthRequestIDSet{ + m: map[AuthRequestID]struct{}{ + MustAuthRequestID("01f3zhcaq35403zdjnd6dcm0t1"): {}, + MustAuthRequestID("01f3zhcaq35403zdjnd6dcm0t2"): {}, + MustAuthRequestID("01f3zhcaq35403zdjnd6dcm0t3"): {}, + }, + s: []AuthRequestID{ + MustAuthRequestID("01f3zhcaq35403zdjnd6dcm0t1"), + MustAuthRequestID("01f3zhcaq35403zdjnd6dcm0t2"), + MustAuthRequestID("01f3zhcaq35403zdjnd6dcm0t3"), + }, + }, + }, + { + name: "multiple elements with duplication", + input: []AuthRequestID{ + MustAuthRequestID("01f3zhcaq35403zdjnd6dcm0t1"), + MustAuthRequestID("01f3zhcaq35403zdjnd6dcm0t1"), + MustAuthRequestID("01f3zhcaq35403zdjnd6dcm0t3"), + }, + expected: &AuthRequestIDSet{ + m: map[AuthRequestID]struct{}{ + MustAuthRequestID("01f3zhcaq35403zdjnd6dcm0t1"): {}, + MustAuthRequestID("01f3zhcaq35403zdjnd6dcm0t3"): {}, + }, + s: []AuthRequestID{ + MustAuthRequestID("01f3zhcaq35403zdjnd6dcm0t1"), + MustAuthRequestID("01f3zhcaq35403zdjnd6dcm0t3"), + }, + }, + }, + } + + for _, tc := range testCases { + tc := tc + t.Run(tc.name, func(tt *testing.T) { + tt.Parallel() + + set := NewAuthRequestIDSet() + set.Add(tc.input...) + assert.Equal(tt, tc.expected, set) + }) + } +} + +func TestAuthRequestIDSet_AddRef(t *testing.T) { + t.Parallel() + + AuthRequestId := MustAuthRequestID("01f3zhcaq35403zdjnd6dcm0t1") + + testCases := []struct { + name string + input *AuthRequestID + expected *AuthRequestIDSet + }{ + { + name: "Empty slice", + input: nil, + expected: &AuthRequestIDSet{ + m: nil, + s: nil, + }, + }, + { + name: "1 element", + input: &AuthRequestId, + expected: &AuthRequestIDSet{ + m: map[AuthRequestID]struct{}{MustAuthRequestID("01f3zhcaq35403zdjnd6dcm0t1"): {}}, + s: []AuthRequestID{MustAuthRequestID("01f3zhcaq35403zdjnd6dcm0t1")}, + }, + }, + } + + for _, tc := range testCases { + tc := tc + t.Run(tc.name, func(tt *testing.T) { + tt.Parallel() + + set := NewAuthRequestIDSet() + set.AddRef(tc.input) + assert.Equal(tt, tc.expected, set) + }) + } +} + +func TestAuthRequestIDSet_Has(t *testing.T) { + t.Parallel() + + testCases := []struct { + name string + input struct { + AuthRequestIDSet + AuthRequestID + } + expected bool + }{ + { + name: "Empty Set", + input: struct { + AuthRequestIDSet + AuthRequestID + }{AuthRequestIDSet: AuthRequestIDSet{}, AuthRequestID: MustAuthRequestID("01f3zhcaq35403zdjnd6dcm0t1")}, + expected: false, + }, + { + name: "Set Contains the element", + input: struct { + AuthRequestIDSet + AuthRequestID + }{AuthRequestIDSet: AuthRequestIDSet{ + m: map[AuthRequestID]struct{}{MustAuthRequestID("01f3zhcaq35403zdjnd6dcm0t1"): {}}, + s: []AuthRequestID{MustAuthRequestID("01f3zhcaq35403zdjnd6dcm0t1")}, + }, AuthRequestID: MustAuthRequestID("01f3zhcaq35403zdjnd6dcm0t1")}, + expected: true, + }, + { + name: "Set does not Contains the element", + input: struct { + AuthRequestIDSet + AuthRequestID + }{AuthRequestIDSet: AuthRequestIDSet{ + m: map[AuthRequestID]struct{}{MustAuthRequestID("01f3zhcaq35403zdjnd6dcm0t1"): {}}, + s: []AuthRequestID{MustAuthRequestID("01f3zhcaq35403zdjnd6dcm0t1")}, + }, AuthRequestID: MustAuthRequestID("01f3zhcaq35403zdjnd6dcm0t2")}, + expected: false, + }, + } + + for _, tc := range testCases { + tc := tc + t.Run(tc.name, func(tt *testing.T) { + tt.Parallel() + assert.Equal(tt, tc.expected, tc.input.AuthRequestIDSet.Has(tc.input.AuthRequestID)) + }) + } +} + +func TestAuthRequestIDSet_Clear(t *testing.T) { + t.Parallel() + + testCases := []struct { + name string + input AuthRequestIDSet + expected AuthRequestIDSet + }{ + { + name: "Empty Set", + input: AuthRequestIDSet{}, + expected: AuthRequestIDSet{ + m: nil, + s: nil, + }, + }, + { + name: "Set Contains the element", + input: AuthRequestIDSet{ + m: map[AuthRequestID]struct{}{MustAuthRequestID("01f3zhcaq35403zdjnd6dcm0t1"): {}}, + s: []AuthRequestID{MustAuthRequestID("01f3zhcaq35403zdjnd6dcm0t1")}, + }, + expected: AuthRequestIDSet{ + m: nil, + s: nil, + }, + }, + } + + for _, tc := range testCases { + tc := tc + t.Run(tc.name, func(tt *testing.T) { + tt.Parallel() + set := tc.input + p := &set + p.Clear() + assert.Equal(tt, tc.expected, *p) + }) + } +} + +func TestAuthRequestIDSet_All(t *testing.T) { + t.Parallel() + + testCases := []struct { + name string + input *AuthRequestIDSet + expected []AuthRequestID + }{ + { + name: "Empty slice", + input: &AuthRequestIDSet{ + m: map[AuthRequestID]struct{}{}, + s: nil, + }, + expected: make([]AuthRequestID, 0), + }, + { + name: "1 element", + input: &AuthRequestIDSet{ + m: map[AuthRequestID]struct{}{MustAuthRequestID("01f3zhcaq35403zdjnd6dcm0t1"): {}}, + s: []AuthRequestID{MustAuthRequestID("01f3zhcaq35403zdjnd6dcm0t1")}, + }, + expected: []AuthRequestID{MustAuthRequestID("01f3zhcaq35403zdjnd6dcm0t1")}, + }, + { + name: "multiple elements", + input: &AuthRequestIDSet{ + m: map[AuthRequestID]struct{}{ + MustAuthRequestID("01f3zhcaq35403zdjnd6dcm0t1"): {}, + MustAuthRequestID("01f3zhcaq35403zdjnd6dcm0t2"): {}, + MustAuthRequestID("01f3zhcaq35403zdjnd6dcm0t3"): {}, + }, + s: []AuthRequestID{ + MustAuthRequestID("01f3zhcaq35403zdjnd6dcm0t1"), + MustAuthRequestID("01f3zhcaq35403zdjnd6dcm0t2"), + MustAuthRequestID("01f3zhcaq35403zdjnd6dcm0t3"), + }, + }, + expected: []AuthRequestID{ + MustAuthRequestID("01f3zhcaq35403zdjnd6dcm0t1"), + MustAuthRequestID("01f3zhcaq35403zdjnd6dcm0t2"), + MustAuthRequestID("01f3zhcaq35403zdjnd6dcm0t3"), + }, + }, + } + + for _, tc := range testCases { + tc := tc + t.Run(tc.name, func(tt *testing.T) { + tt.Parallel() + + assert.Equal(tt, tc.expected, tc.input.All()) + }) + } +} + +func TestAuthRequestIDSet_Clone(t *testing.T) { + t.Parallel() + + testCases := []struct { + name string + input *AuthRequestIDSet + expected *AuthRequestIDSet + }{ + { + name: "nil set", + input: nil, + expected: NewAuthRequestIDSet(), + }, + { + name: "Empty set", + input: NewAuthRequestIDSet(), + expected: NewAuthRequestIDSet(), + }, + { + name: "1 element", + input: &AuthRequestIDSet{ + m: map[AuthRequestID]struct{}{MustAuthRequestID("01f3zhcaq35403zdjnd6dcm0t1"): {}}, + s: []AuthRequestID{MustAuthRequestID("01f3zhcaq35403zdjnd6dcm0t1")}, + }, + expected: &AuthRequestIDSet{ + m: map[AuthRequestID]struct{}{MustAuthRequestID("01f3zhcaq35403zdjnd6dcm0t1"): {}}, + s: []AuthRequestID{MustAuthRequestID("01f3zhcaq35403zdjnd6dcm0t1")}, + }, + }, + { + name: "multiple elements", + input: &AuthRequestIDSet{ + m: map[AuthRequestID]struct{}{ + MustAuthRequestID("01f3zhcaq35403zdjnd6dcm0t1"): {}, + MustAuthRequestID("01f3zhcaq35403zdjnd6dcm0t2"): {}, + MustAuthRequestID("01f3zhcaq35403zdjnd6dcm0t3"): {}, + }, + s: []AuthRequestID{ + MustAuthRequestID("01f3zhcaq35403zdjnd6dcm0t1"), + MustAuthRequestID("01f3zhcaq35403zdjnd6dcm0t2"), + MustAuthRequestID("01f3zhcaq35403zdjnd6dcm0t3"), + }, + }, + expected: &AuthRequestIDSet{ + m: map[AuthRequestID]struct{}{ + MustAuthRequestID("01f3zhcaq35403zdjnd6dcm0t1"): {}, + MustAuthRequestID("01f3zhcaq35403zdjnd6dcm0t2"): {}, + MustAuthRequestID("01f3zhcaq35403zdjnd6dcm0t3"): {}, + }, + s: []AuthRequestID{ + MustAuthRequestID("01f3zhcaq35403zdjnd6dcm0t1"), + MustAuthRequestID("01f3zhcaq35403zdjnd6dcm0t2"), + MustAuthRequestID("01f3zhcaq35403zdjnd6dcm0t3"), + }, + }, + }, + } + + for _, tc := range testCases { + tc := tc + t.Run(tc.name, func(tt *testing.T) { + tt.Parallel() + clone := tc.input.Clone() + assert.Equal(tt, tc.expected, clone) + assert.False(tt, tc.input == clone) + }) + } +} + +func TestAuthRequestIDSet_Merge(t *testing.T) { + t.Parallel() + + testCases := []struct { + name string + input struct { + a *AuthRequestIDSet + b *AuthRequestIDSet + } + expected *AuthRequestIDSet + }{ + { + name: "Empty Set", + input: struct { + a *AuthRequestIDSet + b *AuthRequestIDSet + }{ + a: &AuthRequestIDSet{}, + b: &AuthRequestIDSet{}, + }, + expected: &AuthRequestIDSet{}, + }, + { + name: "1 Empty Set", + input: struct { + a *AuthRequestIDSet + b *AuthRequestIDSet + }{ + a: &AuthRequestIDSet{ + m: map[AuthRequestID]struct{}{MustAuthRequestID("01f3zhcaq35403zdjnd6dcm0t1"): {}}, + s: []AuthRequestID{MustAuthRequestID("01f3zhcaq35403zdjnd6dcm0t1")}, + }, + b: &AuthRequestIDSet{}, + }, + expected: &AuthRequestIDSet{ + m: map[AuthRequestID]struct{}{MustAuthRequestID("01f3zhcaq35403zdjnd6dcm0t1"): {}}, + s: []AuthRequestID{MustAuthRequestID("01f3zhcaq35403zdjnd6dcm0t1")}, + }, + }, + { + name: "2 non Empty Set", + input: struct { + a *AuthRequestIDSet + b *AuthRequestIDSet + }{ + a: &AuthRequestIDSet{ + m: map[AuthRequestID]struct{}{MustAuthRequestID("01f3zhcaq35403zdjnd6dcm0t1"): {}}, + s: []AuthRequestID{MustAuthRequestID("01f3zhcaq35403zdjnd6dcm0t1")}, + }, + b: &AuthRequestIDSet{ + m: map[AuthRequestID]struct{}{MustAuthRequestID("01f3zhcaq35403zdjnd6dcm0t2"): {}}, + s: []AuthRequestID{MustAuthRequestID("01f3zhcaq35403zdjnd6dcm0t2")}, + }, + }, + expected: &AuthRequestIDSet{ + m: map[AuthRequestID]struct{}{ + MustAuthRequestID("01f3zhcaq35403zdjnd6dcm0t1"): {}, + MustAuthRequestID("01f3zhcaq35403zdjnd6dcm0t2"): {}, + }, + s: []AuthRequestID{ + MustAuthRequestID("01f3zhcaq35403zdjnd6dcm0t1"), + MustAuthRequestID("01f3zhcaq35403zdjnd6dcm0t2"), + }, + }, + }, + } + + for _, tc := range testCases { + tc := tc + t.Run(tc.name, func(tt *testing.T) { + tt.Parallel() + + assert.Equal(tt, tc.expected, tc.input.a.Merge(tc.input.b)) + }) + } +} diff --git a/pkg/id/gen.go b/pkg/id/gen.go index 54d0703b..66d9c594 100644 --- a/pkg/id/gen.go +++ b/pkg/id/gen.go @@ -11,6 +11,7 @@ //go:generate go run github.com/reearth/reearth-backend/tools/cmd/gen --template=id.tmpl --output=user_gen.go --name=User //go:generate go run github.com/reearth/reearth-backend/tools/cmd/gen --template=id.tmpl --output=dataset_schema_field_gen.go --name=DatasetSchemaField //go:generate go run github.com/reearth/reearth-backend/tools/cmd/gen --template=id.tmpl --output=infobox_field_gen.go --name=InfoboxField +//go:generate go run github.com/reearth/reearth-backend/tools/cmd/gen --template=id.tmpl --output=auth_request_gen.go --name=AuthRequest // Testing //go:generate go run github.com/reearth/reearth-backend/tools/cmd/gen --template=id_test.tmpl --output=asset_gen_test.go --name=Asset @@ -26,5 +27,6 @@ //go:generate go run github.com/reearth/reearth-backend/tools/cmd/gen --template=id_test.tmpl --output=user_gen_test.go --name=User //go:generate go run github.com/reearth/reearth-backend/tools/cmd/gen --template=id_test.tmpl --output=dataset_schema_field_gen_test.go --name=DatasetSchemaField //go:generate go run github.com/reearth/reearth-backend/tools/cmd/gen --template=id_test.tmpl --output=infobox_field_gen_test.go --name=InfoboxField +//go:generate go run github.com/reearth/reearth-backend/tools/cmd/gen --template=id_test.tmpl --output=auth_request_gen_test.go --name=AuthRequest package id From 0198c371264753011b136b98585390657a90c646 Mon Sep 17 00:00:00 2001 From: yk Date: Mon, 29 Nov 2021 13:32:52 +0300 Subject: [PATCH 2/6] - fix go mod --- go.mod | 85 +++++++++++++++++++++++++++++++++++++++++++--------------- 1 file changed, 63 insertions(+), 22 deletions(-) diff --git a/go.mod b/go.mod index bf2d714c..ef1d9101 100644 --- a/go.mod +++ b/go.mod @@ -6,21 +6,13 @@ require ( github.com/99designs/gqlgen v0.14.0 github.com/99designs/gqlgen-contrib v0.1.1-0.20200601100547-7a955d321bbd github.com/GoogleCloudPlatform/opentelemetry-operations-go/exporter/trace v0.2.0 - github.com/HdrHistogram/hdrhistogram-go v1.0.1 // indirect - github.com/agnivade/levenshtein v1.1.1 // indirect - github.com/alecthomas/units v0.0.0-20210912230133-d1bdfacee922 // indirect github.com/auth0/go-jwt-middleware v0.0.0-20200507191422-d30d7b9ece63 github.com/blang/semver v3.5.1+incompatible github.com/caos/oidc v0.15.11 - github.com/cpuguy83/go-md2man/v2 v2.0.1 // indirect github.com/dgrijalva/jwt-go v3.2.0+incompatible - github.com/dgryski/trifles v0.0.0-20200705224438-cafc02a1ee2b // indirect - github.com/fatih/color v1.12.0 // indirect - github.com/gedex/inflector v0.0.0-20170307190818-16278e9db813 // indirect github.com/goccy/go-yaml v1.9.2 github.com/golang/gddo v0.0.0-20210115222349-20d68f94ee1f github.com/gorilla/mux v1.8.0 - github.com/hashicorp/golang-lru v0.5.4 // indirect github.com/iancoleman/strcase v0.1.3 github.com/idubinskiy/schematyper v0.0.0-20190118213059-f71b40dac30d github.com/jarcoal/httpmock v1.0.8 @@ -28,47 +20,96 @@ require ( github.com/jonas-p/go-shp v0.1.1 github.com/kelseyhightower/envconfig v1.4.0 github.com/kennygrant/sanitize v1.2.4 - github.com/klauspost/compress v1.10.10 // indirect github.com/labstack/echo/v4 v4.2.1 github.com/labstack/gommon v0.3.0 - github.com/mattn/go-isatty v0.0.13 // indirect github.com/mitchellh/mapstructure v1.4.2 github.com/oklog/ulid v1.3.1 - github.com/opentracing/opentracing-go v1.2.0 // indirect github.com/paulmach/go.geojson v1.4.0 github.com/pkg/errors v0.9.1 github.com/sendgrid/sendgrid-go v3.10.3+incompatible github.com/sirupsen/logrus v1.8.1 - github.com/smartystreets/assertions v1.1.1 // indirect github.com/spf13/afero v1.6.0 - github.com/stretchr/objx v0.2.0 // indirect github.com/stretchr/testify v1.7.0 - github.com/tidwall/pretty v1.0.1 // indirect github.com/twpayne/go-kml v1.5.2 github.com/uber/jaeger-client-go v2.25.0+incompatible github.com/uber/jaeger-lib v2.4.1+incompatible - github.com/urfave/cli/v2 v2.3.0 // indirect github.com/vektah/dataloaden v0.3.0 github.com/vektah/gqlparser/v2 v2.2.0 go.mongodb.org/mongo-driver v1.5.1 go.opentelemetry.io/contrib/instrumentation/github.com/labstack/echo v0.0.0-20200707171851-ae0d272a2deb go.opentelemetry.io/contrib/instrumentation/go.mongodb.org/mongo-driver v0.7.0 go.opentelemetry.io/otel v0.7.0 - go.uber.org/atomic v1.7.0 // indirect - golang.org/x/crypto v0.0.0-20201203163018-be400aefbc4c // indirect - golang.org/x/mod v0.5.0 // indirect - golang.org/x/sys v0.0.0-20210915083310-ed5796bab164 // indirect golang.org/x/text v0.3.7 - golang.org/x/time v0.0.0-20210220033141-f8bda1e9f3ba // indirect golang.org/x/tools v0.1.5 google.golang.org/api v0.51.0 gopkg.in/go-playground/colors.v1 v1.2.0 gopkg.in/h2non/gock.v1 v1.1.0 gopkg.in/square/go-jose.v2 v2.6.0 +) + +require ( + github.com/HdrHistogram/hdrhistogram-go v1.0.1 // indirect + github.com/agnivade/levenshtein v1.1.1 // indirect + github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751 // indirect + github.com/alecthomas/units v0.0.0-20210912230133-d1bdfacee922 // indirect + github.com/aws/aws-sdk-go v1.34.28 // indirect + github.com/caos/logging v0.0.2 // indirect + github.com/cpuguy83/go-md2man/v2 v2.0.1 // indirect + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/dgryski/trifles v0.0.0-20200705224438-cafc02a1ee2b // indirect + github.com/fatih/color v1.12.0 // indirect + github.com/felixge/httpsnoop v1.0.1 // indirect + github.com/gedex/inflector v0.0.0-20170307190818-16278e9db813 // indirect + github.com/go-stack/stack v1.8.0 // indirect + github.com/golang/groupcache v0.0.0-20200121045136-8c9f03a8e57e // indirect + github.com/golang/protobuf v1.5.2 // indirect + github.com/golang/snappy v0.0.3 // indirect + github.com/google/go-cmp v0.5.6 // indirect + github.com/google/pprof v0.0.0-20210609004039-a478d1d731e9 // indirect + github.com/googleapis/gax-go/v2 v2.0.5 // indirect + github.com/gorilla/handlers v1.5.1 // indirect + github.com/gorilla/schema v1.2.0 // indirect + github.com/gorilla/securecookie v1.1.1 // indirect + github.com/gorilla/websocket v1.4.2 // indirect + github.com/h2non/parth v0.0.0-20190131123155-b4df798d6542 // indirect + github.com/hashicorp/golang-lru v0.5.4 // indirect + github.com/jmespath/go-jmespath v0.4.0 // indirect + github.com/jstemmer/go-junit-report v0.9.1 // indirect + github.com/klauspost/compress v1.10.10 // indirect + github.com/mattn/go-colorable v0.1.8 // indirect + github.com/mattn/go-isatty v0.0.13 // indirect + github.com/opentracing/opentracing-go v1.2.0 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/russross/blackfriday/v2 v2.1.0 // indirect + github.com/sendgrid/rest v2.6.5+incompatible // indirect + github.com/smartystreets/assertions v1.1.1 // indirect + github.com/stretchr/objx v0.2.0 // indirect + github.com/tidwall/pretty v1.0.1 // indirect + github.com/urfave/cli/v2 v2.3.0 // indirect + github.com/valyala/bytebufferpool v1.0.0 // indirect + github.com/valyala/fasttemplate v1.2.1 // indirect + github.com/xdg-go/pbkdf2 v1.0.0 // indirect + github.com/xdg-go/scram v1.0.2 // indirect + github.com/xdg-go/stringprep v1.0.2 // indirect + github.com/youmark/pkcs8 v0.0.0-20181117223130-1be2e3e5546d // indirect + go.opencensus.io v0.23.0 // indirect + go.uber.org/atomic v1.7.0 // indirect + golang.org/x/crypto v0.0.0-20201203163018-be400aefbc4c // indirect + golang.org/x/lint v0.0.0-20210508222113-6edffad5e616 // indirect + golang.org/x/mod v0.5.0 // indirect + golang.org/x/net v0.0.0-20210503060351-7fd8e65b6420 // indirect + golang.org/x/oauth2 v0.0.0-20210628180205-a41e5a781914 // indirect + golang.org/x/sync v0.0.0-20210220032951-036812b2e83c // indirect + golang.org/x/sys v0.0.0-20210915083310-ed5796bab164 // indirect + golang.org/x/time v0.0.0-20210220033141-f8bda1e9f3ba // indirect + golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 // indirect + google.golang.org/appengine v1.6.7 // indirect + google.golang.org/genproto v0.0.0-20210716133855-ce7ef5c701ea // indirect + google.golang.org/grpc v1.39.0 // indirect + google.golang.org/protobuf v1.27.1 // indirect + gopkg.in/alecthomas/kingpin.v2 v2.2.6 // indirect gopkg.in/yaml.v2 v2.4.0 // indirect gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b // indirect ) -require github.com/sendgrid/rest v2.6.5+incompatible // indirect - go 1.17 From 08397f1d188efd14d9de1dde7350e794cfdf22ed Mon Sep 17 00:00:00 2001 From: yk Date: Mon, 29 Nov 2021 21:02:19 +0300 Subject: [PATCH 3/6] - move auth storage logic to usecase\interactor --- internal/app/{auth.go => auth_client.go} | 0 internal/app/auth_server.go | 11 +++++------ .../appauth/storage.go => usecase/interactor/auth.go} | 2 +- 3 files changed, 6 insertions(+), 7 deletions(-) rename internal/app/{auth.go => auth_client.go} (100%) rename internal/{app/appauth/storage.go => usecase/interactor/auth.go} (99%) diff --git a/internal/app/auth.go b/internal/app/auth_client.go similarity index 100% rename from internal/app/auth.go rename to internal/app/auth_client.go diff --git a/internal/app/auth_server.go b/internal/app/auth_server.go index 2a926dbc..7bf90c33 100644 --- a/internal/app/auth_server.go +++ b/internal/app/auth_server.go @@ -12,7 +12,6 @@ import ( "github.com/golang/gddo/httputil/header" "github.com/gorilla/mux" "github.com/labstack/echo/v4" - "github.com/reearth/reearth-backend/internal/app/appauth" "github.com/reearth/reearth-backend/internal/usecase/interactor" "github.com/reearth/reearth-backend/internal/usecase/interfaces" ) @@ -39,9 +38,9 @@ func authEndPoints(ctx context.Context, e *echo.Echo, r *echo.Group, cfg *Server GrantTypeRefreshToken: true, } - var dn *appauth.AuthDNConfig = nil + var dn *interactor.AuthDNConfig = nil if cfg.Config.Auth.DN != nil { - dn = &appauth.AuthDNConfig{ + dn = &interactor.AuthDNConfig{ CommonName: cfg.Config.Auth.DN.CN, Organization: cfg.Config.Auth.DN.O, OrganizationalUnit: cfg.Config.Auth.DN.OU, @@ -53,8 +52,8 @@ func authEndPoints(ctx context.Context, e *echo.Echo, r *echo.Group, cfg *Server } } - storage := appauth.NewAuthStorage( - &appauth.StorageConfig{ + storage := interactor.NewAuthStorage( + &interactor.StorageConfig{ Domain: domain.String(), Debug: cfg.Debug, DN: dn, @@ -203,7 +202,7 @@ func login(ctx context.Context, cfg *ServerConfig, storage op.Storage, userUseca } // Complete the auth request && set the subject - err = storage.(*appauth.Storage).CompleteAuthRequest(ctx, request.AuthRequestID, user.GetAuthByProvider("auth0").Sub) + err = storage.(*interactor.Storage).CompleteAuthRequest(ctx, request.AuthRequestID, user.GetAuthByProvider("auth0").Sub) if err != nil { ec.Logger().Error("failed to complete the auth request !") return ec.Redirect(http.StatusFound, redirectURL(authRequest.GetRedirectURI(), !cfg.Debug, request.AuthRequestID, "invalid login")) diff --git a/internal/app/appauth/storage.go b/internal/usecase/interactor/auth.go similarity index 99% rename from internal/app/appauth/storage.go rename to internal/usecase/interactor/auth.go index a4f6d747..2605f44b 100644 --- a/internal/app/appauth/storage.go +++ b/internal/usecase/interactor/auth.go @@ -1,4 +1,4 @@ -package appauth +package interactor import ( "context" From a982067a952fdd9cfbe474eb6ed834bd53bdc2e2 Mon Sep 17 00:00:00 2001 From: yk Date: Mon, 29 Nov 2021 21:09:06 +0300 Subject: [PATCH 4/6] - fix lint issue --- internal/infrastructure/mongo/auth_request.go | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/internal/infrastructure/mongo/auth_request.go b/internal/infrastructure/mongo/auth_request.go index 3e1a36e4..bcd716b2 100644 --- a/internal/infrastructure/mongo/auth_request.go +++ b/internal/infrastructure/mongo/auth_request.go @@ -53,16 +53,6 @@ func (r *authRequestRepo) Remove(ctx context.Context, requestID id.AuthRequestID return r.client.RemoveOne(ctx, requestID.String()) } -func (r *authRequestRepo) find(ctx context.Context, dst []*auth.Request, filter bson.D) ([]*auth.Request, error) { - c := mongodoc.AuthRequestConsumer{ - Rows: dst, - } - if err := r.client.Find(ctx, filter, &c); err != nil { - return nil, err - } - return c.Rows, nil -} - func (r *authRequestRepo) findOne(ctx context.Context, filter bson.D) (*auth.Request, error) { dst := make([]*auth.Request, 0, 1) c := mongodoc.AuthRequestConsumer{ From 25f6d4156aae143a581531de036c6ab9c595fc3b Mon Sep 17 00:00:00 2001 From: yk Date: Thu, 2 Dec 2021 16:35:08 +0300 Subject: [PATCH 5/6] - fix PR comments --- internal/app/auth_server.go | 2 +- internal/infrastructure/mongo/auth_request.go | 7 +- .../mongo/mongodoc/auth_request.go | 6 +- internal/usecase/interactor/auth.go | 72 +++++++------------ pkg/auth/builder.go | 38 +++++----- pkg/auth/client.go | 42 +++++------ 6 files changed, 73 insertions(+), 94 deletions(-) diff --git a/internal/app/auth_server.go b/internal/app/auth_server.go index 7bf90c33..e32c6793 100644 --- a/internal/app/auth_server.go +++ b/internal/app/auth_server.go @@ -202,7 +202,7 @@ func login(ctx context.Context, cfg *ServerConfig, storage op.Storage, userUseca } // Complete the auth request && set the subject - err = storage.(*interactor.Storage).CompleteAuthRequest(ctx, request.AuthRequestID, user.GetAuthByProvider("auth0").Sub) + err = storage.(*interactor.AuthStorage).CompleteAuthRequest(ctx, request.AuthRequestID, user.GetAuthByProvider("auth0").Sub) if err != nil { ec.Logger().Error("failed to complete the auth request !") return ec.Redirect(http.StatusFound, redirectURL(authRequest.GetRedirectURI(), !cfg.Debug, request.AuthRequestID, "invalid login")) diff --git a/internal/infrastructure/mongo/auth_request.go b/internal/infrastructure/mongo/auth_request.go index bcd716b2..247e1f6e 100644 --- a/internal/infrastructure/mongo/auth_request.go +++ b/internal/infrastructure/mongo/auth_request.go @@ -3,13 +3,12 @@ package mongo import ( "context" - "github.com/reearth/reearth-backend/pkg/auth" - "go.mongodb.org/mongo-driver/bson" - "github.com/reearth/reearth-backend/internal/infrastructure/mongo/mongodoc" "github.com/reearth/reearth-backend/internal/usecase/repo" + "github.com/reearth/reearth-backend/pkg/auth" "github.com/reearth/reearth-backend/pkg/id" "github.com/reearth/reearth-backend/pkg/log" + "go.mongodb.org/mongo-driver/bson" ) type authRequestRepo struct { @@ -30,7 +29,7 @@ func (r *authRequestRepo) init() { } func (r *authRequestRepo) FindByID(ctx context.Context, id2 id.AuthRequestID) (*auth.Request, error) { - filter := bson.D{{Key: "id", Value: id.ID(id2).String()}} + filter := bson.D{{Key: "id", Value: id2.String()}} return r.findOne(ctx, filter) } diff --git a/internal/infrastructure/mongo/mongodoc/auth_request.go b/internal/infrastructure/mongo/mongodoc/auth_request.go index 589e9ceb..245e94c2 100644 --- a/internal/infrastructure/mongo/mongodoc/auth_request.go +++ b/internal/infrastructure/mongo/mongodoc/auth_request.go @@ -85,7 +85,7 @@ func (d *AuthRequestDocument) Model() (*auth.Request, error) { return nil, nil } - uuid, err := id.AuthRequestIDFrom(d.ID) + ulid, err := id.AuthRequestIDFrom(d.ID) if err != nil { return nil, err } @@ -97,8 +97,8 @@ func (d *AuthRequestDocument) Model() (*auth.Request, error) { Method: oidc.CodeChallengeMethod(d.CodeChallenge.Method), } } - var req = auth.New(). - ID(uuid). + var req = auth.NewRequest(). + ID(ulid). ClientID(d.ClientID). Subject(d.Subject). Code(d.Code). diff --git a/internal/usecase/interactor/auth.go b/internal/usecase/interactor/auth.go index 2605f44b..e93c2870 100644 --- a/internal/usecase/interactor/auth.go +++ b/internal/usecase/interactor/auth.go @@ -8,7 +8,6 @@ import ( "crypto/x509/pkix" "errors" "math/big" - "sync" "time" "github.com/caos/oidc/pkg/oidc" @@ -20,8 +19,7 @@ import ( "gopkg.in/square/go-jose.v2" ) -type Storage struct { - lock sync.Mutex +type AuthStorage struct { appConfig *StorageConfig getUserBySubject func(context.Context, string) (*user.User, error) clients map[string]op.Client @@ -79,7 +77,7 @@ func NewAuthStorage(cfg *StorageConfig, request repo.AuthRequest, getUserBySubje key, sigKey, keySet := initKeys(name) - return &Storage{ + return &AuthStorage{ appConfig: cfg, getUserBySubject: getUserBySubject, requests: request, @@ -131,14 +129,11 @@ func initKeys(name pkix.Name) (*rsa.PrivateKey, jose.SigningKey, jose.JSONWebKey } } -func (s *Storage) Health(_ context.Context) error { +func (s *AuthStorage) Health(_ context.Context) error { return nil } -func (s *Storage) CreateAuthRequest(ctx context.Context, authReq *oidc.AuthRequest, _ string) (op.AuthRequest, error) { - s.lock.Lock() - defer s.lock.Unlock() - +func (s *AuthStorage) CreateAuthRequest(ctx context.Context, authReq *oidc.AuthRequest, _ string) (op.AuthRequest, error) { audiences := []string{ s.appConfig.Domain, } @@ -153,7 +148,7 @@ func (s *Storage) CreateAuthRequest(ctx context.Context, authReq *oidc.AuthReque Method: authReq.CodeChallengeMethod, } } - var request = auth.New(). + var request = auth.NewRequest(). NewID(). ClientID(authReq.ClientID). State(authReq.State). @@ -173,10 +168,7 @@ func (s *Storage) CreateAuthRequest(ctx context.Context, authReq *oidc.AuthReque return request, nil } -func (s *Storage) AuthRequestByID(ctx context.Context, requestID string) (op.AuthRequest, error) { - s.lock.Lock() - defer s.lock.Unlock() - +func (s *AuthStorage) AuthRequestByID(ctx context.Context, requestID string) (op.AuthRequest, error) { if requestID == "" { return nil, errors.New("invalid id") } @@ -191,20 +183,14 @@ func (s *Storage) AuthRequestByID(ctx context.Context, requestID string) (op.Aut return request, nil } -func (s *Storage) AuthRequestByCode(ctx context.Context, code string) (op.AuthRequest, error) { - s.lock.Lock() - defer s.lock.Unlock() - +func (s *AuthStorage) AuthRequestByCode(ctx context.Context, code string) (op.AuthRequest, error) { if code == "" { return nil, errors.New("invalid code") } return s.requests.FindByCode(ctx, code) } -func (s *Storage) AuthRequestBySubject(ctx context.Context, subject string) (op.AuthRequest, error) { - s.lock.Lock() - defer s.lock.Unlock() - +func (s *AuthStorage) AuthRequestBySubject(ctx context.Context, subject string) (op.AuthRequest, error) { if subject == "" { return nil, errors.New("invalid subject") } @@ -212,7 +198,7 @@ func (s *Storage) AuthRequestBySubject(ctx context.Context, subject string) (op. return s.requests.FindBySubject(ctx, subject) } -func (s *Storage) SaveAuthCode(ctx context.Context, requestID, code string) error { +func (s *AuthStorage) SaveAuthCode(ctx context.Context, requestID, code string) error { request, err := s.AuthRequestByID(ctx, requestID) if err != nil { @@ -224,24 +210,21 @@ func (s *Storage) SaveAuthCode(ctx context.Context, requestID, code string) erro return err } -func (s *Storage) DeleteAuthRequest(_ context.Context, requestID string) error { - s.lock.Lock() - defer s.lock.Unlock() - +func (s *AuthStorage) DeleteAuthRequest(_ context.Context, requestID string) error { delete(s.clients, requestID) return nil } -func (s *Storage) CreateAccessToken(_ context.Context, _ op.TokenRequest) (string, time.Time, error) { +func (s *AuthStorage) CreateAccessToken(_ context.Context, _ op.TokenRequest) (string, time.Time, error) { return "id", time.Now().UTC().Add(5 * time.Hour), nil } -func (s *Storage) CreateAccessAndRefreshTokens(_ context.Context, request op.TokenRequest, _ string) (accessTokenID string, newRefreshToken string, expiration time.Time, err error) { +func (s *AuthStorage) CreateAccessAndRefreshTokens(_ context.Context, request op.TokenRequest, _ string) (accessTokenID string, newRefreshToken string, expiration time.Time, err error) { authReq := request.(*auth.Request) return "id", authReq.GetID(), time.Now().UTC().Add(5 * time.Minute), nil } -func (s *Storage) TokenRequestByRefreshToken(ctx context.Context, refreshToken string) (op.RefreshTokenRequest, error) { +func (s *AuthStorage) TokenRequestByRefreshToken(ctx context.Context, refreshToken string) (op.RefreshTokenRequest, error) { r, err := s.AuthRequestByID(ctx, refreshToken) if err != nil { return nil, err @@ -249,23 +232,23 @@ func (s *Storage) TokenRequestByRefreshToken(ctx context.Context, refreshToken s return r.(op.RefreshTokenRequest), err } -func (s *Storage) TerminateSession(_ context.Context, _, _ string) error { +func (s *AuthStorage) TerminateSession(_ context.Context, _, _ string) error { return errors.New("not implemented") } -func (s *Storage) GetSigningKey(_ context.Context, keyCh chan<- jose.SigningKey) { +func (s *AuthStorage) GetSigningKey(_ context.Context, keyCh chan<- jose.SigningKey) { keyCh <- s.sigKey } -func (s *Storage) GetKeySet(_ context.Context) (*jose.JSONWebKeySet, error) { +func (s *AuthStorage) GetKeySet(_ context.Context) (*jose.JSONWebKeySet, error) { return &s.keySet, nil } -func (s *Storage) GetKeyByIDAndUserID(_ context.Context, kid, _ string) (*jose.JSONWebKey, error) { +func (s *AuthStorage) GetKeyByIDAndUserID(_ context.Context, kid, _ string) (*jose.JSONWebKey, error) { return &s.keySet.Key(kid)[0], nil } -func (s *Storage) GetClientByClientID(_ context.Context, clientID string) (op.Client, error) { +func (s *AuthStorage) GetClientByClientID(_ context.Context, clientID string) (op.Client, error) { if clientID == "" { return nil, errors.New("invalid client id") @@ -279,15 +262,15 @@ func (s *Storage) GetClientByClientID(_ context.Context, clientID string) (op.Cl return client, nil } -func (s *Storage) AuthorizeClientIDSecret(_ context.Context, _ string, _ string) error { +func (s *AuthStorage) AuthorizeClientIDSecret(_ context.Context, _ string, _ string) error { return nil } -func (s *Storage) SetUserinfoFromToken(ctx context.Context, userinfo oidc.UserInfoSetter, _, _, _ string) error { +func (s *AuthStorage) SetUserinfoFromToken(ctx context.Context, userinfo oidc.UserInfoSetter, _, _, _ string) error { return s.SetUserinfoFromScopes(ctx, userinfo, "", "", []string{}) } -func (s *Storage) SetUserinfoFromScopes(ctx context.Context, userinfo oidc.UserInfoSetter, subject, _ string, _ []string) error { +func (s *AuthStorage) SetUserinfoFromScopes(ctx context.Context, userinfo oidc.UserInfoSetter, subject, _ string, _ []string) error { request, err := s.AuthRequestBySubject(ctx, subject) if err != nil { @@ -308,11 +291,11 @@ func (s *Storage) SetUserinfoFromScopes(ctx context.Context, userinfo oidc.UserI return nil } -func (s *Storage) GetPrivateClaimsFromScopes(_ context.Context, _, _ string, _ []string) (map[string]interface{}, error) { +func (s *AuthStorage) GetPrivateClaimsFromScopes(_ context.Context, _, _ string, _ []string) (map[string]interface{}, error) { return map[string]interface{}{"private_claim": "test"}, nil } -func (s *Storage) SetIntrospectionFromToken(ctx context.Context, introspect oidc.IntrospectionResponse, _, subject, clientID string) error { +func (s *AuthStorage) SetIntrospectionFromToken(ctx context.Context, introspect oidc.IntrospectionResponse, _, subject, clientID string) error { if err := s.SetUserinfoFromScopes(ctx, introspect, subject, clientID, []string{}); err != nil { return err } @@ -324,11 +307,11 @@ func (s *Storage) SetIntrospectionFromToken(ctx context.Context, introspect oidc return nil } -func (s *Storage) ValidateJWTProfileScopes(_ context.Context, _ string, scope []string) ([]string, error) { +func (s *AuthStorage) ValidateJWTProfileScopes(_ context.Context, _ string, scope []string) ([]string, error) { return scope, nil } -func (s *Storage) CompleteAuthRequest(ctx context.Context, requestId, sub string) error { +func (s *AuthStorage) CompleteAuthRequest(ctx context.Context, requestId, sub string) error { request, err := s.AuthRequestByID(ctx, requestId) if err != nil { return err @@ -339,10 +322,7 @@ func (s *Storage) CompleteAuthRequest(ctx context.Context, requestId, sub string return err } -func (s *Storage) updateRequest(ctx context.Context, requestID string, req auth.Request) error { - s.lock.Lock() - defer s.lock.Unlock() - +func (s *AuthStorage) updateRequest(ctx context.Context, requestID string, req auth.Request) error { if requestID == "" { return errors.New("invalid id") } diff --git a/pkg/auth/builder.go b/pkg/auth/builder.go index ed91727a..9eed6e64 100644 --- a/pkg/auth/builder.go +++ b/pkg/auth/builder.go @@ -7,15 +7,15 @@ import ( "github.com/reearth/reearth-backend/pkg/id" ) -type Builder struct { +type RequestBuilder struct { r *Request } -func New() *Builder { - return &Builder{r: &Request{}} +func NewRequest() *RequestBuilder { + return &RequestBuilder{r: &Request{}} } -func (b *Builder) Build() (*Request, error) { +func (b *RequestBuilder) Build() (*Request, error) { if id.ID(b.r.id).IsNil() { return nil, id.ErrInvalidID } @@ -23,7 +23,7 @@ func (b *Builder) Build() (*Request, error) { return b.r, nil } -func (b *Builder) MustBuild() *Request { +func (b *RequestBuilder) MustBuild() *Request { r, err := b.Build() if err != nil { panic(err) @@ -31,72 +31,72 @@ func (b *Builder) MustBuild() *Request { return r } -func (b *Builder) ID(id id.AuthRequestID) *Builder { +func (b *RequestBuilder) ID(id id.AuthRequestID) *RequestBuilder { b.r.id = id return b } -func (b *Builder) NewID() *Builder { +func (b *RequestBuilder) NewID() *RequestBuilder { b.r.id = id.AuthRequestID(id.New()) return b } -func (b *Builder) ClientID(id string) *Builder { +func (b *RequestBuilder) ClientID(id string) *RequestBuilder { b.r.clientID = id return b } -func (b *Builder) Subject(subject string) *Builder { +func (b *RequestBuilder) Subject(subject string) *RequestBuilder { b.r.subject = subject return b } -func (b *Builder) Code(code string) *Builder { +func (b *RequestBuilder) Code(code string) *RequestBuilder { b.r.code = code return b } -func (b *Builder) State(state string) *Builder { +func (b *RequestBuilder) State(state string) *RequestBuilder { b.r.state = state return b } -func (b *Builder) ResponseType(rt oidc.ResponseType) *Builder { +func (b *RequestBuilder) ResponseType(rt oidc.ResponseType) *RequestBuilder { b.r.responseType = rt return b } -func (b *Builder) Scopes(scopes []string) *Builder { +func (b *RequestBuilder) Scopes(scopes []string) *RequestBuilder { b.r.scopes = scopes return b } -func (b *Builder) Audiences(audiences []string) *Builder { +func (b *RequestBuilder) Audiences(audiences []string) *RequestBuilder { b.r.audiences = audiences return b } -func (b *Builder) RedirectURI(redirectURI string) *Builder { +func (b *RequestBuilder) RedirectURI(redirectURI string) *RequestBuilder { b.r.redirectURI = redirectURI return b } -func (b *Builder) Nonce(nonce string) *Builder { +func (b *RequestBuilder) Nonce(nonce string) *RequestBuilder { b.r.nonce = nonce return b } -func (b *Builder) CodeChallenge(CodeChallenge *oidc.CodeChallenge) *Builder { +func (b *RequestBuilder) CodeChallenge(CodeChallenge *oidc.CodeChallenge) *RequestBuilder { b.r.codeChallenge = CodeChallenge return b } -func (b *Builder) CreatedAt(createdAt time.Time) *Builder { +func (b *RequestBuilder) CreatedAt(createdAt time.Time) *RequestBuilder { b.r.createdAt = createdAt return b } -func (b *Builder) AuthorizedAt(authorizedAt *time.Time) *Builder { +func (b *RequestBuilder) AuthorizedAt(authorizedAt *time.Time) *RequestBuilder { b.r.authorizedAt = authorizedAt return b } diff --git a/pkg/auth/client.go b/pkg/auth/client.go index 2c4b2bcb..93075710 100644 --- a/pkg/auth/client.go +++ b/pkg/auth/client.go @@ -8,8 +8,8 @@ import ( "github.com/caos/oidc/pkg/op" ) -type ConfClient struct { - ID string +type Client struct { + id string applicationType op.ApplicationType authMethod oidc.AuthMethod accessTokenType op.AccessTokenType @@ -25,8 +25,8 @@ type ConfClient struct { } func NewLocalClient(devMode bool) op.Client { - return &ConfClient{ - ID: "01FH69GFQ4DFCXS5XD91JK4HZ1", + return &Client{ + id: "01FH69GFQ4DFCXS5XD91JK4HZ1", applicationType: op.ApplicationTypeWeb, authMethod: oidc.AuthMethodNone, accessTokenType: op.AccessTokenTypeJWT, @@ -41,63 +41,63 @@ func NewLocalClient(devMode bool) op.Client { } } -func (c *ConfClient) GetID() string { - return c.ID +func (c *Client) GetID() string { + return c.id } -func (c *ConfClient) RedirectURIs() []string { +func (c *Client) RedirectURIs() []string { return c.redirectURIs } -func (c *ConfClient) PostLogoutRedirectURIs() []string { +func (c *Client) PostLogoutRedirectURIs() []string { return c.logoutRedirectURIs } -func (c *ConfClient) LoginURL(id string) string { +func (c *Client) LoginURL(id string) string { return fmt.Sprintf(c.loginURI, id) } -func (c *ConfClient) ApplicationType() op.ApplicationType { +func (c *Client) ApplicationType() op.ApplicationType { return c.applicationType } -func (c *ConfClient) AuthMethod() oidc.AuthMethod { +func (c *Client) AuthMethod() oidc.AuthMethod { return c.authMethod } -func (c *ConfClient) IDTokenLifetime() time.Duration { +func (c *Client) IDTokenLifetime() time.Duration { return c.idTokenLifetime } -func (c *ConfClient) AccessTokenType() op.AccessTokenType { +func (c *Client) AccessTokenType() op.AccessTokenType { return c.accessTokenType } -func (c *ConfClient) ResponseTypes() []oidc.ResponseType { +func (c *Client) ResponseTypes() []oidc.ResponseType { return c.responseTypes } -func (c *ConfClient) GrantTypes() []oidc.GrantType { +func (c *Client) GrantTypes() []oidc.GrantType { return c.grantTypes } -func (c *ConfClient) DevMode() bool { +func (c *Client) DevMode() bool { return c.devMode } -func (c *ConfClient) RestrictAdditionalIdTokenScopes() func(scopes []string) []string { +func (c *Client) RestrictAdditionalIdTokenScopes() func(scopes []string) []string { return func(scopes []string) []string { return scopes } } -func (c *ConfClient) RestrictAdditionalAccessTokenScopes() func(scopes []string) []string { +func (c *Client) RestrictAdditionalAccessTokenScopes() func(scopes []string) []string { return func(scopes []string) []string { return scopes } } -func (c *ConfClient) IsScopeAllowed(scope string) bool { +func (c *Client) IsScopeAllowed(scope string) bool { for _, clientScope := range c.allowedScopes { if clientScope == scope { return true @@ -106,10 +106,10 @@ func (c *ConfClient) IsScopeAllowed(scope string) bool { return false } -func (c *ConfClient) IDTokenUserinfoClaimsAssertion() bool { +func (c *Client) IDTokenUserinfoClaimsAssertion() bool { return false } -func (c *ConfClient) ClockSkew() time.Duration { +func (c *Client) ClockSkew() time.Duration { return c.clockSkew } From 6805ceaab63ee382d18676995a54f5401bd576ee Mon Sep 17 00:00:00 2001 From: yk Date: Thu, 2 Dec 2021 16:56:14 +0300 Subject: [PATCH 6/6] implement memory storage for authRequest repo --- .../infrastructure/memory/auth_request.go | 75 +++++++++++++++++++ 1 file changed, 75 insertions(+) create mode 100644 internal/infrastructure/memory/auth_request.go diff --git a/internal/infrastructure/memory/auth_request.go b/internal/infrastructure/memory/auth_request.go new file mode 100644 index 00000000..daabf8c6 --- /dev/null +++ b/internal/infrastructure/memory/auth_request.go @@ -0,0 +1,75 @@ +package memory + +import ( + "context" + "sync" + + "github.com/reearth/reearth-backend/internal/usecase/repo" + "github.com/reearth/reearth-backend/pkg/auth" + "github.com/reearth/reearth-backend/pkg/id" + "github.com/reearth/reearth-backend/pkg/rerror" +) + +type AuthRequest struct { + lock sync.Mutex + data map[id.AuthRequestID]auth.Request +} + +func NewAuthRequest() repo.AuthRequest { + return &AuthRequest{ + data: map[id.AuthRequestID]auth.Request{}, + } +} + +func (r *AuthRequest) FindByID(_ context.Context, id id.AuthRequestID) (*auth.Request, error) { + r.lock.Lock() + defer r.lock.Unlock() + + d, ok := r.data[id] + if ok { + return &d, nil + } + return &auth.Request{}, rerror.ErrNotFound +} + +func (r *AuthRequest) FindByCode(_ context.Context, s string) (*auth.Request, error) { + r.lock.Lock() + defer r.lock.Unlock() + + for _, ar := range r.data { + if ar.GetCode() == s { + return &ar, nil + } + } + + return &auth.Request{}, rerror.ErrNotFound +} + +func (r *AuthRequest) FindBySubject(_ context.Context, s string) (*auth.Request, error) { + r.lock.Lock() + defer r.lock.Unlock() + + for _, ar := range r.data { + if ar.GetSubject() == s { + return &ar, nil + } + } + + return &auth.Request{}, rerror.ErrNotFound +} + +func (r *AuthRequest) Save(_ context.Context, request *auth.Request) error { + r.lock.Lock() + defer r.lock.Unlock() + + r.data[request.ID()] = *request + return nil +} + +func (r *AuthRequest) Remove(_ context.Context, requestID id.AuthRequestID) error { + r.lock.Lock() + defer r.lock.Unlock() + + delete(r.data, requestID) + return nil +}