-
Notifications
You must be signed in to change notification settings - Fork 281
/
jwtcache.go
164 lines (133 loc) · 3.49 KB
/
jwtcache.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
package cliutil
import (
"encoding/json"
"errors"
"fmt"
"io/ioutil"
"os"
"path/filepath"
"sync"
"time"
"github.com/martinlindhe/base36"
"gopkg.in/square/go-jose.v2"
"github.com/pomerium/pomerium/pkg/cryptutil"
)
// predefined cache errors
var (
ErrExpired = errors.New("expired")
ErrInvalid = errors.New("invalid")
ErrNotFound = errors.New("not found")
)
// A JWTCache loads and stores JWTs.
type JWTCache interface {
DeleteJWT(key string) error
LoadJWT(key string) (rawJWT string, err error)
StoreJWT(key string, rawJWT string) error
}
// A LocalJWTCache stores files in the user's cache directory.
type LocalJWTCache struct {
dir string
}
// NewLocalJWTCache creates a new LocalJWTCache.
func NewLocalJWTCache() (*LocalJWTCache, error) {
root, err := os.UserCacheDir()
if err != nil {
return nil, err
}
dir := filepath.Join(root, "pomerium-cli", "jwts")
err = os.MkdirAll(dir, 0o755)
if err != nil {
return nil, fmt.Errorf("error creating user cache directory: %w", err)
}
return &LocalJWTCache{
dir: dir,
}, nil
}
// DeleteJWT deletes a raw JWT from the local cache.
func (cache *LocalJWTCache) DeleteJWT(key string) error {
path := filepath.Join(cache.dir, cache.fileName(key))
err := os.Remove(path)
if os.IsNotExist(err) {
err = nil
}
return err
}
// LoadJWT loads a raw JWT from the local cache.
func (cache *LocalJWTCache) LoadJWT(key string) (rawJWT string, err error) {
path := filepath.Join(cache.dir, cache.fileName(key))
rawBS, err := ioutil.ReadFile(path)
if os.IsNotExist(err) {
return "", ErrNotFound
} else if err != nil {
return "", err
}
rawJWT = string(rawBS)
return rawJWT, checkExpiry(rawJWT)
}
// StoreJWT stores a raw JWT in the local cache.
func (cache *LocalJWTCache) StoreJWT(key string, rawJWT string) error {
path := filepath.Join(cache.dir, cache.fileName(key))
err := ioutil.WriteFile(path, []byte(rawJWT), 0o600)
if err != nil {
return err
}
return nil
}
func (cache *LocalJWTCache) hash(str string) string {
h := cryptutil.Hash("LocalJWTCache", []byte(str))
return base36.EncodeBytes(h)
}
func (cache *LocalJWTCache) fileName(key string) string {
return cache.hash(key) + ".jwt"
}
// A MemoryJWTCache stores JWTs in an in-memory map.
type MemoryJWTCache struct {
mu sync.Mutex
entries map[string]string
}
// NewMemoryJWTCache creates a new in-memory JWT cache.
func NewMemoryJWTCache() *MemoryJWTCache {
return &MemoryJWTCache{entries: make(map[string]string)}
}
// DeleteJWT deletes a JWT from the in-memory map.
func (cache *MemoryJWTCache) DeleteJWT(key string) error {
cache.mu.Lock()
defer cache.mu.Unlock()
delete(cache.entries, key)
return nil
}
// LoadJWT loads a JWT from the in-memory map.
func (cache *MemoryJWTCache) LoadJWT(key string) (rawJWT string, err error) {
cache.mu.Lock()
defer cache.mu.Unlock()
rawJWT, ok := cache.entries[key]
if !ok {
return "", ErrNotFound
}
return rawJWT, checkExpiry(rawJWT)
}
// StoreJWT stores a JWT in the in-memory map.
func (cache *MemoryJWTCache) StoreJWT(key string, rawJWT string) error {
cache.mu.Lock()
defer cache.mu.Unlock()
cache.entries[key] = rawJWT
return nil
}
func checkExpiry(rawJWT string) error {
tok, err := jose.ParseSigned(rawJWT)
if err != nil {
return ErrInvalid
}
var claims struct {
Expiry int64 `json:"exp"`
}
err = json.Unmarshal(tok.UnsafePayloadWithoutVerification(), &claims)
if err != nil {
return ErrInvalid
}
expiresAt := time.Unix(claims.Expiry, 0)
if expiresAt.Before(time.Now()) {
return ErrExpired
}
return nil
}