Skip to content

Commit

Permalink
fix: add explicit cast for byte literal in method invoke (#719)
Browse files Browse the repository at this point in the history
  • Loading branch information
skylot committed Jul 30, 2019
1 parent 4629043 commit be9dae5
Show file tree
Hide file tree
Showing 6 changed files with 42 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,9 @@ public interface CallMthInterface {
MethodInfo getCallMth();

RegisterArg getInstanceArg();

/**
* Return offset to match method args from {@link #getCallMth()}
*/
int getFirstArgOffset();
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ public class InvokeNode extends InsnNode implements CallMthInterface {
private final MethodInfo mth;

public InvokeNode(MethodInfo mth, DecodedInstruction insn, InvokeType type, boolean isRange, int resReg) {
super(InsnType.INVOKE, mth.getArgsCount() + (type != InvokeType.STATIC ? 1 : 0));
super(InsnType.INVOKE, mth.getArgsCount() + (type == InvokeType.STATIC ? 0 : 1));
this.mth = mth;
this.type = type;

Expand Down Expand Up @@ -66,6 +66,11 @@ public RegisterArg getInstanceArg() {
return null;
}

@Override
public int getFirstArgOffset() {
return type == InvokeType.STATIC ? 0 : 1;
}

@Override
public InsnNode copy() {
return copyCommonParams(new InvokeNode(mth, type, getArgsCount()));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,11 @@ public boolean isSelf() {
return callType == CallType.SELF;
}

@Override
public int getFirstArgOffset() {
return 0;
}

@Override
public boolean isSame(InsnNode obj) {
if (this == obj) {
Expand Down
2 changes: 1 addition & 1 deletion jadx-core/src/main/java/jadx/core/dex/nodes/InsnNode.java
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ protected InsnArg removeArg(int index) {
return arg;
}

protected int getArgIndex(InsnArg arg) {
public int getArgIndex(InsnArg arg) {
int count = getArgsCount();
for (int i = 0; i < count; i++) {
if (arg == arguments.get(i)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import java.util.List;

import jadx.core.dex.attributes.AFlag;
import jadx.core.dex.info.MethodInfo;
import jadx.core.dex.instructions.CallMthInterface;
import jadx.core.dex.instructions.ConstStringNode;
import jadx.core.dex.instructions.IndexInsnNode;
import jadx.core.dex.instructions.InsnType;
Expand Down Expand Up @@ -203,6 +205,10 @@ private static boolean replaceArg(MethodNode mth, RegisterArg arg, InsnArg const
}
if (fieldNode != null) {
litArg.wrapInstruction(mth, new IndexInsnNode(InsnType.SGET, fieldNode.getFieldInfo(), 0));
} else {
if (needExplicitCast(useInsn, litArg)) {
litArg.add(AFlag.EXPLICIT_PRIMITIVE_TYPE);
}
}
} else {
if (!useInsn.replaceArg(arg, constArg.duplicate())) {
Expand All @@ -214,4 +220,19 @@ private static boolean replaceArg(MethodNode mth, RegisterArg arg, InsnArg const
}
return true;
}

private static boolean needExplicitCast(InsnNode insn, LiteralArg arg) {
if (insn instanceof CallMthInterface) {
CallMthInterface callInsn = (CallMthInterface) insn;
MethodInfo callMth = callInsn.getCallMth();
int offset = callInsn.getFirstArgOffset();
int argIndex = insn.getArgIndex(arg);
ArgType argType = callMth.getArgumentsTypes().get(argIndex - offset);
if (argType.isPrimitive()) {
arg.setType(argType);
return argType.equals(ArgType.BYTE);
}
}
return false;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@ public class TestInnerEnums extends IntegrationTest {
public static class TestCls {

public enum Numbers {
ONE(1, NumString.ONE), TWO(2, NumString.TWO);
ONE((byte) 1, NumString.ONE), TWO((byte) 2, NumString.TWO);

private final int num;
private final byte num;
private final NumString str;

public enum NumString {
Expand All @@ -33,7 +33,7 @@ public String getName() {
}
}

Numbers(int n, NumString str) {
Numbers(byte n, NumString str) {
this.num = n;
this.str = str;
}
Expand Down Expand Up @@ -63,7 +63,7 @@ public void test() {
ClassNode cls = getClassNode(TestCls.class);
String code = cls.getCode().toString();

assertThat(code, containsOne("ONE(1, NumString.ONE)"));
assertThat(code, containsOne("ONE((byte) 1, NumString.ONE)"));
assertThat(code, containsOne("ONE(\"one\")"));
}
}

0 comments on commit be9dae5

Please sign in to comment.