forked from google/agi
-
Notifications
You must be signed in to change notification settings - Fork 0
/
memory.go
252 lines (224 loc) · 7.96 KB
/
memory.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
// Copyright (C) 2018 Google Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package dependencygraph2
import (
"context"
"math/bits"
"github.com/google/gapid/core/math/interval"
"github.com/google/gapid/gapis/api"
"github.com/google/gapid/gapis/memory"
)
type MemoryAccess struct {
Node NodeID
Pool memory.PoolID
Span interval.U64Span
Mode AccessMode
Deps []NodeID
}
type MemWatcher interface {
OnWriteSlice(ctx context.Context, cmdCtx CmdContext, s memory.Slice)
OnReadSlice(ctx context.Context, cmdCtx CmdContext, s memory.Slice)
OnWriteObs(ctx context.Context, cmdCtx CmdContext, obs []api.CmdObservation, nodes []NodeID)
OnReadObs(ctx context.Context, cmdCtx CmdContext, obs []api.CmdObservation, nodes []NodeID)
OnBeginCmd(ctx context.Context, cmdCtx CmdContext)
OnEndCmd(ctx context.Context, cmdCtx CmdContext) map[NodeID][]MemoryAccess
OnBeginSubCmd(ctx context.Context, cmdCtx CmdContext, subCmdCtx CmdContext)
OnEndSubCmd(ctx context.Context, cmdCtx CmdContext)
}
func NewMemWatcher() *memWatcher {
return &memWatcher{
pendingAccesses: make(map[memory.PoolID]*memoryAccessList),
memoryWrites: make(map[memory.PoolID]*memoryWriteList),
nodeAccesses: make(map[NodeID][]MemoryAccess),
}
}
type memWatcher struct {
pendingAccesses map[memory.PoolID]*memoryAccessList
memoryWrites map[memory.PoolID]*memoryWriteList
nodeAccesses map[NodeID][]MemoryAccess
isPostFence bool
stats struct {
// The distribution of the number of relevant writes for each read
RelevantWriteDist Distribution
}
}
func (b *memWatcher) OnWriteSlice(ctx context.Context, cmdCtx CmdContext, slice memory.Slice) {
span := interval.U64Span{
Start: slice.Base(),
End: slice.Base() + slice.Size(),
}
if list, ok := b.pendingAccesses[slice.Pool()]; ok {
list.AddWrite(span)
} else {
b.pendingAccesses[slice.Pool()] = &memoryAccessList{memoryAccess{ACCESS_WRITE, span}}
}
}
func (b *memWatcher) OnReadSlice(ctx context.Context, cmdCtx CmdContext, slice memory.Slice) {
span := interval.U64Span{
Start: slice.Base(),
End: slice.Base() + slice.Size(),
}
if list, ok := b.pendingAccesses[slice.Pool()]; ok {
list.AddRead(span)
} else {
b.pendingAccesses[slice.Pool()] = &memoryAccessList{memoryAccess{ACCESS_READ, span}}
}
}
func (b *memWatcher) OnWriteObs(ctx context.Context, cmdCtx CmdContext, obs []api.CmdObservation, nodeIDs []NodeID) {
for i, o := range obs {
b.addObs(ctx, cmdCtx, o, true, nodeIDs[i])
}
}
func (b *memWatcher) OnReadObs(ctx context.Context, cmdCtx CmdContext, obs []api.CmdObservation, nodeIDs []NodeID) {
for i, o := range obs {
b.addObs(ctx, cmdCtx, o, false, nodeIDs[i])
}
}
// Flush commits the pending memory accesses accumulated so far.
func (b *memWatcher) Flush(ctx context.Context, cmdCtx CmdContext) {
nodeID := cmdCtx.nodeID
memAccesses := b.nodeAccesses[nodeID]
// DO NOT REMOVE! Optimization: manually set the final slice capacity,
// to avoid numerous realloc. This has a perceptible impact on big
// captures where it can save several seconds of computation.
// Compute the maximum possible of size of memAccesses at the end of `Flush`.
memAccessesCap := len(memAccesses)
for _, acc := range b.pendingAccesses {
memAccessesCap += len(*acc)
}
// Ensure that memAccesses has sufficient capacity
if memAccessesCap > cap(memAccesses) {
// round up to next power of 2
memAccessesCap = 1 << uint(bits.Len(uint(memAccessesCap-1)))
newMemAccesses := make([]MemoryAccess, len(memAccesses), memAccessesCap)
copy(newMemAccesses, memAccesses)
memAccesses = newMemAccesses
}
// Iterate over this command's node pending memory accesses, and create the
// list of nodes it depends on, i.e. the nodes that are the latest to have
// written in the memory locations that this node reads. Also, update
// b.memoryWrites with the writes that this node performs.
for poolID, accessList := range b.pendingAccesses {
for _, access := range *accessList {
writeNodes := []NodeID{}
mode := access.mode & (ACCESS_PLAIN_READ | ACCESS_PLAIN_WRITE)
if access.mode&ACCESS_DEP_READ != 0 {
writeNodes = applyMemRead(b.memoryWrites, poolID, access.span)
b.stats.RelevantWriteDist.Add(uint64(len(writeNodes)))
// There is a relevant dependency read only if this node reads
// the write of at least one other node.
if len(writeNodes) > 0 {
mode |= ACCESS_DEP_READ
}
}
if access.mode&ACCESS_DEP_WRITE != 0 && poolID != 0 {
// There is a dependency write only if this node writes to a
// memory location that was latest written to by an other node.
if applyMemWrite(b.memoryWrites, poolID, nodeID, access.span) {
mode |= ACCESS_DEP_WRITE
}
}
memAccesses = append(memAccesses, MemoryAccess{
Node: nodeID,
Pool: poolID,
Span: access.span,
Mode: mode,
Deps: writeNodes,
})
}
}
b.nodeAccesses[nodeID] = memAccesses
b.pendingAccesses = make(map[memory.PoolID]*memoryAccessList)
}
func (b *memWatcher) NodeAccesses() map[NodeID][]MemoryAccess {
return b.nodeAccesses
}
func (b *memWatcher) OnBeginCmd(ctx context.Context, cmdCtx CmdContext) {
b.isPostFence = false
}
func (b *memWatcher) OnEndCmd(ctx context.Context, cmdCtx CmdContext) map[NodeID][]MemoryAccess {
b.Flush(ctx, cmdCtx)
acc := b.nodeAccesses
b.pendingAccesses = make(map[memory.PoolID]*memoryAccessList)
b.nodeAccesses = make(map[NodeID][]MemoryAccess)
return acc
}
func (b *memWatcher) OnBeginSubCmd(ctx context.Context, cmdCtx CmdContext, subCmdCtx CmdContext) {
b.Flush(ctx, cmdCtx)
}
func (b *memWatcher) OnEndSubCmd(ctx context.Context, cmdCtx CmdContext) {
b.Flush(ctx, cmdCtx)
}
func (b *memWatcher) addObs(ctx context.Context, cmdCtx CmdContext, obs api.CmdObservation, isWrite bool, nodeID NodeID) {
if isWrite && !b.isPostFence {
b.Flush(ctx, cmdCtx)
b.isPostFence = true
}
span := obs.Range.Span()
applyMemWrite(b.memoryWrites, obs.Pool, nodeID, span)
b.nodeAccesses[nodeID] = []MemoryAccess{
MemoryAccess{
Node: nodeID,
Pool: obs.Pool,
Span: span,
// An observation is always a *write* memory access in the
// dependency graph context, in particular: a read observation
// reflects a write by the application to some memory, which the
// driver consumes with a readObs.
Mode: ACCESS_WRITE,
}}
}
// applyMemWrite updates the write map with the write (p, s) of node n, and
// returns true if this write overwrites a write made by one or more other
// nodes: if false is returned, then the node n is just re-writing on places
// that it has already been writing to.
func applyMemWrite(wmap map[memory.PoolID]*memoryWriteList,
p memory.PoolID, n NodeID, s interval.U64Span) bool {
if writes, ok := wmap[p]; ok {
i := interval.Replace(writes, s)
w := &(*writes)[i]
if w.node != n {
w.node = n
return true
}
} else {
wmap[p] = &memoryWriteList{
memoryWrite{
node: n,
span: s,
},
}
return true
}
return false
}
// applyMemRead returns the list of nodes for which there is a memoryWrite in
// wmap ("writeMap") that interesects with (p, s): these nodes are the ones on
// which this read at (p, s) depends.
func applyMemRead(wmap map[memory.PoolID]*memoryWriteList,
p memory.PoolID, s interval.U64Span) []NodeID {
writeNodes := []NodeID{}
if writes, ok := wmap[p]; ok {
i, c := interval.Intersect(writes, s)
depSet := map[NodeID]struct{}{}
for _, w := range (*writes)[i : i+c] {
depSet[w.node] = struct{}{}
}
writeNodes = make([]NodeID, 0, len(depSet))
for d := range depSet {
writeNodes = append(writeNodes, d)
}
}
return writeNodes
}