Skip to content

Commit

Permalink
feat!: encrypt secret info with machine id
Browse files Browse the repository at this point in the history
- Add SSX_UNSAFE_MODE env var
- Add SSX_SECRET_KEY env var
  • Loading branch information
vimiix committed Jun 11, 2024
1 parent fd69945 commit 34fd274
Show file tree
Hide file tree
Showing 13 changed files with 202 additions and 34 deletions.
4 changes: 3 additions & 1 deletion cmd/ssx/cmd/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,10 @@ ssx 100 pwd`,
root.Flags().StringVarP(&opt.Tag, "tag", "t", "", "search entry by tag")
root.Flags().StringVarP(&opt.IdentityFile, "keyfile", "k", "", "identity_file path")
root.Flags().StringVarP(&opt.JumpServers, "jump-server", "J", "", "jump servers, multiple jump hops may be specified separated by comma characters\nformat: [user1@]host1[:port1][,[user2@]host2[:port2]...]")
root.Flags().StringVarP(&opt.Command, "cmd", "c", "", "the command to execute\nssh connection will exit after the execution complete")
root.Flags().StringVarP(&opt.Command, "cmd", "c", "", "excute the command and exit")
root.Flags().DurationVar(&opt.Timeout, "timeout", 0, "timeout for connecting and executing command")
root.Flags().IntVarP(&opt.Port, "port", "p", 22, "port to connect to on the remote host")
root.Flags().BoolVar(&opt.Unsafe, "unsafe", false, "store host secret information with unsafe format")

root.PersistentFlags().BoolVarP(&printVersion, "version", "v", false, "print ssx version")
root.PersistentFlags().BoolVar(&logVerbose, "verbose", false, "output detail logs")
Expand Down
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ require (
require (
github.com/chzyer/readline v0.0.0-20180603132655-2972be24d48e // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/denisbrodbeck/machineid v1.0.1 // indirect
github.com/inconshreveable/mousetrap v1.1.0 // indirect
github.com/mattn/go-colorable v0.1.13 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ github.com/containerd/console v1.0.4/go.mod h1:YynlIjWYF8myEu6sdkwKIvGQq+cOckRm6
github.com/cpuguy83/go-md2man/v2 v2.0.3/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/denisbrodbeck/machineid v1.0.1 h1:geKr9qtkB876mXguW2X6TU4ZynleN6ezuMSRhl4D7AQ=
github.com/denisbrodbeck/machineid v1.0.1/go.mod h1:dJUwb7PTidGDeYyUBmXZ2GphQBbjJCrnectwCyxcUSI=
github.com/fatih/color v1.17.0 h1:GlRw1BRJxkpqUCBKzKOw098ed57fEsKeNjpTe3cSjK4=
github.com/fatih/color v1.17.0/go.mod h1:YZ7TlrGPkiz6ku9fK3TLD/pl3CpsiFyu8N92HLgmosI=
github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1/go.mod h1:wJfORRmW1u3UXTncJ5qlYoELFm8eSnnEO6hX4iZ3EWY=
Expand Down
18 changes: 16 additions & 2 deletions internal/encrypt/encrypt.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,21 @@ import (
// Encrypt Generates the ciphertext for the given string.
// If the encryption fails, the original characters will be returned.
// If the passed string is empty, return empty directly.
func Encrypt(text string) string {
func Encrypt(text string, unsafe bool) string {
if text == "" {
return ""
}

curTime := time.Now().Format("01021504")
salt := md5encode(curTime)
key := salt[:8] + curTime
if !unsafe {
secretKey, err := utils.GetSecretKey()
if err != nil {
lg.Warn("failed to get secret key: %v", err)
}
key += secretKey
}

cipherText, err := aesEncrypt(text, key)
if err != nil {
Expand All @@ -41,7 +48,7 @@ func Encrypt(text string) string {
return base64.StdEncoding.EncodeToString([]byte(salt[:8] + shiftEncode(curTime) + cipherText))
}

func Decrypt(rawCipher string) string {
func Decrypt(rawCipher string, unsafe bool) string {
if rawCipher == "" {
return ""
}
Expand All @@ -54,6 +61,13 @@ func Decrypt(rawCipher string) string {

key := string(dec[:8]) + shiftDecode(string(dec[8:16]))
text := string(dec[16:])
if !unsafe {
secretKey, err := utils.GetSecretKey()
if err != nil {
lg.Warn("failed to get secret key: %v", err)
}
key += secretKey
}
res, err := aesDecrypt(text, key)
if err != nil {
lg.Debug("failed to decypt cipher '%s': %s", text, err)
Expand Down
4 changes: 2 additions & 2 deletions internal/encrypt/encrypt_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ func TestEncryptDecrypt(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
actual := Decrypt(Encrypt(tt.text))
actual := Decrypt(Encrypt(tt.text, false), false)
assert.Equal(t, tt.text, actual)
})
}
Expand All @@ -35,7 +35,7 @@ func TestDecrypt(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
actual := Decrypt(tt.cipher)
actual := Decrypt(tt.cipher, true)
assert.Equal(t, tt.expect, actual)
})
}
Expand Down
48 changes: 48 additions & 0 deletions internal/utils/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,12 @@ import (
"os"
"os/user"
"path/filepath"
"regexp"
"strings"

"github.com/denisbrodbeck/machineid"
"github.com/pkg/errors"
"github.com/vimiix/ssx/ssx/env"
)

// FileExists check given filename if exists
Expand Down Expand Up @@ -59,3 +64,46 @@ func ContainsI(s, sub string) bool {
strings.ToLower(sub),
)
}

type Address struct {
User string
Host string
Port string
}

var addrRegex = regexp.MustCompile(`^(?:(?P<user>[\w.-_]+)@)?(?P<host>[\w.-]+)(?::(?P<port>\d+))?(?:/(?P<path>[\w/.-]+))?$`)

func MatchAddress(addr string) (*Address, error) {
matches := addrRegex.FindStringSubmatch(addr)
if len(matches) == 0 {
return nil, errors.Errorf("invalid address: %q", addr)
}
username, host, port := matches[1], matches[2], matches[3]
addrObj := &Address{
User: username,
Host: host,
Port: port,
}
return addrObj, nil
}

func to16chars(s string) string {
if len(s) >= 16 {
return s[:16]
}
return s + strings.Repeat("=", 16-len(s))
}

// GetSecretKey get secret key from env, if not set returns machine id
// always returns 16 characters key
func GetSecretKey() (string, error) {
if os.Getenv(env.SSXSecretKey) != "" {
return to16chars(os.Getenv(env.SSXSecretKey)), nil
}
// ref: https://man7.org/linux/man-pages/man5/machine-id.5.html
machineID, err := machineid.ProtectedID("ssx")
if err != nil {
return "", err
}
return to16chars(machineID), nil
}
40 changes: 40 additions & 0 deletions internal/utils/utils_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@ package utils

import (
"fmt"
"os"
"os/user"
"path"
"testing"

"github.com/stretchr/testify/assert"
"github.com/vimiix/ssx/ssx/env"
)

func TestFileExists(t *testing.T) {
Expand Down Expand Up @@ -68,3 +70,41 @@ func TestMaskString(t *testing.T) {
assert.Equal(t, tt.expect, actual)
}
}

func TestMatchAddress(t *testing.T) {
tests := []struct {
addr string
username string
host string
port string
}{
{"user@host:22", "user", "host", "22"},
{"host:2222", "", "host", "2222"},
{"host", "", "host", ""},
{"a.b@1.1.1.1", "a.b", "1.1.1.1", ""},
{"a_b@1.1.1.1", "a_b", "1.1.1.1", ""},
}
for _, tt := range tests {
t.Run(tt.addr, func(t *testing.T) {
addr, err := MatchAddress(tt.addr)
assert.NoError(t, err)
assert.Equal(t, tt.username, addr.User)
assert.Equal(t, tt.host, addr.Host)
assert.Equal(t, tt.port, addr.Port)
})
}
}

func TestGetSecretKeyShort(t *testing.T) {
os.Setenv(env.SSXSecretKey, "abc")
res, err := GetSecretKey()
assert.NoError(t, err)
assert.Equal(t, "abc=============", res)
}

func TestGetSecretKeyLong(t *testing.T) {
os.Setenv(env.SSXSecretKey, "abcdefghijklmnopqrstuvwxyz")
res, err := GetSecretKey()
assert.NoError(t, err)
assert.Equal(t, "abcdefghijklmnop", res)
}
8 changes: 4 additions & 4 deletions ssx/bbolt/bbolt.go
Original file line number Diff line number Diff line change
Expand Up @@ -205,8 +205,8 @@ func NewRepo(file string) *Repo {
}

func encodeEntry(e *entry.Entry) ([]byte, error) {
e.Password = encrypt.Encrypt(e.Password)
e.Passphrase = encrypt.Encrypt(e.Passphrase)
e.Password = encrypt.Encrypt(e.Password, e.IsUnsafe())
e.Passphrase = encrypt.Encrypt(e.Passphrase, e.IsUnsafe())
return json.Marshal(e)
}

Expand All @@ -215,7 +215,7 @@ func decodeEntry(bs []byte) (*entry.Entry, error) {
if err := json.Unmarshal(bs, e); err != nil {
return nil, err
}
e.Password = encrypt.Decrypt(e.Password)
e.Passphrase = encrypt.Decrypt(e.Passphrase)
e.Password = encrypt.Decrypt(e.Password, e.IsUnsafe())
e.Passphrase = encrypt.Decrypt(e.Passphrase, e.IsUnsafe())
return e, nil
}
16 changes: 15 additions & 1 deletion ssx/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,16 @@ func (c *Client) Login(ctx context.Context) error {
lg.Debug("connecting to %s", c.entry.String())
cli, err := c.dial(ctx)
if err != nil {
return err
// try fix authentication
if c.entry.ID != 0 {
lg.Error("login failed with stored authentication, try login with interactive")
cli, err = c.tryLoginAgainWithEmptyPassword(ctx)
if err != nil {
return err
}
} else {
return err
}
}
c.cli = cli
if err := c.touchEntry(c.entry); err != nil {
Expand All @@ -247,6 +256,11 @@ func (c *Client) Login(ctx context.Context) error {
return nil
}

func (c *Client) tryLoginAgainWithEmptyPassword(ctx context.Context) (*ssh.Client, error) {
c.entry.ClearPassword()
return c.dial(ctx)
}

func (c *Client) dial(ctx context.Context) (*ssh.Client, error) {
if c.entry.Proxy != nil {
return dialThroughProxy(ctx, c.entry.Proxy, nil, c.entry)
Expand Down
18 changes: 18 additions & 0 deletions ssx/entry/entry.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,12 @@ const (
defaultPort = "22"
)

const (
ModeUninit = ""
ModeSafe = "safe"
ModeUnsafe = "unsafe"
)

// Entry represent a target server
type Entry struct {
ID uint64 `json:"id"`
Expand All @@ -42,13 +48,18 @@ type Entry struct {
KeyPath string `json:"key_path"`
Passphrase string `json:"passphrase"`
Password string `json:"password"`
SafeMode string `json:"safe_mode"`
Tags []string `json:"tags"`
Source string `json:"source"` // Data source, used to distinguish that it is from ssx stored or local ssh configuration
CreateAt time.Time `json:"create_at"`
UpdateAt time.Time `json:"update_at"`
Proxy *Proxy `json:"proxy"`
}

func (e *Entry) IsUnsafe() bool {
return e.SafeMode == ModeUninit || e.SafeMode == ModeUnsafe
}

func (e *Entry) String() string {
return fmt.Sprintf("%s@%s:%s", e.User, e.Host, e.Port)
}
Expand Down Expand Up @@ -83,6 +94,13 @@ func (e *Entry) Mask() {
}
}

func (e *Entry) ClearPassword() {
e.Password = ""
if e.Proxy != nil {
e.Proxy.ClearPassword()
}
}

func getConnectTimeout() time.Duration {
var defaultTimeout = time.Second * 10
val := os.Getenv(env.SSXConnectTimeout)
Expand Down
7 changes: 7 additions & 0 deletions ssx/entry/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,3 +70,10 @@ func (p *Proxy) GenSSHConfig(ctx context.Context) (*ssh.ClientConfig, error) {
}
return cfg, nil
}

func (p *Proxy) ClearPassword() {
p.Password = ""
if p.Proxy != nil {
p.Proxy.ClearPassword()
}
}
16 changes: 16 additions & 0 deletions ssx/env/env.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,23 @@
package env

import (
"os"
"strings"
)

const (
SSXDBPath = "SSX_DB_PATH"
SSXConnectTimeout = "SSX_CONNECT_TIMEOUT"
SSXImportSSHConfig = "SSX_IMPORT_SSH_CONFIG" // 设置了该环境变量的话,就会自动将 ~/.ssh/config 中的条目也加载
SSXUnsafeMode = "SSX_UNSAFE_MODE"
SSXSecretKey = "SSX_SECRET_KEY"
)

func IsUnsafeMode() bool {
switch strings.ToLower(os.Getenv(SSXUnsafeMode)) {
case "t", "true", "on", "1":
return true
default:
return false
}
}
Loading

0 comments on commit 34fd274

Please sign in to comment.