/
config.go
79 lines (67 loc) · 1.55 KB
/
config.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
package main
import (
"fmt"
"io/fs"
"os/user"
"regexp"
"github.com/BurntSushi/toml"
)
type sshConnection struct {
Addr string `toml:"addr"`
User string `toml:"user"`
Identity []string `toml:"identity"`
KnownHosts string `toml:"known_hosts"`
}
type Connection struct {
Addr string `toml:"addr"`
Dbname string `toml:"dbname"`
Ssh sshConnection `toml:"ssh"`
}
type config struct {
fs fs.FS
Connections map[string]*Connection
}
var hasPort = regexp.MustCompile(`:\d+$`)
func clarifyKnownPort(addr string, kp int16) string {
if hasPort.MatchString(addr) {
return addr
}
return fmt.Sprintf("%s:%d", addr, kp)
}
func parseConfig(fs fs.FS, path string) (*config, error) {
r := config{
fs: fs,
Connections: map[string]*Connection{},
}
if _, err := toml.DecodeFS(fs, path, &r.Connections); err != nil {
return nil, err
}
for name, conf := range r.Connections {
if conf.Addr == "" {
return nil, fmt.Errorf("requires: `addr`")
}
conf.Addr = clarifyKnownPort(conf.Addr, 5432)
if conf.Dbname == "" {
conf.Dbname = name
}
if conf.Ssh.Addr == "" {
return nil, fmt.Errorf("requires: `ssh.addr`")
}
conf.Ssh.Addr = clarifyKnownPort(conf.Ssh.Addr, 22)
if conf.Ssh.User == "" {
if u, _ := user.Current(); u != nil {
conf.Ssh.User = u.Username
}
}
if conf.Ssh.Identity == nil {
conf.Ssh.Identity = []string{
"~/.ssh/id_rsa",
"~/.ssh/id_ed25519",
}
}
if conf.Ssh.KnownHosts == "" {
conf.Ssh.KnownHosts = "~/.ssh/known_hosts"
}
}
return &r, nil
}