Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,82 @@ public static Tensor fromBlob(DoubleBuffer data, long[] shape) {
return new Tensor_float64(data, shape);
}

/**
*
* Creates a new Tensor instance with given data-type and all elements initialized to one.
*
* @param shape Tensor shape
* @param dtype Tensor data-type
*/
public static Tensor ones(long[] shape, DType dtype) {
checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL);
checkShape(shape);
int numElements = (int) numel(shape);
switch (dtype) {
case UINT8:
byte[] uInt8Data = new byte[numElements];
Arrays.fill(uInt8Data, (byte) 1);
return Tensor.fromBlobUnsigned(uInt8Data, shape);
case INT8:
byte[] int8Data = new byte[numElements];
Arrays.fill(int8Data, (byte) 1);
return Tensor.fromBlob(int8Data, shape);
case INT32:
int[] int32Data = new int[numElements];
Arrays.fill(int32Data, 1);
return Tensor.fromBlob(int32Data, shape);
case FLOAT:
float[] float32Data = new float[numElements];
Arrays.fill(float32Data, 1.0f);
return Tensor.fromBlob(float32Data, shape);
case INT64:
long[] int64Data = new long[numElements];
Arrays.fill(int64Data, 1L);
return Tensor.fromBlob(int64Data, shape);
case DOUBLE:
double[] float64Data = new double[numElements];
Arrays.fill(float64Data, 1.0);
return Tensor.fromBlob(float64Data, shape);
default:
throw new IllegalArgumentException(String.format("Tensor.ones() cannot be used with DType %s", dtype));
}
}

/**
*
* Creates a new Tensor instance with given data-type and all elements initialized to zero.
*
* @param shape Tensor shape
* @param dtype Tensor data-type
*/
public static Tensor zeros(long[] shape, DType dtype) {
checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL);
checkShape(shape);
int numElements = (int) numel(shape);
switch (dtype) {
case UINT8:
byte[] uInt8Data = new byte[numElements];
return Tensor.fromBlobUnsigned(uInt8Data, shape);
case INT8:
byte[] int8Data = new byte[numElements];
return Tensor.fromBlob(int8Data, shape);
case INT32:
int[] int32Data = new int[numElements];
return Tensor.fromBlob(int32Data, shape);
case FLOAT:
float[] float32Data = new float[numElements];
return Tensor.fromBlob(float32Data, shape);
case INT64:
long[] int64Data = new long[numElements];
return Tensor.fromBlob(int64Data, shape);
case DOUBLE:
double[] float64Data = new double[numElements];
return Tensor.fromBlob(float64Data, shape);
default:
throw new IllegalArgumentException(String.format("Tensor.zeros() cannot be used with DType %s", dtype));
}
}

@DoNotStrip private HybridData mHybridData;

private Tensor(long[] shape) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -274,4 +274,32 @@ class TensorTest {
assertEquals(shape[i], deserShape[i])
}
}

@Test
fun testOnes_DTypeIsFloat() {
val shape = longArrayOf(2, 2)
val tensor = Tensor.ones(shape, DType.FLOAT)
val data = tensor.dataAsFloatArray
assertEquals(DType.FLOAT, tensor.dtype())
for (i in shape.indices) {
assertEquals(shape[i], tensor.shape[i])
}
for (i in data.indices) {
assertEquals(data[i], 1.0f, 1e-5.toFloat())
}
}

@Test
fun testZeros_DTypeIsFloat() {
val shape = longArrayOf(2, 2)
val tensor = Tensor.zeros(shape, DType.FLOAT)
val data = tensor.dataAsFloatArray
assertEquals(DType.FLOAT, tensor.dtype())
for (i in shape.indices) {
assertEquals(shape[i], tensor.shape[i])
}
for (i in data.indices) {
assertEquals(data[i], 0.0f, 1e-5.toFloat())
}
}
}
Loading