/
cert.go
130 lines (105 loc) · 3.32 KB
/
cert.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
package server
import (
"crypto/rand"
"fmt"
"time"
"github.com/owenthereal/upterm/upterm"
"golang.org/x/crypto/ssh"
"google.golang.org/protobuf/proto"
)
var (
errCertNotSignedByHost = fmt.Errorf("ssh cert not signed by host")
)
type UserCertChecker struct {
UserKeyFallback func(user string, key ssh.PublicKey) (ssh.PublicKey, error)
}
// Authenticate tries to pass auth request and public key from a cert.
// If the public key is not a cert, it calls the UserKeyFallback func. Otherwise it returns an error.
func (c *UserCertChecker) Authenticate(user string, key ssh.PublicKey) (*AuthRequest, ssh.PublicKey, error) {
cert, ok := key.(*ssh.Certificate)
if !ok {
if c.UserKeyFallback != nil {
key, err := c.UserKeyFallback(user, key)
return nil, key, err
}
return nil, nil, fmt.Errorf("public key not a cert")
}
return parseAuthRequestFromCert(user, cert)
}
// parseAuthRequestFromCert parses auth request and public key from a cert.
// The public key is always the signature key of the cert.
func parseAuthRequestFromCert(principal string, cert *ssh.Certificate) (*AuthRequest, ssh.PublicKey, error) {
key := cert.SignatureKey
if cert.CertType != ssh.UserCert {
return nil, key, fmt.Errorf("ssh: cert has type %d", cert.CertType)
}
checker := &ssh.CertChecker{}
if err := checker.CheckCert(principal, cert); err != nil {
return nil, key, err
}
if cert.Permissions.Extensions == nil {
return nil, key, errCertNotSignedByHost
}
ext, ok := cert.Permissions.Extensions[upterm.SSHCertExtension]
if !ok {
return nil, key, errCertNotSignedByHost
}
var auth AuthRequest
if err := proto.Unmarshal([]byte(ext), &auth); err != nil {
return nil, key, err
}
key, _, _, _, err := ssh.ParseAuthorizedKey(auth.AuthorizedKey)
if err != nil {
return nil, key, fmt.Errorf("error parsing public key from auth request: %w", err)
}
return &auth, key, nil
}
type UserCertSigner struct {
SessionID string
User string
AuthRequest *AuthRequest
}
func (g *UserCertSigner) SignCert(signer ssh.Signer) (ssh.Signer, error) {
b, err := proto.Marshal(g.AuthRequest)
if err != nil {
return nil, fmt.Errorf("error marshaling auth request: %w", err)
}
at := time.Now()
bt := at.Add(1 * time.Minute) // cert valid for 1 min
cert := &ssh.Certificate{
Key: signer.PublicKey(),
CertType: ssh.UserCert,
KeyId: g.SessionID,
ValidPrincipals: []string{g.User},
ValidAfter: uint64(at.Unix()),
ValidBefore: uint64(bt.Unix()),
Permissions: ssh.Permissions{
Extensions: map[string]string{upterm.SSHCertExtension: string(b)},
},
}
// TODO: use differnt key to sign
if err := cert.SignCert(rand.Reader, signer); err != nil {
return nil, fmt.Errorf("error signing host cert: %w", err)
}
cs, err := ssh.NewCertSigner(cert, signer)
if err != nil {
return nil, fmt.Errorf("error generating host signer: %w", err)
}
return cs, nil
}
type HostCertSigner struct {
Hostnames []string
}
func (s *HostCertSigner) SignCert(signer ssh.Signer) (ssh.Signer, error) {
cert := &ssh.Certificate{
Key: signer.PublicKey(),
CertType: ssh.HostCert,
KeyId: "uptermd",
ValidPrincipals: s.Hostnames,
ValidBefore: ssh.CertTimeInfinity,
}
if err := cert.SignCert(rand.Reader, signer); err != nil {
return nil, err
}
return ssh.NewCertSigner(cert, signer)
}