forked from amit-davidson/Chronos
/
testutils.go
100 lines (86 loc) · 2.52 KB
/
testutils.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
package ssaUtils
import (
"go/constant"
"path/filepath"
"runtime"
"testing"
"github.com/pdufour/Chronos/domain"
"github.com/pdufour/Chronos/utils"
"github.com/stretchr/testify/require"
"golang.org/x/tools/go/ssa"
)
func FindGA(GuardedAccesses []*domain.GuardedAccess, validationFunc func(value *domain.GuardedAccess) bool) *domain.GuardedAccess {
wasFound := false
for _, ga := range GuardedAccesses {
wasFound = validationFunc(ga)
if wasFound {
return ga
}
}
return nil
}
func FindMultipleGA(GuardedAccesses []*domain.GuardedAccess, validationFunc func(value *domain.GuardedAccess) bool) []*domain.GuardedAccess {
foundGAs := make([]*domain.GuardedAccess, 0)
for _, ga := range GuardedAccesses {
wasFound := validationFunc(ga)
if wasFound {
foundGAs = append(foundGAs, ga)
}
}
return foundGAs
}
func GetConstString(v *ssa.Const) string {
return constant.StringVal(v.Value)
}
func GetGlobalString(v *ssa.Global) string {
return v.Name()
}
func IsGARead(ga *domain.GuardedAccess) bool {
return ga.OpKind == domain.GuardAccessRead
}
func IsGAWrite(ga *domain.GuardedAccess) bool {
return ga.OpKind == domain.GuardAccessWrite
}
func FindGAWithFail(t *testing.T, GuardedAccesses []*domain.GuardedAccess, validationFunc func(value *domain.GuardedAccess) bool) *domain.GuardedAccess {
res := FindGA(GuardedAccesses, validationFunc)
require.NotNil(t, res)
return res
}
func FindMultipleGAWithFail(t *testing.T, GuardedAccesses []*domain.GuardedAccess, validationFunc func(value *domain.GuardedAccess) bool, expectedAmount int) []*domain.GuardedAccess {
res := FindMultipleGA(GuardedAccesses, validationFunc)
require.Equal(t, expectedAmount, len(res))
return res
}
func LoadMain(t *testing.T, filePath string) (*ssa.Function, *ssa.Package) {
domain.GoroutineCounter = utils.NewCounter()
domain.GuardedAccessCounter = utils.NewCounter()
domain.PosIDCounter = utils.NewCounter()
_, ex, _, ok := runtime.Caller(0)
require.True(t, ok)
modulePath := filepath.Dir(filepath.Dir(ex))
ssaProg, ssaPkg, err := LoadPackage(filePath, modulePath)
require.NoError(t, err)
f := ssaPkg.Func("main")
err = InitPreProcess(ssaProg, modulePath)
require.NoError(t, err)
return f, ssaPkg
}
func EqualDifferentOrder(a, b []*domain.GuardedAccess) bool {
if len(a) != len(b) {
return false
}
diff := make(map[int]int, len(a))
for _, x := range a {
diff[x.ID]++
}
for _, y := range b {
if _, ok := diff[y.ID]; !ok {
return false
}
diff[y.ID] -= 1
if diff[y.ID] == 0 {
delete(diff, y.ID)
}
}
return len(diff) == 0
}