Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support validities in templates #534

Merged
merged 6 commits into from
Jul 1, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 36 additions & 4 deletions internal/templates/funcmap.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,20 @@ package templates
import (
"errors"
"text/template"
"time"

"github.com/Masterminds/sprig/v3"
"go.step.sm/crypto/jose"
)

// GetFuncMap returns the list of functions provided by sprig. It changes the
// function "fail" to set the given string, this way we can report template
// errors directly to the template without having the wrapper that text/template
// adds.
// GetFuncMap returns the list of functions provided by sprig. It adds the
// function "toTime" and changes the function "fail".
//
// The "toTime" function receives a time or a Unix epoch and formats it to
// RFC3339 in UTC. The "fail" function sets the provided message, so that
// template errors are reported directly to the template without having the
// wrapper that text/template adds.
//
//
// sprig "env" and "expandenv" functions are removed to avoid the leak of
// information.
Expand All @@ -22,5 +28,31 @@ func GetFuncMap(failMessage *string) template.FuncMap {
*failMessage = msg
return "", errors.New(msg)
}
m["toTime"] = toTime
return m
}

func toTime(v any) string {
var t time.Time
switch date := v.(type) {
case time.Time:
t = date
case *time.Time:
t = *date
case int64:
t = time.Unix(date, 0)
case float64: // from json
t = time.Unix(int64(date), 0)
case int:
t = time.Unix(int64(date), 0)
case int32:
t = time.Unix(int64(date), 0)
case jose.NumericDate:
t = date.Time()
case *jose.NumericDate:
t = date.Time()
default:
t = time.Now()
}
return t.UTC().Format(time.RFC3339)
}
53 changes: 53 additions & 0 deletions internal/templates/funcmap_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,11 @@ package templates
import (
"errors"
"testing"
"time"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.step.sm/crypto/jose"
)

func Test_GetFuncMap_fail(t *testing.T) {
Expand All @@ -20,3 +25,51 @@ func Test_GetFuncMap_fail(t *testing.T) {
t.Errorf("fail() message = \"%s\", want \"the fail message\"", failMesage)
}
}

func TestGetFuncMap_toTime(t *testing.T) {
now := time.Now()
numericDate := jose.NewNumericDate(now)
expected := now.UTC().Format(time.RFC3339)
loc, err := time.LoadLocation("America/Los_Angeles")
require.NoError(t, err)

type args struct {
v any
}
tests := []struct {
name string
args args
want string
}{
{"time", args{now}, expected},
{"time pointer", args{&now}, expected},
{"time UTC", args{now.UTC()}, expected},
{"time with location", args{now.In(loc)}, expected},
{"unix", args{now.Unix()}, expected},
{"unix int", args{int(now.Unix())}, expected},
{"unix int32", args{int32(now.Unix())}, expected},
{"unix float64", args{float64(now.Unix())}, expected},
{"unix float64", args{float64(now.Unix()) + 0.999}, expected},
{"jose.NumericDate", args{*numericDate}, expected},
{"jose.NumericDate pointer", args{numericDate}, expected},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var failMesage string
fns := GetFuncMap(&failMesage)
fn := fns["toTime"].(func(any) string)
assert.Equal(t, tt.want, fn(tt.args.v))
})
}

t.Run("default", func(t *testing.T) {
var failMesage string
fns := GetFuncMap(&failMesage)
fn := fns["toTime"].(func(any) string)
want := time.Now()
got, err := time.Parse(time.RFC3339, fn(nil))
require.NoError(t, err)
assert.WithinDuration(t, want, got, time.Second)
assert.Equal(t, time.UTC, got.Location())
})
}
16 changes: 12 additions & 4 deletions sshutil/certificate.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"crypto/rand"
"encoding/binary"
"encoding/json"
"time"

"github.com/pkg/errors"
"go.step.sm/crypto/randutil"
Expand All @@ -20,8 +21,8 @@ type Certificate struct {
Type CertType `json:"type"`
KeyID string `json:"keyId"`
Principals []string `json:"principals"`
ValidAfter uint64 `json:"-"`
ValidBefore uint64 `json:"-"`
ValidAfter time.Time `json:"validAfter"`
ValidBefore time.Time `json:"validBefore"`
CriticalOptions map[string]string `json:"criticalOptions"`
Extensions map[string]string `json:"extensions"`
Reserved []byte `json:"reserved"`
Expand Down Expand Up @@ -62,8 +63,8 @@ func (c *Certificate) GetCertificate() *ssh.Certificate {
CertType: uint32(c.Type),
KeyId: c.KeyID,
ValidPrincipals: c.Principals,
ValidAfter: c.ValidAfter,
ValidBefore: c.ValidBefore,
ValidAfter: toValidity(c.ValidAfter),
ValidBefore: toValidity(c.ValidBefore),
Permissions: ssh.Permissions{
CriticalOptions: c.CriticalOptions,
Extensions: c.Extensions,
Expand Down Expand Up @@ -124,3 +125,10 @@ func CreateCertificate(cert *ssh.Certificate, signer ssh.Signer) (*ssh.Certifica

return cert, nil
}

func toValidity(t time.Time) uint64 {
if t.IsZero() {
return 0
}
return uint64(t.Unix())
}
85 changes: 65 additions & 20 deletions sshutil/certificate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@ import (
"io"
"reflect"
"testing"
"time"

"github.com/stretchr/testify/assert"
"golang.org/x/crypto/ssh"
)

Expand Down Expand Up @@ -71,6 +73,7 @@ func mustGeneratePublicKey(t *testing.T) ssh.PublicKey {
}

func TestNewCertificate(t *testing.T) {
now := time.Now().Truncate(time.Second)
key := mustGeneratePublicKey(t)
cr := CertificateRequest{
Key: key,
Expand Down Expand Up @@ -100,8 +103,8 @@ func TestNewCertificate(t *testing.T) {
Type: UserCert,
KeyID: "jane@doe.com",
Principals: []string{"jane"},
ValidAfter: 0,
ValidBefore: 0,
ValidAfter: time.Time{},
ValidBefore: time.Time{},
CriticalOptions: nil,
Extensions: map[string]string{
"permit-X11-forwarding": "",
Expand All @@ -121,8 +124,8 @@ func TestNewCertificate(t *testing.T) {
Type: HostCert,
KeyID: "foobar",
Principals: []string{"foo.internal", "bar.internal"},
ValidAfter: 0,
ValidBefore: 0,
ValidAfter: time.Time{},
ValidBefore: time.Time{},
CriticalOptions: nil,
Extensions: nil,
Reserved: nil,
Expand All @@ -136,8 +139,8 @@ func TestNewCertificate(t *testing.T) {
Type: HostCert,
KeyID: `foobar", "criticalOptions": {"foo": "bar"},"foo":"`,
Principals: []string{"foo.internal", "bar.internal"},
ValidAfter: 0,
ValidBefore: 0,
ValidAfter: time.Time{},
ValidBefore: time.Time{},
CriticalOptions: nil,
Extensions: nil,
Reserved: nil,
Expand All @@ -159,8 +162,8 @@ func TestNewCertificate(t *testing.T) {
Type: UserCert,
KeyID: "john@doe.com",
Principals: []string{"john", "john@doe.com"},
ValidAfter: 0,
ValidBefore: 0,
ValidAfter: time.Time{},
ValidBefore: time.Time{},
CriticalOptions: nil,
Extensions: map[string]string{
"login@github.com": "john",
Expand All @@ -174,15 +177,47 @@ func TestNewCertificate(t *testing.T) {
SignatureKey: nil,
Signature: nil,
}, false},
{"file with dates", args{cr, []Option{WithTemplateFile("./testdata/date.tpl", TemplateData{
TypeKey: UserCert,
KeyIDKey: "john@doe.com",
PrincipalsKey: []string{"john", "john@doe.com"},
ExtensionsKey: DefaultExtensions(UserCert),
InsecureKey: TemplateData{
"User": map[string]interface{}{"username": "john"},
},
WebhooksKey: TemplateData{
"Test": map[string]interface{}{"validity": "16h"},
},
})}}, &Certificate{
Nonce: nil,
Key: key,
Serial: 0,
Type: UserCert,
KeyID: "john@doe.com",
Principals: []string{"john", "john@doe.com"},
ValidAfter: now,
ValidBefore: now.Add(16 * time.Hour),
CriticalOptions: nil,
Extensions: map[string]string{
"permit-X11-forwarding": "",
"permit-agent-forwarding": "",
"permit-port-forwarding": "",
"permit-pty": "",
"permit-user-rc": "",
},
Reserved: nil,
SignatureKey: nil,
Signature: nil,
}, false},
{"base64", args{cr, []Option{WithTemplateBase64(base64.StdEncoding.EncodeToString([]byte(DefaultTemplate)), CreateTemplateData(HostCert, "foo.internal", nil))}}, &Certificate{
Nonce: nil,
Key: key,
Serial: 0,
Type: HostCert,
KeyID: "foo.internal",
Principals: nil,
ValidAfter: 0,
ValidBefore: 0,
ValidAfter: time.Time{},
ValidBefore: time.Time{},
CriticalOptions: nil,
Extensions: nil,
Reserved: nil,
Expand All @@ -203,6 +238,15 @@ func TestNewCertificate(t *testing.T) {
t.Errorf("NewCertificate() error = %v, wantErr %v", err, tt.wantErr)
return
}
if got != nil && tt.want != nil {
if assert.WithinDuration(t, tt.want.ValidAfter, got.ValidAfter, 2*time.Second) {
tt.want.ValidAfter = got.ValidAfter
}
if assert.WithinDuration(t, tt.want.ValidBefore, got.ValidBefore, 2*time.Second) {
tt.want.ValidBefore = got.ValidBefore
}

}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("NewCertificate() = %v, want %v", got, tt.want)
}
Expand All @@ -212,6 +256,7 @@ func TestNewCertificate(t *testing.T) {

func TestCertificate_GetCertificate(t *testing.T) {
key := mustGeneratePublicKey(t)
now := time.Now()

type fields struct {
Nonce []byte
Expand All @@ -220,8 +265,8 @@ func TestCertificate_GetCertificate(t *testing.T) {
Type CertType
KeyID string
Principals []string
ValidAfter uint64
ValidBefore uint64
ValidAfter time.Time
ValidBefore time.Time
CriticalOptions map[string]string
Extensions map[string]string
Reserved []byte
Expand All @@ -240,8 +285,8 @@ func TestCertificate_GetCertificate(t *testing.T) {
Type: UserCert,
KeyID: "key-id",
Principals: []string{"john"},
ValidAfter: 1111,
ValidBefore: 2222,
ValidAfter: now,
ValidBefore: now.Add(time.Hour),
CriticalOptions: map[string]string{"foo": "bar"},
Extensions: map[string]string{"login@github.com": "john"},
Reserved: []byte("reserved"),
Expand All @@ -254,8 +299,8 @@ func TestCertificate_GetCertificate(t *testing.T) {
CertType: ssh.UserCert,
KeyId: "key-id",
ValidPrincipals: []string{"john"},
ValidAfter: 1111,
ValidBefore: 2222,
ValidAfter: uint64(now.Unix()),
ValidBefore: uint64(now.Add(time.Hour).Unix()),
Permissions: ssh.Permissions{
CriticalOptions: map[string]string{"foo": "bar"},
Extensions: map[string]string{"login@github.com": "john"},
Expand All @@ -269,8 +314,8 @@ func TestCertificate_GetCertificate(t *testing.T) {
Type: HostCert,
KeyID: "key-id",
Principals: []string{"foo.internal", "bar.internal"},
ValidAfter: 1111,
ValidBefore: 2222,
ValidAfter: time.Time{},
ValidBefore: time.Time{},
CriticalOptions: map[string]string{"foo": "bar"},
Extensions: nil,
Reserved: []byte("reserved"),
Expand All @@ -283,8 +328,8 @@ func TestCertificate_GetCertificate(t *testing.T) {
CertType: ssh.HostCert,
KeyId: "key-id",
ValidPrincipals: []string{"foo.internal", "bar.internal"},
ValidAfter: 1111,
ValidBefore: 2222,
ValidAfter: 0,
ValidBefore: 0,
Permissions: ssh.Permissions{
CriticalOptions: map[string]string{"foo": "bar"},
Extensions: nil,
Expand Down
8 changes: 8 additions & 0 deletions sshutil/testdata/date.tpl
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
{
"type": "{{ .Type }}",
"keyId": "{{ .KeyID }}",
"principals": {{ toJson .Principals }},
"extensions": {{ toJson .Extensions }},
"validAfter": {{ now | toJson }},
"validBefore": {{ now | dateModify .Webhooks.Test.validity | toJson }}
}
7 changes: 7 additions & 0 deletions x509util/certificate.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"crypto/rand"
"crypto/x509"
"encoding/json"
"time"

"github.com/pkg/errors"
)
Expand All @@ -23,6 +24,8 @@ type Certificate struct {
IPAddresses MultiIP `json:"ipAddresses"`
URIs MultiURL `json:"uris"`
SANs []SubjectAlternativeName `json:"sans"`
NotBefore time.Time `json:"notBefore"`
NotAfter time.Time `json:"notAfter"`
Extensions []Extension `json:"extensions"`
KeyUsage KeyUsage `json:"keyUsage"`
ExtKeyUsage ExtKeyUsage `json:"extKeyUsage"`
Expand Down Expand Up @@ -165,6 +168,10 @@ func (c *Certificate) GetCertificate() *x509.Certificate {
e.Set(cert)
}

// Validity bounds.
cert.NotBefore = c.NotBefore
cert.NotAfter = c.NotAfter

// Others.
c.SerialNumber.Set(cert)
c.SignatureAlgorithm.Set(cert)
Expand Down
Loading
Loading