Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ TAG := $(shell git rev-list --tags --max-count=1)
VERSION := $(shell git describe --tags ${TAG})
.PHONY: build check fmt lint test test-race vet test-cover-html help install proto admin-app compose-up-dev
.DEFAULT_GOAL := build
PROTON_COMMIT := "06620262d6c0e850b1e97ace2bebdbbfc5b5ed51"
PROTON_COMMIT := "0e66eea2643eb54aba01be2c2b7298e96c82d749"

admin-app:
@echo " > generating admin build"
Expand Down
59 changes: 59 additions & 0 deletions core/userpat/mocks/repository.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

53 changes: 53 additions & 0 deletions core/userpat/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,59 @@ func (s *Service) Delete(ctx context.Context, userID, id string) error {
return nil
}

// Regenerate creates a new secret and updates the expiry for an existing PAT.
// The scope (roles + projects) and policies are preserved. Expired PATs can be
// regenerated; if reviving an expired PAT, checks the active count limit.
func (s *Service) Regenerate(ctx context.Context, userID, id string, newExpiresAt time.Time) (patmodels.PAT, string, error) {
if !s.config.Enabled {
return patmodels.PAT{}, "", paterrors.ErrDisabled
}

pat, err := s.getOwnedPAT(ctx, userID, id)
if err != nil {
return patmodels.PAT{}, "", err
}

if err := s.ValidateExpiry(newExpiresAt); err != nil {
return patmodels.PAT{}, "", err
}

// If PAT is expired, regenerating revives it — check active count limit.
if pat.ExpiresAt.Before(time.Now()) {
count, err := s.repo.CountActive(ctx, pat.UserID, pat.OrgID)
if err != nil {
return patmodels.PAT{}, "", fmt.Errorf("counting active PATs: %w", err)
}
if count >= s.config.MaxPerUserPerOrg {
return patmodels.PAT{}, "", paterrors.ErrLimitExceeded
}
}
Comment on lines +165 to +174
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Treat expiry-at-now as expired when enforcing active-PAT limit.
CountActive considers active as expires_at > now, but Line 166 uses Before(now). A PAT expiring exactly now bypasses the limit check incorrectly.

🐛 Proposed fix
-	// If PAT is expired, regenerating revives it — check active count limit.
-	if pat.ExpiresAt.Before(time.Now()) {
+	// If PAT is expired (or expires exactly now), regenerating revives it — check active count limit.
+	now := time.Now()
+	if !pat.ExpiresAt.After(now) {
 		count, err := s.repo.CountActive(ctx, pat.UserID, pat.OrgID)
 		if err != nil {
 			return patmodels.PAT{}, "", fmt.Errorf("counting active PATs: %w", err)
 		}


patValue, secretHash, err := s.generatePAT()
if err != nil {
return patmodels.PAT{}, "", err
}

oldExpiresAt := pat.ExpiresAt
regenerated, err := s.repo.Regenerate(ctx, id, secretHash, newExpiresAt)
if err != nil {
return patmodels.PAT{}, "", fmt.Errorf("regenerating PAT: %w", err)
}

if err := s.enrichWithScope(ctx, &regenerated); err != nil {
return patmodels.PAT{}, "", fmt.Errorf("enriching PAT scope: %w", err)
}

if err := s.createAuditRecord(ctx, pkgAuditRecord.PATRegeneratedEvent, regenerated, time.Now().UTC(), map[string]any{
"expires_at": regenerated.ExpiresAt,
"old_expires_at": oldExpiresAt,
}); err != nil {
s.logger.Error("failed to create audit record for PAT regeneration", "pat_id", id, "error", err)
}

return regenerated, patValue, nil
}

// Update replaces a PAT's title, metadata, and scope (roles + projects).
// Scope changes use revoke-all + recreate pattern with a TOCTOU guard
// against concurrent Delete.
Expand Down
225 changes: 225 additions & 0 deletions core/userpat/service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2144,3 +2144,228 @@ func TestService_Update(t *testing.T) {
})
}
}

func TestService_Regenerate(t *testing.T) {
futureExpiry := time.Now().Add(48 * time.Hour)

activePAT := models.PAT{
ID: "pat-1",
UserID: "user-1",
OrgID: "org-1",
Title: "my-token",
ExpiresAt: time.Now().Add(24 * time.Hour),
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
}

expiredPAT := models.PAT{
ID: "pat-2",
UserID: "user-1",
OrgID: "org-1",
Title: "expired-token",
ExpiresAt: time.Now().Add(-24 * time.Hour),
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
}

regeneratedPAT := models.PAT{
ID: "pat-1",
UserID: "user-1",
OrgID: "org-1",
Title: "my-token",
ExpiresAt: futureExpiry,
CreatedAt: activePAT.CreatedAt,
UpdatedAt: time.Now(),
}

tests := []struct {
name string
setup func() *userpat.Service
userID string
patID string
expiresAt time.Time
wantErr bool
wantErrIs error
}{
{
name: "should return ErrDisabled when PAT feature is disabled",
userID: "user-1",
patID: "pat-1",
expiresAt: futureExpiry,
setup: func() *userpat.Service {
return userpat.NewService(log.NewNoop(), nil, userpat.Config{
Enabled: false,
}, nil, nil, nil, nil)
},
wantErr: true,
wantErrIs: paterrors.ErrDisabled,
},
{
name: "should return ErrNotFound when PAT does not exist",
userID: "user-1",
patID: "pat-1",
expiresAt: futureExpiry,
setup: func() *userpat.Service {
repo := mocks.NewRepository(t)
repo.EXPECT().GetByID(mock.Anything, "pat-1").
Return(models.PAT{}, paterrors.ErrNotFound)
return userpat.NewService(log.NewNoop(), repo, defaultConfig, nil, nil, nil, nil)
},
wantErr: true,
wantErrIs: paterrors.ErrNotFound,
},
{
name: "should return ErrNotFound when PAT belongs to different user",
userID: "user-2",
patID: "pat-1",
expiresAt: futureExpiry,
setup: func() *userpat.Service {
repo := mocks.NewRepository(t)
repo.EXPECT().GetByID(mock.Anything, "pat-1").
Return(activePAT, nil)
return userpat.NewService(log.NewNoop(), repo, defaultConfig, nil, nil, nil, nil)
},
wantErr: true,
wantErrIs: paterrors.ErrNotFound,
},
{
name: "should return error when expiry is in the past",
userID: "user-1",
patID: "pat-1",
expiresAt: time.Now().Add(-1 * time.Hour),
setup: func() *userpat.Service {
repo := mocks.NewRepository(t)
repo.EXPECT().GetByID(mock.Anything, "pat-1").
Return(activePAT, nil)
return userpat.NewService(log.NewNoop(), repo, defaultConfig, nil, nil, nil, nil)
},
wantErr: true,
wantErrIs: paterrors.ErrExpiryInPast,
},
{
name: "should return ErrLimitExceeded when reviving expired PAT at limit",
userID: "user-1",
patID: "pat-2",
expiresAt: futureExpiry,
setup: func() *userpat.Service {
repo := mocks.NewRepository(t)
repo.EXPECT().GetByID(mock.Anything, "pat-2").
Return(expiredPAT, nil)
repo.EXPECT().CountActive(mock.Anything, "user-1", "org-1").
Return(int64(50), nil)
return userpat.NewService(log.NewNoop(), repo, defaultConfig, nil, nil, nil, nil)
},
wantErr: true,
wantErrIs: paterrors.ErrLimitExceeded,
},
{
name: "should not check limit when regenerating active PAT",
userID: "user-1",
patID: "pat-1",
expiresAt: futureExpiry,
setup: func() *userpat.Service {
repo := mocks.NewRepository(t)
repo.EXPECT().GetByID(mock.Anything, "pat-1").
Return(activePAT, nil)
// No CountActive call expected — PAT is active
repo.EXPECT().Regenerate(mock.Anything, "pat-1", mock.Anything, mock.Anything).
Return(regeneratedPAT, nil)
orgSvc := mocks.NewOrganizationService(t)
orgSvc.On("GetRaw", mock.Anything, mock.Anything).
Return(organization.Organization{ID: "org-1", Title: "Test Org"}, nil).Maybe()
policySvc := mocks.NewPolicyService(t)
policySvc.On("List", mock.Anything, mock.Anything).
Return([]policy.Policy{}, nil).Maybe()
auditRepo := mocks.NewAuditRecordRepository(t)
auditRepo.On("Create", mock.Anything, mock.Anything).
Return(auditmodels.AuditRecord{}, nil).Maybe()
return userpat.NewService(log.NewNoop(), repo, defaultConfig, orgSvc, nil, policySvc, auditRepo)
},
wantErr: false,
},
{
name: "should return error when repo regenerate fails",
userID: "user-1",
patID: "pat-1",
expiresAt: futureExpiry,
setup: func() *userpat.Service {
repo := mocks.NewRepository(t)
repo.EXPECT().GetByID(mock.Anything, "pat-1").
Return(activePAT, nil)
repo.EXPECT().Regenerate(mock.Anything, "pat-1", mock.Anything, mock.Anything).
Return(models.PAT{}, errors.New("db error"))
return userpat.NewService(log.NewNoop(), repo, defaultConfig, nil, nil, nil, nil)
},
wantErr: true,
},
{
name: "should regenerate expired PAT successfully when under limit",
userID: "user-1",
patID: "pat-2",
expiresAt: futureExpiry,
setup: func() *userpat.Service {
repo := mocks.NewRepository(t)
repo.EXPECT().GetByID(mock.Anything, "pat-2").
Return(expiredPAT, nil)
repo.EXPECT().CountActive(mock.Anything, "user-1", "org-1").
Return(int64(10), nil)
repo.EXPECT().Regenerate(mock.Anything, "pat-2", mock.Anything, mock.Anything).
Return(regeneratedPAT, nil)
orgSvc := mocks.NewOrganizationService(t)
orgSvc.On("GetRaw", mock.Anything, mock.Anything).
Return(organization.Organization{ID: "org-1", Title: "Test Org"}, nil).Maybe()
policySvc := mocks.NewPolicyService(t)
policySvc.On("List", mock.Anything, mock.Anything).
Return([]policy.Policy{}, nil).Maybe()
auditRepo := mocks.NewAuditRecordRepository(t)
auditRepo.On("Create", mock.Anything, mock.Anything).
Return(auditmodels.AuditRecord{}, nil).Maybe()
return userpat.NewService(log.NewNoop(), repo, defaultConfig, orgSvc, nil, policySvc, auditRepo)
},
wantErr: false,
},
{
name: "should succeed even when audit record creation fails",
userID: "user-1",
patID: "pat-1",
expiresAt: futureExpiry,
setup: func() *userpat.Service {
repo := mocks.NewRepository(t)
repo.EXPECT().GetByID(mock.Anything, "pat-1").
Return(activePAT, nil)
repo.EXPECT().Regenerate(mock.Anything, "pat-1", mock.Anything, mock.Anything).
Return(regeneratedPAT, nil)
orgSvc := mocks.NewOrganizationService(t)
orgSvc.On("GetRaw", mock.Anything, mock.Anything).
Return(organization.Organization{ID: "org-1", Title: "Test Org"}, nil).Maybe()
policySvc := mocks.NewPolicyService(t)
policySvc.On("List", mock.Anything, mock.Anything).
Return([]policy.Policy{}, nil).Maybe()
auditRepo := mocks.NewAuditRecordRepository(t)
auditRepo.On("Create", mock.Anything, mock.Anything).
Return(auditmodels.AuditRecord{}, errors.New("audit db down"))
return userpat.NewService(log.NewNoop(), repo, defaultConfig, orgSvc, nil, policySvc, auditRepo)
},
wantErr: false,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
svc := tt.setup()
_, _, err := svc.Regenerate(context.Background(), tt.userID, tt.patID, tt.expiresAt)
if tt.wantErr {
if err == nil {
t.Fatal("Regenerate() expected error, got nil")
}
if tt.wantErrIs != nil && !errors.Is(err, tt.wantErrIs) {
t.Errorf("Regenerate() error = %v, want %v", err, tt.wantErrIs)
}
return
}
if err != nil {
t.Fatalf("Regenerate() unexpected error: %v", err)
}
})
}
}
1 change: 1 addition & 0 deletions core/userpat/userpat.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,6 @@ type Repository interface {
GetBySecretHash(ctx context.Context, secretHash string) (models.PAT, error)
UpdateLastUsedAt(ctx context.Context, id string, at time.Time) error
Update(ctx context.Context, pat models.PAT) (models.PAT, error)
Regenerate(ctx context.Context, id, secretHash string, expiresAt time.Time) (models.PAT, error)
Delete(ctx context.Context, id string) error
}
1 change: 1 addition & 0 deletions internal/api/v1beta1connect/interfaces.go
Original file line number Diff line number Diff line change
Expand Up @@ -406,5 +406,6 @@ type UserPATService interface {
Get(ctx context.Context, userID, id string) (models.PAT, error)
Delete(ctx context.Context, userID, id string) error
Update(ctx context.Context, toUpdate models.PAT) (models.PAT, error)
Regenerate(ctx context.Context, userID, id string, newExpiresAt time.Time) (models.PAT, string, error)
ListAllowedRoles(ctx context.Context, scopes []string) ([]role.Role, error)
}
Loading
Loading