Skip to content

Commit

Permalink
recognize any base64 variant in keygen, more error logging.
Browse files Browse the repository at this point in the history
  • Loading branch information
or-else committed Oct 15, 2018
1 parent 0b7bff9 commit 9512c58
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 16 deletions.
10 changes: 9 additions & 1 deletion keygen/keygen.go
Expand Up @@ -68,6 +68,10 @@ func generate(sequence, isRoot int, hmacSaltB64 string) int {
} else {
var err error
hmacSalt, err = base64.URLEncoding.DecodeString(hmacSaltB64)
if err != nil {
// Try standard base64 decoding
hmacSalt, err = base64.StdEncoding.DecodeString(hmacSaltB64)
}
if err != nil {
log.Println("Failed to decode HMAC salt", err)
return 1
Expand Down Expand Up @@ -109,6 +113,10 @@ func validate(apikey string, hmacSaltB64 string) int {
var strIsRoot string

hmacSalt, err := base64.URLEncoding.DecodeString(hmacSaltB64)
if err != nil {
// Try standard base64 decoding
hmacSalt, err = base64.StdEncoding.DecodeString(hmacSaltB64)
}
if err != nil {
log.Println("Failed to decode HMAC salt", err)
return 1
Expand All @@ -121,7 +129,7 @@ func validate(apikey string, hmacSaltB64 string) int {

data, err := base64.URLEncoding.DecodeString(apikey)
if err != nil {
log.Println("Failed to decode key as base64", err)
log.Println("Failed to decode key as base64-URL-encoded", err)
return 1
}

Expand Down
49 changes: 34 additions & 15 deletions server/session.go
Expand Up @@ -225,7 +225,7 @@ func (s *Session) dispatchRaw(raw []byte) {

if err := json.Unmarshal(raw, &msg); err != nil {
// Malformed message
log.Println("s.dispatch", err)
log.Println("s.dispatch", err, s.sid)
s.queueOut(ErrMalformed("", "", time.Now().UTC().Round(time.Millisecond)))
return
}
Expand Down Expand Up @@ -264,6 +264,7 @@ func (s *Session) dispatch(msg *ClientComMessage) {
checkVers := func(m *ClientComMessage, handler func(*ClientComMessage)) func(*ClientComMessage) {
return func(m *ClientComMessage) {
if s.ver == 0 {
log.Println("s.dispatch: {hi} is missing", s.sid)
s.queueOut(ErrCommandOutOfSequence(m.id, m.topic, m.timestamp))
return
}
Expand All @@ -275,6 +276,7 @@ func (s *Session) dispatch(msg *ClientComMessage) {
checkUser := func(m *ClientComMessage, handler func(*ClientComMessage)) func(*ClientComMessage) {
return func(m *ClientComMessage) {
if msg.from == "" {
log.Println("s.dispatch: authentication required", s.sid)
s.queueOut(ErrAuthRequired(m.id, m.topic, msg.timestamp))
return
}
Expand Down Expand Up @@ -363,17 +365,19 @@ func (s *Session) subscribe(msg *ClientComMessage) {
var err *ServerComMessage
expanded, err = s.expandTopicName(msg)
if err != nil {
log.Println("s.subscribe:", err, s.sid)
s.queueOut(err)
return
}
}

if sub := s.getSub(expanded); sub != nil {
log.Println("s.subscribe: already subscribed to topic=", expanded, "sid=", s.sid)
log.Println("s.subscribe: already subscribed to topic=", expanded, s.sid)
s.queueOut(InfoAlreadySubscribed(msg.id, msg.topic, msg.timestamp))
} else if globals.cluster.isRemoteTopic(expanded) {
// The topic is handled by a remote node. Forward message to it.
if err := globals.cluster.routeToTopic(msg, expanded, s); err != nil {
log.Println("s.subscribe:", err, s.sid)
s.queueOut(ErrClusterNodeUnreachable(msg.id, msg.topic, msg.timestamp))
}
} else {
Expand All @@ -390,6 +394,7 @@ func (s *Session) leave(msg *ClientComMessage) {
// Expand topic name
expanded, err := s.expandTopicName(msg)
if err != nil {
log.Println("s.leave:", err, s.sid)
s.queueOut(err)
return
}
Expand All @@ -412,6 +417,7 @@ func (s *Session) leave(msg *ClientComMessage) {
} else if globals.cluster.isRemoteTopic(expanded) {
// The topic is handled by a remote node. Forward message to it.
if err := globals.cluster.routeToTopic(msg, expanded, s); err != nil {
log.Println("s.leave:", err, s.sid)
s.queueOut(ErrClusterNodeUnreachable(msg.id, msg.topic, msg.timestamp))
}
} else if !msg.Leave.Unsub {
Expand All @@ -420,6 +426,7 @@ func (s *Session) leave(msg *ClientComMessage) {
} else {
// Session wants to unsubscribe from the topic it did not join
// FIXME(gene): allow topic to unsubscribe without joining first; send to hub to unsub
log.Println("s.leave:", "must attach first", s.sid)
s.queueOut(ErrAttachFirst(msg.id, msg.topic, msg.timestamp))
}
}
Expand All @@ -429,6 +436,7 @@ func (s *Session) publish(msg *ClientComMessage) {
// TODO(gene): Check for repeated messages with the same ID
expanded, err := s.expandTopicName(msg)
if err != nil {
log.Println("s.publish:", err, s.sid)
s.queueOut(err)
return
}
Expand All @@ -455,11 +463,13 @@ func (s *Session) publish(msg *ClientComMessage) {
} else if globals.cluster.isRemoteTopic(expanded) {
// The topic is handled by a remote node. Forward message to it.
if err := globals.cluster.routeToTopic(msg, expanded, s); err != nil {
log.Println("s.publish:", err, s.sid)
s.queueOut(ErrClusterNodeUnreachable(msg.id, msg.topic, msg.timestamp))
}
} else {
// Publish request received without attaching to topic first.
s.queueOut(ErrAttachFirst(msg.id, msg.topic, msg.timestamp))
log.Println("s.publish:", "must attach first", s.sid)
}
}

Expand All @@ -470,13 +480,15 @@ func (s *Session) hello(msg *ClientComMessage) {
if s.ver == 0 {
s.ver = parseVersion(msg.Hi.Version)
if s.ver == 0 {
log.Println("s.hello:", "failed to parse version", s.sid)
s.queueOut(ErrMalformed(msg.id, "", msg.timestamp))
return
}
// Check version compatibility
if versionCompare(s.ver, minSupportedVersionValue) < 0 {
s.ver = 0
s.queueOut(ErrVersionNotSupported(msg.id, "", msg.timestamp))
log.Println("s.hello:", "unsupported version", s.sid)
return
}
params = map[string]interface{}{"ver": currentVersion, "build": store.GetAdapterName() + ":" + buildstamp}
Expand All @@ -490,13 +502,15 @@ func (s *Session) hello(msg *ClientComMessage) {
LastSeen: msg.timestamp,
Lang: msg.Hi.Lang,
}); err != nil {
log.Println("s.hello:", "database error", err, s.sid)
s.queueOut(ErrUnknown(msg.id, "", msg.timestamp))
return
}
}
} else {
// Version cannot be changed mid-session.
s.queueOut(ErrCommandOutOfSequence(msg.id, "", msg.timestamp))
log.Println("s.hello:", "version cannot be changed", s.sid)
return
}

Expand Down Expand Up @@ -531,6 +545,7 @@ func (s *Session) acc(msg *ClientComMessage) {
if msg.Acc.Token != nil {
if !s.uid.IsZero() {
s.queueOut(ErrAlreadyAuthenticated(msg.Acc.Id, "", msg.timestamp))
log.Println("s.acc: got token while already authenticated", s.sid)
return
}

Expand All @@ -539,6 +554,7 @@ func (s *Session) acc(msg *ClientComMessage) {
if err != nil {
s.queueOut(decodeStoreError(err, msg.Acc.Id, "", msg.timestamp,
map[string]interface{}{"what": "auth"}))
log.Println("s.acc: invalid token", err, s.sid)
return
}
}
Expand All @@ -550,18 +566,20 @@ func (s *Session) acc(msg *ClientComMessage) {
// The session cannot authenticate with the new account because it's already authenticated.
if msg.Acc.Login && !s.uid.IsZero() {
s.queueOut(ErrAlreadyAuthenticated(msg.id, "", msg.timestamp))
log.Println("s.acc: login requested while already authenticated", s.sid)
return
}

if authhdl == nil {
// New accounts must have an authentication scheme
s.queueOut(ErrMalformed(msg.id, "", msg.timestamp))
log.Println("s.acc: unknown auth handler", s.sid)
return
}

// Check if login is unique.
if ok, err := authhdl.IsUnique(msg.Acc.Secret); !ok {
log.Println("auth: check unique failed", err)
log.Println("s.acc: auth secret is not unique", err, s.sid)
s.queueOut(decodeStoreError(err, msg.id, "", msg.timestamp,
map[string]interface{}{"what": "auth"}))
return
Expand All @@ -576,7 +594,7 @@ func (s *Session) acc(msg *ClientComMessage) {

if tags := normalizeTags(msg.Acc.Tags); tags != nil {
if !restrictedTagsEqual(tags, nil, globals.immutableTagNS) {
log.Println("Attempt to directly assign restricted tags")
log.Println("a.acc: attempt to directly assign restricted tags", s.sid)
msg := ErrPermissionDenied(msg.id, "", msg.timestamp)
msg.Ctrl.Params = map[string]interface{}{"what": "tags"}
s.queueOut(msg)
Expand All @@ -592,7 +610,7 @@ func (s *Session) acc(msg *ClientComMessage) {
cr := &creds[i]
vld := store.GetValidator(cr.Method)
if err := vld.PreCheck(cr.Value, cr.Params); err != nil {
log.Println("Failed credential pre-check", cr, err)
log.Println("a.acc: failed credential pre-check", cr, err, s.sid)
s.queueOut(decodeStoreError(err, msg.Acc.Id, "", msg.timestamp,
map[string]interface{}{"what": cr.Method}))
return
Expand Down Expand Up @@ -629,14 +647,14 @@ func (s *Session) acc(msg *ClientComMessage) {
}

if _, err := store.Users.Create(&user, private); err != nil {
log.Println("Failed to create user", err)
log.Println("a.acc: failed to create user", err, s.sid)
s.queueOut(ErrUnknown(msg.id, "", msg.timestamp))
return
}

rec, err := authhdl.AddRecord(&auth.Rec{Uid: user.Uid()}, msg.Acc.Secret)
if err != nil {
log.Println("auth: add record failed", err)
log.Println("s.acc: add auth record failed", err, s.sid)
// Attempt to delete incomplete user record
store.Users.Delete(user.Uid(), false)
s.queueOut(decodeStoreError(err, msg.id, "", msg.timestamp, nil))
Expand All @@ -646,7 +664,7 @@ func (s *Session) acc(msg *ClientComMessage) {
// When creating an account, the user must provide all required credentials.
// If any are missing, reject the request.
if len(creds) < len(globals.authValidators[rec.AuthLevel]) {
log.Println("missing credentials; have:", creds, "want:", globals.authValidators[rec.AuthLevel])
log.Println("s.acc: missing credentials; have:", creds, "want:", globals.authValidators[rec.AuthLevel], s.sid)
// Attempt to delete incomplete user record
store.Users.Delete(user.Uid(), false)
_, missing := stringSliceDelta(globals.authValidators[rec.AuthLevel], credentialMethods(creds))
Expand All @@ -665,7 +683,7 @@ func (s *Session) acc(msg *ClientComMessage) {
cr := &creds[i]
vld := store.GetValidator(cr.Method)
if err := vld.Request(user.Uid(), cr.Value, s.lang, cr.Response, tmpToken); err != nil {
log.Println("Failed to save or validate credential", err)
log.Println("s.acc: failed to save or validate credential", err, s.sid)
// Delete incomplete user record.
store.Users.Delete(user.Uid(), false)
s.queueOut(decodeStoreError(err, msg.id, "", msg.timestamp,
Expand Down Expand Up @@ -709,12 +727,12 @@ func (s *Session) acc(msg *ClientComMessage) {

if s.uid.IsZero() && rec == nil {
// Session is not authenticated and no token provided.
log.Println("acc failed: not a new account and not authenticated", s.sid)
log.Println("s.acc: not a new account and not authenticated", s.sid)
s.queueOut(ErrPermissionDenied(msg.id, "", msg.timestamp))
return
} else if msg.from != "" && rec != nil {
// Two UIDs: one from msg.from, one from token. Ambigous, reject.
log.Println("acc failed: both authenticated session and token", s.sid)
log.Println("s.acc: got both authenticated session and token", s.sid)
s.queueOut(ErrMalformed(msg.id, "", msg.timestamp))
return
}
Expand All @@ -727,7 +745,7 @@ func (s *Session) acc(msg *ClientComMessage) {
}
if msg.Acc.User != "" && msg.Acc.User != userId {
if s.authLvl != auth.LevelRoot {
log.Println("acc failed: attempt to change another's account by non-root", s.sid)
log.Println("s.acc: attempt to change another's account by non-root", s.sid)
s.queueOut(ErrPermissionDenied(msg.id, "", msg.timestamp))
return
}
Expand All @@ -740,6 +758,7 @@ func (s *Session) acc(msg *ClientComMessage) {
if uid.IsZero() || authLvl == auth.LevelNone {
// Either msg.Acc.User or msg.Acc.AuthLevel contains invalid data.
s.queueOut(ErrMalformed(msg.id, "", msg.timestamp))
log.Println("s.acc: either user id or auth level is missing", s.sid)
return
}

Expand All @@ -748,20 +767,20 @@ func (s *Session) acc(msg *ClientComMessage) {
// Request to update auth of an existing account. Only basic auth is currently supported
// TODO(gene): support adding new auth schemes
if err := authhdl.UpdateRecord(&auth.Rec{Uid: uid}, msg.Acc.Secret); err != nil {
log.Println("auth: failed to update secret", err)
log.Println("s.acc: failed to update auth secret", err, s.sid)
s.queueOut(decodeStoreError(err, msg.id, "", msg.timestamp, nil))
return
}
} else if msg.Acc.Scheme != "" {
// Invalid or unknown auth scheme
log.Println("auth: unknown auth scheme", msg.Acc.Scheme)
log.Println("s.acc: unknown auth scheme", msg.Acc.Scheme, s.sid)
s.queueOut(ErrMalformed(msg.id, "", msg.timestamp))
return
} else if len(msg.Acc.Cred) > 0 {
// Use provided credentials for validation.
validated, err := s.getValidatedGred(uid, authLvl, msg.Acc.Cred)
if err != nil {
log.Println("failed to get validated credentials", err)
log.Println("s.acc: failed to get validated credentials", err, s.sid)
s.queueOut(decodeStoreError(err, msg.id, "", msg.timestamp, nil))
return
}
Expand Down

0 comments on commit 9512c58

Please sign in to comment.