diff --git a/hotp.go b/hotp.go index a4b6b8d..e8f4503 100644 --- a/hotp.go +++ b/hotp.go @@ -59,7 +59,7 @@ func (t *hotp) URI(label string, issuer string) string { encodedLabel, encodedSecret, encodedIssuer, - t.otp.hashType, - t.otp.codeLength, + t.otp.HashType, + t.otp.CodeLength, t.Counter) } diff --git a/hotp_test.go b/hotp_test.go index f940fcc..fac54d8 100644 --- a/hotp_test.go +++ b/hotp_test.go @@ -115,6 +115,34 @@ func TestHTOPValidate(t *testing.T) { } } +func TestHTOPInvalidCodes(t *testing.T) { + var testCases = []struct { + Counter int + Input string + Expected bool + }{ + {0, "755223", false}, // off by one + {0, "aaaaaa", false}, + {99, "755223", false}, // off by one + } + + config := basicOTP.HOTPConfig{ + CodeLength: 6, + HashType: basicOTP.SHA1, + Secret: []byte("12345678901234567890"), // Sample secret, replace with actual secret + Counter: 0, + } + hopt := basicOTP.NewHTOP(config) + + for _, tc := range testCases { + t.Run(fmt.Sprintf("Count_%d", tc.Counter), func(t *testing.T) { + if hopt.Validate(tc.Input) != false { + t.Errorf("Failed to validate input Expected: %v, Got: %v", false, true) + } + }) + } +} + func TestHOPTURI(t *testing.T) { secretKey := []byte("Hello!") diff --git a/otp.go b/otp.go index d6163ff..ffb3588 100644 --- a/otp.go +++ b/otp.go @@ -24,15 +24,23 @@ const ( type OTP struct { hashFunc func() hash.Hash - hashType HashType + HashType HashType secret []byte - codeLength int + CodeLength int } // NewOTP creates a new instance of OTP based on the provided configuration. func NewOTP(secret []byte, hashType HashType, codeLength int) OTP { - var hashFunc func() hash.Hash + if len(secret) <= 0 { + panic("OTP requires a secret to be set") + } + + if codeLength == 0 { + codeLength = 6 // default in RFC 4226 + } + + var hashFunc func() hash.Hash switch hashType { case SHA1: hashFunc = sha1.New @@ -40,15 +48,16 @@ func NewOTP(secret []byte, hashType HashType, codeLength int) OTP { hashFunc = sha256.New case SHA512: hashFunc = sha512.New - default: + default: // if hashType is unknown, default to SHA1 hashFunc = sha1.New + hashType = SHA1 } return OTP{ secret: secret, hashFunc: hashFunc, - hashType: hashType, - codeLength: codeLength, + HashType: hashType, + CodeLength: codeLength, } } @@ -56,18 +65,15 @@ func NewOTP(secret []byte, hashType HashType, codeLength int) OTP { This is the base implenation, the input here can be used for TOP (Time based) or HOPT (incremental) */ func (o OTP) Generate(input int) string { - if input < 0 { - panic("Input must be < 0") - } hmac := hmac.New(o.hashFunc, []byte(o.secret)) buf := Itob(input) hmac.Write(buf) hmacData := hmac.Sum(nil) - code := truncate(hmacData, o.codeLength) + code := truncate(hmacData, o.CodeLength) - formatString := fmt.Sprintf("%%0%dd", o.codeLength) + formatString := fmt.Sprintf("%%0%dd", o.CodeLength) return fmt.Sprintf(formatString, code) } diff --git a/otp_test.go b/otp_test.go index 654b96f..749cafd 100644 --- a/otp_test.go +++ b/otp_test.go @@ -20,6 +20,24 @@ func TestOTPLength(t *testing.T) { } } +func TestOTPPanic(t *testing.T) { + defer func() { + if r := recover(); r == nil { + t.Errorf("The code did not panic") + } + }() + + basicOTP.NewOTP(nil, basicOTP.SHA1, 6) +} + +func TestOTPDefaultCodeLen(t *testing.T) { + otp := basicOTP.NewOTP([]byte("test"), basicOTP.SHA1, 0) + + if len(otp.Generate(123)) != 6 { + t.Error("OTP Default code length was not set to 6") + } +} + func TestOTPConsistency(t *testing.T) { secretKey := []byte("test") codeLength := 6 @@ -75,3 +93,13 @@ func TestOTPRandomness(t *testing.T) { t.Errorf("Different inputs produce the same TOTP. Input 1: %s, Input 2: %s", output1, output2) } } + +func TestNewOTPDefaultHash(t *testing.T) { + // Pass an invalid hashType + otp := basicOTP.NewOTP([]byte("mysecret"), basicOTP.HashType("fake-hash"), 6) + + // Verify that sha1.New is used as the default hash function + if otp.HashType != basicOTP.SHA1 { + t.Errorf("Expected default hash function to be sha1, got %v", otp.HashType) + } +} diff --git a/totp.go b/totp.go index 87de904..63a175d 100644 --- a/totp.go +++ b/totp.go @@ -10,7 +10,7 @@ import ( // TOTP represents a Time-based One-Time Password generator. type TOTP struct { otp OTP - timePeriod int + TimePeriod int } // TOTPConfig holds configuration parameters for TOTP generation. @@ -24,22 +24,14 @@ type TOTPConfig struct { // NewTOTP creates a new instance of TOTP based on the provided configuration. func NewTOTP(config TOTPConfig) *TOTP { - if len(config.Secret) <= 0 { - panic("a secret must be provided for TOTP") - } - if config.TimeInterval == 0 { // Set default time to 30 seconds, recommended in rfc6238 config.TimeInterval = 30 } - if config.CodeLength == 0 { - config.CodeLength = 6 - } - return &TOTP{ otp: NewOTP(config.Secret, config.HashType, config.CodeLength), - timePeriod: config.TimeInterval, + TimePeriod: config.TimeInterval, } } @@ -81,11 +73,11 @@ func (t *TOTP) URI(label string, issuer string) string { encodedLabel, encodedSecret, encodedIssuer, - t.otp.hashType, - t.otp.codeLength) + t.otp.HashType, + t.otp.CodeLength) } // timecode calculates the timecode based on the provided Unix timestamp. func (t *TOTP) timecode(unixTimeStamp int64) int { - return int(unixTimeStamp) / t.timePeriod + return int(unixTimeStamp) / t.TimePeriod } diff --git a/totp_test.go b/totp_test.go index 60dc10f..40496f1 100644 --- a/totp_test.go +++ b/totp_test.go @@ -140,7 +140,7 @@ func TestValidate(t *testing.T) { } } -func TesTOTPURI(t *testing.T) { +func TestTOTPURI(t *testing.T) { secretKey := []byte("Hello!") totpConfig := basicOTP.TOTPConfig{ @@ -159,3 +159,15 @@ func TesTOTPURI(t *testing.T) { } } + +func TestTOTPDefaultTimeInterval(t *testing.T) { + topt := basicOTP.NewTOTP(basicOTP.TOTPConfig{ + CodeLength: 6, + HashType: basicOTP.SHA1, + Secret: []byte("Hello!"), + }) + + if topt.TimePeriod != 30 { + t.Error("TOPT default time period was not set to 30 seconds") + } +}