forked from DrmagicE/gmqtt
-
Notifications
You must be signed in to change notification settings - Fork 0
/
auth.go
179 lines (161 loc) · 3.89 KB
/
auth.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
package auth
import (
"crypto/md5"
"crypto/sha256"
"encoding/hex"
"errors"
"fmt"
"hash"
"io/ioutil"
"os"
"path"
"sync"
"go.uber.org/zap"
"golang.org/x/crypto/bcrypt"
"gopkg.in/yaml.v2"
"github.com/DrmagicE/gmqtt/config"
"github.com/DrmagicE/gmqtt/plugin/admin"
"github.com/DrmagicE/gmqtt/server"
)
var _ server.Plugin = (*Auth)(nil)
const Name = "auth"
func init() {
server.RegisterPlugin(Name, New)
config.RegisterDefaultPluginConfig(Name, &DefaultConfig)
}
func New(config config.Config) (server.Plugin, error) {
a := &Auth{
config: config.Plugins[Name].(*Config),
indexer: admin.NewIndexer(),
pwdDir: config.ConfigDir,
}
a.saveFile = a.saveFileHandler
return a, nil
}
var log *zap.Logger
// Auth provides the username/password authentication for gmqtt.
// The authentication data is persist in config.PasswordFile.
type Auth struct {
config *Config
pwdDir string
// gard indexer
mu sync.RWMutex
// store username/password
indexer *admin.Indexer
// saveFile persists the account data to password file.
saveFile func() error
}
// generatePassword generates the hashed password for the plain password.
func (a *Auth) generatePassword(password string) (hashedPassword string, err error) {
var h hash.Hash
switch a.config.Hash {
case Plain:
return password, nil
case MD5:
h = md5.New()
case SHA256:
h = sha256.New()
case Bcrypt:
pwd, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.MinCost)
return string(pwd), err
default:
// just in case.
panic("invalid hash type")
}
_, err = h.Write([]byte(password))
if err != nil {
return "", err
}
rs := h.Sum(nil)
return hex.EncodeToString(rs), nil
}
func (a *Auth) mustEmbedUnimplementedAccountServiceServer() {
return
}
func (a *Auth) validate(username, password string) (permitted bool, err error) {
a.mu.RLock()
elem := a.indexer.GetByID(username)
a.mu.RUnlock()
var hashedPassword string
if elem == nil {
return false, nil
}
ac := elem.Value.(*Account)
hashedPassword = ac.Password
var h hash.Hash
switch a.config.Hash {
case Plain:
return hashedPassword == password, nil
case MD5:
h = md5.New()
case SHA256:
h = sha256.New()
case Bcrypt:
return bcrypt.CompareHashAndPassword([]byte(hashedPassword), []byte(password)) == nil, nil
default:
// just in case.
panic("invalid hash type")
}
_, err = h.Write([]byte(password))
if err != nil {
return false, err
}
rs := h.Sum(nil)
return hashedPassword == hex.EncodeToString(rs), nil
}
var registerAPI = func(service server.Server, a *Auth) error {
apiRegistrar := service.APIRegistrar()
RegisterAccountServiceServer(apiRegistrar, a)
err := apiRegistrar.RegisterHTTPHandler(RegisterAccountServiceHandlerFromEndpoint)
return err
}
func (a *Auth) Load(service server.Server) error {
err := registerAPI(service, a)
log = server.LoggerWithField(zap.String("plugin", Name))
var pwdFile string
if path.IsAbs(a.config.PasswordFile) {
pwdFile = a.config.PasswordFile
} else {
pwdFile = path.Join(a.pwdDir, a.config.PasswordFile)
}
f, err := os.OpenFile(pwdFile, os.O_CREATE|os.O_RDONLY, 0666)
if err != nil {
return err
}
defer f.Close()
b, err := ioutil.ReadAll(f)
if err != nil {
return err
}
var acts []*Account
err = yaml.Unmarshal(b, &acts)
if err != nil {
return err
}
log.Info("authentication data loaded",
zap.String("hash", a.config.Hash),
zap.Int("account_nums", len(acts)),
zap.String("password_file", pwdFile))
dup := make(map[string]struct{})
for _, v := range acts {
if v.Username == "" {
return errors.New("detect empty username in password file")
}
if _, ok := dup[v.Username]; ok {
return fmt.Errorf("detect duplicated username in password file: %s", v.Username)
}
dup[v.Username] = struct{}{}
}
a.mu.Lock()
defer a.mu.Unlock()
for _, v := range acts {
a.indexer.Set(v.Username, v)
}
return nil
}
func (a *Auth) Unload() error {
return nil
}
func (a *Auth) Name() string {
return Name
}