Skip to content

Commit

Permalink
reconcile service account call SCI (#189)
Browse files Browse the repository at this point in the history
* Call sci.BindIdentity to allow a K8s Service account to impersonate an
identity
* Created a FakeSCIControllerClient to be able to run tests easily
  • Loading branch information
samos123 committed Aug 13, 2023
1 parent 426d439 commit 581b6e2
Show file tree
Hide file tree
Showing 16 changed files with 156 additions and 21 deletions.
14 changes: 9 additions & 5 deletions cmd/controllermanager/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,12 +114,13 @@ func main() {
}
defer conn.Close()
// Create a client using the connection
gc := sci.NewControllerClient(conn)
sciClient := sci.NewControllerClient(conn)

if err = (&controller.ModelReconciler{
Client: mgr.GetClient(),
Scheme: mgr.GetScheme(),
Cloud: cld,
SCI: sciClient,
ParamsReconciler: &controller.ParamsReconciler{
Scheme: mgr.GetScheme(),
Client: mgr.GetClient(),
Expand All @@ -132,7 +133,7 @@ func main() {
Scheme: mgr.GetScheme(),
Client: mgr.GetClient(),
Cloud: cld,
SCI: gc,
SCI: sciClient,
NewObject: func() controller.BuildableObject { return &apiv1.Model{} },
Kind: "Model",
}).SetupWithManager(mgr); err != nil {
Expand All @@ -143,6 +144,7 @@ func main() {
Client: mgr.GetClient(),
Scheme: mgr.GetScheme(),
Cloud: cld,
SCI: sciClient,
}).SetupWithManager(mgr); err != nil {
setupLog.Error(err, "unable to create controller", "controller", "Server")
os.Exit(1)
Expand All @@ -151,7 +153,7 @@ func main() {
Scheme: mgr.GetScheme(),
Client: mgr.GetClient(),
Cloud: cld,
SCI: gc,
SCI: sciClient,
NewObject: func() controller.BuildableObject { return &apiv1.Server{} },
Kind: "Server",
}).SetupWithManager(mgr); err != nil {
Expand All @@ -162,6 +164,7 @@ func main() {
Client: mgr.GetClient(),
Scheme: mgr.GetScheme(),
Cloud: cld,
SCI: sciClient,
ParamsReconciler: &controller.ParamsReconciler{
Scheme: mgr.GetScheme(),
Client: mgr.GetClient(),
Expand All @@ -174,7 +177,7 @@ func main() {
Scheme: mgr.GetScheme(),
Client: mgr.GetClient(),
Cloud: cld,
SCI: gc,
SCI: sciClient,
NewObject: func() controller.BuildableObject { return &apiv1.Notebook{} },
Kind: "Notebook",
}).SetupWithManager(mgr); err != nil {
Expand All @@ -185,6 +188,7 @@ func main() {
Client: mgr.GetClient(),
Scheme: mgr.GetScheme(),
Cloud: cld,
SCI: sciClient,
ParamsReconciler: &controller.ParamsReconciler{
Scheme: mgr.GetScheme(),
Client: mgr.GetClient(),
Expand All @@ -197,7 +201,7 @@ func main() {
Scheme: mgr.GetScheme(),
Client: mgr.GetClient(),
Cloud: cld,
SCI: gc,
SCI: sciClient,
NewObject: func() controller.BuildableObject { return &apiv1.Dataset{} },
Kind: "Dataset",
}).SetupWithManager(mgr); err != nil {
Expand Down
11 changes: 8 additions & 3 deletions internal/cloud/cloud.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,14 @@ type Cloud interface {
// ObjectArtifactURL returns the URL of the artifact that was stored for a given Object.
ObjectArtifactURL(Object) *BucketURL

// AssociateServiceAccount associates the given service account with a cloud
// identity (i.e. updates annotations).
AssociateServiceAccount(*corev1.ServiceAccount)
// AssociatePrincipal associates the given K8s service account with a cloud
// identity (i.e. updates cloud specific annotations on K8s SA)
AssociatePrincipal(*corev1.ServiceAccount)

// GetPrincipal returns the IAM Principal (GCP SA, AWS IAM Role) that should be used
// for a specific K8s Service Account. Returns the principal and whether the principal
// was already bound successfully to the service account.
GetPrincipal(*corev1.ServiceAccount) (string, bool)

// MountBucket mutates the given Pod metadata and Pod spec in order to append
// volumes mounts for a bucket.
Expand Down
1 change: 1 addition & 0 deletions internal/cloud/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ type Common struct {
ClusterName string `env:"CLUSTER_NAME" validate:"required"`
ArtifactBucketURL *BucketURL `env:"ARTIFACT_BUCKET_URL,noinit" validate:"required"`
RegistryURL string `env:"REGISTRY_URL" validate:"required"`
Principal string `env:"PRINCIPAL" validate:"required"`
}

func (c *Common) ObjectBuiltImageURL(obj BuildableObject) string {
Expand Down
2 changes: 2 additions & 0 deletions internal/cloud/common_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ func TestCommon(t *testing.T) {
os.Setenv("CLUSTER_NAME", "my-cluster")
os.Setenv("ARTIFACT_BUCKET_URL", "gs://my-artifact-bucket")
os.Setenv("REGISTRY_URL", "gcr.io/my-project")
os.Setenv("PRINCIPAL", "dummy-value")

require.Error(t, validator.New().Struct(&common))
require.NoError(t, envconfig.Process(context.Background(), &common))
Expand All @@ -27,6 +28,7 @@ func TestCommon(t *testing.T) {
ClusterName: "my-cluster",
ArtifactBucketURL: &cloud.BucketURL{Scheme: "gs", Bucket: "my-artifact-bucket"},
RegistryURL: "gcr.io/my-project",
Principal: "dummy-value",
}, common)

require.Equal(t, "gcr.io/my-project/my-cluster-model-my-ns-my-model:latest", common.ObjectBuiltImageURL(&apiv1.Model{
Expand Down
22 changes: 19 additions & 3 deletions internal/cloud/gcp.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,10 @@ import (
"k8s.io/utils/ptr"
)

const GCPName = "gcp"
const (
GCPName = "gcp"
GCPWorkloadIdentityLabel = "iam.gke.io/gcp-service-account"
)

type GCP struct {
Common
Expand Down Expand Up @@ -60,6 +63,10 @@ func (gcp *GCP) AutoConfigure(ctx context.Context) error {
}
}

if gcp.Principal != "" {
gcp.Principal = fmt.Sprintf("substratus@%s.iam.gserviceaccount.com", gcp.ProjectID)
}

return nil
}

Expand Down Expand Up @@ -116,11 +123,20 @@ func (gcp *GCP) MountBucket(podMetadata *metav1.ObjectMeta, podSpec *corev1.PodS
return fmt.Errorf("container not found: %s", req.Container)
}

func (gcp *GCP) AssociateServiceAccount(sa *corev1.ServiceAccount) {
func (gcp *GCP) GetPrincipal(sa *corev1.ServiceAccount) (string, bool) {
principalBound := true
if val, exist := sa.Annotations[GCPWorkloadIdentityLabel]; !exist || val != gcp.Principal {
principalBound = false
}
return gcp.Principal, principalBound
}

func (gcp *GCP) AssociatePrincipal(sa *corev1.ServiceAccount) {
if sa.Annotations == nil {
sa.Annotations = map[string]string{}
}
sa.Annotations["iam.gke.io/gcp-service-account"] = fmt.Sprintf("substratus-%s@%s.iam.gserviceaccount.com", sa.Name, gcp.ProjectID)
principal, _ := gcp.GetPrincipal(sa)
sa.Annotations[GCPWorkloadIdentityLabel] = principal
}

func (gcp *GCP) region() string {
Expand Down
48 changes: 48 additions & 0 deletions internal/cloud/gcp_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
package cloud_test

import (
"context"
"os"
"testing"

"github.com/go-playground/validator/v10"
"github.com/sethvargo/go-envconfig"
"github.com/stretchr/testify/require"
"github.com/substratusai/substratus/internal/cloud"
corev1 "k8s.io/api/core/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
)

func TestGCP(t *testing.T) {
var gcp cloud.GCP
expectedPrincipal := "substratus@my-project.iam.gserviceaccount.com"
os.Setenv("CLUSTER_NAME", "my-cluster")
os.Setenv("ARTIFACT_BUCKET_URL", "gs://my-artifact-bucket")
os.Setenv("REGISTRY_URL", "gcr.io/my-project")
os.Setenv("PRINCIPAL", "substratus@my-project.iam.gserviceaccount.com")
os.Setenv("PROJECT_ID", "my-project")
os.Setenv("CLUSTER_LOCATION", "us-central1")

require.Error(t, validator.New().Struct(&gcp))
require.NoError(t, envconfig.Process(context.Background(), &gcp))
require.NoError(t, validator.New().Struct(&gcp))

sa := corev1.ServiceAccount{}
actualPrincipal, bound := gcp.GetPrincipal(&sa)
require.Equal(t, actualPrincipal, expectedPrincipal)
require.Equal(t, bound, false)

gcp.AssociatePrincipal(&sa)
actualPrincipal, bound = gcp.GetPrincipal(&sa)
require.Equal(t, actualPrincipal, expectedPrincipal)
require.Equal(t, bound, true)

sa = corev1.ServiceAccount{
ObjectMeta: metav1.ObjectMeta{
Annotations: map[string]string{cloud.GCPWorkloadIdentityLabel: expectedPrincipal},
},
}
actualPrincipal, bound = gcp.GetPrincipal(&sa)
require.Equal(t, actualPrincipal, expectedPrincipal)
require.Equal(t, bound, true)
}
2 changes: 1 addition & 1 deletion internal/controller/build_reconciler.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ func (r *BuildReconciler) Reconcile(ctx context.Context, req ctrl.Request) (ctrl
defer log.Info("Done reconciling build")

// Service account used for building and pushing the image.
if result, err := reconcileCloudServiceAccount(ctx, r.Cloud, r.Client, &corev1.ServiceAccount{
if result, err := reconcileServiceAccount(ctx, r.Cloud, r.SCI, r.Client, &corev1.ServiceAccount{
ObjectMeta: metav1.ObjectMeta{
Name: containerBuilderServiceAccountName,
Namespace: obj.GetNamespace(),
Expand Down
4 changes: 3 additions & 1 deletion internal/controller/dataset_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
apiv1 "github.com/substratusai/substratus/api/v1"
"github.com/substratusai/substratus/internal/cloud"
"github.com/substratusai/substratus/internal/resources"
"github.com/substratusai/substratus/internal/sci"
)

// DatasetReconciler reconciles a Dataset object.
Expand All @@ -28,6 +29,7 @@ type DatasetReconciler struct {
*ParamsReconciler

Cloud cloud.Cloud
SCI sci.ControllerClient
}

func (r *DatasetReconciler) Reconcile(ctx context.Context, req ctrl.Request) (ctrl.Result, error) {
Expand Down Expand Up @@ -84,7 +86,7 @@ func (r *DatasetReconciler) reconcileData(ctx context.Context, dataset *apiv1.Da
// ServiceAccount for the loader job.
// Within the context of GCP, this ServiceAccount will need IAM permissions
// to write the GCS bucket containing training data.
if result, err := reconcileCloudServiceAccount(ctx, r.Cloud, r.Client, &corev1.ServiceAccount{
if result, err := reconcileServiceAccount(ctx, r.Cloud, r.SCI, r.Client, &corev1.ServiceAccount{
ObjectMeta: metav1.ObjectMeta{
Name: dataLoaderServiceAccountName,
Namespace: dataset.Namespace,
Expand Down
2 changes: 1 addition & 1 deletion internal/controller/dataset_controller_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ func testDatasetLoad(t *testing.T, dataset *apiv1.Dataset) {
err := k8sClient.Get(ctx, types.NamespacedName{Namespace: dataset.Namespace, Name: "data-loader"}, &sa)
assert.NoError(t, err, "getting the data loader serviceaccount")
}, timeout, interval, "waiting for the data loader serviceaccount to be created")
require.Equal(t, "substratus-data-loader@test-project-id.iam.gserviceaccount.com", sa.Annotations["iam.gke.io/gcp-service-account"])
require.Equal(t, "substratus@test-project-id.iam.gserviceaccount.com", sa.Annotations["iam.gke.io/gcp-service-account"])

// Test that a data loader builder Job gets created by the controller.
var loaderJob batchv1.Job
Expand Down
14 changes: 13 additions & 1 deletion internal/controller/main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import (
apiv1 "github.com/substratusai/substratus/api/v1"
"github.com/substratusai/substratus/internal/cloud"
"github.com/substratusai/substratus/internal/controller"
"github.com/substratusai/substratus/internal/sci"
//+kubebuilder:scaffold:imports
)

Expand Down Expand Up @@ -82,6 +83,9 @@ func TestMain(m *testing.M) {
testCloud.ClusterLocation = "us-central1"
testCloud.ArtifactBucketURL = &cloud.BucketURL{Scheme: "gs", Bucket: "test-artifact-bucket"}
testCloud.RegistryURL = "registry.test"
testCloud.Principal = "substratus@test-project-id.iam.gserviceaccount.com"

sciClient := &sci.FakeCSIControllerClient{}

//runtimeMgr, err := controller.NewRuntimeManager(controller.GPUTypeNvidiaL4)
//requireNoError(err)
Expand All @@ -90,6 +94,7 @@ func TestMain(m *testing.M) {
Client: mgr.GetClient(),
Scheme: mgr.GetScheme(),
Cloud: testCloud,
SCI: sciClient,
ParamsReconciler: &controller.ParamsReconciler{
Scheme: mgr.GetScheme(),
Client: mgr.GetClient(),
Expand All @@ -100,6 +105,7 @@ func TestMain(m *testing.M) {
Scheme: mgr.GetScheme(),
Client: mgr.GetClient(),
Cloud: testCloud,
SCI: sciClient,
NewObject: func() controller.BuildableObject { return &apiv1.Model{} },
Kind: "Model",
}).SetupWithManager(mgr)
Expand All @@ -108,12 +114,14 @@ func TestMain(m *testing.M) {
Client: mgr.GetClient(),
Scheme: mgr.GetScheme(),
Cloud: testCloud,
SCI: sciClient,
}).SetupWithManager(mgr)
requireNoError(err)
err = (&controller.BuildReconciler{
Scheme: mgr.GetScheme(),
Client: mgr.GetClient(),
Cloud: testCloud,
SCI: sciClient,
NewObject: func() controller.BuildableObject { return &apiv1.Server{} },
Kind: "Server",
}).SetupWithManager(mgr)
Expand All @@ -122,6 +130,7 @@ func TestMain(m *testing.M) {
Client: mgr.GetClient(),
Scheme: mgr.GetScheme(),
Cloud: testCloud,
SCI: sciClient,
ParamsReconciler: &controller.ParamsReconciler{
Scheme: mgr.GetScheme(),
Client: mgr.GetClient(),
Expand All @@ -132,6 +141,7 @@ func TestMain(m *testing.M) {
Scheme: mgr.GetScheme(),
Client: mgr.GetClient(),
Cloud: testCloud,
SCI: sciClient,
NewObject: func() controller.BuildableObject { return &apiv1.Notebook{} },
Kind: "Notebook",
}).SetupWithManager(mgr)
Expand All @@ -140,6 +150,7 @@ func TestMain(m *testing.M) {
Client: mgr.GetClient(),
Scheme: mgr.GetScheme(),
Cloud: testCloud,
SCI: sciClient,
ParamsReconciler: &controller.ParamsReconciler{
Scheme: mgr.GetScheme(),
Client: mgr.GetClient(),
Expand All @@ -150,6 +161,7 @@ func TestMain(m *testing.M) {
Scheme: mgr.GetScheme(),
Client: mgr.GetClient(),
Cloud: testCloud,
SCI: sciClient,
NewObject: func() controller.BuildableObject { return &apiv1.Dataset{} },
Kind: "Dataset",
}).SetupWithManager(mgr)
Expand Down Expand Up @@ -196,7 +208,7 @@ func testContainerBuild(t *testing.T, obj testObject, kind string) {
err := k8sClient.Get(ctx, types.NamespacedName{Namespace: obj.GetNamespace(), Name: "container-builder"}, &sa)
assert.NoError(t, err, "getting the container builder serviceaccount")
}, timeout, interval, "waiting for the container builder serviceaccount to be created")
require.Equal(t, "substratus-container-builder@test-project-id.iam.gserviceaccount.com", sa.Annotations["iam.gke.io/gcp-service-account"])
require.Equal(t, "substratus@test-project-id.iam.gserviceaccount.com", sa.Annotations["iam.gke.io/gcp-service-account"])

// Test that a container builder Job gets created by the controller.
var builderJob batchv1.Job
Expand Down
4 changes: 3 additions & 1 deletion internal/controller/model_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
apiv1 "github.com/substratusai/substratus/api/v1"
"github.com/substratusai/substratus/internal/cloud"
"github.com/substratusai/substratus/internal/resources"
"github.com/substratusai/substratus/internal/sci"
)

// ModelReconciler reconciles a Model object.
Expand All @@ -32,6 +33,7 @@ type ModelReconciler struct {
*ParamsReconciler

Cloud cloud.Cloud
SCI sci.ControllerClient
}

type ModelReconcilerConfig struct {
Expand Down Expand Up @@ -78,7 +80,7 @@ func (r *ModelReconciler) reconcileModel(ctx context.Context, model *apiv1.Model
// Within the context of GCP, this ServiceAccount will need IAM permissions
// to read the GCS bucket containing the training data and read and write from
// the bucket that contains base model artifacts.
if result, err := reconcileCloudServiceAccount(ctx, r.Cloud, r.Client, &corev1.ServiceAccount{
if result, err := reconcileServiceAccount(ctx, r.Cloud, r.SCI, r.Client, &corev1.ServiceAccount{
ObjectMeta: metav1.ObjectMeta{
Name: modellerServiceAccountName,
Namespace: model.Namespace,
Expand Down
2 changes: 1 addition & 1 deletion internal/controller/model_controller_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ func testModelTrain(t *testing.T, model *apiv1.Model) {
err := k8sClient.Get(ctx, types.NamespacedName{Namespace: model.Namespace, Name: "modeller"}, &sa)
assert.NoError(t, err, "getting the model trainer serviceaccount")
}, timeout, interval, "waiting for the model trainer serviceaccount to be created")
require.Equal(t, "substratus-modeller@test-project-id.iam.gserviceaccount.com", sa.Annotations["iam.gke.io/gcp-service-account"])
require.Equal(t, "substratus@test-project-id.iam.gserviceaccount.com", sa.Annotations["iam.gke.io/gcp-service-account"])

// Test that a trainer Job gets created by the controller.
var job batchv1.Job
Expand Down

0 comments on commit 581b6e2

Please sign in to comment.