forked from deeplearning4j/deeplearning4j-examples
-
Notifications
You must be signed in to change notification settings - Fork 0
/
WorkspacesExample.java
162 lines (135 loc) · 7.91 KB
/
WorkspacesExample.java
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
/* *****************************************************************************
*
*
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
* See the NOTICE file distributed with this work for additional
* information regarding copyright ownership.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.examples.advanced.memoryoptimization;
import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.memory.conf.WorkspaceConfiguration;
import org.nd4j.linalg.api.memory.enums.AllocationPolicy;
import org.nd4j.linalg.api.memory.enums.LearningPolicy;
import org.nd4j.linalg.api.memory.enums.ResetPolicy;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.slf4j.Logger;
/**
* This example shows how to use memory Workspaces with ND4j for cyclic workloads. Only advanced users with very unusual workloads need to delve into this.
* For the most part users needn't concern themselves with how DL4J reuses memory internally to increase performance.
*
* Background:
*
* ND4j Workspace is a memory chunk, allocated once, and reused over in over.
* Basically it gives you a way to avoid garbage collection for off-heap memory if you work with cyclic workloads.
*
* PLEASE NOTE: Workspaces are OPTIONAL. If you prefer using original GC-based memory managemend - you can use it without any issues.
* PLEASE NOTE: When working with workspaces, YOU are responsible for tracking scopes etc. You are NOT supposed to access any INDArray that's attached to some workspace, outside of it. Results will be unpredictable, up to JVM crashes.
*
* @author raver119@gmail.com
*/
@SuppressWarnings("unused") //variable names clarify the example.
public class WorkspacesExample {
private static final Logger log = org.slf4j.LoggerFactory.getLogger(WorkspacesExample.class);
public static void main(String[] args) {
/*
* Each workspace is tied to a JVM Thread via ID. So, same ID in different threads will point to different actual workspaces
* Each workspace is created using some configuration, and different workspaces can either share the same config, or have their own
*/
// we create config with 10MB memory space pre allocated
WorkspaceConfiguration initialConfig = WorkspaceConfiguration.builder()
.initialSize(10 * 1024L * 1024L)
.policyAllocation(AllocationPolicy.STRICT)
.policyLearning(LearningPolicy.NONE)
.build();
INDArray result;
// we use
try(MemoryWorkspace ws = Nd4j.getWorkspaceManager().getAndActivateWorkspace(initialConfig, "SOME_ID")) {
// now, every INDArray created within this try block will be allocated from this workspace pool
INDArray array = Nd4j.rand(10, 10);
// easiest way to see if this array is attached to some workspace. We expect TRUE printed out here.
log.info("Array attached? {}", array.isAttached());
// please note, new array mean will be also attached to this workspace
INDArray mean = array.mean(1);
/*
* PLEASE NOTE: if after doing some operations on the workspace, you want to bring result away from it, you should either leverage it, or detach
*/
result = mean.detach();
}
// Since we've detached array, we expect FALSE printed out here. So, result array is managed by GC now.
log.info("Array attached? {}", result.isAttached());
/*
* Workspace can be initially preallocated as shown above, or can be learning their desired size over time, or after first loop
*/
WorkspaceConfiguration learningConfig = WorkspaceConfiguration.builder()
.policyAllocation(AllocationPolicy.STRICT) // <-- this option disables overallocation behavior
.policyLearning(LearningPolicy.FIRST_LOOP) // <-- this option makes workspace learning after first loop
.build();
for (int x = 0; x < 10; x++) {
try (MemoryWorkspace ws = Nd4j.getWorkspaceManager().getAndActivateWorkspace(learningConfig, "OTHER_ID")) {
INDArray array = Nd4j.create(100);
/*
* At first iteration, workspace will be spilling all allocations as separate memory chunks.
* But after first iteration is finished - workspace will be allocated, to match all required allocations in this loop
* So, further iterations will be reusing workspace memory over and over again
*/
}
}
/*
* Workspaces can be nested. And INDArrays can migrate between them, if needed
*/
try(MemoryWorkspace ws1 = Nd4j.getWorkspaceManager().getAndActivateWorkspace(initialConfig, "SOME_ID")) {
INDArray array = Nd4j.create(10, 10).assign(1.0f);
INDArray sumRes;
try(MemoryWorkspace ws2 = Nd4j.getWorkspaceManager().getAndActivateWorkspace(initialConfig, "THIRD_ID")) {
// PLEASE NOTE: we can access memory from parent workspace without any issues ONLY if it wasn't closed/reset yet.
INDArray res = array.sum(1);
// array is allocated at ws1, and res is allocated in ws2. But we can migrate them if/when needed.
sumRes = res.leverageTo("SOME_ID");
}
// at this point sumRes contains valid data, allocated in current workspace. We expect 100 printed here.
log.info("Sum: {}", sumRes.sumNumber().floatValue());
}
/*
* You can break your workspace flow, if, for some reason you need part of calculations to be handled with GC
*/
try(MemoryWorkspace ws1 = Nd4j.getWorkspaceManager().getAndActivateWorkspace(initialConfig, "SOME_ID")) {
INDArray array1 = Nd4j.create(10, 10).assign(1.0f);
INDArray array2;
try(MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()) {
// anything allocated within this try block will be managed by GC
array2 = Nd4j.create(10, 10).assign(2.0f);
}
// at this point sumRes contains valid data, allocated in current workspace. We expect 300 printed here.
log.info("Sum: {}", array1.addi(array2).sumNumber().floatValue());
}
/*
* It's also possible to build workspace that'll be acting as circular buffer.
*/
WorkspaceConfiguration circularConfig = WorkspaceConfiguration.builder()
.initialSize(10 * 1024L * 1024L)
.policyAllocation(AllocationPolicy.STRICT)
.policyLearning(LearningPolicy.NONE) // <-- this options disables workspace reallocation over time
.policyReset(ResetPolicy.ENDOFBUFFER_REACHED) // <--- this option makes workspace act as circular buffer, beware.
.build();
for (int x = 0; x < 10; x++) {
// since this workspace is circular, we know that all pointers allocated before buffer ended - will be viable.
try (MemoryWorkspace ws1 = Nd4j.getWorkspaceManager().getAndActivateWorkspace(circularConfig, "CIRCULAR_ID")) {
INDArray array = Nd4j.create(100);
// so, you can use this array anywhere as long as YOU're sure buffer wasn't reset.
// in other words: it's suitable for producer/consumer pattern use if you're in charge of flow
}
}
}
}