diff --git a/src/main/java/io/quarkus/gizmo/AbstractSwitch.java b/src/main/java/io/quarkus/gizmo/AbstractSwitch.java new file mode 100644 index 0000000..0089509 --- /dev/null +++ b/src/main/java/io/quarkus/gizmo/AbstractSwitch.java @@ -0,0 +1,34 @@ +package io.quarkus.gizmo; + +import java.util.Objects; +import java.util.function.Consumer; + +abstract class AbstractSwitch extends BytecodeCreatorImpl implements Switch { + + protected static final Consumer EMPTY_BLOCK = bc -> { + }; + + protected boolean fallThrough; + protected Consumer defaultBlockConsumer; + + AbstractSwitch(BytecodeCreatorImpl enclosing) { + super(enclosing); + } + + @Override + public void fallThrough() { + fallThrough = true; + } + + @Override + public void defaultCase(Consumer defatultBlockConsumer) { + Objects.requireNonNull(defatultBlockConsumer); + this.defaultBlockConsumer = defatultBlockConsumer; + } + + @Override + public void doBreak(BytecodeCreator creator) { + creator.breakScope(this); + } + +} diff --git a/src/main/java/io/quarkus/gizmo/BytecodeCreator.java b/src/main/java/io/quarkus/gizmo/BytecodeCreator.java index e1fdf44..33333ff 100644 --- a/src/main/java/io/quarkus/gizmo/BytecodeCreator.java +++ b/src/main/java/io/quarkus/gizmo/BytecodeCreator.java @@ -1033,6 +1033,24 @@ default ResultHandle increment(ResultHandle toIncrement) { return add(toIncrement, load(1)); } + /** + * Create a new switch construct for a string value. + * + * @param value The string value to switch on + * @return the switch construct + */ + Switch.StringSwitch stringSwitch(ResultHandle value); + + /** + * Create a new switch construct for an enum constant. + * + * @param + * @param value The enum constant to switch on + * @param enumClass + * @return the switch construct + */ + > Switch.EnumSwitch enumSwitch(ResultHandle value, Class enumClass); + /** * Indicate that the scope is no longer in use. The scope may refuse additional instructions after this method * is called. diff --git a/src/main/java/io/quarkus/gizmo/BytecodeCreatorImpl.java b/src/main/java/io/quarkus/gizmo/BytecodeCreatorImpl.java index 8d65e6c..db4e310 100644 --- a/src/main/java/io/quarkus/gizmo/BytecodeCreatorImpl.java +++ b/src/main/java/io/quarkus/gizmo/BytecodeCreatorImpl.java @@ -801,6 +801,16 @@ public BytecodeCreator createScope() { operations.add(new BlockOperation(enclosed)); return enclosed; } + + /** + * Go the the top of the given scope. Unlike {@link #continueScope(BytecodeCreator)} this method does not verify if this + * bytecode creator is scoped within the given bytecode creator. + * + * @param scope + */ + void jumpTo(BytecodeCreator scope) { + operations.add(new JumpOperation(((BytecodeCreatorImpl) scope).top)); + } static void storeResultHandle(MethodVisitor methodVisitor, ResultHandle handle) { if (handle.getResultType() == ResultHandle.ResultType.UNUSED) { @@ -1318,6 +1328,23 @@ public ResultHandle bitwiseXor(ResultHandle a1, ResultHandle a2) { return emitBinaryArithmetic(Opcodes.IXOR, a1, a2); } + @Override + public Switch.StringSwitch stringSwitch(ResultHandle value) { + Objects.requireNonNull(value); + StringSwitchImpl stringSwitch = new StringSwitchImpl(value, this); + operations.add(new BlockOperation(stringSwitch)); + return stringSwitch; + } + + @Override + public > Switch.EnumSwitch enumSwitch(ResultHandle value, Class enumClass) { + Objects.requireNonNull(value); + Objects.requireNonNull(enumClass); + EnumSwitchImpl enumSwitch = new EnumSwitchImpl<>(value, enumClass, this); + operations.add(new BlockOperation(enumSwitch)); + return enumSwitch; + } + private ResultHandle emitBinaryArithmetic(int intOpcode, ResultHandle a1, ResultHandle a2) { Objects.requireNonNull(a1); Objects.requireNonNull(a2); diff --git a/src/main/java/io/quarkus/gizmo/EnumSwitchImpl.java b/src/main/java/io/quarkus/gizmo/EnumSwitchImpl.java new file mode 100644 index 0000000..19c6728 --- /dev/null +++ b/src/main/java/io/quarkus/gizmo/EnumSwitchImpl.java @@ -0,0 +1,193 @@ +package io.quarkus.gizmo; + +import static org.objectweb.asm.Opcodes.ACC_PRIVATE; +import static org.objectweb.asm.Opcodes.ACC_STATIC; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Iterator; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.Map.Entry; +import java.util.Objects; +import java.util.Set; +import java.util.function.Consumer; + +import org.objectweb.asm.Label; +import org.objectweb.asm.MethodVisitor; + +class EnumSwitchImpl> extends AbstractSwitch implements Switch.EnumSwitch { + + private final Map> ordinalToCaseBlocks; + + public EnumSwitchImpl(ResultHandle value, Class enumClass, BytecodeCreatorImpl enclosing) { + super(enclosing); + this.ordinalToCaseBlocks = new LinkedHashMap<>(); + + MethodDescriptor enumOrdinal = MethodDescriptor.ofMethod(enumClass, "ordinal", int.class); + ResultHandle ordinal = invokeVirtualMethod(enumOrdinal, value); + + // Generate the int[] switch table needed for binary compatibility + ResultHandle switchTable; + MethodCreatorImpl methodCreator = findMethodCreator(enclosing); + if (methodCreator != null) { + // Generate a static method that returns the switch table + char sep = '$'; + ClassCreator classCreator = methodCreator.getClassCreator(); + // $GIZMO_SWITCH_TABLE$org$acme$MyEnum() + StringBuilder methodName = new StringBuilder(); + methodName.append(sep).append("GIZMO_SWITCH_TABLE"); + for (String part : enumClass.getName().split("\\.")) { + methodName.append(sep).append(part); + } + MethodDescriptor gizmoSwitchTableDescriptor = MethodDescriptor.ofMethod(classCreator.getClassName(), + methodName.toString(), int[].class); + if (!classCreator.getExistingMethods() + .contains(gizmoSwitchTableDescriptor)) { + MethodCreator gizmoSwitchTable = classCreator.getMethodCreator(gizmoSwitchTableDescriptor) + .setModifiers(ACC_PRIVATE | ACC_STATIC); + gizmoSwitchTable.returnValue(generateSwitchTable(enumClass, gizmoSwitchTable, enumOrdinal)); + } + switchTable = invokeStaticMethod(gizmoSwitchTableDescriptor); + } else { + // This is suboptimal - the switch table is generated for each switch construct + switchTable = generateSwitchTable(enumClass, methodCreator, enumOrdinal); + } + ResultHandle effectiveOrdinal = readArrayValue(switchTable, ordinal); + + Set inputHandles = new HashSet<>(); + inputHandles.add(effectiveOrdinal); + + operations.add(new Operation() { + + @Override + void writeBytecode(MethodVisitor methodVisitor) { + E[] constants = enumClass.getEnumConstants(); + Map ordinalToLabel = new HashMap<>(); + List caseBlocks = new ArrayList<>(); + + BytecodeCreatorImpl defaultBlock = new BytecodeCreatorImpl(EnumSwitchImpl.this); + if (defaultBlockConsumer != null) { + defaultBlockConsumer.accept(defaultBlock); + } + + // Initialize the case blocks + for (Entry> caseEntry : ordinalToCaseBlocks.entrySet()) { + BytecodeCreatorImpl caseBlock = new BytecodeCreatorImpl(EnumSwitchImpl.this); + Consumer blockConsumer = caseEntry.getValue(); + blockConsumer.accept(caseBlock); + if (blockConsumer != EMPTY_BLOCK && !fallThrough) { + caseBlock.breakScope(EnumSwitchImpl.this); + } + caseBlock.findActiveResultHandles(inputHandles); + caseBlocks.add(caseBlock); + ordinalToLabel.put(caseEntry.getKey(), caseBlock.getTop()); + } + + int min = ordinalToLabel.keySet().stream().mapToInt(Integer::intValue).min().orElse(0); + int max = ordinalToLabel.keySet().stream().mapToInt(Integer::intValue).max().orElse(0); + + // Add empty blocks for missing ordinals + // This would be suboptimal for cases if there is a large number of missing ordinals + for (int i = 0; i < constants.length; i++) { + if (i >= min && i <= max) { + if (ordinalToLabel.get(i) == null) { + BytecodeCreatorImpl emptyCaseBlock = new BytecodeCreatorImpl(EnumSwitchImpl.this); + caseBlocks.add(emptyCaseBlock); + ordinalToLabel.put(i, emptyCaseBlock.getTop()); + } + } + } + + // Load the ordinal of the tested value + loadResultHandle(methodVisitor, effectiveOrdinal, EnumSwitchImpl.this, "I"); + + int[] ordinals = ordinalToLabel.keySet().stream().mapToInt(Integer::intValue).sorted().toArray(); + Label[] labels = new Label[ordinals.length]; + for (int i = 0; i < ordinals.length; i++) { + labels[i] = ordinalToLabel.get(ordinals[i]); + } + methodVisitor.visitTableSwitchInsn(min, max, defaultBlock.getTop(), labels); + + // Write the case blocks + for (BytecodeCreatorImpl caseBlock : caseBlocks) { + caseBlock.writeOperations(methodVisitor); + } + + // Write the default block + defaultBlock.writeOperations(methodVisitor); + } + + @Override + ResultHandle getTopResultHandle() { + return null; + } + + @Override + ResultHandle getOutgoingResultHandle() { + return null; + } + + @Override + Set getInputResultHandles() { + return inputHandles; + } + + }); + } + + @Override + public void caseOf(E value, Consumer caseBlockConsumer) { + Objects.requireNonNull(value); + Objects.requireNonNull(caseBlockConsumer); + addCaseBlock(value, caseBlockConsumer); + } + + @Override + public void caseOf(List values, Consumer caseBlockConsumer) { + Objects.requireNonNull(values); + Objects.requireNonNull(caseBlockConsumer); + for (Iterator it = values.iterator(); it.hasNext();) { + E e = it.next(); + if (it.hasNext()) { + addCaseBlock(e, EMPTY_BLOCK); + } else { + addCaseBlock(e, caseBlockConsumer); + } + } + } + + private void addCaseBlock(E value, Consumer caseBlockConsumer) { + int ordinal = value.ordinal(); + if (ordinalToCaseBlocks.containsKey(ordinal)) { + throw new IllegalArgumentException("A case block for the enum value " + value + " already exists"); + } + ordinalToCaseBlocks.put(ordinal, caseBlockConsumer); + } + + private MethodCreatorImpl findMethodCreator(BytecodeCreatorImpl enclosing) { + if (enclosing instanceof MethodCreatorImpl) { + return (MethodCreatorImpl) enclosing; + } + if (enclosing.getOwner() != null) { + return findMethodCreator(enclosing.getOwner()); + } + return null; + } + + private ResultHandle generateSwitchTable(Class enumClass, BytecodeCreator bytecodeCreator, + MethodDescriptor enumOrdinal) { + E[] constants = enumClass.getEnumConstants(); + ResultHandle switchTable = bytecodeCreator.newArray(int.class, constants.length); + for (int i = 0; i < constants.length; i++) { + ResultHandle currentConstant = bytecodeCreator + .readStaticField(FieldDescriptor.of(enumClass, constants[i].name(), enumClass)); + ResultHandle currentOrdinal = bytecodeCreator.invokeVirtualMethod(enumOrdinal, currentConstant); + bytecodeCreator.writeArrayValue(switchTable, i, currentOrdinal); + } + return switchTable; + } + +} \ No newline at end of file diff --git a/src/main/java/io/quarkus/gizmo/StringSwitchImpl.java b/src/main/java/io/quarkus/gizmo/StringSwitchImpl.java new file mode 100644 index 0000000..703774d --- /dev/null +++ b/src/main/java/io/quarkus/gizmo/StringSwitchImpl.java @@ -0,0 +1,144 @@ +package io.quarkus.gizmo; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Iterator; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.Map.Entry; +import java.util.Objects; +import java.util.Set; +import java.util.function.Consumer; + +import org.objectweb.asm.Label; +import org.objectweb.asm.MethodVisitor; + +class StringSwitchImpl extends AbstractSwitch implements Switch.StringSwitch { + + private final Map>>> hashToCaseBlocks; + + public StringSwitchImpl(ResultHandle value, BytecodeCreatorImpl enclosing) { + super(enclosing); + this.hashToCaseBlocks = new LinkedHashMap<>(); + ResultHandle strHash = invokeVirtualMethod(MethodDescriptor.ofMethod(Object.class, "hashCode", int.class), value); + + Set inputHandles = new HashSet<>(); + inputHandles.add(value); + inputHandles.add(strHash); + + operations.add(new Operation() { + + @Override + void writeBytecode(MethodVisitor methodVisitor) { + Map hashToLabel = new HashMap<>(); + List lookupBlocks = new ArrayList<>(); + Map caseBlocks = new LinkedHashMap<>(); + BytecodeCreatorImpl defaultBlock = new BytecodeCreatorImpl(StringSwitchImpl.this); + if (defaultBlockConsumer != null) { + defaultBlockConsumer.accept(defaultBlock); + } + + // Initialize the case blocks and lookup blocks + for (Entry>>> hashEntry : hashToCaseBlocks.entrySet()) { + BytecodeCreatorImpl lookupBlock = new BytecodeCreatorImpl(StringSwitchImpl.this); + for (Entry> caseEntry : hashEntry.getValue()) { + BytecodeCreatorImpl caseBlock = new BytecodeCreatorImpl(StringSwitchImpl.this); + Consumer blockConsumer = caseEntry.getValue(); + blockConsumer.accept(caseBlock); + if (blockConsumer != EMPTY_BLOCK && !fallThrough) { + caseBlock.breakScope(StringSwitchImpl.this); + } + caseBlock.findActiveResultHandles(inputHandles); + caseBlocks.put(caseEntry.getKey(), caseBlock); + BytecodeCreatorImpl isEqual = (BytecodeCreatorImpl) lookupBlock + .ifTrue(Gizmo.equals(lookupBlock, lookupBlock.load(caseEntry.getKey()), value)).trueBranch(); + isEqual.jumpTo(caseBlock); + } + hashToLabel.put(hashEntry.getKey(), lookupBlock.getTop()); + lookupBlock.findActiveResultHandles(inputHandles); + lookupBlocks.add(lookupBlock); + } + + // Load the hashCode of the tested value + loadResultHandle(methodVisitor, strHash, StringSwitchImpl.this, "I"); + + // The lookupswitch keys must be sorted in increasing numerical order + int[] keys = hashToCaseBlocks.keySet().stream().mapToInt(e -> e.intValue()).sorted().toArray(); + Label[] labels = new Label[keys.length]; + for (int i = 0; i < keys.length; i++) { + labels[i] = hashToLabel.get(keys[i]); + } + methodVisitor.visitLookupSwitchInsn(defaultBlock.getTop(), keys, labels); + + // Write the lookup blocks + for (BytecodeCreatorImpl lookupBlock : lookupBlocks) { + lookupBlock.writeOperations(methodVisitor); + } + + // Write the case blocks + for (BytecodeCreatorImpl caseBlock : caseBlocks.values()) { + caseBlock.writeOperations(methodVisitor); + } + + // Write the default block + defaultBlock.writeOperations(methodVisitor); + } + + @Override + ResultHandle getTopResultHandle() { + return null; + } + + @Override + ResultHandle getOutgoingResultHandle() { + return null; + } + + @Override + Set getInputResultHandles() { + return inputHandles; + } + + }); + } + + @Override + public void caseOf(String value, Consumer caseBlockConsumer) { + Objects.requireNonNull(value); + Objects.requireNonNull(caseBlockConsumer); + addCaseBlock(value, caseBlockConsumer); + } + + @Override + public void caseOf(List values, Consumer caseBlockConsumer) { + Objects.requireNonNull(values); + Objects.requireNonNull(caseBlockConsumer); + for (Iterator it = values.iterator(); it.hasNext();) { + String s = it.next(); + if (it.hasNext()) { + addCaseBlock(s, EMPTY_BLOCK); + } else { + addCaseBlock(s, caseBlockConsumer); + } + } + } + + private void addCaseBlock(String value, Consumer blockConsumer) { + int hashCode = value.hashCode(); + List>> caseBlocks = hashToCaseBlocks.get(hashCode); + if (caseBlocks == null) { + caseBlocks = new ArrayList<>(); + hashToCaseBlocks.put(hashCode, caseBlocks); + } else { + for (Entry> e : caseBlocks) { + if (e.getKey().equals(value)) { + throw new IllegalArgumentException("A case block for the string value " + value + " already exists"); + } + } + } + caseBlocks.add(Map.entry(value, blockConsumer)); + } + +} diff --git a/src/main/java/io/quarkus/gizmo/Switch.java b/src/main/java/io/quarkus/gizmo/Switch.java new file mode 100644 index 0000000..b861dcd --- /dev/null +++ b/src/main/java/io/quarkus/gizmo/Switch.java @@ -0,0 +1,107 @@ +package io.quarkus.gizmo; + +import java.util.List; +import java.util.function.Consumer; + +/** + * A switch statement. + *

+ * This construct is not thread-safe and should not be re-used. + * + * @param Constant type + */ +public interface Switch { + + /** + * Enables fall through. + *

+ * By default, the fall through is disabled. A case block is treated as a switch rule block; i.e. it's not necessary to add + * the break statement to prevent the fall through. However, if fall through is enabled then a case block is treated as a + * labeled statement group; i.e. it's necessary to add the break statement to prevent the fall through. + *

+ * For example, if fall through is disabled then: + * + *

+     * 
+     * StringSwitch s = method.stringSwitch(val);
+       s.caseOf(List.of("boom", "foo"), bc -> {...});
+     * 
+     * 
+ * + * is an equivalent of: + * + *
+     * switch (val) {
+     *     case "boom", "foo" -> // statements provided by the consumer
+     * }
+     * 
+ * + * However, if fall though is enabled then: + * + *
+     * 
+     * StringSwitch s = method.stringSwitch(val);
+     * s.fallThrough();
+     * s.caseOf(List.of("boom", "foo"), bc -> {...});
+     * 
+     * 
+ * + * is an equivalent of: + * + *
+     * switch (val) {
+     *     case "val1":
+     *     case "val2":
+     *         // statements provided by the consumer
+     * }
+     * 
+ */ + void fallThrough(); + + /** + * Adds a case block. + * + * @param value The value for the case label + * @param caseBlockConsumer The consumer used to define the case block + * @throws IllegalArgumentException If a case block for the specified value was already added + */ + void caseOf(T value, Consumer caseBlockConsumer); + + /** + * Adds multiple case labels for a single block. + * + * @param values + * @param caseBlockConsumer + * @throws IllegalArgumentException If a case block for the specified value was already added + */ + void caseOf(List values, Consumer caseBlockConsumer); + + /** + * Adds the default block. + * + * @param defatultBlockConsumer + */ + void defaultCase(Consumer defatultBlockConsumer); + + /** + * Writes bytecode into the provided {@link BytecodeCreator} to make it exit the + * switch, effectively issuing a Java 'break' statement. + * + * @param creator + * @see #fallThrough() + */ + void doBreak(BytecodeCreator creator); + + /** + * A switch for {@link String}. + */ + interface StringSwitch extends Switch { + } + + /** + * A switch for {@link Enum}. + */ + interface EnumSwitch> extends Switch { + } + +} diff --git a/src/test/java/io/quarkus/gizmo/SwitchTest.java b/src/test/java/io/quarkus/gizmo/SwitchTest.java new file mode 100644 index 0000000..aa65aa7 --- /dev/null +++ b/src/test/java/io/quarkus/gizmo/SwitchTest.java @@ -0,0 +1,382 @@ +/* + * Copyright 2022 Red Hat, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.quarkus.gizmo; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.fail; + +import java.util.List; +import java.util.function.Function; + +import org.junit.Test; + +import io.quarkus.gizmo.Switch.EnumSwitch; +import io.quarkus.gizmo.Switch.StringSwitch; + +public class SwitchTest { + + @SuppressWarnings("unchecked") + @Test + public void testStringSwitch() throws InstantiationException, IllegalAccessException, ClassNotFoundException { + TestClassLoader cl = new TestClassLoader(getClass().getClassLoader()); + try (ClassCreator creator = ClassCreator.builder().classOutput(cl).className("com.MyTest").interfaces(Function.class) + .build()) { + MethodCreator method = creator.getMethodCreator("apply", Object.class, Object.class); + AssignableResultHandle ret = method.createVariable(String.class); + // String ret; + // switch(arg) { + // case "boom", "foo" -> ret = "fooo"; + // case "bar" -> ret = "barr"; + // case "baz" -> ret = "bazz"; + // default -> ret = null; + // } + // return ret; + StringSwitch s = method.stringSwitch(method.getMethodParam(0)); + s.caseOf(List.of("boom", "foo"), bc -> { + bc.assign(ret, bc.load("foooboom")); + }); + s.caseOf("bar", bc -> { + bc.assign(ret, bc.load("barr")); + }); + s.caseOf("baz", bc -> { + bc.assign(ret, bc.load("bazz")); + }); + s.defaultCase(bc -> bc.assign(ret, bc.loadNull())); + + method.returnValue(ret); + } + Function myInterface = (Function) cl.loadClass("com.MyTest").newInstance(); + assertEquals("foooboom", myInterface.apply("boom")); + assertEquals("foooboom", myInterface.apply("foo")); + assertEquals("barr", myInterface.apply("bar")); + assertEquals("bazz", myInterface.apply("baz")); + assertNull(myInterface.apply("unknown")); + } + + @SuppressWarnings("unchecked") + @Test + public void testStringSwitchFallThrough() throws InstantiationException, IllegalAccessException, ClassNotFoundException { + TestClassLoader cl = new TestClassLoader(getClass().getClassLoader()); + try (ClassCreator creator = ClassCreator.builder().classOutput(cl).className("com.MyTest").interfaces(Function.class) + .build()) { + MethodCreator method = creator.getMethodCreator("apply", Object.class, Object.class); + AssignableResultHandle ret = method.createVariable(String.class); + // String ret; + // switch(arg) { + // case "boom": + // case "foo": + // ret = "fooo"; + // break; + // case "bar": + // ret = "barr" + // case "baz" + // ret = "bazz"; + // break; + // default: + // ret = null; + // } + // return ret; + StringSwitch s = method.stringSwitch(method.getMethodParam(0)); + s.fallThrough(); + s.caseOf(List.of("boom", "foo"), bc -> { + bc.assign(ret, bc.load("fooo")); + s.doBreak(bc); + }); + s.caseOf("bar", bc -> { + bc.assign(ret, bc.load("barr")); + }); + s.caseOf("baz", bc -> { + bc.assign(ret, bc.load("bazz")); + s.doBreak(bc); + }); + s.defaultCase(bc -> bc.assign(ret, bc.loadNull())); + + method.returnValue(ret); + } + Function myInterface = (Function) cl.loadClass("com.MyTest").newInstance(); + assertEquals("fooo", myInterface.apply("boom")); + assertEquals("fooo", myInterface.apply("foo")); + assertEquals("bazz", myInterface.apply("bar")); + assertEquals("bazz", myInterface.apply("baz")); + assertNull(myInterface.apply("unknown")); + } + + @SuppressWarnings("unchecked") + @Test + public void testStringSwitchWithHashCollision() + throws InstantiationException, IllegalAccessException, ClassNotFoundException { + // Test that a couple of string literals that share the same hash code + assertEquals("Aa".hashCode(), "BB".hashCode()); + + TestClassLoader cl = new TestClassLoader(getClass().getClassLoader()); + try (ClassCreator creator = ClassCreator.builder().classOutput(cl).className("com.MyTest").interfaces(Function.class) + .build()) { + MethodCreator method = creator.getMethodCreator("apply", Object.class, Object.class); + AssignableResultHandle ret = method.createVariable(String.class); + // String ret; + // switch(arg) { + // case "Aa": + // ret = "aa"; + // break; + // case "BB": + // ret = "bb" + // break; + // default: + // ret = null; + // } + // return ret; + StringSwitch s = method.stringSwitch(method.getMethodParam(0)); + s.fallThrough(); + s.caseOf("Aa", bc -> { + bc.assign(ret, bc.load("aa")); + s.doBreak(bc); + }); + s.caseOf("BB", bc -> { + bc.assign(ret, bc.load("bb")); + s.doBreak(bc); + }); + s.defaultCase(bc -> bc.assign(ret, bc.loadNull())); + + method.returnValue(ret); + } + Function myInterface = (Function) cl.loadClass("com.MyTest").newInstance(); + assertEquals("aa", myInterface.apply("Aa")); + assertEquals("bb", myInterface.apply("BB")); + assertNull(myInterface.apply("unknown")); + } + + @SuppressWarnings("unchecked") + @Test + public void testEmptyStringSwitch() throws InstantiationException, IllegalAccessException, ClassNotFoundException { + TestClassLoader cl = new TestClassLoader(getClass().getClassLoader()); + try (ClassCreator creator = ClassCreator.builder().classOutput(cl).className("com.MyTest").interfaces(Function.class) + .build()) { + MethodCreator method = creator.getMethodCreator("apply", Object.class, Object.class); + AssignableResultHandle ret = method.createVariable(String.class); + method.assign(ret, method.loadNull()); + method.stringSwitch(method.getMethodParam(0)); + method.returnValue(ret); + } + Function myInterface = (Function) cl.loadClass("com.MyTest").newInstance(); + assertNull(myInterface.apply("foo")); + } + + @Test + public void testStringSwitchDuplicateCase() throws InstantiationException, IllegalAccessException, ClassNotFoundException { + TestClassLoader cl = new TestClassLoader(getClass().getClassLoader()); + try (ClassCreator creator = ClassCreator.builder().classOutput(cl).className("com.MyTest").interfaces(Function.class) + .build()) { + MethodCreator method = creator.getMethodCreator("apply", Object.class, Object.class); + AssignableResultHandle ret = method.createVariable(String.class); + method.assign(ret, method.loadNull()); + StringSwitch s = method.stringSwitch(method.getMethodParam(0)); + try { + s.caseOf("foo", bc -> { + }); + s.caseOf("foo", bc -> { + }); + fail(); + } catch (IllegalArgumentException expected) { + } + try { + s.caseOf(List.of("foo"), bc -> { + }); + fail(); + } catch (IllegalArgumentException expected) { + } + try { + s.caseOf(List.of("bar", "baz", "bar"), bc -> { + }); + fail(); + } catch (IllegalArgumentException expected) { + } + method.returnValue(ret); + } + } + + @SuppressWarnings("unchecked") + @Test + public void testStringSwitchReturn() throws InstantiationException, IllegalAccessException, ClassNotFoundException { + TestClassLoader cl = new TestClassLoader(getClass().getClassLoader()); + try (ClassCreator creator = ClassCreator.builder().classOutput(cl).className("com.MyTest").interfaces(Function.class) + .build()) { + MethodCreator method = creator.getMethodCreator("apply", Object.class, Object.class); + // switch(arg) { + // case "boom": + // case "foo": + // return "fooo"; + // case "bar": + // return "barr" + // default: + // return null; + // } + StringSwitch s = method.stringSwitch(method.getMethodParam(0)); + s.fallThrough(); + s.caseOf(List.of("boom", "foo"), bc -> { + bc.returnValue(bc.load("fooo")); + }); + s.caseOf("bar", bc -> { + bc.returnValue(bc.load("barr")); + }); + s.defaultCase(bc -> bc.returnNull()); + } + Function myInterface = (Function) cl.loadClass("com.MyTest").newInstance(); + assertEquals("fooo", myInterface.apply("foo")); + assertEquals("barr", myInterface.apply("bar")); + assertNull(myInterface.apply("unknown")); + } + + @SuppressWarnings("unchecked") + @Test + public void testEnumSwitch() throws InstantiationException, IllegalAccessException, ClassNotFoundException { + TestClassLoader cl = new TestClassLoader(getClass().getClassLoader()); + try (ClassCreator creator = ClassCreator.builder().classOutput(cl).className("com.MyTest").interfaces(Function.class) + .build()) { + MethodCreator method = creator.getMethodCreator("apply", Object.class, Object.class); + // switch(status) { + // case ON, OFF -> return status.toString(); + // case UNKNOWN -> return "?"; + // default: -> return null; + // } + EnumSwitch s = method.enumSwitch(method.getMethodParam(0), Status.class); + s.caseOf(List.of(Status.ON, Status.OFF), bc -> { + bc.returnValue(Gizmo.toString(bc, method.getMethodParam(0))); + }); + s.caseOf(Status.UNKNOWN, bc -> { + bc.returnValue(bc.load("?")); + }); + s.defaultCase(bc -> bc.returnNull()); + } + Function myInterface = (Function) cl.loadClass("com.MyTest").newInstance(); + assertEquals("ON", myInterface.apply(Status.ON)); + assertEquals("OFF", myInterface.apply(Status.OFF)); + assertEquals("?", myInterface.apply(Status.UNKNOWN)); + } + + @SuppressWarnings("unchecked") + @Test + public void testEnumSwitchFallThrough() throws InstantiationException, IllegalAccessException, ClassNotFoundException { + TestClassLoader cl = new TestClassLoader(getClass().getClassLoader()); + try (ClassCreator creator = ClassCreator.builder().classOutput(cl).className("com.MyTest").interfaces(Function.class) + .build()) { + MethodCreator method = creator.getMethodCreator("apply", Object.class, Object.class); + AssignableResultHandle ret = method.createVariable(String.class); + // String ret; + // switch(status) { + // case ON: + // ret = "on"; + // case OFF: + // ret = "off"; + // default: + // ret = "??"; + // } + EnumSwitch s = method.enumSwitch(method.getMethodParam(0), Status.class); + s.fallThrough(); + s.caseOf(Status.ON, bc -> bc.assign(ret, bc.load("on"))); + s.caseOf(Status.OFF, bc -> bc.assign(ret, bc.load("off"))); + s.defaultCase(bc -> bc.assign(ret, bc.load("??"))); + method.returnValue(ret); + } + Function myInterface = (Function) cl.loadClass("com.MyTest").newInstance(); + assertEquals("??", myInterface.apply(Status.ON)); + assertEquals("??", myInterface.apply(Status.OFF)); + assertEquals("??", myInterface.apply(Status.UNKNOWN)); + } + + @SuppressWarnings("unchecked") + @Test + public void testEnumSwitchMissingConstant() throws InstantiationException, IllegalAccessException, ClassNotFoundException { + TestClassLoader cl = new TestClassLoader(getClass().getClassLoader()); + try (ClassCreator creator = ClassCreator.builder().classOutput(cl).className("com.MyTest").interfaces(Function.class) + .build()) { + MethodCreator method = creator.getMethodCreator("apply", Object.class, Object.class); + // switch(status) { + // case OFF: + // return status.toString(); + // case UNKNOWN: + // return "?"; + // default: + // return null; + // } + EnumSwitch s = method.enumSwitch(method.getMethodParam(0), Status.class); + s.caseOf(Status.OFF, bc -> { + bc.returnValue(bc.load("offf")); + }); + s.caseOf(Status.UNKNOWN, bc -> { + bc.returnValue(bc.load("?")); + }); + s.defaultCase(bc -> bc.returnNull()); + } + Function myInterface = (Function) cl.loadClass("com.MyTest").newInstance(); + assertEquals("offf", myInterface.apply(Status.OFF)); + assertEquals("?", myInterface.apply(Status.UNKNOWN)); + assertNull(myInterface.apply(Status.ON)); + } + + @SuppressWarnings("unchecked") + @Test + public void testEmptyEnumSwitch() throws InstantiationException, IllegalAccessException, ClassNotFoundException { + TestClassLoader cl = new TestClassLoader(getClass().getClassLoader()); + try (ClassCreator creator = ClassCreator.builder().classOutput(cl).className("com.MyTest").interfaces(Function.class) + .build()) { + MethodCreator method = creator.getMethodCreator("apply", Object.class, Object.class); + method.enumSwitch(method.getMethodParam(0), Status.class); + method.returnNull(); + } + Function myInterface = (Function) cl.loadClass("com.MyTest").newInstance(); + assertNull(myInterface.apply(Status.ON)); + } + + @Test + public void testEnumSwitchDuplicateCase() throws InstantiationException, IllegalAccessException, ClassNotFoundException { + TestClassLoader cl = new TestClassLoader(getClass().getClassLoader()); + try (ClassCreator creator = ClassCreator.builder().classOutput(cl).className("com.MyTest").interfaces(Function.class) + .build()) { + MethodCreator method = creator.getMethodCreator("apply", Object.class, Object.class); + EnumSwitch s = method.enumSwitch(method.getMethodParam(0), Status.class); + try { + s.caseOf(Status.ON, bc -> { + }); + s.caseOf(Status.ON, bc -> { + }); + fail(); + } catch (IllegalArgumentException expected) { + } + try { + s.caseOf(List.of(Status.ON), bc -> { + }); + fail(); + } catch (IllegalArgumentException expected) { + } + try { + s.caseOf(List.of(Status.ON, Status.OFF, Status.ON), bc -> { + }); + fail(); + } catch (IllegalArgumentException expected) { + } + method.returnNull(); + } + } + + public enum Status { + ON, + OFF, + UNKNOWN + } + +}