Skip to content

Commit

Permalink
Merge pull request #44 from uc-cdis/feat/multiple_paymodels
Browse files Browse the repository at this point in the history
HP-682 Feat/multiple paymodels
  • Loading branch information
mfshao committed May 16, 2022
2 parents 36042ae + 64f055f commit f611912
Show file tree
Hide file tree
Showing 13 changed files with 590 additions and 148 deletions.
1 change: 1 addition & 0 deletions .github/workflows/golang-ci-workflow.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ on: push

jobs:
ci:
name: golang-ci
runs-on: ubuntu-latest
env:
COVERAGE_PROFILE_OUTPUT_LOCATION: "./profile.cov"
Expand Down
26 changes: 26 additions & 0 deletions hatchery/alb.go
Original file line number Diff line number Diff line change
Expand Up @@ -216,3 +216,29 @@ func (creds *CREDS) CreateLoadBalancer(userName string) (*elbv2.CreateLoadBalanc
}
return loadBalancer, targetGroup.TargetGroups[0].TargetGroupArn, listener, nil
}

func (creds *CREDS) terminateLoadBalancer(userName string) error {
svc := elbv2.New(session.Must(session.NewSession(&aws.Config{
Credentials: creds.creds,
Region: aws.String("us-east-1"),
})))
albName := truncateString(strings.ReplaceAll(userToResourceName(userName, "service")+os.Getenv("GEN3_ENDPOINT"), ".", "-")+"alb", 32)

getInput := &elbv2.DescribeLoadBalancersInput{
Names: []*string{aws.String(albName)},
}
result, err := svc.DescribeLoadBalancers(getInput)
if err != nil {
return err
}
if len(result.LoadBalancers) == 1 {
delInput := &elbv2.DeleteLoadBalancerInput{
LoadBalancerArn: result.LoadBalancers[0].LoadBalancerArn,
}
_, err := svc.DeleteLoadBalancer(delInput)
if err != nil {
return err
}
}
return nil
}
24 changes: 17 additions & 7 deletions hatchery/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,19 +56,29 @@ type AppConfigInfo struct {

// TODO remove PayModel from config once DynamoDB contains all necessary data
type PayModel struct {
Name string `json:"name"`
User string `json:"user_id"`
AWSAccountId string `json:"aws_account_id"`
Region string `json:"region"`
Ecs string `json:"ecs"`
VpcId string `json:"vpcid"`
Subnet int `json:"subnet"`
Id string `json:"bmh_workspace_id"`
Name string `json:"workspace_type"`
User string `json:"user_id"`
AWSAccountId string `json:"account_id"`
Region string `json:"region"`
Ecs bool `json:"ecs"`
Subnet int `json:"subnet"`
HardLimit float32 `json:"hard-limit"`
SoftLimit float32 `json:"soft-limit"`
TotalUsage float32 `json:"total-usage"`
CurrentPayModel bool `json:"current_pay_model"`
}

type AllPayModels struct {
CurrentPayModel *PayModel `json:"current_pay_model"`
PayModels []PayModel `json:"all_pay_models"`
}

// HatcheryConfig is the root of all the configuration
type HatcheryConfig struct {
UserNamespace string `json:"user-namespace"`
DefaultPayModel PayModel `json:"default-pay-model"`
DisableLocalWS bool `json:"disable-local-ws"`
PayModels []PayModel `json:"pay-models"`
PayModelsDynamodbTable string `json:"pay-models-dynamodb-table"`
SubDir string `json:"sub-dir"`
Expand Down
1 change: 0 additions & 1 deletion hatchery/ec2.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,6 @@ func (creds *CREDS) describeWorkspaceNetwork(userName string) (*NetworkInfo, err
}
Config.Logger.Printf("Create Security Group: %s", *newSecurityGroup.GroupId)

// TODO: Make this secure. Right now it's wide open
ingressRules := ec2.AuthorizeSecurityGroupIngressInput{
GroupId: newSecurityGroup.GroupId,
IpPermissions: []*ec2.IpPermission{
Expand Down
15 changes: 9 additions & 6 deletions hatchery/ecs.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@ func (input *CreateTaskDefinitionInput) Environment() []*ecs.KeyValuePair {
}

// Create ECS cluster
// TODO: Evaluate if this is still this needed..
func (sess *CREDS) launchEcsCluster(userName string) (*ecs.Cluster, error) {
svc := sess.svc
clusterName := strings.ReplaceAll(os.Getenv("GEN3_ENDPOINT"), ".", "-") + "-cluster"
Expand Down Expand Up @@ -201,7 +200,7 @@ func (sess *CREDS) statusEcsWorkspace(ctx context.Context, userName string, acce
if err != nil {
return &status, err
}

// TODO: Check TransitGatewayAttachment is not in Deleting state (Can't create new one until it's deleted).
var taskDefName string
if len(service.Services) > 0 {
statusMessage = *service.Services[0].Status
Expand Down Expand Up @@ -330,7 +329,13 @@ func terminateEcsWorkspace(ctx context.Context, userName string, accessToken str
if err != nil {
return "", err
}
// TODO: Terminate ALB + target group here too

// Terminate load balancer
err = svc.terminateLoadBalancer(userName)
if err != nil {
return "", err
}

err = teardownTransitGateway(userName)
if err != nil {
return "", err
Expand All @@ -339,8 +344,6 @@ func terminateEcsWorkspace(ctx context.Context, userName string, accessToken str
}

func launchEcsWorkspace(ctx context.Context, userName string, hash string, accessToken string, payModel PayModel) error {
// TODO: Setup EBS volume as pd
// Must create volume using SDK too.. :(
roleARN := "arn:aws:iam::" + payModel.AWSAccountId + ":role/csoc_adminvm"
sess := session.Must(session.NewSession(&aws.Config{
// TODO: Make this configurable
Expand Down Expand Up @@ -486,6 +489,7 @@ func launchEcsWorkspace(ctx context.Context, userName string, hash string, acces
}
return err
}

err = setupTransitGateway(userName)
if err != nil {
return err
Expand All @@ -499,7 +503,6 @@ func launchEcsWorkspace(ctx context.Context, userName string, hash string, acces
}
return err
}

fmt.Printf("Launched ECS workspace service at %s for user %s\n", launchTask, userName)
return nil
}
Expand Down
4 changes: 2 additions & 2 deletions hatchery/efs.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,10 @@ func (creds *CREDS) createAccessPoint(FileSystemId string, userName string, svc
if err != nil {
return nil, err
}

ap := userToResourceName(userName, "service") + "-" + strings.ReplaceAll(os.Getenv("GEN3_ENDPOINT"), ".", "-") + "-accesspoint"
if len(exResult.AccessPoints) == 0 {
input := &efs.CreateAccessPointInput{
ClientToken: aws.String(fmt.Sprintf("ap-%s", userToResourceName(userName, "pod"))),
ClientToken: aws.String(ap),
FileSystemId: aws.String(FileSystemId),
PosixUser: &efs.PosixUser{
Gid: aws.Int64(100),
Expand Down
102 changes: 85 additions & 17 deletions hatchery/hatchery.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ func RegisterHatchery(mux *httptrace.ServeMux) {
mux.HandleFunc("/status", status)
mux.HandleFunc("/options", options)
mux.HandleFunc("/paymodels", paymodels)
mux.HandleFunc("/setpaymodel", setpaymodel)
mux.HandleFunc("/allpaymodels", allpaymodels)

// ECS functions
mux.HandleFunc("/create-ecs-cluster", createECSCluster)
Expand Down Expand Up @@ -55,16 +57,65 @@ func paymodels(w http.ResponseWriter, r *http.Request) {
return
}
userName := getCurrentUserName(r)
payModel, err := getPayModelForUser(userName)

payModel, err := getCurrentPayModel(userName)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
if payModel == nil {
http.Error(w, err.Error(), http.StatusNotFound)
http.Error(w, "Current paymodel not set", http.StatusNotFound)
return
}
out, err := json.Marshal(payModel)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
out, err := json.Marshal(payModel)
fmt.Fprint(w, string(out))
}

func allpaymodels(w http.ResponseWriter, r *http.Request) {
if r.Method != "GET" {
http.Error(w, "Not Found", http.StatusNotFound)
return
}
userName := getCurrentUserName(r)

payModels, err := getPayModelsForUser(userName)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
if payModels == nil {
http.Error(w, "No paymodel set", http.StatusNotFound)
return
}
out, err := json.Marshal(payModels)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
fmt.Fprint(w, string(out))
}

func setpaymodel(w http.ResponseWriter, r *http.Request) {
if r.Method != "POST" {
http.Error(w, "Not Found", http.StatusNotFound)
return
}
userName := getCurrentUserName(r)
id := r.URL.Query().Get("id")
if id == "" {
http.Error(w, "Missing ID argument", http.StatusBadRequest)
return
}
pm, err := setCurrentPaymodel(userName, id)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
out, err := json.Marshal(pm)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
Expand All @@ -76,19 +127,35 @@ func status(w http.ResponseWriter, r *http.Request) {
userName := getCurrentUserName(r)
accessToken := getBearerToken(r)

payModel, err := getPayModelForUser(userName)
payModel, err := getCurrentPayModel(userName)
if err != nil {
Config.Logger.Printf(err.Error())
if err != NopaymodelsError {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
}
var result *WorkspaceStatus
if payModel != nil && payModel.Ecs == "true" {
result, err = statusEcs(r.Context(), userName, accessToken, payModel.AWSAccountId)
} else {

if payModel == nil {
result, err = statusK8sPod(r.Context(), userName, accessToken, payModel)
}
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
} else {
if payModel.Ecs {
result, err = statusEcs(r.Context(), userName, accessToken, payModel.AWSAccountId)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
} else {
result, err = statusK8sPod(r.Context(), userName, accessToken, payModel)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
}
}

out, err := json.Marshal(result)
Expand Down Expand Up @@ -154,18 +221,19 @@ func launch(w http.ResponseWriter, r *http.Request) {
}

userName := getCurrentUserName(r)
payModel, err := getPayModelForUser(userName)
payModel, err := getCurrentPayModel(userName)
if err != nil {
Config.Logger.Printf(err.Error())
}
if payModel == nil {
err = createLocalK8sPod(r.Context(), hash, userName, accessToken)
} else if payModel.Ecs == "true" {
} else if payModel.Ecs {
err = launchEcsWorkspace(r.Context(), userName, hash, accessToken, *payModel)
} else {
err = createExternalK8sPod(r.Context(), hash, userName, accessToken, *payModel)
}
if err != nil {
Config.Logger.Printf("error during launch: %-v", err)
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
Expand All @@ -179,11 +247,11 @@ func terminate(w http.ResponseWriter, r *http.Request) {
}
accessToken := getBearerToken(r)
userName := getCurrentUserName(r)
payModel, err := getPayModelForUser(userName)
payModel, err := getCurrentPayModel(userName)
if err != nil {
Config.Logger.Printf(err.Error())
}
if payModel != nil && payModel.Ecs == "true" {
if payModel != nil && payModel.Ecs {
svc, err := terminateEcsWorkspace(r.Context(), userName, accessToken, payModel.AWSAccountId)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
Expand Down Expand Up @@ -219,7 +287,7 @@ func getBearerToken(r *http.Request) string {
// TODO: NEED TO CALL THIS FUNCTION IF IT DOESN'T EXIST!!!
func createECSCluster(w http.ResponseWriter, r *http.Request) {
userName := getCurrentUserName(r)
payModel, err := getPayModelForUser(userName)
payModel, err := getCurrentPayModel(userName)
if payModel == nil {
http.Error(w, "Paymodel has not been setup for user", http.StatusNotFound)
return
Expand Down
10 changes: 7 additions & 3 deletions hatchery/iam.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,10 @@ func (creds *CREDS) taskRole(userName string) (*string, error) {
Credentials: creds.creds,
Region: aws.String("us-east-1"),
})))
pm := Config.PayModelMap[userName]
pm, err := getCurrentPayModel(userName)
if err != nil {
return nil, err
}
policyArn := fmt.Sprintf("arn:aws:iam::%s:policy/%s", pm.AWSAccountId, fmt.Sprintf("ws-task-policy-%s", userName))
taskRoleInput := &iam.GetRoleInput{
RoleName: aws.String(userToResourceName(userName, "pod")),
Expand Down Expand Up @@ -96,8 +99,9 @@ func (creds *CREDS) taskRole(userName string) (*string, error) {
}

}
// https://docs.aws.amazon.com/AmazonECS/latest/developerguide/task_execution_IAM_role.html
// The task execution role grants the Amazon ECS container and Fargate agents permission to make AWS API calls on your behalf.

// https://docs.aws.amazon.com/AmazonECS/latest/developerguide/task_execution_IAM_role.html
// The task execution role grants the Amazon ECS container and Fargate agents permission to make AWS API calls on your behalf.
const ecsTaskExecutionRoleName = "ecsTaskExecutionRole"
const ecsTaskExecutionPolicyArn = "arn:aws:iam::aws:policy/service-role/AmazonECSTaskExecutionRolePolicy"
const ecsTaskExecutionRoleAssumeRolePolicyDocument = `{
Expand Down

0 comments on commit f611912

Please sign in to comment.