forked from deeplearning4j/deeplearning4j-examples
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Nd4jEx3_GettingAndSettingSubsets.java
106 lines (83 loc) · 4.88 KB
/
Nd4jEx3_GettingAndSettingSubsets.java
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
/* *****************************************************************************
*
*
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
* See the NOTICE file distributed with this work for additional
* information regarding copyright ownership.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.examples.quickstart;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;
import java.util.Arrays;
/**
* --- Nd4j Example 3: Getting and setting parts of INDArrays ---
*
* In this example, we'll see ways to obtain and manipulate subsets of INDArray
*
* @author Alex Black
*/
public class Nd4jEx3_GettingAndSettingSubsets {
public static void main(String[] args){
//Let's start by creating a 3x5 INDArray with manually specified values
// To do this, we are starting with a 1x15 array, and perform a 'reshape' operation to convert it to a 3x5 INDArray
INDArray originalArray = Nd4j.linspace(1,15,15).reshape('c',3,5);
System.out.println("Original Array:");
System.out.println(originalArray);
//We can use getRow and getColumn operations to get a row or column respectively:
INDArray firstRow = originalArray.getRow(0);
INDArray lastColumn = originalArray.getColumn(4);
System.out.println();
System.out.println("First row:\n" + firstRow);
System.out.println("Last column:\n" + lastColumn);
//Careful of the printing here: lastColumn looks like a row vector when printed, but it's really a column vector
System.out.println("Shapes: " + Arrays.toString(firstRow.shape()) + "\t" + Arrays.toString(lastColumn.shape()));
//A key concept in ND4J is the idea of views: one INDArray may point to the same locations in memory as other arrays
//For example, getRow and getColumn are both views of originalArray
//Consequently, changes to one results in changes to the other:
firstRow.addi(1.0); //In-place addition operation: changes the values of both firstRow AND originalArray:
System.out.println("\n\n");
System.out.println("firstRow, after addi operation:");
System.out.println(firstRow);
System.out.println("originalArray, after firstRow.addi(1.0) operation: (note it is modified, as firstRow is a view of originalArray)");
System.out.println(originalArray);
//Let's recreate our our original array for the next section...
originalArray = Nd4j.linspace(1,15,15).reshape('c',3,5);
//We can select arbitrary subsets, using INDArray indexing:
//All rows, first 3 columns (note that internal here is columns 0 inclusive to 3 exclusive)
INDArray first3Columns = originalArray.get(NDArrayIndex.all(), NDArrayIndex.interval(0,3));
System.out.println("first 3 columns:\n" + first3Columns);
//Again, this is also a view:
first3Columns.addi(100);
System.out.println("originalArray, after first3Columns.addi(100)");
System.out.println(originalArray);
//Let's recreate our our original array for the next section...
originalArray = Nd4j.linspace(1,15,15).reshape('c',3,5);
//We can similarly set arbitrary subsets.
//Let's set the 3rd column (index 2) to zeros:
INDArray zerosColumn = Nd4j.zeros(3,1);
originalArray.put(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.point(2)}, zerosColumn); //All rows, column index 2
System.out.println("\n\n\nOriginal array, after put operation:\n" + originalArray);
//Let's recreate our our original array for the next section...
originalArray = Nd4j.linspace(1,15,15).reshape('c',3,5);
//Sometimes, we don't want this in-place behaviour. In this case: just add a .dup() operation at the end
//the .dup() operation - aka 'duplicate' - creates a new and separate array
INDArray firstRowDup = originalArray.getRow(0).dup(); //We now have a copy of the first row. i.e., firstRowDup is NOT a view of originalArray
firstRowDup.addi(100);
System.out.println("\n\n\n");
System.out.println("firstRowDup, after .addi(100):\n" + firstRowDup);
System.out.println("originalArray, after firstRowDup.addi(100): (note it is unmodified)\n" + originalArray);
}
}