diff --git a/android/sdl_android/src/androidTest/java/com/smartdevicelink/test/transport/SdlPsmTests.java b/android/sdl_android/src/androidTest/java/com/smartdevicelink/test/transport/SdlPsmTests.java index 64b0cb7685..78db56c32b 100644 --- a/android/sdl_android/src/androidTest/java/com/smartdevicelink/test/transport/SdlPsmTests.java +++ b/android/sdl_android/src/androidTest/java/com/smartdevicelink/test/transport/SdlPsmTests.java @@ -3,12 +3,17 @@ import android.util.Log; import com.smartdevicelink.protocol.SdlPacket; +import com.smartdevicelink.protocol.SdlPacketFactory; import com.smartdevicelink.protocol.SdlProtocol; +import com.smartdevicelink.protocol.enums.ControlFrameTags; +import com.smartdevicelink.protocol.enums.SessionType; import com.smartdevicelink.test.TestValues; import com.smartdevicelink.transport.SdlPsm; import junit.framework.TestCase; +import org.junit.Assert; + import java.lang.reflect.Field; import java.lang.reflect.Method; @@ -18,12 +23,15 @@ */ public class SdlPsmTests extends TestCase { private static final String TAG = "SdlPsmTests"; - private static final int MAX_DATA_LENGTH = SdlProtocol.V1_V2_MTU_SIZE - SdlProtocol.V1_HEADER_SIZE; + private static final int MAX_DATA_LENGTH_V1 = SdlProtocol.V1_V2_MTU_SIZE - SdlProtocol.V1_HEADER_SIZE; + private static final int MAX_DATA_LENGTH_V2 = SdlProtocol.V1_V2_MTU_SIZE - SdlProtocol.V2_HEADER_SIZE; SdlPsm sdlPsm; Field frameType, dataLength, version, controlFrameInfo; Method transitionOnInput; byte rawByte = (byte) 0x0; + SdlPacket startServiceACK; + protected void setUp() throws Exception { super.setUp(); sdlPsm = new SdlPsm(); @@ -38,9 +46,44 @@ protected void setUp() throws Exception { dataLength.setAccessible(true); version.setAccessible(true); controlFrameInfo.setAccessible(true); + + startServiceACK = SdlPacketFactory.createStartSessionACK(SessionType.RPC, (byte) 0x01, (byte) 0x05, (byte) 0x05); + startServiceACK.putTag(ControlFrameTags.RPC.StartServiceACK.HASH_ID, "3bb34978fe3a"); + startServiceACK.putTag(ControlFrameTags.RPC.StartServiceACK.MTU, "150000"); + startServiceACK.putTag(ControlFrameTags.RPC.StartServiceACK.PROTOCOL_VERSION, "5.1.0"); } + public void testHappyPath() { + + + byte[] packetBytes = startServiceACK.constructPacket(); + + SdlPsm sdlPsm = new SdlPsm(); + boolean didTransition = false; + + for (byte packetByte : packetBytes) { + didTransition = sdlPsm.handleByte(packetByte); + assertTrue(didTransition); + } + + assertEquals(SdlPsm.FINISHED_STATE, sdlPsm.getState()); + SdlPacket parsedPacket = sdlPsm.getFormedPacket(); + assertNotNull(parsedPacket); + + assertEquals(startServiceACK.getVersion(), parsedPacket.getVersion()); + assertEquals(startServiceACK.getServiceType(), parsedPacket.getServiceType()); + assertEquals(startServiceACK.getFrameInfo(), parsedPacket.getFrameInfo()); + assertEquals(startServiceACK.getFrameType(), parsedPacket.getFrameType()); + assertEquals(startServiceACK.getDataSize(), parsedPacket.getDataSize()); + assertEquals(startServiceACK.getMessageId(), parsedPacket.getMessageId()); + + assertTrue(startServiceACK.getTag(ControlFrameTags.RPC.StartServiceACK.HASH_ID).equals(parsedPacket.getTag(ControlFrameTags.RPC.StartServiceACK.HASH_ID))); + assertTrue(startServiceACK.getTag(ControlFrameTags.RPC.StartServiceACK.MTU).equals(parsedPacket.getTag(ControlFrameTags.RPC.StartServiceACK.MTU))); + assertTrue(startServiceACK.getTag(ControlFrameTags.RPC.StartServiceACK.PROTOCOL_VERSION).equals(parsedPacket.getTag(ControlFrameTags.RPC.StartServiceACK.PROTOCOL_VERSION))); + + } + public void testGarbledControlFrame() { try { rawByte = 0x0; @@ -48,27 +91,46 @@ public void testGarbledControlFrame() { controlFrameInfo.set(sdlPsm, SdlPacket.FRAME_INFO_START_SERVICE); frameType.set(sdlPsm, SdlPacket.FRAME_TYPE_CONTROL); - dataLength.set(sdlPsm, MAX_DATA_LENGTH + 1); + dataLength.set(sdlPsm, MAX_DATA_LENGTH_V1 + 1); int STATE = (Integer) transitionOnInput.invoke(sdlPsm, rawByte, SdlPsm.DATA_SIZE_4_STATE); assertEquals(TestValues.MATCH, SdlPsm.ERROR_STATE, STATE); } catch (Exception e) { + Assert.fail(e.toString()); Log.e(TAG, e.toString()); } } - public void testMaximumControlFrame() { + public void testMaximumControlFrameForVersion1() { try { rawByte = 0x0; version.set(sdlPsm, 1); controlFrameInfo.set(sdlPsm, SdlPacket.FRAME_INFO_START_SERVICE); frameType.set(sdlPsm, SdlPacket.FRAME_TYPE_CONTROL); - dataLength.set(sdlPsm, MAX_DATA_LENGTH); + dataLength.set(sdlPsm, MAX_DATA_LENGTH_V1); int STATE = (Integer) transitionOnInput.invoke(sdlPsm, rawByte, SdlPsm.DATA_SIZE_4_STATE); assertEquals(TestValues.MATCH, SdlPsm.DATA_PUMP_STATE, STATE); } catch (Exception e) { + Assert.fail(e.toString()); + Log.e(TAG, e.toString()); + } + } + + public void testMaximumControlFrameForVersion2Plus() { + try { + rawByte = 0x0; + version.set(sdlPsm, 2); + controlFrameInfo.set(sdlPsm, SdlPacket.FRAME_INFO_START_SERVICE); + frameType.set(sdlPsm, SdlPacket.FRAME_TYPE_CONTROL); + + dataLength.set(sdlPsm, MAX_DATA_LENGTH_V2); + int STATE = (Integer) transitionOnInput.invoke(sdlPsm, rawByte, SdlPsm.DATA_SIZE_4_STATE); + + assertEquals(TestValues.MATCH, SdlPsm.MESSAGE_1_STATE, STATE); + } catch (Exception e) { + Assert.fail(e.toString()); Log.e(TAG, e.toString()); } } @@ -80,14 +142,117 @@ public void testOutOfMemoryDS4() { frameType.set(sdlPsm, SdlPacket.FRAME_TYPE_SINGLE); dataLength.set(sdlPsm, 2147483647); - int STATE = (Integer) transitionOnInput.invoke(sdlPsm, rawByte, SdlPsm.DATA_SIZE_4_STATE); + int STATE = (Integer) transitionOnInput.invoke(sdlPsm, rawByte, SdlPsm.MESSAGE_4_STATE); assertEquals(TestValues.MATCH, SdlPsm.ERROR_STATE, STATE); } catch (Exception e) { + Assert.fail(e.toString()); Log.e(TAG, e.toString()); } } + public void testNegativeDataSize() { + byte[] packetBytes = startServiceACK.constructPacket(); + + SdlPsm sdlPsm = new SdlPsm(); + boolean didTransition = false; + + for (byte packetByte : packetBytes) { + int state = sdlPsm.getState(); + switch (state) { + case SdlPsm.MESSAGE_4_STATE: + didTransition = sdlPsm.handleByte(packetByte); + assertFalse(didTransition); + assertEquals(SdlPsm.ERROR_STATE, sdlPsm.getState()); + return; + case SdlPsm.DATA_SIZE_1_STATE: + case SdlPsm.DATA_SIZE_2_STATE: + case SdlPsm.DATA_SIZE_3_STATE: + case SdlPsm.DATA_SIZE_4_STATE: + didTransition = sdlPsm.handleByte((byte) 0xFF); + assertTrue(didTransition); + break; + default: + didTransition = sdlPsm.handleByte(packetByte); + assertTrue(didTransition); + } + } + } + + public void testIncorrectVersion() { + SdlPacket startServiceACK = SdlPacketFactory.createStartSessionACK(SessionType.RPC, (byte) 0x01, (byte) 0x05, (byte) 0x06); + startServiceACK.putTag(ControlFrameTags.RPC.StartServiceACK.HASH_ID, "3bb34978fe3a"); + startServiceACK.putTag(ControlFrameTags.RPC.StartServiceACK.MTU, "150000"); + startServiceACK.putTag(ControlFrameTags.RPC.StartServiceACK.PROTOCOL_VERSION, "5.1.0"); + byte[] packetBytes = startServiceACK.constructPacket(); + + SdlPsm sdlPsm = new SdlPsm(); + boolean didTransition = sdlPsm.handleByte(packetBytes[0]); + assertFalse(didTransition); + } + + public void testIncorrectService() { + + byte[] packetBytes = startServiceACK.constructPacket(); + + SdlPsm sdlPsm = new SdlPsm(); + boolean didTransition = false; + + for (byte packetByte : packetBytes) { + int state = sdlPsm.getState(); + switch (state) { + case SdlPsm.SERVICE_TYPE_STATE: + didTransition = sdlPsm.handleByte((byte) 0xFF); + assertFalse(didTransition); + assertEquals(SdlPsm.ERROR_STATE, sdlPsm.getState()); + return; + default: + didTransition = sdlPsm.handleByte(packetByte); + assertTrue(didTransition); + } + } + } + + public void testRecovery() { + byte[] packetBytes = startServiceACK.constructPacket(); + byte[] processingBytes = new byte[packetBytes.length + 15]; + + System.arraycopy(packetBytes, 10, processingBytes, 0, 15); + System.arraycopy(packetBytes, 0, processingBytes, 15, packetBytes.length); + + + SdlPsm sdlPsm = new SdlPsm(); + boolean didTransition = false; + byte packetByte; + int state = SdlPsm.START_STATE, i = 0, limit = 0; + + while (state != SdlPsm.FINISHED_STATE && limit < 10) { + + packetByte = processingBytes[i]; + didTransition = sdlPsm.handleByte(packetByte); + state = sdlPsm.getState(); + if (!didTransition) { + assertEquals(SdlPsm.ERROR_STATE, state); + sdlPsm.reset(); + } else if (state == SdlPsm.FINISHED_STATE) { + break; + } + + if (i == processingBytes.length - 1) { + i = 0; + limit++; + } else { + i++; + } + } + + assertEquals(SdlPsm.FINISHED_STATE, sdlPsm.getState()); + SdlPacket parsedPacket = sdlPsm.getFormedPacket(); + assertNotNull(parsedPacket); + + } + + protected void tearDown() throws Exception { super.tearDown(); } diff --git a/base/src/main/java/com/smartdevicelink/transport/SdlPsm.java b/base/src/main/java/com/smartdevicelink/transport/SdlPsm.java index 99510b0619..dc8fd3bbc6 100644 --- a/base/src/main/java/com/smartdevicelink/transport/SdlPsm.java +++ b/base/src/main/java/com/smartdevicelink/transport/SdlPsm.java @@ -32,15 +32,18 @@ package com.smartdevicelink.transport; import com.smartdevicelink.protocol.SdlPacket; +import com.smartdevicelink.util.DebugTool; import static com.smartdevicelink.protocol.SdlProtocol.V1_HEADER_SIZE; import static com.smartdevicelink.protocol.SdlProtocol.V1_V2_MTU_SIZE; +import static com.smartdevicelink.protocol.SdlProtocol.V2_HEADER_SIZE; +import static com.smartdevicelink.protocol.SdlProtocol.V3_V4_MTU_SIZE; public class SdlPsm { - //private static final String TAG = "Sdl PSM"; - //Each state represents the byte that should be incoming + private static final String TAG = "Sdl PSM"; + //Each state represents the byte that should be incoming public static final int START_STATE = 0x0; public static final int SERVICE_TYPE_STATE = 0x02; public static final int CONTROL_FRAME_INFO_STATE = 0x03; @@ -83,8 +86,13 @@ public SdlPsm() { } public boolean handleByte(byte data) { - //Log.trace(TAG, data + " = incoming"); - state = transitionOnInput(data, state); + try { + state = transitionOnInput(data, state); + } catch (Exception e) { + DebugTool.logError(TAG, "Exception thrown while parsing byte - " + data, e); + state = ERROR_STATE; + return false; + } return state != ERROR_STATE; } @@ -93,18 +101,11 @@ private int transitionOnInput(byte rawByte, int state) { switch (state) { case START_STATE: version = (rawByte & (byte) VERSION_MASK) >> 4; - //Log.trace(TAG, "Version: " + version); - if (version == 0) { //It should never be 0 - return ERROR_STATE; - } encrypted = (1 == ((rawByte & (byte) ENCRYPTION_MASK) >> 3)); - - frameType = rawByte & (byte) FRAME_TYPE_MASK; - //Log.trace(TAG, rawByte + " = Frame Type: " + frameType); - if ((version < 1 || version > 5) //These are known versions supported by this library. - && frameType != SdlPacket.FRAME_TYPE_CONTROL) { + if ((version < 1 || version > 5)) { + //These are known versions supported by this library. return ERROR_STATE; } @@ -116,7 +117,16 @@ private int transitionOnInput(byte rawByte, int state) { case SERVICE_TYPE_STATE: serviceType = (int) (rawByte & 0xFF); - return CONTROL_FRAME_INFO_STATE; + switch (serviceType) { + case 0x00: //SessionType.CONTROL: + case 0x07: //SessionType.RPC: + case 0x0A: //SessionType.PCM (Audio): + case 0x0B: //SessionType.NAV (Video): + case 0x0F: //SessionType.BULK (Hybrid): + return CONTROL_FRAME_INFO_STATE; + default: + return ERROR_STATE; + } case CONTROL_FRAME_INFO_STATE: controlFrameInfo = (int) (rawByte & 0xFF); @@ -203,19 +213,34 @@ private int transitionOnInput(byte rawByte, int state) { default: return ERROR_STATE; } - if (version == 1) { //Version 1 packets will not have message id's - if (dataLength == 0) { - return FINISHED_STATE; //We are done if we don't have any payload - } - if (dataLength <= V1_V2_MTU_SIZE - V1_HEADER_SIZE) { // sizes from protocol/WiProProtocol.java - payload = new byte[dataLength]; - } else { - return ERROR_STATE; - } - dumpSize = dataLength; - return DATA_PUMP_STATE; - } else { - return MESSAGE_1_STATE; + switch (version) { + case 1: + //Version 1 packets will not have message id's + if (dataLength == 0) { + return FINISHED_STATE; //We are done if we don't have any payload + } + if (dataLength <= V1_V2_MTU_SIZE - V1_HEADER_SIZE) { // sizes from SDL protocol + payload = new byte[dataLength]; + } else { + return ERROR_STATE; + } + dumpSize = dataLength; + return DATA_PUMP_STATE; + case 2: + if (dataLength <= V1_V2_MTU_SIZE - V2_HEADER_SIZE) { + return MESSAGE_1_STATE; + } else { + return ERROR_STATE; + } + case 3: + case 4: + if (dataLength <= V3_V4_MTU_SIZE - V2_HEADER_SIZE) { + return MESSAGE_1_STATE; + } else { + return ERROR_STATE; + } + default: + return MESSAGE_1_STATE; } case MESSAGE_1_STATE: