diff --git a/cmd/envd/up.go b/cmd/envd/up.go index 8f333d7b7..204f57ed6 100644 --- a/cmd/envd/up.go +++ b/cmd/envd/up.go @@ -146,6 +146,16 @@ func up(clicontext *cli.Context) error { return err } + if gpu { + nvruntimeExists, err := dockerClient.GPUEnabled(clicontext.Context) + if err != nil { + return err + } + if !nvruntimeExists { + return errors.New("GPU is required but nvidia container runtime is not installed, please refer to https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html#docker") + } + } + containerID, containerIP, err := dockerClient.StartEnvd(clicontext.Context, tag, ctr, buildContext, gpu, *ir.DefaultGraph, clicontext.Duration("timeout"), clicontext.StringSlice("volume")) diff --git a/pkg/docker/docker.go b/pkg/docker/docker.go index f75accf8e..ceca9ee94 100644 --- a/pkg/docker/docker.go +++ b/pkg/docker/docker.go @@ -55,6 +55,8 @@ type Client interface { Exec(ctx context.Context, cname string, cmd []string) error Destroy(ctx context.Context, name string) error List(ctx context.Context) ([]types.Container, error) + // GPUEnabled returns true if nvidia container runtime exists in docker daemon. + GPUEnabled(ctx context.Context) (bool, error) } type generalClient struct { @@ -70,6 +72,18 @@ func NewClient(ctx context.Context) (Client, error) { return generalClient{cli}, nil } +func (g generalClient) GPUEnabled(ctx context.Context) (bool, error) { + info, err := g.Info(ctx) + if err != nil { + return false, errors.Wrap(err, "failed to get docker info") + } + logrus.WithField("info", info).Debug("docker info") + if nv, ok := info.Runtimes["nvidia"]; ok { + return nv.Path != "", nil + } + return false, nil +} + func (g generalClient) WaitUntilRunning(ctx context.Context, name string, timeout time.Duration) error { logger := logrus.WithField("container", name)