diff --git a/pkg/app/bootstrap.go b/pkg/app/bootstrap.go index 925bda9a6..35f4c18b9 100644 --- a/pkg/app/bootstrap.go +++ b/pkg/app/bootstrap.go @@ -89,7 +89,7 @@ func bootstrap(clicontext *cli.Context) error { var exists bool var newPrivateKeyName string for ok := true; ok; ok = exists { - newPrivateKeyName = filepath.Join(filepath.Dir(sshconfig.GetPrivateKey()), fmt.Sprintf("%s.pk", namesgenerator.GetRandomName(0))) + newPrivateKeyName = filepath.Join(filepath.Dir(sshconfig.GetPrivateKey()), fmt.Sprintf("envd_%s.pk", namesgenerator.GetRandomName(0))) exists, err = fileutil.FileExists(newPrivateKeyName) if err != nil { return err diff --git a/pkg/ssh/config/ssh_config.go b/pkg/ssh/config/ssh_config.go index b7c28cee1..7ca5ee56b 100644 --- a/pkg/ssh/config/ssh_config.go +++ b/pkg/ssh/config/ssh_config.go @@ -27,6 +27,8 @@ import ( "strings" "github.com/sirupsen/logrus" + + "github.com/tensorchord/envd/pkg/util/osutil" ) type ( @@ -302,7 +304,28 @@ func buildHostname(name string) string { // AddEntry adds an entry to the user's sshconfig func AddEntry(name, iface string, port int, privateKeyPath string) error { - return add(getSSHConfigPath(), buildHostname(name), iface, port, privateKeyPath) + err := add(getSSHConfigPath(), buildHostname(name), iface, port, privateKeyPath) + if err != nil { + return err + } + if osutil.IsWsl() { + logrus.Debug("Try adding entry to WSL's ssh-agent") + winSshConfig, err := osutil.GetWslHostSshConfig() + if err != nil { + return err + } + winKeyPath, err := osutil.CopyToWinEnvdHome(privateKeyPath, 0600) + if err != nil { + return err + } + // Add the entry to the WSL host SSH config + logrus.Debugf("Adding entry to WSL's ssh-agent: %s", winSshConfig) + err = add(winSshConfig, buildHostname(name), iface, port, winKeyPath) + if err != nil { + return err + } + } + return nil } func ReplaceKeyManagedByEnvd(oldKey string, newKey string) error { @@ -329,7 +352,49 @@ func ReplaceKeyManagedByEnvd(oldKey string, newKey string) error { if err != nil { return err } - return save(cfg, getSSHConfigPath()) + + err = save(cfg, getSSHConfigPath()) + if err != nil { + return err + } + + if osutil.IsWsl() { + winSshConfig, err := osutil.GetWslHostSshConfig() + if err != nil { + return err + } + cfg, err := getConfig(winSshConfig) + if err != nil { + return err + } + winNewKey, err := osutil.CopyToWinEnvdHome(newKey, 0600) + if err != nil { + return err + } + winOldKey, err := osutil.CopyToWinEnvdHome(oldKey, 0600) + if err != nil { + return err + } + logrus.Infof("Rewrite WSL ssh keys old: %s, new: %s", winOldKey, winNewKey) + for ih, h := range cfg.hosts { + for _, hn := range h.hostnames { + logrus.Info(h.hostnames) + if strings.HasSuffix(hn, ".envd") { + for ip, p := range h.params { + if p.keyword == identityFile && strings.Trim(p.args[0], "\"") == winOldKey { + logrus.Debug("Change key") + cfg.hosts[ih].params[ip].args[0] = winNewKey + } + } + } + } + } + err = save(cfg, winSshConfig) + if err != nil { + return err + } + } + return nil } func add(path, name, iface string, port int, privateKeyPath string) error { @@ -361,7 +426,22 @@ func add(path, name, iface string, port int, privateKeyPath string) error { // RemoveEntry removes the entry to the user's sshconfig if found func RemoveEntry(name string) error { - return remove(getSSHConfigPath(), buildHostname(name)) + err := remove(getSSHConfigPath(), buildHostname(name)) + if err != nil { + return err + } + if osutil.IsWsl() { + logrus.Debug("Try removing entry from WSL's ssh-agent") + winSshConfig, err := osutil.GetWslHostSshConfig() + if err != nil { + return err + } + err = remove(winSshConfig, buildHostname(name)) + if err != nil { + return err + } + } + return nil } // GetPort returns the corresponding SSH port for the dev env diff --git a/pkg/util/osutil/wsl.go b/pkg/util/osutil/wsl.go new file mode 100644 index 000000000..bd6e83d34 --- /dev/null +++ b/pkg/util/osutil/wsl.go @@ -0,0 +1,157 @@ +// Copyright 2022 The envd Authors +// Copyright 2022 mateors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package osutil + +import ( + "fmt" + "io" + "net" + "os" + "os/exec" + "path" + "path/filepath" + "strings" + + "github.com/cockroachdb/errors" + "github.com/sirupsen/logrus" +) + +func IsWsl() bool { + // Return false if meet error + cmd := exec.Command("cat", "/proc/version") + output, err := cmd.Output() + if err != nil { + logrus.Debugf("Error when check whether sys is WSL: %v", err) + return false + } + + return strings.Contains(strings.ToLower(string(output)), "microsoft") +} + +func GetWslHostSshConfig() (string, error) { + userCmd := exec.Command("wslvar", "USERPROFILE") + userOutput, err := userCmd.Output() + if err != nil { + return "", err + } + + cmd := exec.Command("wslpath", string(userOutput)) + output, err := cmd.Output() + if err != nil { + return "", err + } + outputPath := path.Join(strings.Trim(string(output), "\n"), ".ssh", "config") + logrus.Debugf("wsl sshconfig path: %s", outputPath) + return outputPath, nil +} + +func GetWslIp() (string, error) { + ip, err := getInterfaceIpv4Addr("eth0") + if err != nil { + return "", err + } + return ip, nil +} + +func GetWindowsEnvdConfigHome() (string, error) { + + userCmd := exec.Command("wslvar", "LOCALAPPDATA") + userOutput, err := userCmd.Output() + if err != nil { + return "", err + } + + cmd := exec.Command("wslpath", string(userOutput)) + output, err := cmd.Output() + if err != nil { + return "", err + } + envdDir := filepath.Join(strings.Trim(string(output), "\n"), "envd") + if err := os.MkdirAll(envdDir, 0755); err != nil { + return "", err + } + return envdDir, nil +} + +// from: https://gist.github.com/schwarzeni/f25031a3123f895ff3785970921e962c +func getInterfaceIpv4Addr(interfaceName string) (addr string, err error) { + var ( + ief *net.Interface + addrs []net.Addr + ipv4Addr net.IP + ) + if ief, err = net.InterfaceByName(interfaceName); err != nil { // get interface + return + } + if addrs, err = ief.Addrs(); err != nil { // get addresses + return + } + for _, addr := range addrs { // get ipv4 address + if ipv4Addr = addr.(*net.IPNet).IP.To4(); ipv4Addr != nil { + break + } + } + if ipv4Addr == nil { + return "", errors.New(fmt.Sprintf("interface %s don't have an ipv4 address\n", interfaceName)) + } + return ipv4Addr.String(), nil +} + +func CopyToWinEnvdHome(src string, permission os.FileMode) (string, error) { + // Return dst path in windows format + winhome, err := GetWindowsEnvdConfigHome() + if err != nil { + return "", err + } + filename := filepath.Base(src) + dst := filepath.Join(winhome, filename) + err = copy(src, dst, permission) + if err != nil { + return "", err + } + + envdDirWinCmd := exec.Command("wslpath", "-w", dst) + winDir, err := envdDirWinCmd.Output() + + if err != nil { + return "", err + } + return strings.Trim(string(winDir), "\n"), nil +} + +func copy(src, dst string, permission os.FileMode) error { + in, err := os.Open(src) + if err != nil { + return err + } + defer in.Close() + + out, err := os.Create(dst) + if err != nil { + return err + } + defer out.Close() + err = os.Chmod(dst, permission) + if err != nil { + return err + } + + _, err = io.Copy(out, in) + if err != nil { + return err + } + return out.Close() +}