Skip to content

Commit

Permalink
fix: improve StringBuilder elimination (#704)
Browse files Browse the repository at this point in the history
  • Loading branch information
skylot committed Jul 11, 2019
1 parent 0c5a83c commit a530371
Show file tree
Hide file tree
Showing 7 changed files with 209 additions and 90 deletions.
11 changes: 5 additions & 6 deletions jadx-core/src/main/java/jadx/core/dex/nodes/InsnNode.java
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,13 @@ public class InsnNode extends LineAttrNode {
protected int offset;

public InsnNode(InsnType type, int argsCount) {
this(type, argsCount == 0 ? Collections.emptyList() : new ArrayList<>(argsCount));
}

public InsnNode(InsnType type, List<InsnArg> args) {
this.insnType = type;
this.arguments = args;
this.offset = -1;

if (argsCount == 0) {
this.arguments = Collections.emptyList();
} else {
this.arguments = new ArrayList<>(argsCount);
}
}

public static InsnNode wrapArg(InsnArg arg) {
Expand Down
209 changes: 128 additions & 81 deletions jadx-core/src/main/java/jadx/core/dex/visitors/SimplifyVisitor.java
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import jadx.core.dex.info.MethodInfo;
import jadx.core.dex.instructions.ArithNode;
import jadx.core.dex.instructions.ArithOp;
import jadx.core.dex.instructions.CallMthInterface;
import jadx.core.dex.instructions.ConstStringNode;
import jadx.core.dex.instructions.FilledNewArrayNode;
import jadx.core.dex.instructions.IfNode;
Expand All @@ -28,13 +27,18 @@
import jadx.core.dex.instructions.args.InsnArg;
import jadx.core.dex.instructions.args.InsnWrapArg;
import jadx.core.dex.instructions.args.LiteralArg;
import jadx.core.dex.instructions.args.RegisterArg;
import jadx.core.dex.instructions.args.SSAVar;
import jadx.core.dex.instructions.mods.ConstructorInsn;
import jadx.core.dex.instructions.mods.TernaryInsn;
import jadx.core.dex.nodes.BlockNode;
import jadx.core.dex.nodes.InsnNode;
import jadx.core.dex.nodes.MethodNode;
import jadx.core.dex.nodes.RootNode;
import jadx.core.dex.regions.conditions.IfCondition;
import jadx.core.utils.BlockUtils;
import jadx.core.utils.InsnList;
import jadx.core.utils.InsnRemover;

public class SimplifyVisitor extends AbstractVisitor {

Expand Down Expand Up @@ -232,77 +236,136 @@ private static void simplifyTernary(TernaryInsn insn) {
* Those chains are usually automatically generated by the Java compiler when you create String
* concatenations like <code>"text " + 1 + " text"</code>.
*/
@SuppressWarnings("InnerAssignment") // TODO
private static InsnNode convertInvoke(MethodNode mth, InvokeNode insn) {
MethodInfo callMth = insn.getCallMth();

// If this is a 'new StringBuilder(xxx).append(yyy).append(zzz).toString(),
// convert it to STRING_CONCAT pseudo instruction.
if (callMth.getDeclClass().getFullName().equals(Consts.CLASS_STRING_BUILDER)
&& callMth.getShortId().equals(Consts.MTH_TOSTRING_SIGNATURE)
&& insn.getArg(0).isInsnWrap()) {
try {
List<InsnNode> chain = flattenInsnChain(insn);
int constrIndex = -1; // RAF
// Case where new StringBuilder() is called with NO args (the entire
// string is created using .append() calls:
if (chain.size() > 1 && chain.get(0).getType() == InsnType.CONSTRUCTOR) {
constrIndex = 0;
} else if (chain.size() > 2 && chain.get(1).getType() == InsnType.CONSTRUCTOR) {
// RAF Case where the first string element is String arg to the
// new StringBuilder("xxx") constructor
constrIndex = 1;
} else if (chain.size() > 3 && chain.get(2).getType() == InsnType.CONSTRUCTOR) {
// RAF Case where the first string element is String.valueOf() arg
// to the new StringBuilder(String.valueOf(zzz)) constructor
constrIndex = 2;
&& callMth.getShortId().equals(Consts.MTH_TOSTRING_SIGNATURE)) {
InsnArg instanceArg = insn.getArg(0);
if (instanceArg.isInsnWrap()) {
// Convert 'new StringBuilder(xxx).append(yyy).append(zzz).toString() to STRING_CONCAT insn
List<InsnNode> callChain = flattenInsnChainUntil(insn, InsnType.CONSTRUCTOR);
return convertStringBuilderChain(mth, insn, callChain);
}
if (instanceArg.isRegister()) {
// Convert 'StringBuilder sb = new StringBuilder(xxx); sb.append(yyy); String str = sb.toString();'
List<InsnNode> useChain = collectUseChain(mth, insn, (RegisterArg) instanceArg);
return convertStringBuilderChain(mth, insn, useChain);
}
}
return null;
}

private static List<InsnNode> collectUseChain(MethodNode mth, InvokeNode insn, RegisterArg instanceArg) {
SSAVar sVar = instanceArg.getSVar();
if (sVar.isUsedInPhi() || sVar.getUseCount() == 0) {
return Collections.emptyList();
}
List<InsnNode> useChain = new ArrayList<>(sVar.getUseCount() + 1);
InsnNode assignInsn = sVar.getAssign().getParentInsn();
if (assignInsn == null) {
return Collections.emptyList();
}
useChain.add(assignInsn);
for (RegisterArg reg : sVar.getUseList()) {
InsnNode parentInsn = reg.getParentInsn();
if (parentInsn == null) {
return Collections.emptyList();
}
useChain.add(parentInsn);
}
int toStrIdx = InsnList.getIndex(useChain, insn);
if (useChain.size() - 1 != toStrIdx) {
return Collections.emptyList();
}
useChain.remove(toStrIdx);

// all insns must be in one block and sequential
BlockNode assignBlock = BlockUtils.getBlockByInsn(mth, assignInsn);
if (assignBlock == null) {
return Collections.emptyList();
}
List<InsnNode> blockInsns = assignBlock.getInstructions();
int assignIdx = InsnList.getIndex(blockInsns, assignInsn);
int chainSize = useChain.size();
int lastInsn = blockInsns.size() - assignIdx;
if (lastInsn < chainSize) {
return Collections.emptyList();
}
for (int i = 1; i < chainSize; i++) {
if (blockInsns.get(assignIdx + i) != useChain.get(i)) {
return Collections.emptyList();
}
}
return useChain;
}

private static InsnNode convertStringBuilderChain(MethodNode mth, InvokeNode toStrInsn, List<InsnNode> chain) {
try {
int chainSize = chain.size();
if (chainSize < 2) {
return null;
}
List<InsnArg> args = new ArrayList<>(chainSize);
InsnNode firstInsn = chain.get(0);
if (firstInsn.getType() != InsnType.CONSTRUCTOR) {
return null;
}
ConstructorInsn constrInsn = (ConstructorInsn) firstInsn;
if (constrInsn.getArgsCount() == 1) {
ArgType argType = constrInsn.getCallMth().getArgumentsTypes().get(0);
if (!argType.isObject()) {
return null;
}
args.add(constrInsn.getArg(0));
}
for (int i = 1; i < chainSize; i++) {
InsnNode chainInsn = chain.get(i);
InsnArg arg = getArgFromAppend(chainInsn);
if (arg == null) {
return null;
}
args.add(arg);
}
InsnNode concatInsn = new InsnNode(InsnType.STR_CONCAT, args);
concatInsn.setResult(toStrInsn.getResult());
concatInsn.copyAttributesFrom(toStrInsn);

if (constrIndex != -1) { // If we found a CONSTRUCTOR, is it a StringBuilder?
ConstructorInsn constr = (ConstructorInsn) chain.get(constrIndex);
if (constr.getClassType().getFullName().equals(Consts.CLASS_STRING_BUILDER)) {
int len = chain.size();
int argInd = 1;
InsnNode concatInsn = new InsnNode(InsnType.STR_CONCAT, len - 1);
InsnNode argInsn;
if (constrIndex > 0) { // There was an arg to the StringBuilder constr
InsnWrapArg iwa;
if (constrIndex == 2
&& (argInsn = chain.get(1)).getType() == InsnType.INVOKE
&& ((InvokeNode) argInsn).getCallMth().getName().compareTo("valueOf") == 0) {
// The argument of new StringBuilder() is a String.valueOf(chainElement0)
iwa = (InsnWrapArg) argInsn.getArg(0);
argInd = 3; // Cause for loop below to skip to after the constructor
} else {
InsnNode firstNode = chain.get(0);
if (firstNode instanceof ConstStringNode) {
ConstStringNode csn = (ConstStringNode) firstNode;
iwa = new InsnWrapArg(csn);
argInd = 2; // Cause for loop below to skip to after the constructor
} else {
return null;
}
}
concatInsn.addArg(iwa);
}
InsnRemover insnRemover = new InsnRemover(mth);
for (InsnNode insnNode : chain) {
insnRemover.addAndUnbind(insnNode);
}
insnRemover.perform();

for (; argInd < len; argInd++) { // Add the .append(xxx) arg string to concat
InsnNode node = chain.get(argInd);
MethodInfo method = ((CallMthInterface) node).getCallMth();
if (!(node.getArgsCount() < 2 && method.isConstructor() || method.getName().equals("append"))) {
// The chain contains other calls to StringBuilder methods than the constructor or append.
// We can't simplify such chains, therefore we leave them as they are.
return null;
}
// process only constructor and append() calls
concatInsn.addArg(node.getArg(1));
}
concatInsn.setResult(insn.getResult());
return concatInsn;
} // end of if constructor is for StringBuilder
} // end of if we found a constructor early in the chain
} catch (Exception e) {
LOG.warn("Can't convert string concatenation: {} insn: {}", mth, insn, e);
return concatInsn;
} catch (Exception e) {
LOG.warn("Can't convert string concatenation: {} insn: {}", mth, toStrInsn, e);
}
return null;
}

private static List<InsnNode> flattenInsnChainUntil(InsnNode insn, InsnType insnType) {
List<InsnNode> chain = new ArrayList<>();
InsnArg arg = insn.getArg(0);
while (arg.isInsnWrap()) {
InsnNode wrapInsn = ((InsnWrapArg) arg).getWrapInsn();
chain.add(wrapInsn);
if (wrapInsn.getType() == insnType
|| wrapInsn.getArgsCount() == 0) {
break;
}
arg = wrapInsn.getArg(0);
}
Collections.reverse(chain);
return chain;
}

private static InsnArg getArgFromAppend(InsnNode chainInsn) {
if (chainInsn.getType() == InsnType.INVOKE && chainInsn.getArgsCount() == 2) {
MethodInfo callMth = ((InvokeNode) chainInsn).getCallMth();
if (callMth.getDeclClass().getFullName().equals(Consts.CLASS_STRING_BUILDER)
&& callMth.getName().equals("append")) {
return chainInsn.getArg(1);
}
}
return null;
Expand Down Expand Up @@ -392,20 +455,4 @@ private static InsnNode convertFieldArith(MethodNode mth, InsnNode insn) {
}
return null;
}

private static List<InsnNode> flattenInsnChain(InsnNode insn) {
List<InsnNode> chain = new ArrayList<>();
InsnArg i = insn.getArg(0);
while (i.isInsnWrap()) {
InsnNode wrapInsn = ((InsnWrapArg) i).getWrapInsn();
chain.add(wrapInsn);
if (wrapInsn.getArgsCount() == 0) {
break;
}

i = wrapInsn.getArg(0);
}
Collections.reverse(chain);
return chain;
}
}
1 change: 1 addition & 0 deletions jadx-core/src/main/java/jadx/core/utils/BlockUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,7 @@ public static InsnNode getLastInsn(@Nullable IBlock block) {
return insns.get(insns.size() - 1);
}

@Nullable
public static BlockNode getBlockByInsn(MethodNode mth, InsnNode insn) {
if (insn instanceof PhiInsn) {
return searchBlockWithPhi(mth, (PhiInsn) insn);
Expand Down
11 changes: 10 additions & 1 deletion jadx-core/src/main/java/jadx/core/utils/InsnRemover.java
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import java.util.Iterator;
import java.util.List;

import org.jetbrains.annotations.Nullable;

import jadx.core.dex.attributes.AFlag;
import jadx.core.dex.instructions.InsnType;
import jadx.core.dex.instructions.args.InsnArg;
Expand All @@ -22,6 +24,7 @@ public class InsnRemover {

private final MethodNode mth;
private final List<InsnNode> toRemove;
@Nullable
private List<InsnNode> instrList;

public InsnRemover(MethodNode mth) {
Expand Down Expand Up @@ -53,7 +56,13 @@ public void perform() {
if (toRemove.isEmpty()) {
return;
}
removeAll(instrList, toRemove);
if (instrList == null) {
for (InsnNode remInsn : toRemove) {
remove(mth, remInsn);
}
} else {
removeAll(instrList, toRemove);
}
toRemove.clear();
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package jadx.tests.integration;
package jadx.tests.integration.others;

import org.junit.jupiter.api.Test;

Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package jadx.tests.integration;
package jadx.tests.integration.others;

import org.junit.jupiter.api.Test;

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
package jadx.tests.integration.others;

import org.junit.jupiter.api.Test;

import jadx.core.dex.nodes.ClassNode;
import jadx.tests.api.IntegrationTest;

import static org.hamcrest.CoreMatchers.containsString;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.not;

public class TestStringBuilderElimination3 extends IntegrationTest {

public static class TestCls {
public static String test(String a) {
StringBuilder sb = new StringBuilder();
sb.append("result = ");
sb.append(a);
return sb.toString();
}
}

@Test
public void test() {
ClassNode cls = getClassNode(TestCls.class);
String code = cls.getCode().toString();

assertThat(code, containsString("return \"result = \" + a;"));
assertThat(code, not(containsString("new StringBuilder()")));
}

public static class TestClsNegative {
private String f = "first";

public String test() {
StringBuilder sb = new StringBuilder();
sb.append("before = ");
sb.append(this.f);
updateF();
sb.append(", after = ");
sb.append(this.f);
return sb.toString();
}

private void updateF() {
this.f = "second";
}

public void check() {
assertThat(test(), is("before = first, after = second"));
}
}

@Test
public void testNegative() {
ClassNode cls = getClassNode(TestClsNegative.class);
String code = cls.getCode().toString();

assertThat(code, containsString("return sb.toString();"));
assertThat(code, containsString("new StringBuilder()"));
}
}

0 comments on commit a530371

Please sign in to comment.