Skip to content

Commit

Permalink
API Harmonization (#15)
Browse files Browse the repository at this point in the history
* Update README with latest API

* API Harmonization Work
- Fixed test cases parsing to follow json unmarshalling number to float64
- Update submodule to api harmonization tests
- URL and Domain validation
- Remove spaces from matching groups in URI
- Modify return type on `Verify`
- Add domain, nonce, timestamp optional parameters to `Verify`

* Add error message to error string
* Adding review comments changes
* Fix CreateEmpty test; Update submodule
* Update submodule
  • Loading branch information
theosirian committed May 9, 2022
1 parent 5d84159 commit a36f2a0
Show file tree
Hide file tree
Showing 8 changed files with 230 additions and 68 deletions.
15 changes: 14 additions & 1 deletion README.md
Expand Up @@ -65,7 +65,20 @@ can be done in a single call with verify:
var publicKey *ecdsa.PublicKey
var err error

publicKey, err = message.Verify(signature)
// Optional nonce variable to be matched against the
// built message struct being verified
var optionalNonce *string

// Optional timestamp variable to verify at any point
// in time, by default it will use `time.Now()`
var optionalTimestamp *time.Time

publicKey, err = message.Verify(signature, optionalNonce, optionalTimestamp)

// If you won't be using nonce matching and want
// to verify the message at current time, it's
// safe to pass `nil` in both arguments
publicKey, err = message.Verify(signature, nil, nil)
```

### Serialization of a SIWE Message
Expand Down
4 changes: 2 additions & 2 deletions errors.go
Expand Up @@ -9,11 +9,11 @@ type InvalidMessage struct{ string }
type InvalidSignature struct{ string }

func (m *ExpiredMessage) Error() string {
return "Expired Message"
return fmt.Sprintf("Expired Message: %s", m.string)
}

func (m *InvalidMessage) Error() string {
return "Invalid Message"
return fmt.Sprintf("Invalid Message: %s", m.string)
}

func (m *InvalidSignature) Error() string {
Expand Down
4 changes: 2 additions & 2 deletions message.go
Expand Up @@ -23,7 +23,7 @@ type Message struct {
notBefore *string

requestID *string
resources []string
resources []url.URL
}

func (m *Message) GetDomain() string {
Expand Down Expand Up @@ -102,6 +102,6 @@ func (m *Message) GetRequestID() *string {
return nil
}

func (m *Message) GetResources() []string {
func (m *Message) GetResources() []url.URL {
return m.resources
}
10 changes: 5 additions & 5 deletions regex.go
Expand Up @@ -5,12 +5,12 @@ import (
"regexp"
)

const _SIWE_DOMAIN = "^(?P<domain>([^?#]*)) wants you to sign in with your Ethereum account:\\n"
const _SIWE_DOMAIN = "(?P<domain>([^/?#]+)) wants you to sign in with your Ethereum account:\\n"
const _SIWE_ADDRESS = "(?P<address>0x[a-zA-Z0-9]{40})\\n\\n"
const _SIWE_STATEMENT = "((?P<statement>[^\\n]+)\\n)?\\n"
const _SIWE_URI = "(([^:?#]+):)?(([^?#]*))?([^?#]*)(\\?([^#]*))?(#(.*))"
const _RFC3986 = "(([^ :/?#]+):)?(//([^ /?#]*))?([^ ?#]*)(\\?([^ #]*))?(#(.*))?"

var _SIWE_URI_LINE = fmt.Sprintf("URI: (?P<uri>%s?)\\n", _SIWE_URI)
var _SIWE_URI_LINE = fmt.Sprintf("URI: (?P<uri>%s?)\\n", _RFC3986)

const _SIWE_VERSION = "Version: (?P<version>1)\\n"
const _SIWE_CHAIN_ID = "Chain ID: (?P<chainId>[0-9]+)\\n"
Expand All @@ -23,9 +23,9 @@ var _SIWE_NOT_BEFORE = fmt.Sprintf("(\\nNot Before: (?P<notBefore>%s))?", _SIWE_

const _SIWE_REQUEST_ID = "(\\nRequest ID: (?P<requestId>[-._~!$&'()*+,;=:@%a-zA-Z0-9]*))?"

var _SIWE_RESOURCES = fmt.Sprintf("(\\nResources:(?P<resources>(\\n- %s?)+))?$", _SIWE_URI)
var _SIWE_RESOURCES = fmt.Sprintf("(\\nResources:(?P<resources>(\\n- %s)+))?", _RFC3986)

var _SIWE_MESSAGE = regexp.MustCompile(fmt.Sprintf("%s%s%s%s%s%s%s%s%s%s%s%s",
var _SIWE_MESSAGE = regexp.MustCompile(fmt.Sprintf("^%s%s%s%s%s%s%s%s%s%s%s%s$",
_SIWE_DOMAIN,
_SIWE_ADDRESS,
_SIWE_STATEMENT,
Expand Down
123 changes: 103 additions & 20 deletions siwe.go
Expand Up @@ -13,28 +13,75 @@ import (
"github.com/ethereum/go-ethereum/crypto"
)

func InitMessage(domain, address, uri, version string, options map[string]interface{}) (*Message, error) {
validateURI, err := url.Parse(uri)
func buildAuthority(uri *url.URL) string {
authority := uri.Host
if uri.User != nil {
authority = fmt.Sprintf("%s@%s", uri.User.String(), authority)
}
return authority
}

func validateDomain(domain *string) (bool, error) {
if isEmpty(domain) {
return false, &InvalidMessage{"`domain` must not be empty"}
}

validateDomain, err := url.Parse(fmt.Sprintf("https://%s", *domain))
if err != nil {
return false, &InvalidMessage{"Invalid format for field `domain`"}
}

authority := buildAuthority(validateDomain)
if authority != *domain {
return false, &InvalidMessage{"Invalid format for field `domain`"}
}

return true, nil
}

func validateURI(uri *string) (*url.URL, error) {
if isEmpty(uri) {
return nil, &InvalidMessage{"`uri` must not be empty"}
}

validateURI, err := url.Parse(*uri)
if err != nil {
return nil, &InvalidMessage{"Invalid format for field `uri`"}
}

return validateURI, nil
}

// InitMessage creates a Message object with the provided parameters
func InitMessage(domain, address, uri, nonce string, options map[string]interface{}) (*Message, error) {
if ok, err := validateDomain(&domain); !ok {
return nil, err
}

if isEmpty(&address) {
return nil, &InvalidMessage{"`address` must not be empty"}
}

validateURI, err := validateURI(&uri)
if err != nil {
return nil, err
}

if isEmpty(&nonce) {
return nil, &InvalidMessage{"`nonce` must not be empty"}
}

var statement *string
if val, ok := options["statement"]; ok {
value := val.(string)
statement = &value
}

var nonce string
if val, ok := isStringAndNotEmpty(options, "nonce"); ok {
nonce = *val
} else {
return nil, &InvalidMessage{"Missing or empty `nonce` property"}
}

var chainId int
if val, ok := options["chainId"]; ok {
switch val.(type) {
case float64:
chainId = int(val.(float64))
case int:
chainId = val.(int)
case string:
Expand Down Expand Up @@ -87,21 +134,21 @@ func InitMessage(domain, address, uri, version string, options map[string]interf
requestID = val
}

var resources []string
var resources []url.URL
if val, ok := options["resources"]; ok {
switch val.(type) {
case []string:
resources = val.([]string)
case []url.URL:
resources = val.([]url.URL)
default:
return nil, &InvalidMessage{"`resources` must be a []string"}
return nil, &InvalidMessage{"`resources` must be a []url.URL"}
}
}

return &Message{
domain: domain,
address: common.HexToAddress(address),
uri: *validateURI,
version: version,
version: "1",

statement: statement,
nonce: nonce,
Expand Down Expand Up @@ -130,19 +177,45 @@ func parseMessage(message string) (map[string]interface{}, error) {
}
}

if _, ok := result["domain"]; !ok {
return nil, &InvalidMessage{"`domain` must not be empty"}
}
domain := result["domain"].(string)
if ok, err := validateDomain(&domain); !ok {
return nil, err
}

if _, ok := result["uri"]; !ok {
return nil, &InvalidMessage{"`domain` must not be empty"}
}
uri := result["uri"].(string)
if _, err := validateURI(&uri); err != nil {
return nil, err
}

originalAddress := result["address"].(string)
parsedAddress := common.HexToAddress(originalAddress)
if originalAddress != parsedAddress.String() {
return nil, &InvalidMessage{"Address must be in EIP-55 format"}
}

if val, ok := result["resources"]; ok {
result["resources"] = strings.Split(val.(string), "\n- ")[1:]
resources := strings.Split(val.(string), "\n- ")[1:]
validateResources := make([]url.URL, len(resources))
for i, resource := range resources {
validateResource, err := url.Parse(resource)
if err != nil {
return nil, &InvalidMessage{fmt.Sprintf("Invalid format for field `resources` at position %d", i)}
}
validateResources[i] = *validateResource
}
result["resources"] = validateResources
}

return result, nil
}

// ParseMessage returns a Message object by parsing an EIP-4361 formatted string
func ParseMessage(message string) (*Message, error) {
result, err := parseMessage(message)
if err != nil {
Expand All @@ -153,7 +226,7 @@ func ParseMessage(message string) (*Message, error) {
result["domain"].(string),
result["address"].(string),
result["uri"].(string),
result["version"].(string),
result["nonce"].(string),
result,
)

Expand All @@ -171,26 +244,29 @@ func (m *Message) eip191Hash() common.Hash {
return crypto.Keccak256Hash([]byte(msg))
}

// ValidNow validates the time constraints of the message at current time.
func (m *Message) ValidNow() (bool, error) {
return m.ValidAt(time.Now().UTC())
}

// ValidAt validates the time constraints of the message at a specific point in time.
func (m *Message) ValidAt(when time.Time) (bool, error) {
if m.expirationTime != nil {
if time.Now().UTC().After(*m.getExpirationTime()) {
if when.After(*m.getExpirationTime()) {
return false, &ExpiredMessage{"Message expired"}
}
}

if m.notBefore != nil {
if time.Now().UTC().Before(*m.getNotBefore()) {
if when.Before(*m.getNotBefore()) {
return false, &InvalidMessage{"Message not yet valid"}
}
}

return true, nil
}

// VerifyEIP191 validates the integrity of the object by matching it's signature.
func (m *Message) VerifyEIP191(signature string) (*ecdsa.PublicKey, error) {
if isEmpty(&signature) {
return nil, &InvalidSignature{"Signature cannot be empty"}
Expand Down Expand Up @@ -221,7 +297,8 @@ func (m *Message) VerifyEIP191(signature string) (*ecdsa.PublicKey, error) {
return pkey, nil
}

func (m *Message) Verify(signature string, nonce *string, timestamp *time.Time) (*ecdsa.PublicKey, error) {
// Verify validates time constraints and integrity of the object by matching it's signature.
func (m *Message) Verify(signature string, domain *string, nonce *string, timestamp *time.Time) (*ecdsa.PublicKey, error) {
var err error

if timestamp != nil {
Expand All @@ -234,6 +311,12 @@ func (m *Message) Verify(signature string, nonce *string, timestamp *time.Time)
return nil, err
}

if domain != nil {
if m.GetDomain() != *domain {
return nil, &InvalidSignature{"Message domain doesn't match"}
}
}

if nonce != nil {
if m.GetNonce() != *nonce {
return nil, &InvalidSignature{"Message nonce doesn't match"}
Expand Down Expand Up @@ -281,7 +364,7 @@ func (m *Message) prepareMessage() string {
if len(m.resources) > 0 {
resourcesArr := make([]string, len(m.resources))
for i, v := range m.resources {
resourcesArr[i] = fmt.Sprintf("- %s", v)
resourcesArr[i] = fmt.Sprintf("- %s", v.String())
}

resources := strings.Join(resourcesArr, "\n")
Expand Down

0 comments on commit a36f2a0

Please sign in to comment.