-
Notifications
You must be signed in to change notification settings - Fork 2
/
task_run.go
110 lines (92 loc) · 3.26 KB
/
task_run.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
package actions
import (
"fmt"
"github.com/golang/glog"
"github.com/gorilla/mux"
"github.com/ottogroup/penelope/pkg/config"
"github.com/ottogroup/penelope/pkg/http/impersonate"
"github.com/ottogroup/penelope/pkg/secret"
"github.com/ottogroup/penelope/pkg/tasks"
"go.opencensus.io/trace"
"net"
"net/http"
"strings"
)
type TaskRunHandler struct {
tokenSourceProvider impersonate.TargetPrincipalForProjectProvider
credentialsProvider secret.SecretProvider
}
func NewTaskRunHandler(tokenSourceProvider impersonate.TargetPrincipalForProjectProvider, credentialsProvider secret.SecretProvider) *TaskRunHandler {
return &TaskRunHandler{tokenSourceProvider, credentialsProvider}
}
func (g *TaskRunHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
_, span := trace.StartSpan(r.Context(), "TaskRunHandler.ServeHTTP")
defer span.End()
if err := validateRequest(r); err != nil {
glog.Errorf("request forbidden: %s", err)
msg := "request forbidden"
prepareResponse(w, msg, msg, http.StatusForbidden)
return
}
if task, exist := mux.Vars(r)["task"]; exist {
go tasks.RunTask(task, g.tokenSourceProvider, g.credentialsProvider)
w.WriteHeader(http.StatusCreated)
return
}
msg := "Bad request missing parameter: task"
prepareResponse(w, msg, msg, http.StatusBadRequest)
}
func validateRequest(r *http.Request) error {
if config.TasksValidationHTTPHeaderName.Exist() {
if !config.TasksValidationHTTPHeaderValue.Exist() {
return fmt.Errorf("value for HTTP header validation is not provided for: %s", config.TasksValidationHTTPHeaderName)
}
validationHeader := config.TasksValidationHTTPHeaderName.MustGet()
headerValue := r.Header.Get(validationHeader)
if config.TasksValidationHTTPHeaderValue.MustGet() != headerValue {
return fmt.Errorf("value for header '%s' not provided or wrong: '%s'", validationHeader, headerValue)
}
}
if config.TasksValidationAllowedIPAddresses.Exist() {
allowedIPAddressesRaw := config.TasksValidationAllowedIPAddresses.MustGet()
allowedIPAddresses := strings.Split(allowedIPAddressesRaw, ";")
var invalidIPAddress = true
ip, err := getIP(r)
if err != nil {
return fmt.Errorf("couldn't validate ip: %s", err)
}
for _, ipAddress := range allowedIPAddresses {
if strings.TrimSpace(ipAddress) == ip {
invalidIPAddress = false
}
}
if invalidIPAddress {
return fmt.Errorf("invalid ip address: %s", ip)
}
}
return nil
}
func getIP(r *http.Request) (string, error) {
ip := r.Header.Get("X-REAL-IP")
netIP := net.ParseIP(ip)
if netIP != nil {
return ip, nil
}
ips := r.Header.Get("X-FORWARDED-FOR")
splitIps := strings.Split(ips, ",")
for _, ip := range splitIps {
netIP := net.ParseIP(ip)
if netIP != nil {
return ip, nil
}
}
ip, _, err := net.SplitHostPort(r.RemoteAddr)
if err != nil {
return "", err
}
netIP = net.ParseIP(ip)
if netIP != nil {
return ip, nil
}
return "", fmt.Errorf("no valid ip found")
}