Skip to content

Commit

Permalink
Implementing SCI with sci-aws (#186)
Browse files Browse the repository at this point in the history
## What is this change?

This change adds the sci implementation for AWS.

## Why make this change?

We're expanding provider support following where the largest user base
lives.
  • Loading branch information
brandonjbjelland committed Aug 18, 2023
1 parent 0341db9 commit 6cacb6d
Show file tree
Hide file tree
Showing 7 changed files with 656 additions and 13 deletions.
86 changes: 86 additions & 0 deletions cmd/sci-aws/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
package main

import (
"flag"
"fmt"
"log"
"net"
"strconv"

"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/iam"
"github.com/aws/aws-sdk-go/service/s3"
"github.com/aws/aws-sdk-go/service/sts"
"github.com/substratusai/substratus/internal/sci"
awssci "github.com/substratusai/substratus/internal/sci/aws"
"google.golang.org/grpc"
"google.golang.org/grpc/health"
hv1 "google.golang.org/grpc/health/grpc_health_v1"
)

func main() {
// serve by default on port 10081
var port int
flag.IntVar(&port, "port", 10081, "port number to listen on")
flag.Parse()

// Create new AWS Server
s, err := NewServer()
if err != nil {
log.Fatalf("failed to create AWS server: %v", err)
}

gs := grpc.NewServer()
sci.RegisterControllerServer(gs, s)

// Setup Health Check
hs := health.NewServer()
hs.SetServingStatus("", hv1.HealthCheckResponse_SERVING)
hv1.RegisterHealthServer(gs, hs)

fmt.Printf("awssci server listening on port %v...", port)
lis, err := net.Listen("tcp", ":"+strconv.Itoa(port))
if err != nil {
log.Fatalf("failed to listen: %v", err)
}

if err := gs.Serve(lis); err != nil {
log.Fatalf("failed to serve: %v", err)
}
}

func NewServer() (*awssci.Server, error) {
sess, err := session.NewSession()
if err != nil {
return nil, fmt.Errorf("failed to create AWS session: %w", err)
}

clusterID, err := awssci.GetClusterID()
if err != nil {
return nil, fmt.Errorf("failed to get cluster ID: %w", err)
}

oidcProviderURL, err := awssci.GetOidcProviderUrl(sess, clusterID)
if err != nil {
return nil, fmt.Errorf("failed to get cluster OIDC provider URL: %w", err)
}

stsSvc := sts.New(sess)
accountId, err := awssci.GetAccountID(stsSvc)
if err != nil {
return nil, fmt.Errorf("failed to get account ID: %w", err)
}

oidcProviderARN := fmt.Sprintf("arn:aws:iam::%s:oidc-provider/%s", accountId, oidcProviderURL)

c := &awssci.Clients{
S3Client: s3.New(sess),
IAMClient: iam.New(sess),
}

return &awssci.Server{
Clients: *c,
OIDCProviderURL: oidcProviderURL,
OIDCProviderARN: oidcProviderARN,
}, nil
}
2 changes: 2 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ require (
require github.com/inconshreveable/mousetrap v1.0.1 // indirect

require (
github.com/aws/aws-sdk-go v1.44.321 // indirect
cloud.google.com/go v0.110.6 // indirect
cloud.google.com/go/compute v1.23.0 // indirect
github.com/evanphx/json-patch v4.12.0+incompatible // indirect
Expand All @@ -34,6 +35,7 @@ require (
github.com/go-playground/universal-translator v0.18.1 // indirect
github.com/google/s2a-go v0.1.4 // indirect
github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 // indirect
github.com/jmespath/go-jmespath v0.4.0 // indirect
github.com/googleapis/enterprise-certificate-proxy v0.2.5 // indirect
github.com/leodido/go-urn v1.2.4 // indirect
github.com/monochromegane/go-gitignore v0.0.0-20200626010858-205db1a8cc00 // indirect
Expand Down
17 changes: 17 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ cloud.google.com/go/storage v1.31.0/go.mod h1:81ams1PrhW16L4kF7qg+4mTq7SRs5HsbDT
github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU=
github.com/antihax/optional v1.0.0/go.mod h1:uupD/76wgC+ih3iEmQUL+0Ugr19nfwCT1kdvxnR2qWY=
github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5 h1:0CwZNZbxp69SHPdPJAN/hZIm0C4OItdklCFmMRWYpio=
github.com/aws/aws-sdk-go v1.44.321 h1:iXwFLxWjZPjYqjPq0EcCs46xX7oDLEELte1+BzgpKk8=
github.com/aws/aws-sdk-go v1.44.321/go.mod h1:aVsgQcEevwlmQ7qHE9I3h+dtQgpqhFB+i8Phjh7fkwI=
github.com/benbjohnson/clock v1.1.0 h1:Q92kusRqC1XV2MjkWETPvjJVqKetz1OzxZB7mHJLju8=
github.com/benbjohnson/clock v1.1.0/go.mod h1:J11/hYXuz8f4ySSvYwY0FKfm+ezbsZBKZxNJlLklBHA=
github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM=
Expand Down Expand Up @@ -60,6 +62,10 @@ github.com/gabriel-vasile/mimetype v1.4.2/go.mod h1:zApsH/mKG4w07erKIaJPFiX0Tsq9
github.com/ghodss/yaml v1.0.0/go.mod h1:4dBDuWmgqj2HViK6kFavaiC9ZROes6MMH2rRYeMEF04=
github.com/go-errors/errors v1.4.2 h1:J6MZopCL4uSllY1OfXM374weqZFFItUbrImctkmUxIA=
github.com/go-errors/errors v1.4.2/go.mod h1:sIVyrIiJhuEF+Pj9Ebtd6P/rEYROXFi3BopGUQ5a5Og=
github.com/go-gl/glfw v0.0.0-20190409004039-e6da0acd62b1/go.mod h1:vR7hzQXu2zJy9AVAgeJqvqgH9Q5CA+iKCZ2gyEVpxRU=
github.com/go-gl/glfw/v3.3/glfw v0.0.0-20191125211704-12ad95a8df72/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8=
github.com/go-gl/glfw/v3.3/glfw v0.0.0-20200222043503-6f7a984d4dc4/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8=
github.com/go-logfmt/logfmt v0.5.1/go.mod h1:WYhtIu8zTZfxdn5+rREduYbwxfcBr/Vr6KEVveWlfTs=
github.com/go-logr/logr v1.2.0/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
github.com/go-logr/logr v1.2.4 h1:g01GSCwiDw2xSZfjJ2/T9M+S6pFdcNtFYsp+Y43HYDQ=
github.com/go-logr/logr v1.2.4/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
Expand Down Expand Up @@ -136,10 +142,18 @@ github.com/imdario/mergo v0.3.6/go.mod h1:2EnlNZ0deacrJVfApfmtdGgDfMuh/nq6Ok1EcJ
github.com/inconshreveable/mousetrap v1.0.1 h1:U3uMjPSQEBMNp1lFxmllqCPM6P5u/Xq7Pgzkat/bFNc=
github.com/inconshreveable/mousetrap v1.0.1/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
github.com/jessevdk/go-flags v1.4.0/go.mod h1:4FA24M0QyGHXBuZZK/XkWh8h0e1EYbRYJSGM75WSRxI=
github.com/jmespath/go-jmespath v0.4.0 h1:BEgLn5cpjn8UN1mAw4NjwDrS35OdebyEtFe+9YPoQUg=
github.com/jmespath/go-jmespath v0.4.0/go.mod h1:T8mJZnbsbmF+m6zOOFylbeCJqk5+pHWvzYPziyZiYoo=
github.com/jmespath/go-jmespath/internal/testify v1.5.1/go.mod h1:L3OGu8Wl2/fWfCI6z80xFu9LTZmf1ZRjMHUOPmWr69U=
github.com/josharian/intern v1.0.0 h1:vlS4z54oSdjm0bgjRigI+G1HpF+tI+9rE5LLzOg8HmY=
github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y=
github.com/jpillora/backoff v1.0.0/go.mod h1:J/6gKK9jxlEcS3zixgDgUAsiuZ7yrSoa/FX5e0EB2j4=
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
github.com/jstemmer/go-junit-report v0.0.0-20190106144839-af01ea7f8024/go.mod h1:6v2b51hI/fHJwM22ozAgKL4VKDeJcHhJFhtBdhmNjmU=
github.com/jstemmer/go-junit-report v0.9.1 h1:6QPYqodiu3GuPL+7mfx+NwDdp2eTkp9IfEUpgAwUN0o=
github.com/jstemmer/go-junit-report v0.9.1/go.mod h1:Brl9GWCQeLvo8nXZwPNNblvFj/XSXhF0NWZEnDohbsk=
github.com/julienschmidt/httprouter v1.3.0/go.mod h1:JR6WtHb+2LUe8TCKY3cZOxFyyO8IZAc4RVcycCCAKdM=
github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8=
github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck=
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
Expand Down Expand Up @@ -171,6 +185,7 @@ github.com/monochromegane/go-gitignore v0.0.0-20200626010858-205db1a8cc00 h1:n6/
github.com/monochromegane/go-gitignore v0.0.0-20200626010858-205db1a8cc00/go.mod h1:Pm3mSP3c5uWn86xMLZ5Sa7JB9GsEZySvHYXCTK4E9q4=
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA=
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ=
github.com/mwitkow/go-conntrack v0.0.0-20190716064945-2f068394615f/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U=
github.com/onsi/ginkgo/v2 v2.9.5 h1:+6Hr4uxzP4XIUyAkg61dWBw8lb/gc4/X5luuxN/EC+Q=
github.com/onsi/gomega v1.27.7 h1:fVih9JD6ogIiHUN6ePK7HJidyEDpWGVB5mzM7cWNXoU=
github.com/pkg/browser v0.0.0-20210911075715-681adbf594b8 h1:KoWmjvw+nsYOo29YJK9vDA65RGE3NrOnUtO7a+RF9HU=
Expand Down Expand Up @@ -250,6 +265,8 @@ golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
golang.org/x/mod v0.10.0 h1:lFO9qtOdlre5W1jxS3r/4szv2/6iXxScdzjoBMXNhYk=
golang.org/x/mod v0.10.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
Expand Down
9 changes: 8 additions & 1 deletion install/kubernetes/aws/eks-cluster.yaml.tpl
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,9 @@ iam:
wellKnownPolicies:
ebsCSIController: true
- metadata:
name: aws-manager
name: substratus
namespace: substratus
roleName: substratus
attachPolicy:
# https://docs.aws.amazon.com/AmazonS3/latest/userguide/using-presigned-url.html
Version: "2012-10-17"
Expand Down Expand Up @@ -94,3 +95,9 @@ iam:
- "iam:UpdateAssumeRolePolicy"
Resource:
- "arn:aws:iam::${AWS_ACCOUNT_ID}:role/$${aws:userid}"
- Sid: "DescribeSubstratusCluster"
Effect: Allow
Action:
- "eks:DescribeCluster"
Resource:
- "arn:aws:eks:${AWS_REGION}:${AWS_ACCOUNT_ID}:cluster/${CLUSTER_NAME}"
12 changes: 0 additions & 12 deletions internal/awsmanager/manager.go

This file was deleted.

201 changes: 201 additions & 0 deletions internal/sci/aws/server.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,201 @@
// Package sciaws provides an AWS implementation of the Substratus Cloud Interface (SCI)
package aws

import (
"context"
"encoding/base64"
"encoding/hex"
"encoding/json"
"fmt"
"net/url"
"os"
"time"

"github.com/aws/aws-sdk-go/aws"
awsSdk "github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/eks"
"github.com/aws/aws-sdk-go/service/iam"
"github.com/aws/aws-sdk-go/service/s3"
"github.com/aws/aws-sdk-go/service/sts"
"github.com/substratusai/substratus/internal/sci"
)

type Server struct {
sci.UnimplementedControllerServer
OIDCProviderURL string
OIDCProviderARN string
Clients
}

type Clients struct {
S3Client *s3.S3
IAMClient *iam.IAM
}

func (s *Server) GetObjectMd5(ctx context.Context, req *sci.GetObjectMd5Request) (*sci.GetObjectMd5Response, error) {
// ensure the object is accessible
bucketName, objectName := req.GetBucketName(), req.GetObjectName()
input := &s3.HeadObjectInput{
Bucket: awsSdk.String(bucketName),
Key: awsSdk.String(objectName),
}
headResult, err := s.Clients.S3Client.HeadObject(input)
if err != nil {
return nil, err
}

// NOTE: AWS returns an MD5 checksum as an ETag except for multi-part uploads where it's an MD5 with a dash suffix.
if headResult.ETag == nil {
return nil, fmt.Errorf("object does not exist: %s", s3.ErrCodeNoSuchKey)
}

md5 := *headResult.ETag

return &sci.GetObjectMd5Response{
Md5Checksum: md5,
}, nil
}

func (s *Server) CreateSignedURL(ctx context.Context, req *sci.CreateSignedURLRequest) (*sci.CreateSignedURLResponse, error) {
bucketName, objectName, checksum := req.GetBucketName(),
req.GetObjectName(),
req.GetMd5Checksum()

// Convert hex MD5 to base64
data, err := hex.DecodeString(checksum)
if err != nil {
return nil, fmt.Errorf("failed to decode MD5 checksum: %w", err)
}
base64md5 := base64.StdEncoding.EncodeToString(data)

reqInput := &s3.PutObjectInput{
Bucket: awsSdk.String(bucketName),
Key: awsSdk.String(objectName),
ContentType: awsSdk.String("application/octet-stream"),
ContentMD5: awsSdk.String(base64md5),
}

expiration := time.Duration(req.GetExpirationSeconds()) * time.Second
putReq, _ := s.Clients.S3Client.PutObjectRequest(reqInput)
url, err := putReq.Presign(expiration)
if err != nil {
return nil, fmt.Errorf("failed to presign request: %w", err)
}
return &sci.CreateSignedURLResponse{Url: url}, nil
}

func (s *Server) BindIdentity(ctx context.Context, req *sci.BindIdentityRequest) (*sci.BindIdentityResponse, error) {
// Fetch the current trust policy
getRoleInput := &iam.GetRoleInput{
RoleName: awsSdk.String(req.Principal),
}
getRoleOutput, err := s.Clients.IAMClient.GetRole(getRoleInput)
if err != nil {
return nil, fmt.Errorf("failed to get the role: %v", err)
}

// URL decode the trust policy before decoding
decodedPolicy, err := url.QueryUnescape(*getRoleOutput.Role.AssumeRolePolicyDocument)
if err != nil {
return nil, fmt.Errorf("failed to decode trust policy: %v", err)
}

// Decode the current trust policy
var existingTrustPolicy map[string]interface{}
if err := json.Unmarshal([]byte(decodedPolicy), &existingTrustPolicy); err != nil {
return nil, fmt.Errorf("failed to unmarshal trust policy: %v", err)
}

subValue := fmt.Sprintf("system:serviceaccount:%s:%s", req.KubernetesNamespace, req.KubernetesServiceAccount)

// Check if the OIDC provider's trust relationship already exists
statements := existingTrustPolicy["Statement"].([]interface{})
alreadyExists := false
for _, stmt := range statements {
stmtMap := stmt.(map[string]interface{})
if principal, ok := stmtMap["Principal"].(map[string]interface{}); ok {
if federated, ok := principal["Federated"].(string); ok && federated == s.OIDCProviderARN {
condition := stmtMap["Condition"].(map[string]interface{})["StringEquals"].(map[string]interface{})
condition[fmt.Sprintf("%s:sub", s.OIDCProviderURL)] = subValue
alreadyExists = true
break
}
}
}

// Construct the new trust relationship
newTrustRelationship := map[string]interface{}{
"Effect": "Allow",
"Principal": map[string]interface{}{
"Federated": s.OIDCProviderARN,
},
"Action": "sts:AssumeRoleWithWebIdentity",
"Condition": map[string]interface{}{
"StringEquals": map[string]string{
fmt.Sprintf("%s:sub", s.OIDCProviderURL): subValue,
},
},
}
if !alreadyExists {
// Append the new trust relationship to the existing policy
existingTrustPolicy["Statement"] = append(statements, newTrustRelationship)
}

updatedTrustPolicy, err := json.Marshal(existingTrustPolicy)
if err != nil {
return nil, fmt.Errorf("failed to marshal updated trust policy: %v", err)
}

// Apply the updated policy
input := &iam.UpdateAssumeRolePolicyInput{
PolicyDocument: awsSdk.String(string(updatedTrustPolicy)),
RoleName: awsSdk.String(req.Principal),
}

_, err = s.Clients.IAMClient.UpdateAssumeRolePolicy(input)
if err != nil {
return nil, fmt.Errorf("failed to update trust policy: %v", err)
}

return &sci.BindIdentityResponse{}, nil
}

func GetAccountID(stsSvc *sts.STS) (string, error) {
result, err := stsSvc.GetCallerIdentity(&sts.GetCallerIdentityInput{})
if err == nil {
return *result.Account, nil
}

// Fall back to the environment variable if the sts:GetCallerIdentity call fails
envAccountID := os.Getenv("AWS_ACCOUNT_ID")
if envAccountID != "" {
return envAccountID, nil
}

return "", fmt.Errorf("failed to determine AWS account ID from both STS and environment variable")
}

func GetClusterID() (string, error) {
// Fall back to the environment variable
clusterID := os.Getenv("CLUSTER_NAME")
if clusterID != "" {
return clusterID, nil
}

return clusterID, nil
}

func GetOidcProviderUrl(sess *session.Session, clusterName string) (string, error) {
svc := eks.New(sess)
input := &eks.DescribeClusterInput{
Name: aws.String(clusterName),
}

result, err := svc.DescribeCluster(input)
if err != nil {
return "", err
}

return *result.Cluster.Identity.Oidc.Issuer, nil
}

0 comments on commit 6cacb6d

Please sign in to comment.