/
general.go
334 lines (282 loc) · 9.05 KB
/
general.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
// Package common contains general functions that are used by various packages and unit tests in ibmcloud-terratest-wrapper module
package common
import (
"crypto/rand"
"crypto/rsa"
"encoding/json"
"errors"
"fmt"
"golang.org/x/crypto/ssh"
"gopkg.in/yaml.v3"
"io"
"os"
"reflect"
"strings"
"testing"
"github.com/stretchr/testify/require"
)
// GetRequiredEnvVars returns a map containing required environment variables and their values
// Fails the test if any are missing
func GetRequiredEnvVars(t *testing.T, variableNames []string) map[string]string {
var missingVariables []string
envVars := make(map[string]string)
for _, variableName := range variableNames {
val, present := os.LookupEnv(variableName)
if present {
envVars[variableName] = val
} else {
missingVariables = append(missingVariables, variableName)
}
}
require.Empty(t, missingVariables, "The following environment variables must be set: %v", missingVariables)
return envVars
}
// GetBeforeAfterDiff takes a JSON string as input and returns a string with the differences
// between the "before" and "after" objects in the JSON.
//
// For example, given the JSON string:
//
// {"before": {"a": 1, "b": 2}, "after": {"a": 2, "b": 3}}
//
// the function would return the string:
//
// "Before: {"b": 2}\nAfter: {"a": 2, "b": 3}"
func GetBeforeAfterDiff(jsonString string) (string, error) {
// Parse the JSON string into a map
var jsonMap map[string]interface{}
err := json.Unmarshal([]byte(jsonString), &jsonMap)
if err != nil {
return "", errors.New("unable to parse JSON string")
}
// Get the "before" and "after" values from the map
before, beforeOk := jsonMap["before"]
after, afterOk := jsonMap["after"]
if !beforeOk || !afterOk {
return "", errors.New("missing 'before' or 'after' key in JSON")
}
// Check if the "before" and "after" values are objects
beforeObject, beforeOk := before.(map[string]interface{})
if !beforeOk {
return "", errors.New("'before' value is not an object")
}
afterObject, afterOk := after.(map[string]interface{})
if !afterOk {
return "", errors.New("'after' value is not an object")
}
// Find the differences between the two objects
diffsBefore := make(map[string]interface{})
for key, value := range beforeObject {
if !reflect.DeepEqual(afterObject[key], value) {
diffsBefore[key] = value
}
}
// Convert the diffs map to a JSON string
diffsJson, err := json.Marshal(diffsBefore)
if err != nil {
return "", errors.New("unable to convert diffs to JSON")
}
// Find the differences between the two objects
diffsAfter := make(map[string]interface{})
for key, value := range afterObject {
if !reflect.DeepEqual(beforeObject[key], value) {
diffsAfter[key] = value
}
}
// Convert the diffs map to a JSON string
diffsJson2, err := json.Marshal(diffsAfter)
if err != nil {
return "", errors.New("unable to convert diffs to JSON")
}
return "Before: " + string(diffsJson) + "\nAfter: " + string(diffsJson2), nil
}
// overwriting duplicate keys
func MergeMaps(maps ...map[string]interface{}) map[string]interface{} {
result := make(map[string]interface{})
for _, m := range maps {
for k, v := range m {
result[k] = v
}
}
return result
}
// Adds value to map[key] only if value != compareValue
func ConditionalAdd(amap map[string]interface{}, key string, value string, compareValue string) {
if value != compareValue {
amap[key] = value
}
}
// ConvertArrayToJsonString is a helper function that will take an array of Golang data types, and return a string
// of the array formatted as a JSON array.
// Helpful to convert Golang arrays into a format that Terraform can consume.
func ConvertArrayToJsonString(arr interface{}) (string, error) {
// first marshal array into json compatible
json, jsonErr := json.Marshal(arr)
if jsonErr != nil {
return "", jsonErr
}
// take json array, wrap as one string, and escape any double quotes inside
s := string(json)
return s, nil
}
// IsArray is a simple helper function that will determine if a given Golang value is a slice or array.
func IsArray(v interface{}) bool {
theType := reflect.TypeOf(v).Kind()
if (theType == reflect.Slice) || (theType == reflect.Array) {
return true
}
return false
}
// StrArrayContains is a helper function that will check an array and see if a value is already present
func StrArrayContains(arr []string, val string) bool {
for _, arrVal := range arr {
if arrVal == val {
return true
}
}
return false
}
// IntArrayContains is a helper function that will check an array and see if an int value is already present
func IntArrayContains(arr []int, val int) bool {
for _, arrVal := range arr {
if arrVal == val {
return true
}
}
return false
}
// LoadMapFromYaml loads a YAML file into a map[string]interface{}.
// It returns the resulting map and any error encountered.
func LoadMapFromYaml(filePath string) (map[string]interface{}, error) {
data, err := os.ReadFile(filePath)
if err != nil {
if errors.Is(err, os.ErrNotExist) {
return nil, fmt.Errorf("file not found: %w", err)
}
return nil, err
}
var result map[string]interface{}
err = yaml.Unmarshal(data, &result)
if err != nil {
return nil, fmt.Errorf("error parsing YAML data: %w", err)
}
return result, nil
}
// Generate an SSH RSA Keypair (4096 bits), and return the PublicKey in OpenSSH Authorized Key format.
// Used for tests to generate unique throw-away (but valid) SSH key to supply to test inputs.
// SPECIAL NOTE: the newline character at end of key will be trimmed and not included!
func GenerateSshRsaPublicKey() (string, error) {
// generate a new RSA key
newkey, keyerr := rsa.GenerateKey(rand.Reader, 4096)
if keyerr != nil {
return "", keyerr
}
// convert the RSA key into OpenSSH structure
pubKey, ssherr := ssh.NewPublicKey(&newkey.PublicKey)
if ssherr != nil {
return "", ssherr
}
// marshall public key into "authorized_key" format string (from binary)
pubKeyStr := string(ssh.MarshalAuthorizedKey(pubKey))
// trim all whitespace, including trailing newline
pubKeyStrTrim := strings.TrimSpace(pubKeyStr)
return pubKeyStrTrim, nil
}
// CopyFile copies a file from source to destination.
// Returns an error if the operation fails.
func CopyFile(source, destination string) error {
// Check path exists
if _, err := os.Stat(source); os.IsNotExist(err) {
return fmt.Errorf("source path %s does not exist: %w", source, err)
}
// Check if source is a symlink
srcInfo, err := os.Lstat(source)
if err != nil {
return fmt.Errorf("failed to stat source file: %w", err)
}
if srcInfo.Mode()&os.ModeSymlink != 0 {
// Source is a symlink
linkTarget, err := os.Readlink(source)
if err != nil {
return fmt.Errorf("failed to read symlink: %w", err)
}
return os.Symlink(linkTarget, destination)
}
src, err := os.Open(source)
if err != nil {
return fmt.Errorf("failed to open source file: %w", err)
}
defer src.Close()
dst, err := os.Create(destination)
if err != nil {
return fmt.Errorf("failed to create destination file: %w", err)
}
defer dst.Close()
_, err = io.Copy(dst, src)
if err != nil {
return fmt.Errorf("failed to copy file contents: %w", err)
}
// Set the permissions of the destination file to match the source file
if err := os.Chmod(destination, srcInfo.Mode()); err != nil {
return fmt.Errorf("failed to set destination file permissions: %w", err)
}
return nil
}
// CopyDirectory copies a directory from source to destination, with optional file filtering.
// src Source directory to copy from
// dst Destination directory to copy to
// fileFilter Optional function to filter files to copy
// Returns an error if the operation fails.
func CopyDirectory(src string, dst string, fileFilter ...func(string) bool) error {
// Check path exists
if _, err := os.Stat(src); os.IsNotExist(err) {
return fmt.Errorf("source path %s does not exist: %w", src, err)
}
// Check if source is a symlink
srcInfo, err := os.Lstat(src)
if err != nil {
return err
}
if err := os.MkdirAll(dst, srcInfo.Mode()); err != nil {
return err
}
directory, _ := os.Open(src)
objects, err := directory.Readdir(-1)
if err != nil {
return err
}
var filterFunc func(string) bool
if len(fileFilter) > 0 && fileFilter[0] != nil {
filterFunc = fileFilter[0]
} else {
// Default behavior: copy all files if no filter is provided
filterFunc = func(_ string) bool {
return true
}
}
for _, obj := range objects {
srcFile := src + "/" + obj.Name()
dstFile := dst + "/" + obj.Name()
if !filterFunc(srcFile) {
continue // Skip files that don't match the filter
}
if obj.IsDir() {
// Create sub-directories - recursively
if err = CopyDirectory(srcFile, dstFile, fileFilter...); err != nil {
return err
}
} else {
// Perform the file copy
if err = CopyFile(srcFile, dstFile); err != nil {
return err
}
}
}
return nil
}
// StringContainsIgnoreCase checks if a string contains a substring, ignoring case.
// Returns true if the string contains the substring, false otherwise.
func StringContainsIgnoreCase(s, substr string) bool {
s = strings.ToLower(s)
substr = strings.ToLower(substr)
return strings.Contains(s, substr)
}