Skip to content

Commit ca103f8

Browse files
grokifyclaude
andcommitted
feat: add vault token set implementation
VaultTokenSet implements goauth's tokens.TokenSet interface backed by omnivault for persistent OAuth2 token storage: - Token serialization/deserialization - Field-based storage for multi-field providers - Support for access/refresh tokens and expiry tracking Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
1 parent 8486113 commit ca103f8

1 file changed

Lines changed: 251 additions & 0 deletions

File tree

tokenset.go

Lines changed: 251 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,251 @@
1+
package omnitoken
2+
3+
import (
4+
"context"
5+
"encoding/json"
6+
"fmt"
7+
"log/slog"
8+
"strings"
9+
"time"
10+
11+
"github.com/grokify/goauth/multiservice/tokens"
12+
"github.com/plexusone/omnivault/vault"
13+
"golang.org/x/oauth2"
14+
)
15+
16+
// VaultTokenSet implements goauth's tokens.TokenSet interface backed by an omnivault.
17+
// This allows goauth's multi-service token management to use vault storage.
18+
type VaultTokenSet struct {
19+
vault vault.Vault
20+
prefix string
21+
logger *slog.Logger
22+
}
23+
24+
// NewVaultTokenSet creates a new VaultTokenSet.
25+
func NewVaultTokenSet(v vault.Vault, prefix string, logger *slog.Logger) *VaultTokenSet {
26+
if prefix == "" {
27+
prefix = "tokens/"
28+
}
29+
if logger == nil {
30+
logger = slog.Default()
31+
}
32+
return &VaultTokenSet{
33+
vault: v,
34+
prefix: prefix,
35+
logger: logger,
36+
}
37+
}
38+
39+
// GetTokenInfo retrieves token info from the vault.
40+
func (ts *VaultTokenSet) GetTokenInfo(key string) (*tokens.TokenInfo, error) {
41+
key = tokens.FormatKey(key)
42+
path := ts.keyToPath(key)
43+
44+
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
45+
defer cancel()
46+
47+
secret, err := ts.vault.Get(ctx, path)
48+
if err != nil {
49+
if isNotFoundError(err) {
50+
return nil, fmt.Errorf("%w: %s", ErrTokenNotFound, key)
51+
}
52+
return nil, fmt.Errorf("failed to get token info: %w", err)
53+
}
54+
55+
tokenInfo, err := parseTokenInfo(secret)
56+
if err != nil {
57+
return nil, fmt.Errorf("failed to parse token info: %w", err)
58+
}
59+
60+
return tokenInfo, nil
61+
}
62+
63+
// GetToken retrieves just the OAuth2 token from the vault.
64+
func (ts *VaultTokenSet) GetToken(key string) (*oauth2.Token, error) {
65+
tokenInfo, err := ts.GetTokenInfo(key)
66+
if err != nil {
67+
return nil, err
68+
}
69+
if tokenInfo.Token == nil {
70+
return nil, fmt.Errorf("%w: token is nil for key %s", ErrTokenNotFound, key)
71+
}
72+
return tokenInfo.Token, nil
73+
}
74+
75+
// SetTokenInfo stores token info in the vault.
76+
func (ts *VaultTokenSet) SetTokenInfo(key string, tokenInfo *tokens.TokenInfo) error {
77+
key = tokens.FormatKey(key)
78+
path := ts.keyToPath(key)
79+
80+
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
81+
defer cancel()
82+
83+
secret, err := tokenInfoToSecret(tokenInfo)
84+
if err != nil {
85+
return fmt.Errorf("failed to serialize token info: %w", err)
86+
}
87+
88+
if err := ts.vault.Set(ctx, path, secret); err != nil {
89+
return fmt.Errorf("failed to store token info: %w", err)
90+
}
91+
92+
ts.logger.Debug("stored token info",
93+
"key", key,
94+
"service_key", tokenInfo.ServiceKey,
95+
"service_type", tokenInfo.ServiceType,
96+
)
97+
98+
return nil
99+
}
100+
101+
// DeleteToken removes a token from the vault.
102+
func (ts *VaultTokenSet) DeleteToken(key string) error {
103+
key = tokens.FormatKey(key)
104+
path := ts.keyToPath(key)
105+
106+
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
107+
defer cancel()
108+
109+
if err := ts.vault.Delete(ctx, path); err != nil {
110+
return fmt.Errorf("failed to delete token: %w", err)
111+
}
112+
113+
ts.logger.Debug("deleted token", "key", key)
114+
return nil
115+
}
116+
117+
// ListTokens returns all token keys in the vault.
118+
func (ts *VaultTokenSet) ListTokens() ([]string, error) {
119+
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
120+
defer cancel()
121+
122+
paths, err := ts.vault.List(ctx, ts.prefix)
123+
if err != nil {
124+
return nil, fmt.Errorf("failed to list tokens: %w", err)
125+
}
126+
127+
keys := make([]string, 0, len(paths))
128+
for _, path := range paths {
129+
key := ts.pathToKey(path)
130+
if key != "" {
131+
keys = append(keys, key)
132+
}
133+
}
134+
135+
return keys, nil
136+
}
137+
138+
// ExistsToken checks if a token exists in the vault.
139+
func (ts *VaultTokenSet) ExistsToken(key string) (bool, error) {
140+
key = tokens.FormatKey(key)
141+
path := ts.keyToPath(key)
142+
143+
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
144+
defer cancel()
145+
146+
return ts.vault.Exists(ctx, path)
147+
}
148+
149+
// keyToPath converts a token key to a vault path.
150+
func (ts *VaultTokenSet) keyToPath(key string) string {
151+
return ts.prefix + sanitizeKey(key)
152+
}
153+
154+
// pathToKey converts a vault path to a token key.
155+
func (ts *VaultTokenSet) pathToKey(path string) string {
156+
if !strings.HasPrefix(path, ts.prefix) {
157+
return ""
158+
}
159+
return strings.TrimPrefix(path, ts.prefix)
160+
}
161+
162+
// sanitizeKey sanitizes a key for use in vault paths.
163+
func sanitizeKey(key string) string {
164+
key = strings.TrimSpace(key)
165+
key = strings.ReplaceAll(key, "/", "_")
166+
key = strings.ReplaceAll(key, "\\", "_")
167+
return key
168+
}
169+
170+
// parseTokenInfo parses a vault secret into TokenInfo.
171+
func parseTokenInfo(secret *vault.Secret) (*tokens.TokenInfo, error) {
172+
var tokenInfo tokens.TokenInfo
173+
174+
// First try to unmarshal from the Value field (JSON)
175+
if secret.Value != "" {
176+
if err := json.Unmarshal([]byte(secret.Value), &tokenInfo); err == nil {
177+
return &tokenInfo, nil
178+
}
179+
}
180+
181+
// Fall back to Fields for multi-field secrets
182+
if secret.Fields != nil {
183+
tokenInfo.ServiceKey = secret.Fields["service_key"]
184+
tokenInfo.ServiceType = secret.Fields["service_type"]
185+
186+
if tokenJSON := secret.Fields["token"]; tokenJSON != "" {
187+
var token oauth2.Token
188+
if err := json.Unmarshal([]byte(tokenJSON), &token); err != nil {
189+
return nil, fmt.Errorf("failed to parse token field: %w", err)
190+
}
191+
tokenInfo.Token = &token
192+
}
193+
194+
return &tokenInfo, nil
195+
}
196+
197+
return nil, fmt.Errorf("secret has no parseable token info")
198+
}
199+
200+
// tokenInfoToSecret converts TokenInfo to a vault secret.
201+
func tokenInfoToSecret(tokenInfo *tokens.TokenInfo) (*vault.Secret, error) {
202+
data, err := json.Marshal(tokenInfo)
203+
if err != nil {
204+
return nil, err
205+
}
206+
207+
// Also store as fields for providers that support multi-field
208+
fields := make(map[string]string)
209+
fields["service_key"] = tokenInfo.ServiceKey
210+
fields["service_type"] = tokenInfo.ServiceType
211+
212+
if tokenInfo.Token != nil {
213+
//nolint:gosec // G117: OAuth token stored in vault per RFC 6749
214+
tokenData, err := json.Marshal(tokenInfo.Token)
215+
if err != nil {
216+
return nil, fmt.Errorf("failed to serialize token: %w", err)
217+
}
218+
fields["token"] = string(tokenData)
219+
fields["access_token"] = tokenInfo.Token.AccessToken
220+
if tokenInfo.Token.RefreshToken != "" {
221+
fields["refresh_token"] = tokenInfo.Token.RefreshToken
222+
}
223+
if !tokenInfo.Token.Expiry.IsZero() {
224+
fields["expiry"] = tokenInfo.Token.Expiry.Format(time.RFC3339)
225+
}
226+
}
227+
228+
now := vault.Now()
229+
return &vault.Secret{
230+
Value: string(data),
231+
Fields: fields,
232+
Metadata: vault.Metadata{
233+
ModifiedAt: now,
234+
Tags: map[string]string{
235+
"type": "oauth2_token",
236+
"service_key": tokenInfo.ServiceKey,
237+
"service_type": tokenInfo.ServiceType,
238+
},
239+
},
240+
}, nil
241+
}
242+
243+
// isNotFoundError checks if an error indicates a secret was not found.
244+
func isNotFoundError(err error) bool {
245+
return err == vault.ErrSecretNotFound ||
246+
strings.Contains(err.Error(), "not found") ||
247+
strings.Contains(err.Error(), "does not exist")
248+
}
249+
250+
// Ensure VaultTokenSet implements tokens.TokenSet.
251+
var _ tokens.TokenSet = (*VaultTokenSet)(nil)

0 commit comments

Comments
 (0)