/
cli.go
247 lines (222 loc) · 9.42 KB
/
cli.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
package cli
import (
"crypto/x509"
"encoding/json"
"encoding/pem"
"fmt"
"os"
"strings"
"github.com/square/certigo/cli/terminal"
"github.com/square/certigo/lib"
"github.com/square/certigo/starttls"
"gopkg.in/alecthomas/kingpin.v2"
)
var (
app = kingpin.New("certigo", "A command-line utility to examine and validate certificates to help with debugging SSL/TLS issues.")
verbose = app.Flag("verbose", "Print verbose").Short('v').Bool()
dump = app.Command("dump", "Display information about a certificate from a file or stdin.")
dumpFiles = dump.Arg("file", "Certificate file to dump (or stdin if not specified).").ExistingFiles()
dumpType = dump.Flag("format", "Format of given input (PEM, DER, JCEKS, PKCS12; heuristic if missing).").Short('f').String()
dumpPassword = dump.Flag("password", "Password for PKCS12/JCEKS key stores (reads from TTY if missing).").Short('p').String()
dumpPem = dump.Flag("pem", "Write output as PEM blocks instead of human-readable format.").Short('m').Bool()
dumpJSON = dump.Flag("json", "Write output as machine-readable JSON format.").Short('j').Bool()
connect = app.Command("connect", "Connect to a server and print its certificate(s).")
connectTo = connect.Arg("server[:port]", "Hostname or IP to connect to, with optional port.").Required().String()
connectName = connect.Flag("name", "Override the server name used for Server Name Indication (SNI).").Short('n').String()
connectCaPath = connect.Flag("ca", "Path to CA bundle (system default if unspecified).").ExistingFile()
connectCert = connect.Flag("cert", "Client certificate chain for connecting to server (PEM).").ExistingFile()
connectKey = connect.Flag("key", "Private key for client certificate, if not in same file (PEM).").ExistingFile()
connectStartTLS = connect.Flag("start-tls", fmt.Sprintf("Enable StartTLS protocol; one of: %v.", starttls.Protocols)).Short('t').PlaceHolder("PROTOCOL").Enum(starttls.Protocols...)
connectIdentity = connect.Flag("identity", "With --start-tls, sets the DB user or SMTP EHLO name").Default("certigo").String()
connectProxy = connect.Flag("proxy", "Optional URI for HTTP(s) CONNECT proxy to dial connections with").URL()
connectTimeout = connect.Flag("timeout", "Timeout for connecting to remote server (can be '5m', '1s', etc).").Default("5s").Duration()
connectPem = connect.Flag("pem", "Write output as PEM blocks instead of human-readable format.").Short('m').Bool()
connectJSON = connect.Flag("json", "Write output as machine-readable JSON format.").Short('j').Bool()
connectVerify = connect.Flag("verify", "Verify certificate chain.").Bool()
connectVerifyExpectedName = connect.Flag("expected-name", "Name expected in the server TLS certificate. Defaults to name from SNI or, if SNI not overridden, the hostname to connect to.").String()
verify = app.Command("verify", "Verify a certificate chain from file/stdin against a name.")
verifyFile = verify.Arg("file", "Certificate file to dump (or stdin if not specified).").ExistingFile()
verifyType = verify.Flag("format", "Format of given input (PEM, DER, JCEKS, PKCS12; heuristic if missing).").Short('f').String()
verifyPassword = verify.Flag("password", "Password for PKCS12/JCEKS key stores (reads from TTY if missing).").Short('p').String()
verifyName = verify.Flag("name", "Server name to verify certificate against.").Short('n').Required().String()
verifyCaPath = verify.Flag("ca", "Path to CA bundle (system default if unspecified).").ExistingFile()
verifyJSON = verify.Flag("json", "Write output as machine-readable JSON format.").Short('j').Bool()
)
const (
version = "1.13.0"
)
func Run(args []string, tty terminal.Terminal) int {
terminalWidth := tty.DetermineWidth()
stdout := tty.Output()
errOut := tty.Error()
printErr := func(format string, args ...interface{}) int {
_, err := fmt.Fprintf(errOut, format, args...)
if err != nil {
// If we can't write the error, we bail with a different return code... not much good
// we can do at this point
return 3
}
return 2
}
app.HelpFlag.Short('h')
app.Version(version)
// Alias starttls to start-tls
connect.Flag("starttls", "").Hidden().EnumVar(connectStartTLS, starttls.Protocols...)
// Use long help because many useful flags are under subcommands
app.UsageTemplate(kingpin.LongHelpTemplate)
result := lib.SimpleResult{}
command, err := app.Parse(args)
if err != nil {
return printErr("%s, try --help\n", err)
}
switch command {
case dump.FullCommand(): // Dump certificate
if dumpPassword != nil && *dumpPassword != "" {
tty.SetDefaultPassword(*dumpPassword)
}
files, err := inputFiles(*dumpFiles)
defer func() {
for _, file := range files {
file.Close()
}
}()
if *dumpPem {
err = lib.ReadAsPEMFromFiles(files, *dumpType, tty.ReadPassword, func(block *pem.Block, format string) error {
block.Headers = nil
return pem.Encode(stdout, block)
})
} else {
err = lib.ReadAsX509FromFiles(files, *dumpType, tty.ReadPassword, func(cert *x509.Certificate, format string, err error) error {
if err != nil {
return fmt.Errorf("error parsing block: %s\n", strings.TrimSuffix(err.Error(), "\n"))
} else {
result.Certificates = append(result.Certificates, cert)
result.Formats = append(result.Formats, format)
}
return nil
})
if *dumpJSON {
blob, _ := json.Marshal(result)
fmt.Println(string(blob))
} else {
for i, cert := range result.Certificates {
fmt.Fprintf(stdout, "** CERTIFICATE %d **\n", i+1)
fmt.Fprintf(stdout, "Input Format: %s\n", result.Formats[i])
fmt.Fprintf(stdout, "%s\n\n", lib.EncodeX509ToText(cert, terminalWidth, *verbose))
}
}
}
if err != nil {
return printErr("error: %s\n", strings.TrimSuffix(err.Error(), "\n"))
} else if len(result.Certificates) == 0 && !*dumpPem {
printErr("warning: no certificates found in input\n")
}
case connect.FullCommand(): // Get certs by connecting to a server
if connectStartTLS == nil && connectIdentity != nil {
return printErr("error: --identity can only be used with --start-tls")
}
connState, cri, err := starttls.GetConnectionState(
*connectStartTLS, *connectName, *connectTo, *connectIdentity,
*connectCert, *connectKey, *connectProxy, *connectTimeout)
if err != nil {
return printErr("%s\n", strings.TrimSuffix(err.Error(), "\n"))
}
result.TLSConnectionState = connState
result.CertificateRequestInfo = cri
for _, cert := range connState.PeerCertificates {
if *connectPem {
pem.Encode(stdout, lib.EncodeX509ToPEM(cert, nil))
} else {
result.Certificates = append(result.Certificates, cert)
}
}
// Determine what name the server's certificate should match
var expectedNameInCertificate string
switch {
case *connectVerifyExpectedName != "":
// Use the explicitly provided name
expectedNameInCertificate = *connectVerifyExpectedName
case *connectName != "":
// Use the provided SNI
expectedNameInCertificate = *connectName
default:
// Use the hostname/IP from the connect string
expectedNameInCertificate = strings.Split(*connectTo, ":")[0]
}
verifyResult := lib.VerifyChain(connState.PeerCertificates, connState.OCSPResponse, expectedNameInCertificate, *connectCaPath)
result.VerifyResult = &verifyResult
if *connectJSON {
blob, _ := json.Marshal(result)
fmt.Println(string(blob))
} else if !*connectPem {
fmt.Fprintf(
stdout, "%s\n\n",
lib.EncodeTLSInfoToText(result.TLSConnectionState, result.CertificateRequestInfo))
for i, cert := range result.Certificates {
fmt.Fprintf(stdout, "** CERTIFICATE %d **\n", i+1)
fmt.Fprintf(stdout, "%s\n\n", lib.EncodeX509ToText(cert, terminalWidth, *verbose))
}
lib.PrintVerifyResult(stdout, *result.VerifyResult)
}
if *connectVerify && len(result.VerifyResult.Error) > 0 {
return 1
}
case verify.FullCommand():
if verifyPassword != nil && *verifyPassword != "" {
tty.SetDefaultPassword(*verifyPassword)
}
file, err := inputFile(*verifyFile)
if err != nil {
return printErr("%s\n", err.Error())
}
defer file.Close()
chain := []*x509.Certificate{}
err = lib.ReadAsX509FromFiles([]*os.File{file}, *verifyType, tty.ReadPassword, func(cert *x509.Certificate, format string, err error) error {
if err != nil {
return err
} else {
chain = append(chain, cert)
}
return nil
})
if err != nil {
return printErr("error parsing block: %s\n", strings.TrimSuffix(err.Error(), "\n"))
}
verifyResult := lib.VerifyChain(chain, nil, *verifyName, *verifyCaPath)
if *verifyJSON {
blob, _ := json.Marshal(verifyResult)
fmt.Println(string(blob))
} else {
lib.PrintVerifyResult(stdout, verifyResult)
}
if verifyResult.Error != "" {
return 1
}
}
return 0
}
func inputFile(fileName string) (*os.File, error) {
if fileName == "" {
return os.Stdin, nil
}
rawFile, err := os.Open(fileName)
if err != nil {
return nil, fmt.Errorf("unable to open file: %s\n", err)
}
return rawFile, nil
}
func inputFiles(fileNames []string) ([]*os.File, error) {
var files []*os.File
if fileNames != nil {
for _, filename := range fileNames {
rawFile, err := os.Open(filename)
if err != nil {
return nil, fmt.Errorf("unable to open file: %s\n", err)
}
files = append(files, rawFile)
}
} else {
files = append(files, os.Stdin)
}
return files, nil
}