forked from deeplearning4j/deeplearning4j-examples
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Nd4jEx4_Ops.java
107 lines (88 loc) · 5.12 KB
/
Nd4jEx4_Ops.java
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
/* *****************************************************************************
*
*
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
* See the NOTICE file distributed with this work for additional
* information regarding copyright ownership.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.examples.quickstart;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.transforms.strict.Sin;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.ops.transforms.Transforms;
import java.util.Arrays;
/**
* --- Nd4j Example 4: Additional Operations with INDArrays ---
*
* In this example, we'll see ways to manipulate INDArray
*
* @author Alex Black
*/
public class Nd4jEx4_Ops {
public static void main(String[] args){
/*
ND4J defines a wide variety of operations. Here we'll see how to use some of them:
- Elementwise operations: add, multiply, divide, subtract, etc
add, mul, div, sub,
INDArray.add(INDArray), INDArray.mul(INDArray), etc
- Matrix multiplication: mmul
- Row/column vector ops: addRowVector, mulColumnVector, etc
- Element-wise transforms, like tanh, scalar max operations, etc
*/
//First, let's see how in-place vs. copy operations work
//Consider the calls: myArray.add(1.0) vs myArray.addi(1.0)
// i.e., "add" vs. "addi" -> the "i" means in-place.
//In practice: the in-place ops modify the original array; the others ("copy ops") make a copy
INDArray originalArray = Nd4j.linspace(1,15,15).reshape('c',3,5); //As per example 3
INDArray copyAdd = originalArray.add(1.0);
System.out.println("Same object returned by add: " + (originalArray == copyAdd));
System.out.println("Original array after originalArray.add(1.0):\n" + originalArray);
System.out.println("copyAdd array:\n" + copyAdd);
//Let's do the same thing with the in-place add operation:
INDArray inPlaceAdd = originalArray.addi(1.0);
System.out.println();
System.out.println("Same object returned by addi: " + (originalArray == inPlaceAdd)); //addi returns the exact same Java object
System.out.println("Original array after originalArray.addi(1.0):\n" + originalArray);
System.out.println("inPlaceAdd array:\n" + copyAdd);
//Let's recreate our our original array for the next section, and create another one:
originalArray = Nd4j.linspace(1,15,15).reshape('c',3,5);
INDArray random = Nd4j.rand(3,5); //See example 2; we have a 3x5 with uniform random (0 to 1) values
//We can perform element-wise operations. Note that the array shapes must match here
// add vs. addi works in exactly the same way as for scalars
INDArray added = originalArray.add(random);
System.out.println("\n\n\nRandom values:\n" + random);
System.out.println("Original plus random values:\n" + added);
//Matrix multiplication is easy:
INDArray first = Nd4j.rand(3,4);
INDArray second = Nd4j.rand(4,5);
INDArray mmul = first.mmul(second);
System.out.println("\n\n\nShape of mmul array: " + Arrays.toString(mmul.shape())); //3x5 output as expected
//We can do row-wise ("for each row") and column-wise ("for each column") operations
//Again, inplace vs. copy ops work the same way (i.e., addRowVector vs. addiRowVector)
INDArray row = Nd4j.linspace(0,4,5);
System.out.println("\n\n\nRow:\n" + row);
INDArray mulRowVector = originalArray.mulRowVector(row); //For each row in 'originalArray', do an element-wise multiplication with the row vector
System.out.println("Result of originalArray.mulRowVector(row)");
System.out.println(mulRowVector);
//Element-wise transforms are things like 'tanh' and scalar max values. These can be applied in a few ways:
System.out.println("\n\n\n");
System.out.println("Random array:\n" + random); //Again, note the limited printing precision, as per example 2
System.out.println("Element-wise tanh on random array:\n" + Transforms.tanh(random));
System.out.println("Element-wise power (x^3.0) on random array:\n" + Transforms.pow(random,3.0));
System.out.println("Element-wise scalar max (with scalar 0.5):\n" + Transforms.max(random,0.5));
//We can perform this in a more verbose way, too:
INDArray sinx = Nd4j.getExecutioner().exec(new Sin(random.dup()));
System.out.println("Element-wise sin(x) operation:\n" + sinx);
}
}