diff --git a/pkg/lang/ir/compile.go b/pkg/lang/ir/compile.go index 0d1a9237f..a8473a423 100644 --- a/pkg/lang/ir/compile.go +++ b/pkg/lang/ir/compile.go @@ -62,9 +62,14 @@ func Compile(ctx context.Context, cachePrefix string, pub string) (*llb.Definiti DefaultGraph.Writer = w DefaultGraph.CachePrefix = cachePrefix DefaultGraph.PublicKeyPath = pub - state, err := DefaultGraph.Compile() + + uid, gid, err := getUIDGID() if err != nil { - return nil, err + return nil, errors.Wrap(err, "failed to get uid/gid") + } + state, err := DefaultGraph.Compile(uid, gid) + if err != nil { + return nil, errors.Wrap(err, "failed to compile") } // TODO(gaocegege): Support multi platform. def, err := state.Marshal(ctx, llb.LinuxAmd64) @@ -111,7 +116,10 @@ func (g Graph) Labels() (map[string]string, error) { return labels, nil } -func (g Graph) Compile() (llb.State, error) { +func (g Graph) Compile(uid, gid int) (llb.State, error) { + g.uid = uid + g.gid = gid + // TODO(gaocegege): Support more OS and langs. base := g.compileBase() aptStage := g.compileUbuntuAPT(base) diff --git a/pkg/lang/ir/conda.go b/pkg/lang/ir/conda.go index 08f618bd7..80e02c33d 100644 --- a/pkg/lang/ir/conda.go +++ b/pkg/lang/ir/conda.go @@ -35,7 +35,7 @@ func (g Graph) compileCondaChannel(root llb.State) llb.State { logrus.WithField("conda-channel", *g.CondaChannel).Debug("using custom connda channel") stage := root. File(llb.Mkfile(condarc, - 0644, []byte(*g.CondaChannel), llb.WithUIDGID(defaultUID, defaultGID)), llb.WithCustomName("[internal] settings conda channel")) + 0644, []byte(*g.CondaChannel), llb.WithUIDGID(g.uid, g.gid)), llb.WithCustomName("[internal] settings conda channel")) return stage } return root @@ -68,7 +68,7 @@ func (g Graph) compileCondaPackages(root llb.State) llb.State { root = llb.User("envd")(root) // Refer to https://github.com/moby/buildkit/blob/31054718bf775bf32d1376fe1f3611985f837584/frontend/dockerfile/dockerfile2llb/convert_runmount.go#L46 cache := root.File(llb.Mkdir("/cache", - 0755, llb.WithParents(true), llb.WithUIDGID(defaultUID, defaultGID)), + 0755, llb.WithParents(true), llb.WithUIDGID(g.uid, g.gid)), llb.WithCustomName("[internal] settings conda cache mount permissions")) run := root. Run(llb.Shlex(cmd), llb.WithCustomNamef("conda install %s", diff --git a/pkg/lang/ir/consts.go b/pkg/lang/ir/consts.go index 48aaae6d6..b971a1ad5 100644 --- a/pkg/lang/ir/consts.go +++ b/pkg/lang/ir/consts.go @@ -27,7 +27,4 @@ const ( index-url=%s %s ` - - defaultUID = 1000 - defaultGID = 1000 ) diff --git a/pkg/lang/ir/editor.go b/pkg/lang/ir/editor.go index 30d79b992..769d6da64 100644 --- a/pkg/lang/ir/editor.go +++ b/pkg/lang/ir/editor.go @@ -44,7 +44,7 @@ func (g Graph) compileVSCode() (*llb.State, error) { "/home/envd/.vscode-server/extensions/"+p.String(), &llb.CopyInfo{ CreateDestPath: true, - }, llb.WithUIDGID(defaultUID, defaultGID)), + }, llb.WithUIDGID(g.uid, g.gid)), llb.WithCustomNamef("install vscode plugin %s", p.String())) inputs = append(inputs, ext) } diff --git a/pkg/lang/ir/git.go b/pkg/lang/ir/git.go index 485cca58b..231cbc09c 100644 --- a/pkg/lang/ir/git.go +++ b/pkg/lang/ir/git.go @@ -38,6 +38,6 @@ func (g *Graph) compileGit(root llb.State) (llb.State, error) { content := fmt.Sprintf(templateGitConfig, g.GitConfig.Email, g.GitConfig.Name, g.GitConfig.Editor) installPath := "/home/envd/.gitconfig" gitStage := root.File(llb.Mkfile(installPath, - 0644, []byte(content), llb.WithUIDGID(defaultUID, defaultGID))) + 0644, []byte(content), llb.WithUIDGID(g.uid, g.gid))) return gitStage, nil } diff --git a/pkg/lang/ir/python.go b/pkg/lang/ir/python.go index 9c042167c..bd2f2570d 100644 --- a/pkg/lang/ir/python.go +++ b/pkg/lang/ir/python.go @@ -45,7 +45,7 @@ func (g Graph) compilePyPIPackages(root llb.State) llb.State { root = llb.User("envd")(root) // Refer to https://github.com/moby/buildkit/blob/31054718bf775bf32d1376fe1f3611985f837584/frontend/dockerfile/dockerfile2llb/convert_runmount.go#L46 cache := root.File(llb.Mkdir("/cache", - 0755, llb.WithParents(true), llb.WithUIDGID(defaultUID, defaultGID)), llb.WithCustomName("[internal] settings pip cache mount permissions")) + 0755, llb.WithParents(true), llb.WithUIDGID(g.uid, g.gid)), llb.WithCustomName("[internal] settings pip cache mount permissions")) run := root. Run(llb.Shlex(cmd), llb.WithCustomNamef("pip install %s", strings.Join(g.PyPIPackages, " "))) @@ -65,10 +65,10 @@ func (g Graph) compilePyPIIndex(root llb.State) llb.State { content := fmt.Sprintf(pypiConfigTemplate, *g.PyPIIndexURL, extraIndex) pypiMirror := root. File(llb.Mkdir(filepath.Dir(pypiIndexFilePath), - 0755, llb.WithParents(true), llb.WithUIDGID(defaultUID, defaultGID)), + 0755, llb.WithParents(true), llb.WithUIDGID(g.uid, g.gid)), llb.WithCustomName("[internal] settings PyPI index")). File(llb.Mkfile(pypiIndexFilePath, - 0644, []byte(content), llb.WithUIDGID(defaultUID, defaultGID)), + 0644, []byte(content), llb.WithUIDGID(g.uid, g.gid)), llb.WithCustomName("[internal] settings PyPI index")) return pypiMirror } diff --git a/pkg/lang/ir/shell.go b/pkg/lang/ir/shell.go index 491e9477e..b32d9f011 100644 --- a/pkg/lang/ir/shell.go +++ b/pkg/lang/ir/shell.go @@ -45,12 +45,12 @@ func (g Graph) compileZSH(root llb.State) (llb.State, error) { } zshStage := root. File(llb.Copy(llb.Local(flag.FlagCacheDir), "oh-my-zsh", ohMyZSHPath, - &llb.CopyInfo{CreateDestPath: true}, llb.WithUIDGID(defaultUID, defaultGID))). + &llb.CopyInfo{CreateDestPath: true}, llb.WithUIDGID(g.uid, g.gid))). File(llb.Mkfile(installPath, - 0644, []byte(m.InstallScript()), llb.WithUIDGID(defaultUID, defaultGID))) + 0644, []byte(m.InstallScript()), llb.WithUIDGID(g.uid, g.gid))) run := zshStage.Run(llb.Shlex(fmt.Sprintf("bash %s", installPath)), llb.WithCustomName("install oh-my-zsh")). File(llb.Mkfile(zshrcPath, - 0644, []byte(m.ZSHRC()), llb.WithUIDGID(defaultUID, defaultGID))) + 0644, []byte(m.ZSHRC()), llb.WithUIDGID(g.uid, g.gid))) return run, nil } diff --git a/pkg/lang/ir/system.go b/pkg/lang/ir/system.go index c7b001009..b49d84eab 100644 --- a/pkg/lang/ir/system.go +++ b/pkg/lang/ir/system.go @@ -96,12 +96,14 @@ func (g *Graph) compileBase() llb.State { logger.Debug("compile base image") var base llb.State - var groupID string = "1000" if g.CUDA == nil && g.CUDNN == nil { if g.Language.Name == "r" { base = llb.Image("docker.io/r-base:4.2.0") // r-base image already has GID 1000. - groupID = "1001" + // It is a trick, we actually use GID 1000 + if g.gid == 1000 { + g.gid = 1001 + } } else { base = llb.Image("docker.io/tensorchord/python:3.8-ubuntu20.04") } @@ -110,8 +112,10 @@ func (g *Graph) compileBase() llb.State { } // TODO(gaocegege): Refactor user to a seperate stage. res := base. - Run(llb.Shlex(fmt.Sprintf("groupadd -g %s envd", groupID)), llb.WithCustomName("[internal] create user group envd")). - Run(llb.Shlex(fmt.Sprintf("useradd -p \"\" -u %s -g envd -s /bin/sh -m envd", groupID)), llb.WithCustomName("[internal] create user envd")). + Run(llb.Shlex(fmt.Sprintf("groupadd -g %d envd", g.gid)), + llb.WithCustomName("[internal] create user group envd")). + Run(llb.Shlex(fmt.Sprintf("useradd -p \"\" -u %d -g envd -s /bin/sh -m envd", g.gid)), + llb.WithCustomName("[internal] create user envd")). Run(llb.Shlex("adduser envd sudo"), llb.WithCustomName("[internal] add user envd to sudoers")). Run(llb.Shlex("chown -R envd:envd /usr/local/lib"), @@ -131,6 +135,6 @@ func (g Graph) copySSHKey(root llb.State) (llb.State, error) { } run := root. File(llb.Mkfile(config.ContainerauthorizedKeysPath, - 0644, []byte(dat+" envd"), llb.WithUIDGID(defaultUID, defaultGID)), llb.WithCustomName("install ssh keys")) + 0644, []byte(dat+" envd"), llb.WithUIDGID(g.uid, g.gid)), llb.WithCustomName("install ssh keys")) return run, nil } diff --git a/pkg/lang/ir/types.go b/pkg/lang/ir/types.go index f403d03ee..6b2084815 100644 --- a/pkg/lang/ir/types.go +++ b/pkg/lang/ir/types.go @@ -22,6 +22,9 @@ import ( // A Graph contains the state, // such as its call stack and thread-local storage. type Graph struct { + uid int + gid int + OS string Language diff --git a/pkg/lang/ir/util.go b/pkg/lang/ir/util.go index fe2b1a811..30ee0a140 100644 --- a/pkg/lang/ir/util.go +++ b/pkg/lang/ir/util.go @@ -15,9 +15,12 @@ package ir import ( - "errors" "fmt" + "os/user" "regexp" + "strconv" + + "github.com/cockroachdb/errors" ) func parseLanguage(l string) (string, *string, error) { @@ -44,3 +47,18 @@ func parseLanguage(l string) (string, *string, error) { return "", nil, fmt.Errorf("language %s is not supported", language) } } + +func getUIDGID() (int, int, error) { + user, err := user.Current() + if err != nil { + return 0, 0, errors.Wrap(err, "failed to get uid/gid") + } + // Do not support windows yet. + if uid, err := strconv.Atoi(user.Uid); err != nil { + return 0, 0, errors.Wrap(err, "failed to get uid") + } else if gid, err := strconv.Atoi(user.Gid); err != nil { + return 0, 0, errors.Wrap(err, "failed to get gid") + } else { + return uid, gid, nil + } +}