-
Notifications
You must be signed in to change notification settings - Fork 0
/
nvidia.go
61 lines (50 loc) · 1.24 KB
/
nvidia.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
// Copyright (c) 2015-2016, NVIDIA CORPORATION. All rights reserved.
package nvidia
import (
"errors"
"os"
"os/exec"
"github.com/NVIDIA/nvidia-docker/src/cuda"
"github.com/NVIDIA/nvidia-docker/src/nvml"
)
const (
DockerPlugin = "nvidia-docker"
DeviceCtl = "/dev/nvidiactl"
DeviceUVM = "/dev/nvidia-uvm"
DeviceUVMTools = "/dev/nvidia-uvm-tools"
)
func Init() error {
if err := os.Setenv("CUDA_DISABLE_UNIFIED_MEMORY", "1"); err != nil {
return err
}
if err := os.Setenv("CUDA_CACHE_DISABLE", "1"); err != nil {
return err
}
if err := os.Unsetenv("CUDA_VISIBLE_DEVICES"); err != nil {
return err
}
return nvml.Init()
}
func Shutdown() error {
return nvml.Shutdown()
}
func LoadUVM() error {
if exec.Command("nvidia-modprobe", "-u", "-c=0").Run() != nil {
return errors.New("Could not load UVM kernel module. Is nvidia-modprobe installed?")
}
return nil
}
func GetDriverVersion() (string, error) {
return nvml.GetDriverVersion()
}
func GetCUDAVersion() (string, error) {
return cuda.GetDriverVersion()
}
func GetControlDevicePaths() ([]string, error) {
devs := []string{DeviceCtl, DeviceUVM}
_, err := os.Stat(DeviceUVMTools)
if os.IsNotExist(err) {
return devs, nil
}
return append(devs, DeviceUVMTools), err
}