Skip to content

Commit

Permalink
feat: slurm support (determined-ai#98)
Browse files Browse the repository at this point in the history
This change adds a dispatcher resource manager, which gives Determined the ability to run on top of Slurm.

Co-authored-by: rcorujo <90728398+rcorujo@users.noreply.github.com>
Co-authored-by: Phillip Gaisford <phillip.gaisford@hpe.com>
Co-authored-by: phillip-gaisford <98362331+phillip-gaisford@users.noreply.github.com>
Co-authored-by: Jerry J. Harrow <84593277+jerryharrow@users.noreply.github.com>
Co-authored-by: Jagadeesh Madagundi <jagadeesh545@gmail.com>
Co-authored-by: Philip Norman <philipnrmn@users.noreply.github.com>
  • Loading branch information
7 people authored and dzhu committed Apr 25, 2023
1 parent 6017dbc commit f3e862e
Show file tree
Hide file tree
Showing 29 changed files with 3,818 additions and 4 deletions.
8 changes: 8 additions & 0 deletions agent/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,14 @@ require (
sigs.k8s.io/yaml v1.2.0 // indirect
)

require github.hpe.com/hpe/hpc-ard-launcher-go/launcher v0.1.2 // indirect

replace github.com/determined-ai/determined/master => ../master

replace github.com/determined-ai/determined/proto => ../proto

// Determined AI's CircleCI doesn't have access to "github.hpe.com/hpe/hpc-ard-launcher-go",
// so the build will fail in CircleCI. Therefore, we had to do a "git clone" of the
// launcher repo to store a local copy. We make use of the "replace" directive to use the
// local copy and not try to pull it from GitHub.
replace github.hpe.com/hpe/hpc-ard-launcher-go/launcher => ../hpc-ard-launcher-go/launcher
1 change: 0 additions & 1 deletion agent/go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -735,7 +735,6 @@ github.com/onsi/ginkgo v1.10.1/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+
github.com/onsi/ginkgo v1.10.2/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE=
github.com/onsi/ginkgo v1.10.3/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE=
github.com/onsi/ginkgo v1.11.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE=
github.com/onsi/ginkgo v1.11.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE=
github.com/onsi/ginkgo v1.12.1/go.mod h1:zj2OWP4+oCPe1qIXoGWkgMRwljMUYCdkwsT2108oapk=
github.com/onsi/ginkgo v1.14.2 h1:8mVmC9kjFFmA8H4pKMUhcblgifdkOIXPvbhN1T36q1M=
github.com/onsi/ginkgo v1.14.2/go.mod h1:iSB4RoI2tjJc9BBv4NKIKWKya62Rps+oPG/Lv9klQyY=
Expand Down
29 changes: 29 additions & 0 deletions hpc-ard-launcher-go/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# hpc-ard-launcher-go

This repo is the home of the Capsules (hpc-ard-capsules-core) dispatch server Go client.

The code found here is generated automatically using openapi tools from the Capsules REST API specification. It can be build wit the following command line executed in the hpc-ard-capsules-core project:

```
mvn -pl com.cray.analytics.capsules:capsules-dispatch-client clean generate-sources -P go-client
```
To install the package to your Go environment:

If you use ssh to interact with github.hpe.com, add the following to your ~/.gitconfig:
```
[url "ssh://git@github.hpe.com/"]
insteadOf = https://github.hpe.com/
```
Then:
```
% export GOPRIVATE=github.hpe.com/hpe/hpc-ard-launcher-go
% go get github.hpe.com/hpe/hpc-ard-launcher-go/launcher
```
Import the launcher package to your Go program thus:
```
import (
<other imports go here>
"github.hpe.com/hpe/hpc-ard-launcher-go/launcher"
)
```
5 changes: 5 additions & 0 deletions hpc-ard-launcher-go/go.mod
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
module github.hpe.com/hpe/hpc-ard-launcher-go

go 1.13

require golang.org/x/oauth2 v0.0.0-20210218202405-ba52d332ba99
362 changes: 362 additions & 0 deletions hpc-ard-launcher-go/go.sum

Large diffs are not rendered by default.

14 changes: 13 additions & 1 deletion master/determined.code-workspace
Original file line number Diff line number Diff line change
Expand Up @@ -21,5 +21,17 @@
{
"path": "../examples"
}
]
],
"launch": {
"version": "0.2.0",
"configurations": [
{
"name": "Attach to Process",
"type": "go",
"request": "attach",
"mode": "local",
"processId": 0
}
]
}
}
6 changes: 6 additions & 0 deletions master/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ require (
)

require (
github.hpe.com/hpe/hpc-ard-launcher-go/launcher v0.1.2
go.opentelemetry.io/contrib/instrumentation/github.com/labstack/echo/otelecho v0.29.0
go.opentelemetry.io/otel v1.6.1
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.6.1
Expand Down Expand Up @@ -194,3 +195,8 @@ require (

replace github.com/determined-ai/determined/proto => ../proto

// Determined AI's CircleCI doesn't have access to "github.hpe.com/hpe/hpc-ard-launcher-go",
// so the build will fail in CircleCI. Therefore, we had to do a "git clone" of the
// launcher repo to store a local copy. We make use of the "replace" directive to use the
// local copy and not try to pull it from GitHub.
replace github.hpe.com/hpe/hpc-ard-launcher-go/launcher => ../hpc-ard-launcher-go/launcher
3 changes: 3 additions & 0 deletions master/internal/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -385,6 +385,9 @@ func readRMPreemptionStatus(config *Config, rpName string) bool {
return config.ResourceManager.AgentRM.Scheduler.GetPreemption()
case config.ResourceManager.KubernetesRM != nil:
return config.ResourceManager.KubernetesRM.GetPreemption()
case config.ResourceManager.DispatcherRM != nil:
// TODO: Determine if this needs to be enabled for DispatcherRM
return false
default:
panic("unexpected resource configuration")
}
Expand Down
127 changes: 127 additions & 0 deletions master/internal/config/dispatcher_resource_manager_config.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
package config

import (
"encoding/json"

"github.com/determined-ai/determined/master/pkg/device"
"github.com/determined-ai/determined/master/pkg/model"
)

// DispatcherResourceManagerConfig is the object that stores the values of
// the "resource_manager" section of "tools/devcluster.yaml".
type DispatcherResourceManagerConfig struct {
MasterHost string `json:"master_host"`
MasterPort int `json:"master_port"`
LauncherHost string `json:"host"`
LauncherPort int `json:"port"`
LauncherProtocol string `json:"protocol"`
SlotType *device.Type `json:"slot_type"`
LauncherAuthFile string `json:"auth_file"`
RendezvousNetworkInterface string `json:"rendezvous_network_interface"`
ProxyNetworkInterface string `json:"proxy_network_interface"`
// Configuration parameters that are proxies for launcher.conf
// and will be applied there by the init script.
UserName string `json:"user_name"`
GroupName string `json:"group_name"`
SingularityImageRoot string `json:"singularity_image_root"`
JobStorageRoot string `json:"job_storage_root"`
Path string `json:"path"`
LdLibraryPath string `json:"ld_library_path"`
TresSupported bool `json:"tres_supported"`

Security *DispatcherSecurityConfig `json:"security"`
PartitionOverrides map[string]DispatcherPartitionOverrideConfigs `json:"partition_overrides"`
}

// DispatcherSecurityConfig configures security-related options for the elastic logging backend.
type DispatcherSecurityConfig struct {
TLS model.TLSClientConfig `json:"tls"`
}

// Validate performs validation.
func (c DispatcherResourceManagerConfig) Validate() error {
return nil
}

var defaultDispatcherResourceManagerConfig = DispatcherResourceManagerConfig{
TresSupported: true,
}

// UnmarshalJSON implements the json.Unmarshaler interface.
func (c *DispatcherResourceManagerConfig) UnmarshalJSON(data []byte) error {
*c = defaultDispatcherResourceManagerConfig
type DefaultParser *DispatcherResourceManagerConfig
return json.Unmarshal(data, DefaultParser(c))
}

// ResolveSlotType resolves the slot type by first looking for a partition-specific setting,
// then falling back to the master config, and finally falling back to what we can infer.
func (c DispatcherResourceManagerConfig) ResolveSlotType(partition string) *device.Type {
for name, overrides := range c.PartitionOverrides {
if name != partition {
continue
}
if overrides.SlotType == nil {
break
}
return overrides.SlotType
}
return c.SlotType
}

// ResolveRendezvousNetworkInterface resolves the rendezvous network interface by first looking for
// a partition-specific setting and then falling back to the master config.
func (c DispatcherResourceManagerConfig) ResolveRendezvousNetworkInterface(
partition string) string {
for name, overrides := range c.PartitionOverrides {
if name != partition {
continue
}
if overrides.RendezvousNetworkInterface == nil {
break
}
return *overrides.RendezvousNetworkInterface
}
return c.RendezvousNetworkInterface
}

// ResolveProxyNetworkInterface resolves the proxy network interface by first looking for a
// partition-specific setting and then falling back to the master config.
func (c DispatcherResourceManagerConfig) ResolveProxyNetworkInterface(partition string) string {
for name, overrides := range c.PartitionOverrides {
if name != partition {
continue
}
if overrides.ProxyNetworkInterface == nil {
break
}
return *overrides.ProxyNetworkInterface
}
return c.ProxyNetworkInterface
}

// ResolveTaskContainerDefaults resolves the task container defaults by first looking for
// a partition-specific setting and then falling back to the master config.
func (c DispatcherResourceManagerConfig) ResolveTaskContainerDefaults(
partition string,
) *model.TaskContainerDefaultsConfig {
for name, overrides := range c.PartitionOverrides {
if name != partition {
continue
}
if overrides.TaskContainerDefaultsConfig == nil {
break
}
return overrides.TaskContainerDefaultsConfig
}
return nil
}

// DispatcherPartitionOverrideConfigs describes per-partition overrides.
type DispatcherPartitionOverrideConfigs struct {
//nolint:lll // I honestly don't know how to break this line within Go's grammar.
RendezvousNetworkInterface *string `json:"rendezvous_network_interface"`
ProxyNetworkInterface *string `json:"proxy_network_interface"`
SlotType *device.Type `json:"slot_type"`
TaskContainerDefaultsConfig *model.TaskContainerDefaultsConfig `json:"task_container_defaults"`
}
4 changes: 3 additions & 1 deletion master/internal/config/resource_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@ func (r *ResourceConfig) ResolveResource() error {
AgentRM: &AgentResourceManagerConfig{},
}
}
if r.ResourceManager.AgentRM == nil && r.ResourceManager.KubernetesRM == nil {
if r.ResourceManager.AgentRM == nil &&
r.ResourceManager.KubernetesRM == nil &&
r.ResourceManager.DispatcherRM == nil {
r.ResourceManager.AgentRM = &AgentResourceManagerConfig{}
}
if r.ResourcePools == nil &&
Expand Down
3 changes: 2 additions & 1 deletion master/internal/config/resource_manager_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ const defaultResourcePoolName = "default"
type ResourceManagerConfig struct {
AgentRM *AgentResourceManagerConfig `union:"type,agent" json:"-"`
KubernetesRM *KubernetesResourceManagerConfig `union:"type,kubernetes" json:"-"`
DispatcherRM *DispatcherResourceManagerConfig `union:"type,slurm" json:"-"`
}

// MarshalJSON implements the json.Marshaler interface.
Expand All @@ -37,7 +38,7 @@ func (r *ResourceManagerConfig) UnmarshalJSON(data []byte) error {
}

// Fill in the default config.
if r.AgentRM == nil && r.KubernetesRM == nil {
if r.AgentRM == nil && r.KubernetesRM == nil && r.DispatcherRM == nil {
r.AgentRM = &AgentResourceManagerConfig{
Scheduler: &SchedulerConfig{
FittingPolicy: defaultFitPolicy,
Expand Down
119 changes: 119 additions & 0 deletions master/internal/db/postgres_resource_managers_dispatcher.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
package db

import (
"context"
"fmt"

"github.com/uptrace/bun"

"github.com/determined-ai/determined/master/internal/sproto"
"github.com/determined-ai/determined/master/pkg/model"
)

// Dispatch is the Determined-persisted representation for dispatch existence.
type Dispatch struct {
bun.BaseModel `bun:"table:resourcemanagers_dispatcher_dispatches"`

DispatchID string `bun:"dispatch_id"`
ResourceID sproto.ResourcesID `bun:"resource_id"`
AllocationID model.AllocationID `bun:"allocation_id"`
ImpersonatedUser string `bun:"impersonated_user"`
}

// InsertDispatch persists the existence for a dispatch.
func InsertDispatch(ctx context.Context, r *Dispatch) error {
_, err := Bun().NewInsert().Model(r).Exec(ctx)
if err != nil {
return fmt.Errorf("inserting dispatch: %w", err)
}
return nil
}

// DispatchByID retrieves a dispatch by its ID.
func DispatchByID(
ctx context.Context,
id string,
) (*Dispatch, error) {
d := Dispatch{}
err := Bun().NewSelect().Model(&d).Where("dispatch_id = ?", id).Scan(ctx)
if err != nil {
return nil, fmt.Errorf("scanning dispatch by ID (%s): %w", id, err)
}
return &d, nil
}

// ListDispatchesByJobID returns a list of dispatches associated with the specified job.
func ListDispatchesByJobID(
ctx context.Context,
jobID string,
) ([]*Dispatch, error) {
ds := []*Dispatch{}
err := Bun().NewSelect().Model(&ds).Join(
"join allocations on allocations.allocation_id = dispatch.allocation_id").Join(
"join tasks on tasks.task_id = allocations.task_id").Where("job_id = ?", jobID).Scan(ctx)
if err != nil {
return nil, fmt.Errorf("scanning dispatch by job ID (%s): %w", jobID, err)
}
return ds, nil
}

// ListAllDispatches lists all dispatches in the DB.
func ListAllDispatches(ctx context.Context) ([]*Dispatch, error) {
return ListDispatches(ctx, func(q *bun.SelectQuery) (*bun.SelectQuery, error) {
return q, nil
})
}

// ListDispatchesByAllocationID lists all dispatches for an allocation ID.
func ListDispatchesByAllocationID(
ctx context.Context,
id model.AllocationID,
) ([]*Dispatch, error) {
return ListDispatches(ctx, func(q *bun.SelectQuery) (*bun.SelectQuery, error) {
return q.Where("allocation_id = ?", id), nil
})
}

// ListDispatches lists all dispatches according to the options provided.
func ListDispatches(
ctx context.Context,
opts func(*bun.SelectQuery) (*bun.SelectQuery, error),
) ([]*Dispatch, error) {
var ds []*Dispatch

q, err := opts(Bun().NewSelect().Model(&ds))
if err != nil {
return nil, fmt.Errorf("building dispatch model query: %w", err)
}

if err = q.Scan(ctx); err != nil {
return nil, fmt.Errorf("scanning dispatch models: %w", err)
}

return ds, nil
}

// DeleteDispatch deletes the specified dispatch and returns the number deleted.
func DeleteDispatch(
ctx context.Context,
id string,
) (int64, error) {
return DeleteDispatches(ctx, func(q *bun.DeleteQuery) *bun.DeleteQuery {
return q.Where("dispatch_id = ?", id)
})
}

// DeleteDispatches deletes all dispatches for the specified query
// and returns the number deleted.
func DeleteDispatches(
ctx context.Context,
opts func(*bun.DeleteQuery) *bun.DeleteQuery,
) (int64, error) {
var ds []*Dispatch
res, err := opts(Bun().NewDelete().Model(&ds)).Exec(ctx)
if err != nil {
return 0, fmt.Errorf("delete dispatch exec: %w", err)
}
count, _ := res.RowsAffected()
return count, err
}

0 comments on commit f3e862e

Please sign in to comment.