|
| 1 | +package omnitoken |
| 2 | + |
| 3 | +import ( |
| 4 | + "context" |
| 5 | + "encoding/json" |
| 6 | + "fmt" |
| 7 | + "log/slog" |
| 8 | + "strings" |
| 9 | + "time" |
| 10 | + |
| 11 | + "github.com/grokify/goauth/multiservice/tokens" |
| 12 | + "github.com/plexusone/omnivault/vault" |
| 13 | + "golang.org/x/oauth2" |
| 14 | +) |
| 15 | + |
| 16 | +// VaultTokenSet implements goauth's tokens.TokenSet interface backed by an omnivault. |
| 17 | +// This allows goauth's multi-service token management to use vault storage. |
| 18 | +type VaultTokenSet struct { |
| 19 | + vault vault.Vault |
| 20 | + prefix string |
| 21 | + logger *slog.Logger |
| 22 | +} |
| 23 | + |
| 24 | +// NewVaultTokenSet creates a new VaultTokenSet. |
| 25 | +func NewVaultTokenSet(v vault.Vault, prefix string, logger *slog.Logger) *VaultTokenSet { |
| 26 | + if prefix == "" { |
| 27 | + prefix = "tokens/" |
| 28 | + } |
| 29 | + if logger == nil { |
| 30 | + logger = slog.Default() |
| 31 | + } |
| 32 | + return &VaultTokenSet{ |
| 33 | + vault: v, |
| 34 | + prefix: prefix, |
| 35 | + logger: logger, |
| 36 | + } |
| 37 | +} |
| 38 | + |
| 39 | +// GetTokenInfo retrieves token info from the vault. |
| 40 | +func (ts *VaultTokenSet) GetTokenInfo(key string) (*tokens.TokenInfo, error) { |
| 41 | + key = tokens.FormatKey(key) |
| 42 | + path := ts.keyToPath(key) |
| 43 | + |
| 44 | + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) |
| 45 | + defer cancel() |
| 46 | + |
| 47 | + secret, err := ts.vault.Get(ctx, path) |
| 48 | + if err != nil { |
| 49 | + if isNotFoundError(err) { |
| 50 | + return nil, fmt.Errorf("%w: %s", ErrTokenNotFound, key) |
| 51 | + } |
| 52 | + return nil, fmt.Errorf("failed to get token info: %w", err) |
| 53 | + } |
| 54 | + |
| 55 | + tokenInfo, err := parseTokenInfo(secret) |
| 56 | + if err != nil { |
| 57 | + return nil, fmt.Errorf("failed to parse token info: %w", err) |
| 58 | + } |
| 59 | + |
| 60 | + return tokenInfo, nil |
| 61 | +} |
| 62 | + |
| 63 | +// GetToken retrieves just the OAuth2 token from the vault. |
| 64 | +func (ts *VaultTokenSet) GetToken(key string) (*oauth2.Token, error) { |
| 65 | + tokenInfo, err := ts.GetTokenInfo(key) |
| 66 | + if err != nil { |
| 67 | + return nil, err |
| 68 | + } |
| 69 | + if tokenInfo.Token == nil { |
| 70 | + return nil, fmt.Errorf("%w: token is nil for key %s", ErrTokenNotFound, key) |
| 71 | + } |
| 72 | + return tokenInfo.Token, nil |
| 73 | +} |
| 74 | + |
| 75 | +// SetTokenInfo stores token info in the vault. |
| 76 | +func (ts *VaultTokenSet) SetTokenInfo(key string, tokenInfo *tokens.TokenInfo) error { |
| 77 | + key = tokens.FormatKey(key) |
| 78 | + path := ts.keyToPath(key) |
| 79 | + |
| 80 | + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) |
| 81 | + defer cancel() |
| 82 | + |
| 83 | + secret, err := tokenInfoToSecret(tokenInfo) |
| 84 | + if err != nil { |
| 85 | + return fmt.Errorf("failed to serialize token info: %w", err) |
| 86 | + } |
| 87 | + |
| 88 | + if err := ts.vault.Set(ctx, path, secret); err != nil { |
| 89 | + return fmt.Errorf("failed to store token info: %w", err) |
| 90 | + } |
| 91 | + |
| 92 | + ts.logger.Debug("stored token info", |
| 93 | + "key", key, |
| 94 | + "service_key", tokenInfo.ServiceKey, |
| 95 | + "service_type", tokenInfo.ServiceType, |
| 96 | + ) |
| 97 | + |
| 98 | + return nil |
| 99 | +} |
| 100 | + |
| 101 | +// DeleteToken removes a token from the vault. |
| 102 | +func (ts *VaultTokenSet) DeleteToken(key string) error { |
| 103 | + key = tokens.FormatKey(key) |
| 104 | + path := ts.keyToPath(key) |
| 105 | + |
| 106 | + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) |
| 107 | + defer cancel() |
| 108 | + |
| 109 | + if err := ts.vault.Delete(ctx, path); err != nil { |
| 110 | + return fmt.Errorf("failed to delete token: %w", err) |
| 111 | + } |
| 112 | + |
| 113 | + ts.logger.Debug("deleted token", "key", key) |
| 114 | + return nil |
| 115 | +} |
| 116 | + |
| 117 | +// ListTokens returns all token keys in the vault. |
| 118 | +func (ts *VaultTokenSet) ListTokens() ([]string, error) { |
| 119 | + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) |
| 120 | + defer cancel() |
| 121 | + |
| 122 | + paths, err := ts.vault.List(ctx, ts.prefix) |
| 123 | + if err != nil { |
| 124 | + return nil, fmt.Errorf("failed to list tokens: %w", err) |
| 125 | + } |
| 126 | + |
| 127 | + keys := make([]string, 0, len(paths)) |
| 128 | + for _, path := range paths { |
| 129 | + key := ts.pathToKey(path) |
| 130 | + if key != "" { |
| 131 | + keys = append(keys, key) |
| 132 | + } |
| 133 | + } |
| 134 | + |
| 135 | + return keys, nil |
| 136 | +} |
| 137 | + |
| 138 | +// ExistsToken checks if a token exists in the vault. |
| 139 | +func (ts *VaultTokenSet) ExistsToken(key string) (bool, error) { |
| 140 | + key = tokens.FormatKey(key) |
| 141 | + path := ts.keyToPath(key) |
| 142 | + |
| 143 | + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) |
| 144 | + defer cancel() |
| 145 | + |
| 146 | + return ts.vault.Exists(ctx, path) |
| 147 | +} |
| 148 | + |
| 149 | +// keyToPath converts a token key to a vault path. |
| 150 | +func (ts *VaultTokenSet) keyToPath(key string) string { |
| 151 | + return ts.prefix + sanitizeKey(key) |
| 152 | +} |
| 153 | + |
| 154 | +// pathToKey converts a vault path to a token key. |
| 155 | +func (ts *VaultTokenSet) pathToKey(path string) string { |
| 156 | + if !strings.HasPrefix(path, ts.prefix) { |
| 157 | + return "" |
| 158 | + } |
| 159 | + return strings.TrimPrefix(path, ts.prefix) |
| 160 | +} |
| 161 | + |
| 162 | +// sanitizeKey sanitizes a key for use in vault paths. |
| 163 | +func sanitizeKey(key string) string { |
| 164 | + key = strings.TrimSpace(key) |
| 165 | + key = strings.ReplaceAll(key, "/", "_") |
| 166 | + key = strings.ReplaceAll(key, "\\", "_") |
| 167 | + return key |
| 168 | +} |
| 169 | + |
| 170 | +// parseTokenInfo parses a vault secret into TokenInfo. |
| 171 | +func parseTokenInfo(secret *vault.Secret) (*tokens.TokenInfo, error) { |
| 172 | + var tokenInfo tokens.TokenInfo |
| 173 | + |
| 174 | + // First try to unmarshal from the Value field (JSON) |
| 175 | + if secret.Value != "" { |
| 176 | + if err := json.Unmarshal([]byte(secret.Value), &tokenInfo); err == nil { |
| 177 | + return &tokenInfo, nil |
| 178 | + } |
| 179 | + } |
| 180 | + |
| 181 | + // Fall back to Fields for multi-field secrets |
| 182 | + if secret.Fields != nil { |
| 183 | + tokenInfo.ServiceKey = secret.Fields["service_key"] |
| 184 | + tokenInfo.ServiceType = secret.Fields["service_type"] |
| 185 | + |
| 186 | + if tokenJSON := secret.Fields["token"]; tokenJSON != "" { |
| 187 | + var token oauth2.Token |
| 188 | + if err := json.Unmarshal([]byte(tokenJSON), &token); err != nil { |
| 189 | + return nil, fmt.Errorf("failed to parse token field: %w", err) |
| 190 | + } |
| 191 | + tokenInfo.Token = &token |
| 192 | + } |
| 193 | + |
| 194 | + return &tokenInfo, nil |
| 195 | + } |
| 196 | + |
| 197 | + return nil, fmt.Errorf("secret has no parseable token info") |
| 198 | +} |
| 199 | + |
| 200 | +// tokenInfoToSecret converts TokenInfo to a vault secret. |
| 201 | +func tokenInfoToSecret(tokenInfo *tokens.TokenInfo) (*vault.Secret, error) { |
| 202 | + data, err := json.Marshal(tokenInfo) |
| 203 | + if err != nil { |
| 204 | + return nil, err |
| 205 | + } |
| 206 | + |
| 207 | + // Also store as fields for providers that support multi-field |
| 208 | + fields := make(map[string]string) |
| 209 | + fields["service_key"] = tokenInfo.ServiceKey |
| 210 | + fields["service_type"] = tokenInfo.ServiceType |
| 211 | + |
| 212 | + if tokenInfo.Token != nil { |
| 213 | + //nolint:gosec // G117: OAuth token stored in vault per RFC 6749 |
| 214 | + tokenData, err := json.Marshal(tokenInfo.Token) |
| 215 | + if err != nil { |
| 216 | + return nil, fmt.Errorf("failed to serialize token: %w", err) |
| 217 | + } |
| 218 | + fields["token"] = string(tokenData) |
| 219 | + fields["access_token"] = tokenInfo.Token.AccessToken |
| 220 | + if tokenInfo.Token.RefreshToken != "" { |
| 221 | + fields["refresh_token"] = tokenInfo.Token.RefreshToken |
| 222 | + } |
| 223 | + if !tokenInfo.Token.Expiry.IsZero() { |
| 224 | + fields["expiry"] = tokenInfo.Token.Expiry.Format(time.RFC3339) |
| 225 | + } |
| 226 | + } |
| 227 | + |
| 228 | + now := vault.Now() |
| 229 | + return &vault.Secret{ |
| 230 | + Value: string(data), |
| 231 | + Fields: fields, |
| 232 | + Metadata: vault.Metadata{ |
| 233 | + ModifiedAt: now, |
| 234 | + Tags: map[string]string{ |
| 235 | + "type": "oauth2_token", |
| 236 | + "service_key": tokenInfo.ServiceKey, |
| 237 | + "service_type": tokenInfo.ServiceType, |
| 238 | + }, |
| 239 | + }, |
| 240 | + }, nil |
| 241 | +} |
| 242 | + |
| 243 | +// isNotFoundError checks if an error indicates a secret was not found. |
| 244 | +func isNotFoundError(err error) bool { |
| 245 | + return err == vault.ErrSecretNotFound || |
| 246 | + strings.Contains(err.Error(), "not found") || |
| 247 | + strings.Contains(err.Error(), "does not exist") |
| 248 | +} |
| 249 | + |
| 250 | +// Ensure VaultTokenSet implements tokens.TokenSet. |
| 251 | +var _ tokens.TokenSet = (*VaultTokenSet)(nil) |
0 commit comments