From a530371b6fddfc228090fd299fbd826f85ee7a62 Mon Sep 17 00:00:00 2001 From: Skylot Date: Thu, 11 Jul 2019 20:07:14 +0300 Subject: [PATCH] fix: improve StringBuilder elimination (#704) --- .../java/jadx/core/dex/nodes/InsnNode.java | 11 +- .../core/dex/visitors/SimplifyVisitor.java | 209 +++++++++++------- .../main/java/jadx/core/utils/BlockUtils.java | 1 + .../java/jadx/core/utils/InsnRemover.java | 11 +- .../TestStringBuilderElimination.java | 2 +- .../TestStringBuilderElimination2.java | 2 +- .../others/TestStringBuilderElimination3.java | 63 ++++++ 7 files changed, 209 insertions(+), 90 deletions(-) rename jadx-core/src/test/java/jadx/tests/integration/{ => others}/TestStringBuilderElimination.java (96%) rename jadx-core/src/test/java/jadx/tests/integration/{ => others}/TestStringBuilderElimination2.java (98%) create mode 100644 jadx-core/src/test/java/jadx/tests/integration/others/TestStringBuilderElimination3.java diff --git a/jadx-core/src/main/java/jadx/core/dex/nodes/InsnNode.java b/jadx-core/src/main/java/jadx/core/dex/nodes/InsnNode.java index d21bc53ae6a..42377eba594 100644 --- a/jadx-core/src/main/java/jadx/core/dex/nodes/InsnNode.java +++ b/jadx-core/src/main/java/jadx/core/dex/nodes/InsnNode.java @@ -31,14 +31,13 @@ public class InsnNode extends LineAttrNode { protected int offset; public InsnNode(InsnType type, int argsCount) { + this(type, argsCount == 0 ? Collections.emptyList() : new ArrayList<>(argsCount)); + } + + public InsnNode(InsnType type, List args) { this.insnType = type; + this.arguments = args; this.offset = -1; - - if (argsCount == 0) { - this.arguments = Collections.emptyList(); - } else { - this.arguments = new ArrayList<>(argsCount); - } } public static InsnNode wrapArg(InsnArg arg) { diff --git a/jadx-core/src/main/java/jadx/core/dex/visitors/SimplifyVisitor.java b/jadx-core/src/main/java/jadx/core/dex/visitors/SimplifyVisitor.java index 2eae2f5bf73..f60608a6aac 100644 --- a/jadx-core/src/main/java/jadx/core/dex/visitors/SimplifyVisitor.java +++ b/jadx-core/src/main/java/jadx/core/dex/visitors/SimplifyVisitor.java @@ -15,7 +15,6 @@ import jadx.core.dex.info.MethodInfo; import jadx.core.dex.instructions.ArithNode; import jadx.core.dex.instructions.ArithOp; -import jadx.core.dex.instructions.CallMthInterface; import jadx.core.dex.instructions.ConstStringNode; import jadx.core.dex.instructions.FilledNewArrayNode; import jadx.core.dex.instructions.IfNode; @@ -28,6 +27,8 @@ import jadx.core.dex.instructions.args.InsnArg; import jadx.core.dex.instructions.args.InsnWrapArg; import jadx.core.dex.instructions.args.LiteralArg; +import jadx.core.dex.instructions.args.RegisterArg; +import jadx.core.dex.instructions.args.SSAVar; import jadx.core.dex.instructions.mods.ConstructorInsn; import jadx.core.dex.instructions.mods.TernaryInsn; import jadx.core.dex.nodes.BlockNode; @@ -35,6 +36,9 @@ import jadx.core.dex.nodes.MethodNode; import jadx.core.dex.nodes.RootNode; import jadx.core.dex.regions.conditions.IfCondition; +import jadx.core.utils.BlockUtils; +import jadx.core.utils.InsnList; +import jadx.core.utils.InsnRemover; public class SimplifyVisitor extends AbstractVisitor { @@ -232,77 +236,136 @@ private static void simplifyTernary(TernaryInsn insn) { * Those chains are usually automatically generated by the Java compiler when you create String * concatenations like "text " + 1 + " text". */ - @SuppressWarnings("InnerAssignment") // TODO private static InsnNode convertInvoke(MethodNode mth, InvokeNode insn) { MethodInfo callMth = insn.getCallMth(); - // If this is a 'new StringBuilder(xxx).append(yyy).append(zzz).toString(), - // convert it to STRING_CONCAT pseudo instruction. if (callMth.getDeclClass().getFullName().equals(Consts.CLASS_STRING_BUILDER) - && callMth.getShortId().equals(Consts.MTH_TOSTRING_SIGNATURE) - && insn.getArg(0).isInsnWrap()) { - try { - List chain = flattenInsnChain(insn); - int constrIndex = -1; // RAF - // Case where new StringBuilder() is called with NO args (the entire - // string is created using .append() calls: - if (chain.size() > 1 && chain.get(0).getType() == InsnType.CONSTRUCTOR) { - constrIndex = 0; - } else if (chain.size() > 2 && chain.get(1).getType() == InsnType.CONSTRUCTOR) { - // RAF Case where the first string element is String arg to the - // new StringBuilder("xxx") constructor - constrIndex = 1; - } else if (chain.size() > 3 && chain.get(2).getType() == InsnType.CONSTRUCTOR) { - // RAF Case where the first string element is String.valueOf() arg - // to the new StringBuilder(String.valueOf(zzz)) constructor - constrIndex = 2; + && callMth.getShortId().equals(Consts.MTH_TOSTRING_SIGNATURE)) { + InsnArg instanceArg = insn.getArg(0); + if (instanceArg.isInsnWrap()) { + // Convert 'new StringBuilder(xxx).append(yyy).append(zzz).toString() to STRING_CONCAT insn + List callChain = flattenInsnChainUntil(insn, InsnType.CONSTRUCTOR); + return convertStringBuilderChain(mth, insn, callChain); + } + if (instanceArg.isRegister()) { + // Convert 'StringBuilder sb = new StringBuilder(xxx); sb.append(yyy); String str = sb.toString();' + List useChain = collectUseChain(mth, insn, (RegisterArg) instanceArg); + return convertStringBuilderChain(mth, insn, useChain); + } + } + return null; + } + + private static List collectUseChain(MethodNode mth, InvokeNode insn, RegisterArg instanceArg) { + SSAVar sVar = instanceArg.getSVar(); + if (sVar.isUsedInPhi() || sVar.getUseCount() == 0) { + return Collections.emptyList(); + } + List useChain = new ArrayList<>(sVar.getUseCount() + 1); + InsnNode assignInsn = sVar.getAssign().getParentInsn(); + if (assignInsn == null) { + return Collections.emptyList(); + } + useChain.add(assignInsn); + for (RegisterArg reg : sVar.getUseList()) { + InsnNode parentInsn = reg.getParentInsn(); + if (parentInsn == null) { + return Collections.emptyList(); + } + useChain.add(parentInsn); + } + int toStrIdx = InsnList.getIndex(useChain, insn); + if (useChain.size() - 1 != toStrIdx) { + return Collections.emptyList(); + } + useChain.remove(toStrIdx); + + // all insns must be in one block and sequential + BlockNode assignBlock = BlockUtils.getBlockByInsn(mth, assignInsn); + if (assignBlock == null) { + return Collections.emptyList(); + } + List blockInsns = assignBlock.getInstructions(); + int assignIdx = InsnList.getIndex(blockInsns, assignInsn); + int chainSize = useChain.size(); + int lastInsn = blockInsns.size() - assignIdx; + if (lastInsn < chainSize) { + return Collections.emptyList(); + } + for (int i = 1; i < chainSize; i++) { + if (blockInsns.get(assignIdx + i) != useChain.get(i)) { + return Collections.emptyList(); + } + } + return useChain; + } + + private static InsnNode convertStringBuilderChain(MethodNode mth, InvokeNode toStrInsn, List chain) { + try { + int chainSize = chain.size(); + if (chainSize < 2) { + return null; + } + List args = new ArrayList<>(chainSize); + InsnNode firstInsn = chain.get(0); + if (firstInsn.getType() != InsnType.CONSTRUCTOR) { + return null; + } + ConstructorInsn constrInsn = (ConstructorInsn) firstInsn; + if (constrInsn.getArgsCount() == 1) { + ArgType argType = constrInsn.getCallMth().getArgumentsTypes().get(0); + if (!argType.isObject()) { + return null; } + args.add(constrInsn.getArg(0)); + } + for (int i = 1; i < chainSize; i++) { + InsnNode chainInsn = chain.get(i); + InsnArg arg = getArgFromAppend(chainInsn); + if (arg == null) { + return null; + } + args.add(arg); + } + InsnNode concatInsn = new InsnNode(InsnType.STR_CONCAT, args); + concatInsn.setResult(toStrInsn.getResult()); + concatInsn.copyAttributesFrom(toStrInsn); - if (constrIndex != -1) { // If we found a CONSTRUCTOR, is it a StringBuilder? - ConstructorInsn constr = (ConstructorInsn) chain.get(constrIndex); - if (constr.getClassType().getFullName().equals(Consts.CLASS_STRING_BUILDER)) { - int len = chain.size(); - int argInd = 1; - InsnNode concatInsn = new InsnNode(InsnType.STR_CONCAT, len - 1); - InsnNode argInsn; - if (constrIndex > 0) { // There was an arg to the StringBuilder constr - InsnWrapArg iwa; - if (constrIndex == 2 - && (argInsn = chain.get(1)).getType() == InsnType.INVOKE - && ((InvokeNode) argInsn).getCallMth().getName().compareTo("valueOf") == 0) { - // The argument of new StringBuilder() is a String.valueOf(chainElement0) - iwa = (InsnWrapArg) argInsn.getArg(0); - argInd = 3; // Cause for loop below to skip to after the constructor - } else { - InsnNode firstNode = chain.get(0); - if (firstNode instanceof ConstStringNode) { - ConstStringNode csn = (ConstStringNode) firstNode; - iwa = new InsnWrapArg(csn); - argInd = 2; // Cause for loop below to skip to after the constructor - } else { - return null; - } - } - concatInsn.addArg(iwa); - } + InsnRemover insnRemover = new InsnRemover(mth); + for (InsnNode insnNode : chain) { + insnRemover.addAndUnbind(insnNode); + } + insnRemover.perform(); - for (; argInd < len; argInd++) { // Add the .append(xxx) arg string to concat - InsnNode node = chain.get(argInd); - MethodInfo method = ((CallMthInterface) node).getCallMth(); - if (!(node.getArgsCount() < 2 && method.isConstructor() || method.getName().equals("append"))) { - // The chain contains other calls to StringBuilder methods than the constructor or append. - // We can't simplify such chains, therefore we leave them as they are. - return null; - } - // process only constructor and append() calls - concatInsn.addArg(node.getArg(1)); - } - concatInsn.setResult(insn.getResult()); - return concatInsn; - } // end of if constructor is for StringBuilder - } // end of if we found a constructor early in the chain - } catch (Exception e) { - LOG.warn("Can't convert string concatenation: {} insn: {}", mth, insn, e); + return concatInsn; + } catch (Exception e) { + LOG.warn("Can't convert string concatenation: {} insn: {}", mth, toStrInsn, e); + } + return null; + } + + private static List flattenInsnChainUntil(InsnNode insn, InsnType insnType) { + List chain = new ArrayList<>(); + InsnArg arg = insn.getArg(0); + while (arg.isInsnWrap()) { + InsnNode wrapInsn = ((InsnWrapArg) arg).getWrapInsn(); + chain.add(wrapInsn); + if (wrapInsn.getType() == insnType + || wrapInsn.getArgsCount() == 0) { + break; + } + arg = wrapInsn.getArg(0); + } + Collections.reverse(chain); + return chain; + } + + private static InsnArg getArgFromAppend(InsnNode chainInsn) { + if (chainInsn.getType() == InsnType.INVOKE && chainInsn.getArgsCount() == 2) { + MethodInfo callMth = ((InvokeNode) chainInsn).getCallMth(); + if (callMth.getDeclClass().getFullName().equals(Consts.CLASS_STRING_BUILDER) + && callMth.getName().equals("append")) { + return chainInsn.getArg(1); } } return null; @@ -392,20 +455,4 @@ private static InsnNode convertFieldArith(MethodNode mth, InsnNode insn) { } return null; } - - private static List flattenInsnChain(InsnNode insn) { - List chain = new ArrayList<>(); - InsnArg i = insn.getArg(0); - while (i.isInsnWrap()) { - InsnNode wrapInsn = ((InsnWrapArg) i).getWrapInsn(); - chain.add(wrapInsn); - if (wrapInsn.getArgsCount() == 0) { - break; - } - - i = wrapInsn.getArg(0); - } - Collections.reverse(chain); - return chain; - } } diff --git a/jadx-core/src/main/java/jadx/core/utils/BlockUtils.java b/jadx-core/src/main/java/jadx/core/utils/BlockUtils.java index d5e450f144d..7cce3f5dfb2 100644 --- a/jadx-core/src/main/java/jadx/core/utils/BlockUtils.java +++ b/jadx-core/src/main/java/jadx/core/utils/BlockUtils.java @@ -165,6 +165,7 @@ public static InsnNode getLastInsn(@Nullable IBlock block) { return insns.get(insns.size() - 1); } + @Nullable public static BlockNode getBlockByInsn(MethodNode mth, InsnNode insn) { if (insn instanceof PhiInsn) { return searchBlockWithPhi(mth, (PhiInsn) insn); diff --git a/jadx-core/src/main/java/jadx/core/utils/InsnRemover.java b/jadx-core/src/main/java/jadx/core/utils/InsnRemover.java index 0b2c661ef62..c66b637cdcc 100644 --- a/jadx-core/src/main/java/jadx/core/utils/InsnRemover.java +++ b/jadx-core/src/main/java/jadx/core/utils/InsnRemover.java @@ -4,6 +4,8 @@ import java.util.Iterator; import java.util.List; +import org.jetbrains.annotations.Nullable; + import jadx.core.dex.attributes.AFlag; import jadx.core.dex.instructions.InsnType; import jadx.core.dex.instructions.args.InsnArg; @@ -22,6 +24,7 @@ public class InsnRemover { private final MethodNode mth; private final List toRemove; + @Nullable private List instrList; public InsnRemover(MethodNode mth) { @@ -53,7 +56,13 @@ public void perform() { if (toRemove.isEmpty()) { return; } - removeAll(instrList, toRemove); + if (instrList == null) { + for (InsnNode remInsn : toRemove) { + remove(mth, remInsn); + } + } else { + removeAll(instrList, toRemove); + } toRemove.clear(); } diff --git a/jadx-core/src/test/java/jadx/tests/integration/TestStringBuilderElimination.java b/jadx-core/src/test/java/jadx/tests/integration/others/TestStringBuilderElimination.java similarity index 96% rename from jadx-core/src/test/java/jadx/tests/integration/TestStringBuilderElimination.java rename to jadx-core/src/test/java/jadx/tests/integration/others/TestStringBuilderElimination.java index 9d25c3d4beb..81a19aea95e 100644 --- a/jadx-core/src/test/java/jadx/tests/integration/TestStringBuilderElimination.java +++ b/jadx-core/src/test/java/jadx/tests/integration/others/TestStringBuilderElimination.java @@ -1,4 +1,4 @@ -package jadx.tests.integration; +package jadx.tests.integration.others; import org.junit.jupiter.api.Test; diff --git a/jadx-core/src/test/java/jadx/tests/integration/TestStringBuilderElimination2.java b/jadx-core/src/test/java/jadx/tests/integration/others/TestStringBuilderElimination2.java similarity index 98% rename from jadx-core/src/test/java/jadx/tests/integration/TestStringBuilderElimination2.java rename to jadx-core/src/test/java/jadx/tests/integration/others/TestStringBuilderElimination2.java index e0abace6096..d478c030993 100644 --- a/jadx-core/src/test/java/jadx/tests/integration/TestStringBuilderElimination2.java +++ b/jadx-core/src/test/java/jadx/tests/integration/others/TestStringBuilderElimination2.java @@ -1,4 +1,4 @@ -package jadx.tests.integration; +package jadx.tests.integration.others; import org.junit.jupiter.api.Test; diff --git a/jadx-core/src/test/java/jadx/tests/integration/others/TestStringBuilderElimination3.java b/jadx-core/src/test/java/jadx/tests/integration/others/TestStringBuilderElimination3.java new file mode 100644 index 00000000000..85b437e26e5 --- /dev/null +++ b/jadx-core/src/test/java/jadx/tests/integration/others/TestStringBuilderElimination3.java @@ -0,0 +1,63 @@ +package jadx.tests.integration.others; + +import org.junit.jupiter.api.Test; + +import jadx.core.dex.nodes.ClassNode; +import jadx.tests.api.IntegrationTest; + +import static org.hamcrest.CoreMatchers.containsString; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.not; + +public class TestStringBuilderElimination3 extends IntegrationTest { + + public static class TestCls { + public static String test(String a) { + StringBuilder sb = new StringBuilder(); + sb.append("result = "); + sb.append(a); + return sb.toString(); + } + } + + @Test + public void test() { + ClassNode cls = getClassNode(TestCls.class); + String code = cls.getCode().toString(); + + assertThat(code, containsString("return \"result = \" + a;")); + assertThat(code, not(containsString("new StringBuilder()"))); + } + + public static class TestClsNegative { + private String f = "first"; + + public String test() { + StringBuilder sb = new StringBuilder(); + sb.append("before = "); + sb.append(this.f); + updateF(); + sb.append(", after = "); + sb.append(this.f); + return sb.toString(); + } + + private void updateF() { + this.f = "second"; + } + + public void check() { + assertThat(test(), is("before = first, after = second")); + } + } + + @Test + public void testNegative() { + ClassNode cls = getClassNode(TestClsNegative.class); + String code = cls.getCode().toString(); + + assertThat(code, containsString("return sb.toString();")); + assertThat(code, containsString("new StringBuilder()")); + } +}