/
utils.go
80 lines (68 loc) · 1.69 KB
/
utils.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
package main
import (
"context"
"errors"
"net/url"
"os"
"regexp"
"strings"
"go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/mongo/options"
)
func SplitOrigins() ([]string, error) {
s, exists := os.LookupEnv("APP_ALLOWED_ORIGINS")
if !exists {
return []string{}, errors.New("ALLOWED_ORIGINS variable missing")
}
origins := strings.Split(s, ",")
for _, origin := range origins {
uri, err := url.ParseRequestURI(origin)
if err == nil && (uri.Scheme != "https" && uri.Scheme != "http") {
return []string{}, errors.New("malformed uri")
}
}
return origins, nil
}
func GetDebug() bool {
return os.Getenv("APP_DEBUG") == "true"
}
func connectOrFail(uri string) (*mongo.Database, *mongo.Client, error) {
dbName, err := getDbName(uri)
if err != nil {
return nil, nil, err
}
client, err := mongo.NewClient(options.Client().ApplyURI(uri))
if err != nil {
return nil, nil, err
}
err = client.Connect(context.Background())
if err != nil {
return nil, nil, err
}
var DB = client.Database(dbName)
err = client.Ping(context.TODO(), nil)
if err != nil {
return nil, nil, err
}
return DB, client, nil
}
func getDbName(uri string) (string, error) {
begining, err := regexp.Compile(`(mongodb([+]srv:|)\/\/(\S*):(\S*)@(\S*)\/)`)
if err != nil {
return "", err
}
match := begining.FindAllString(uri, -1)
if len(match) > 0 {
dbName := strings.Replace(uri, match[0], "", 1)
ending, err := regexp.Compile(`\?(\S*)`)
if err != nil {
return "", err
}
match = ending.FindAllString(dbName, -1)
if len(match) > 0 {
dbName = strings.Replace(string(dbName), match[0], "", 1)
}
return string(dbName), nil
}
return "", errors.New("cannot extract db name")
}