Skip to content

Commit

Permalink
Merge pull request #42 from gfpeltier/master
Browse files Browse the repository at this point in the history
Fix bit shift methods to respect byte order
  • Loading branch information
patrickfav committed Feb 27, 2021
2 parents 5e2bdb8 + e997250 commit d737c25
Show file tree
Hide file tree
Showing 5 changed files with 120 additions and 55 deletions.
12 changes: 10 additions & 2 deletions src/main/java/at/favre/lib/bytes/Bytes.java
Expand Up @@ -967,7 +967,11 @@ public Bytes not() {
* @see <a href="https://en.wikipedia.org/wiki/Bitwise_operation#Bit_shifts">Bit shifts</a>
*/
public Bytes leftShift(int shiftCount) {
return transform(new BytesTransformer.ShiftTransformer(shiftCount, BytesTransformer.ShiftTransformer.Type.LEFT_SHIFT));
return transform(new BytesTransformer.ShiftTransformer(
shiftCount,
BytesTransformer.ShiftTransformer.Type.LEFT_SHIFT,
byteOrder
));
}

/**
Expand All @@ -982,7 +986,11 @@ public Bytes leftShift(int shiftCount) {
* @see <a href="https://en.wikipedia.org/wiki/Bitwise_operation#Bit_shifts">Bit shifts</a>
*/
public Bytes rightShift(int shiftCount) {
return transform(new BytesTransformer.ShiftTransformer(shiftCount, BytesTransformer.ShiftTransformer.Type.RIGHT_SHIFT));
return transform(new BytesTransformer.ShiftTransformer(
shiftCount,
BytesTransformer.ShiftTransformer.Type.RIGHT_SHIFT,
byteOrder
));
}

/**
Expand Down
9 changes: 6 additions & 3 deletions src/main/java/at/favre/lib/bytes/BytesTransformer.java
Expand Up @@ -21,6 +21,7 @@

package at.favre.lib.bytes;

import java.nio.ByteOrder;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.util.Objects;
Expand Down Expand Up @@ -131,10 +132,12 @@ public enum Type {

private final int shiftCount;
private final Type type;
private final ByteOrder byteOrder;

ShiftTransformer(int shiftCount, Type type) {
ShiftTransformer(int shiftCount, Type type, ByteOrder byteOrder) {
this.shiftCount = shiftCount;
this.type = Objects.requireNonNull(type, "passed shift type must not be null");
this.byteOrder = Objects.requireNonNull(byteOrder, "passed byteOrder type must not be null");
}

@Override
Expand All @@ -143,10 +146,10 @@ public byte[] transform(byte[] currentArray, boolean inPlace) {

switch (type) {
case RIGHT_SHIFT:
return Util.Byte.shiftRight(out, shiftCount);
return Util.Byte.shiftRight(out, shiftCount, byteOrder);
default:
case LEFT_SHIFT:
return Util.Byte.shiftLeft(out, shiftCount);
return Util.Byte.shiftLeft(out, shiftCount, byteOrder);
}
}

Expand Down
78 changes: 56 additions & 22 deletions src/main/java/at/favre/lib/bytes/Util.java
Expand Up @@ -291,25 +291,42 @@ static void reverse(byte[] array, int fromIndex, int toIndex) {
*
* @param byteArray to shift
* @param shiftBitCount how many bits to shift
* @param byteOrder endianness of given byte array
* @return shifted byte array
*/
static byte[] shiftLeft(byte[] byteArray, int shiftBitCount) {
static byte[] shiftLeft(byte[] byteArray, int shiftBitCount, ByteOrder byteOrder) {
final int shiftMod = shiftBitCount % 8;
final byte carryMask = (byte) ((1 << shiftMod) - 1);
final int offsetBytes = (shiftBitCount / 8);

int sourceIndex;
for (int i = 0; i < byteArray.length; i++) {
sourceIndex = i + offsetBytes;
if (sourceIndex >= byteArray.length) {
byteArray[i] = 0;
} else {
byte src = byteArray[sourceIndex];
byte dst = (byte) (src << shiftMod);
if (sourceIndex + 1 < byteArray.length) {
dst |= byteArray[sourceIndex + 1] >>> (8 - shiftMod) & carryMask;
if (byteOrder == ByteOrder.BIG_ENDIAN) {
for (int i = 0; i < byteArray.length; i++) {
sourceIndex = i + offsetBytes;
if (sourceIndex >= byteArray.length) {
byteArray[i] = 0;
} else {
byte src = byteArray[sourceIndex];
byte dst = (byte) (src << shiftMod);
if (sourceIndex + 1 < byteArray.length) {
dst |= byteArray[sourceIndex + 1] >>> (8 - shiftMod) & carryMask;
}
byteArray[i] = dst;
}
}
} else {
for (int i = byteArray.length - 1; i >= 0; i--) {
sourceIndex = i - offsetBytes;
if (sourceIndex < 0) {
byteArray[i] = 0;
} else {
byte src = byteArray[sourceIndex];
byte dst = (byte) (src << shiftMod);
if (sourceIndex - 1 >= 0) {
dst |= byteArray[sourceIndex - 1] >>> (8 - shiftMod) & carryMask;
}
byteArray[i] = dst;
}
byteArray[i] = dst;
}
}
return byteArray;
Expand All @@ -330,25 +347,42 @@ static byte[] shiftLeft(byte[] byteArray, int shiftBitCount) {
*
* @param byteArray to shift
* @param shiftBitCount how many bits to shift
* @param byteOrder endianness of given byte array
* @return shifted byte array
*/
static byte[] shiftRight(byte[] byteArray, int shiftBitCount) {
static byte[] shiftRight(byte[] byteArray, int shiftBitCount, ByteOrder byteOrder) {
final int shiftMod = shiftBitCount % 8;
final byte carryMask = (byte) (0xFF << (8 - shiftMod));
final int offsetBytes = (shiftBitCount / 8);

int sourceIndex;
for (int i = byteArray.length - 1; i >= 0; i--) {
sourceIndex = i - offsetBytes;
if (sourceIndex < 0) {
byteArray[i] = 0;
} else {
byte src = byteArray[sourceIndex];
byte dst = (byte) ((0xff & src) >>> shiftMod);
if (sourceIndex - 1 >= 0) {
dst |= byteArray[sourceIndex - 1] << (8 - shiftMod) & carryMask;
if (byteOrder == ByteOrder.BIG_ENDIAN) {
for (int i = byteArray.length - 1; i >= 0; i--) {
sourceIndex = i - offsetBytes;
if (sourceIndex < 0) {
byteArray[i] = 0;
} else {
byte src = byteArray[sourceIndex];
byte dst = (byte) ((0xff & src) >>> shiftMod);
if (sourceIndex - 1 >= 0) {
dst |= byteArray[sourceIndex - 1] << (8 - shiftMod) & carryMask;
}
byteArray[i] = dst;
}
}
} else {
for (int i = 0; i < byteArray.length; i++) {
sourceIndex = i + offsetBytes;
if (sourceIndex >= byteArray.length) {
byteArray[i] = 0;
} else {
byte src = byteArray[sourceIndex];
byte dst = (byte) ((0xff & src) >>> shiftMod);
if (sourceIndex + 1 < byteArray.length) {
dst |= byteArray[sourceIndex + 1] << (8 - shiftMod) & carryMask;
}
byteArray[i] = dst;
}
byteArray[i] = dst;
}
}
return byteArray;
Expand Down
3 changes: 2 additions & 1 deletion src/test/java/at/favre/lib/bytes/BytesTransformTest.java
Expand Up @@ -25,6 +25,7 @@

import java.math.BigInteger;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.charset.StandardCharsets;
import java.security.SecureRandom;
import java.util.Comparator;
Expand Down Expand Up @@ -446,7 +447,7 @@ public void transformerInPlaceTest() {
assertTrue(new BytesTransformer.BitSwitchTransformer(0, true).supportInPlaceTransformation());
assertTrue(new BytesTransformer.BitWiseOperatorTransformer(new byte[]{}, BytesTransformer.BitWiseOperatorTransformer.Mode.XOR).supportInPlaceTransformation());
assertTrue(new BytesTransformer.NegateTransformer().supportInPlaceTransformation());
assertTrue(new BytesTransformer.ShiftTransformer(0, BytesTransformer.ShiftTransformer.Type.LEFT_SHIFT).supportInPlaceTransformation());
assertTrue(new BytesTransformer.ShiftTransformer(0, BytesTransformer.ShiftTransformer.Type.LEFT_SHIFT, ByteOrder.BIG_ENDIAN).supportInPlaceTransformation());
assertTrue(new BytesTransformer.ReverseTransformer().supportInPlaceTransformation());

assertFalse(new BytesTransformer.MessageDigestTransformer("SHA1").supportInPlaceTransformation());
Expand Down
73 changes: 46 additions & 27 deletions src/test/java/at/favre/lib/bytes/UtilByteTest.java
Expand Up @@ -25,6 +25,7 @@
import org.junit.Test;

import java.math.BigInteger;
import java.nio.ByteOrder;
import java.util.Arrays;
import java.util.Random;

Expand Down Expand Up @@ -162,21 +163,34 @@ private static void testReverse(byte[] input, int fromIndex, int toIndex, byte[]
@Test
public void testLeftShift() {
byte[] test = new byte[]{0, 0, 1, 0};
assertArrayEquals(new byte[]{0, 1, 0, 0}, Util.Byte.shiftLeft(new byte[]{0, 0, -128, 0}, 1));
assertArrayEquals(new byte[]{0, 1, 0, 0}, Util.Byte.shiftLeft(new byte[]{0, 0, 64, 0}, 2));
assertArrayEquals(new byte[]{1, 1, 1, 0}, Util.Byte.shiftLeft(new byte[]{-128, -128, -128, -128}, 1));
assertArrayEquals(new byte[]{0, 0, 2, 0}, Util.Byte.shiftLeft(Bytes.from(test).array(), 1));
assertArrayEquals(new byte[]{0, 0, 4, 0}, Util.Byte.shiftLeft(Bytes.from(test).array(), 2));
assertArrayEquals(new byte[]{0, 0, 8, 0}, Util.Byte.shiftLeft(Bytes.from(test).array(), 3));
assertArrayEquals(new byte[]{0, 1, 0, 0}, Util.Byte.shiftLeft(Bytes.from(test).array(), 8));
assertArrayEquals(new byte[]{0, 2, 0, 0}, Util.Byte.shiftLeft(Bytes.from(test).array(), 9));
assertArrayEquals(new byte[]{1, 0, 0, 0}, Util.Byte.shiftLeft(Bytes.from(test).array(), 16));
assertArrayEquals(new byte[]{2, 0, 0, 0}, Util.Byte.shiftLeft(Bytes.from(test).array(), 17));
assertArrayEquals(new byte[]{-128, 0, 0, 0}, Util.Byte.shiftLeft(Bytes.from(test).array(), 23));
assertArrayEquals(new byte[]{0, 0, 0, 0}, Util.Byte.shiftLeft(Bytes.from(test).array(), 24));
assertArrayEquals(new byte[]{0, 0, 0, 0}, Util.Byte.shiftLeft(Bytes.from(test).array(), 24));

assertSame(test, Util.Byte.shiftLeft(test, 1));
assertArrayEquals(new byte[]{0, 1, 0, 0}, Util.Byte.shiftLeft(new byte[]{0, 0, -128, 0}, 1, ByteOrder.BIG_ENDIAN));
assertArrayEquals(new byte[]{0, 1, 0, 0}, Util.Byte.shiftLeft(new byte[]{0, 0, 64, 0}, 2, ByteOrder.BIG_ENDIAN));
assertArrayEquals(new byte[]{1, 1, 1, 0}, Util.Byte.shiftLeft(new byte[]{-128, -128, -128, -128}, 1, ByteOrder.BIG_ENDIAN));
assertArrayEquals(new byte[]{0, 0, 2, 0}, Util.Byte.shiftLeft(Bytes.from(test).array(), 1, ByteOrder.BIG_ENDIAN));
assertArrayEquals(new byte[]{0, 0, 4, 0}, Util.Byte.shiftLeft(Bytes.from(test).array(), 2, ByteOrder.BIG_ENDIAN));
assertArrayEquals(new byte[]{0, 0, 8, 0}, Util.Byte.shiftLeft(Bytes.from(test).array(), 3, ByteOrder.BIG_ENDIAN));
assertArrayEquals(new byte[]{0, 1, 0, 0}, Util.Byte.shiftLeft(Bytes.from(test).array(), 8, ByteOrder.BIG_ENDIAN));
assertArrayEquals(new byte[]{0, 2, 0, 0}, Util.Byte.shiftLeft(Bytes.from(test).array(), 9, ByteOrder.BIG_ENDIAN));
assertArrayEquals(new byte[]{1, 0, 0, 0}, Util.Byte.shiftLeft(Bytes.from(test).array(), 16, ByteOrder.BIG_ENDIAN));
assertArrayEquals(new byte[]{2, 0, 0, 0}, Util.Byte.shiftLeft(Bytes.from(test).array(), 17, ByteOrder.BIG_ENDIAN));
assertArrayEquals(new byte[]{-128, 0, 0, 0}, Util.Byte.shiftLeft(Bytes.from(test).array(), 23, ByteOrder.BIG_ENDIAN));
assertArrayEquals(new byte[]{0, 0, 0, 0}, Util.Byte.shiftLeft(Bytes.from(test).array(), 24, ByteOrder.BIG_ENDIAN));
assertArrayEquals(new byte[]{0, 0, 0, 0}, Util.Byte.shiftLeft(Bytes.from(test).array(), 24, ByteOrder.BIG_ENDIAN));

assertSame(test, Util.Byte.shiftLeft(test, 1, ByteOrder.BIG_ENDIAN));

assertArrayEquals(new byte[]{0, 0, 1, 0}, Util.Byte.shiftLeft(new byte[]{0, 1, 0, 0}, 8, ByteOrder.LITTLE_ENDIAN));
assertArrayEquals(new byte[]{(byte) 0, 0, 1, 0}, Util.Byte.shiftLeft(new byte[]{(byte) 0, 1, 0, 1}, 8, ByteOrder.LITTLE_ENDIAN));
assertArrayEquals(new byte[]{(byte) 0x54, 0x1}, Util.Byte.shiftLeft(new byte[]{(byte) 0xAA, 0}, 1, ByteOrder.LITTLE_ENDIAN));
assertArrayEquals(new byte[]{(byte) 0xA8, 0x2}, Util.Byte.shiftLeft(new byte[]{(byte) 0xAA, 0}, 2, ByteOrder.LITTLE_ENDIAN));
assertArrayEquals(new byte[]{(byte) 0x50, 0x5}, Util.Byte.shiftLeft(new byte[]{(byte) 0xAA, 0}, 3, ByteOrder.LITTLE_ENDIAN));
assertArrayEquals(new byte[]{(byte) 0xA0, 0xA}, Util.Byte.shiftLeft(new byte[]{(byte) 0xAA, 0}, 4, ByteOrder.LITTLE_ENDIAN));
assertArrayEquals(new byte[]{(byte) 0x40, 0x15}, Util.Byte.shiftLeft(new byte[]{(byte) 0xAA, 0}, 5, ByteOrder.LITTLE_ENDIAN));
assertArrayEquals(new byte[]{(byte) 0x80, 0x2A}, Util.Byte.shiftLeft(new byte[]{(byte) 0xAA, 0}, 6, ByteOrder.LITTLE_ENDIAN));
assertArrayEquals(new byte[]{(byte) 0x00, 0x55}, Util.Byte.shiftLeft(new byte[]{(byte) 0xAA, 0}, 7, ByteOrder.LITTLE_ENDIAN));
assertArrayEquals(new byte[]{(byte) 0, (byte) 0xAA}, Util.Byte.shiftLeft(new byte[]{(byte) 0xAA, 0}, 8, ByteOrder.LITTLE_ENDIAN));
assertArrayEquals(new byte[]{0, 0, 0, 0x40}, Util.Byte.shiftLeft(new byte[]{1, 0, 0, 0}, 30, ByteOrder.LITTLE_ENDIAN));

}

@Test
Expand All @@ -187,7 +201,7 @@ public void testLeftShiftAgainstRefImpl() {
Bytes rnd = Bytes.random(4 + new Random().nextInt(14));

byte[] expected = Bytes.from(new BigInteger(rnd.array()).shiftLeft(shift).toByteArray()).resize(rnd.length(), BytesTransformer.ResizeTransformer.Mode.RESIZE_KEEP_FROM_MAX_LENGTH).array();
byte[] actual = Bytes.from(Util.Byte.shiftLeft(rnd.copy().array(), shift)).resize(rnd.length(), BytesTransformer.ResizeTransformer.Mode.RESIZE_KEEP_FROM_MAX_LENGTH).array();
byte[] actual = Bytes.from(Util.Byte.shiftLeft(rnd.copy().array(), shift, ByteOrder.BIG_ENDIAN)).resize(rnd.length(), BytesTransformer.ResizeTransformer.Mode.RESIZE_KEEP_FROM_MAX_LENGTH).array();

System.out.println("Original \t" + rnd.encodeBinary() + " << " + shift);
System.out.println("Expected \t" + Bytes.wrap(expected).encodeBinary());
Expand All @@ -202,18 +216,23 @@ public void testLeftShiftAgainstRefImpl() {
public void testRightShift() {
byte[] test = new byte[]{0, 0, 16, 0};
assertEquals(0b01111110, 0b11111101 >>> 1);
assertArrayEquals(new byte[]{0b00101101, (byte) 0b01111110}, Util.Byte.shiftRight(new byte[]{0b01011010, (byte) 0b11111101}, 1));
assertArrayEquals(new byte[]{0, -128, -128, -128}, Util.Byte.shiftRight(new byte[]{1, 1, 1, 1}, 1));
assertArrayEquals(new byte[]{0, -128, 66, 0}, Util.Byte.shiftRight(new byte[]{2, 1, 8, 2}, 2));
assertArrayEquals(new byte[]{0b00101101, (byte) 0b01111110}, Util.Byte.shiftRight(new byte[]{0b01011010, (byte) 0b11111101}, 1, ByteOrder.BIG_ENDIAN));
assertArrayEquals(new byte[]{0, -128, -128, -128}, Util.Byte.shiftRight(new byte[]{1, 1, 1, 1}, 1, ByteOrder.BIG_ENDIAN));
assertArrayEquals(new byte[]{0, -128, 66, 0}, Util.Byte.shiftRight(new byte[]{2, 1, 8, 2}, 2, ByteOrder.BIG_ENDIAN));
assertArrayEquals(new byte[]{0, -128, 66, 0}, new BigInteger(new byte[]{2, 1, 8, 2}).shiftRight(2).toByteArray());
assertArrayEquals(new byte[]{0, 0, 0, -128}, Util.Byte.shiftRight(Bytes.from(test).array(), 5));
assertArrayEquals(new byte[]{0, 0, 0, -128}, Util.Byte.shiftRight(new byte[]{0, 0, 1, 0}, 1));
assertArrayEquals(new byte[]{0, 0, 8, 0}, Util.Byte.shiftRight(Bytes.from(test).array(), 1));
assertArrayEquals(new byte[]{0, 0, 4, 0}, Util.Byte.shiftRight(Bytes.from(test).array(), 2));
assertArrayEquals(new byte[]{0, 0, 2, 0}, Util.Byte.shiftRight(Bytes.from(test).array(), 3));
assertArrayEquals(new byte[]{0, 0, 1, 0}, Util.Byte.shiftRight(Bytes.from(test).array(), 4));
assertArrayEquals(new byte[]{0, 0, 0, -128}, Util.Byte.shiftRight(Bytes.from(test).array(), 5, ByteOrder.BIG_ENDIAN));
assertArrayEquals(new byte[]{0, 0, 0, -128}, Util.Byte.shiftRight(new byte[]{0, 0, 1, 0}, 1, ByteOrder.BIG_ENDIAN));
assertArrayEquals(new byte[]{0, 0, 8, 0}, Util.Byte.shiftRight(Bytes.from(test).array(), 1, ByteOrder.BIG_ENDIAN));
assertArrayEquals(new byte[]{0, 0, 4, 0}, Util.Byte.shiftRight(Bytes.from(test).array(), 2, ByteOrder.BIG_ENDIAN));
assertArrayEquals(new byte[]{0, 0, 2, 0}, Util.Byte.shiftRight(Bytes.from(test).array(), 3, ByteOrder.BIG_ENDIAN));
assertArrayEquals(new byte[]{0, 0, 1, 0}, Util.Byte.shiftRight(Bytes.from(test).array(), 4, ByteOrder.BIG_ENDIAN));

assertSame(test, Util.Byte.shiftRight(test, 1, ByteOrder.BIG_ENDIAN));

assertSame(test, Util.Byte.shiftRight(test, 1));
assertArrayEquals(new byte[]{0, 0}, Util.Byte.shiftRight(new byte[]{1, 0}, 1, ByteOrder.LITTLE_ENDIAN));
assertArrayEquals(new byte[]{(byte) 0x80, 0}, Util.Byte.shiftRight(new byte[]{0, 0x02}, 2, ByteOrder.LITTLE_ENDIAN));
assertArrayEquals(new byte[]{(byte) 0x80, 0, 0, 0}, Util.Byte.shiftRight(new byte[]{0, 0, 0, 1}, 17, ByteOrder.LITTLE_ENDIAN));
assertArrayEquals(new byte[]{(byte) 1, 0, 0, 0}, Util.Byte.shiftRight(new byte[]{0, 0, 0, (byte) 0x80}, 31, ByteOrder.LITTLE_ENDIAN));
}

@Test
Expand All @@ -223,7 +242,7 @@ public void testRightShiftAgainstRefImpl() {
Bytes rnd = Bytes.random(4 + new Random().nextInt(12));
if (!rnd.bitAt(rnd.lengthBit() - 1)) { //only unsigned
byte[] expected = Bytes.from(new BigInteger(rnd.array()).shiftRight(shift).toByteArray()).resize(rnd.length(), BytesTransformer.ResizeTransformer.Mode.RESIZE_KEEP_FROM_MAX_LENGTH).array();
byte[] actual = Bytes.from(Util.Byte.shiftRight(rnd.copy().array(), shift)).resize(rnd.length(), BytesTransformer.ResizeTransformer.Mode.RESIZE_KEEP_FROM_MAX_LENGTH).array();
byte[] actual = Bytes.from(Util.Byte.shiftRight(rnd.copy().array(), shift, ByteOrder.BIG_ENDIAN)).resize(rnd.length(), BytesTransformer.ResizeTransformer.Mode.RESIZE_KEEP_FROM_MAX_LENGTH).array();

// System.out.println("Original \t" + rnd.encodeBinary() + " >> " + shift);
// System.out.println("Expected \t" + Bytes.wrap(expected).encodeBinary());
Expand Down

0 comments on commit d737c25

Please sign in to comment.