# Chapter 4: Toy Optimisations

As we saw in the previous chapter, the IR generated from the input program has many
opportunities for optimisation. In this chapter, we'll implement three optimisations:

1. Removing redundant reshapes
2. Reshaping constants during compilation time
3. Eliminating operations whose results are not used

Let's take a look again at our example input:

In [1]:
from xdsl.printer import Printer
from toy.compiler import parse_toy

example = """
def main() {
  var a<2, 3> = [[1, 2, 3], [4, 5, 6]];
  var b<6> = [1, 2, 3, 4, 5, 6];
  var c<2, 3> = b;
  var d = a + c;
  print(d);
}
"""

module = parse_toy(example)
Printer().print_op(module)
print()

builtin.module {
  "toy.func"() ({
    %0 = "toy.constant"() {"value" = dense<[[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>} : () -> tensor<2x3xf64>
    %1 = "toy.reshape"(%0) : (tensor<2x3xf64>) -> tensor<2x3xf64>
    %2 = "toy.constant"() {"value" = dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00, 6.000000e+00]> : tensor<6xf64>} : () -> tensor<6xf64>
    %3 = "toy.reshape"(%2) : (tensor<6xf64>) -> tensor<6xf64>
    %4 = "toy.reshape"(%3) : (tensor<6xf64>) -> tensor<2x3xf64>
    %5 = "toy.add"(%1, %4) : (tensor<2x3xf64>, tensor<2x3xf64>) -> tensor<2x3xf64>
    "toy.print"(%5) : (tensor<2x3xf64>) -> ()
    "toy.return"() : () -> ()
  }) {"sym_name" = "main", "function_type" = () -> ()} : () -> ()
}


## Redundant Reshapes

In [2]:
from typing import cast
from xdsl.ir import OpResult
from xdsl.pattern_rewriter import (
    op_type_rewrite_pattern,
    RewritePattern,
    PatternRewriter,
    PatternRewriteWalker,
)


from toy.dialects import toy


class ReshapeReshapeOptPattern(RewritePattern):
    @op_type_rewrite_pattern
    def match_and_rewrite(self, op: toy.ReshapeOp, rewriter: PatternRewriter):
        """
        Reshape(Reshape(x)) = Reshape(x)
        """
        # Look at the input of the current reshape.
        reshape_input = op.arg
        if not isinstance(reshape_input, OpResult):
            # Input was not produced by an operation, could be a function argument
            return

        reshape_input_op = reshape_input.op
        if not isinstance(reshape_input_op, toy.ReshapeOp):
            # Input defined by another transpose? If not, no match.
            return

        t = cast(toy.TensorTypeF64, op.res.type)
        new_op = toy.ReshapeOp.from_input_and_type(reshape_input_op.arg, t)
        rewriter.replace_matched_op(new_op)


# Use `PatternRewriteWalker` to rewrite all matched operations
PatternRewriteWalker(ReshapeReshapeOptPattern()).rewrite_module(module)
Printer().print_op(module)

builtin.module {
  "toy.func"() ({
    %0 = "toy.constant"() {"value" = dense<[[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>} : () -> tensor<2x3xf64>
    %1 = "toy.reshape"(%0) : (tensor<2x3xf64>) -> tensor<2x3xf64>
    %2 = "toy.constant"() {"value" = dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00, 6.000000e+00]> : tensor<6xf64>} : () -> tensor<6xf64>
    %3 = "toy.reshape"(%2) : (tensor<6xf64>) -> tensor<6xf64>
    %4 = "toy.reshape"(%2) : (tensor<6xf64>) -> tensor<2x3xf64>
    %5 = "toy.add"(%1, %4) : (tensor<2x3xf64>, tensor<2x3xf64>) -> tensor<2x3xf64>
    "toy.print"(%5) : (tensor<2x3xf64>) -> ()
    "toy.return"() : () -> ()
  }) {"sym_name" = "main", "function_type" = () -> ()} : () -> ()
}

This looks very similar to what we had before, but is subtly different. Importantly,
the reshape that assigns to %4 now takes %2 as input, instead of %3. %3 is now no longer
used, and because it's an operation with no observable side-effects, we can avoid doing
the work altogether.

In [3]:
from xdsl.transforms.dead_code_elimination import RemoveUnusedOperations

PatternRewriteWalker(RemoveUnusedOperations()).rewrite_module(module)

Printer().print_op(module)

builtin.module {
  "toy.func"() ({
    %0 = "toy.constant"() {"value" = dense<[[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>} : () -> tensor<2x3xf64>
    %1 = "toy.reshape"(%0) : (tensor<2x3xf64>) -> tensor<2x3xf64>
    %2 = "toy.constant"() {"value" = dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00, 6.000000e+00]> : tensor<6xf64>} : () -> tensor<6xf64>
    %3 = "toy.reshape"(%2) : (tensor<6xf64>) -> tensor<2x3xf64>
    %4 = "toy.add"(%1, %3) : (tensor<2x3xf64>, tensor<2x3xf64>) -> tensor<2x3xf64>
    "toy.print"(%4) : (tensor<2x3xf64>) -> ()
    "toy.return"() : () -> ()
  }) {"sym_name" = "main", "function_type" = () -> ()} : () -> ()
}

## Fold Constant Reshaping

One more opportunity for optimisation is to reshape the constants at compile-time,
instead of at runtime. We can do this with another custom `RewritePattern`:

In [4]:
from xdsl.dialects.builtin import AnyFloatAttr, ArrayAttr, DenseIntOrFPElementsAttr
from xdsl.utils.hints import isa


class FoldConstantReshapeOptPattern(RewritePattern):
    @op_type_rewrite_pattern
    def match_and_rewrite(self, op: toy.ReshapeOp, rewriter: PatternRewriter):
        """
        Reshaping a constant can be done at compile time
        """
        # Look at the input of the current reshape.
        reshape_input = op.arg
        if not isinstance(reshape_input, OpResult):
            # Input was not produced by an operation, could be a function argument
            return

        reshape_input_op = reshape_input.op
        if not isinstance(reshape_input_op, toy.ConstantOp):
            # Input defined by another transpose? If not, no match.
            return

        assert isa(op.res.type, toy.TensorTypeF64)
        assert isa(reshape_input_op.value.data, ArrayAttr[AnyFloatAttr])

        new_value = DenseIntOrFPElementsAttr.from_list(
            type=op.res.type, data=reshape_input_op.value.data.data
        )
        new_op = toy.ConstantOp(new_value)
        rewriter.replace_matched_op(new_op)


PatternRewriteWalker(FoldConstantReshapeOptPattern()).rewrite_module(module)
Printer().print_op(module)

builtin.module {
  "toy.func"() ({
    %0 = "toy.constant"() {"value" = dense<[[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>} : () -> tensor<2x3xf64>
    %1 = "toy.constant"() {"value" = dense<[[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>} : () -> tensor<2x3xf64>
    %2 = "toy.constant"() {"value" = dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00, 6.000000e+00]> : tensor<6xf64>} : () -> tensor<6xf64>
    %3 = "toy.constant"() {"value" = dense<[[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>} : () -> tensor<2x3xf64>
    %4 = "toy.add"(%1, %3) : (tensor<2x3xf64>, tensor<2x3xf64>) -> tensor<2x3xf64>
    "toy.print"(%4) : (tensor<2x3xf64>) -> ()
    "toy.return"() : () -> ()
  }) {"sym_name" = "main", "function_type" = () -> ()} : () -> ()
}

In [5]:
# Remove now unused original constants
PatternRewriteWalker(RemoveUnusedOperations()).rewrite_module(module)
Printer().print_op(module)

builtin.module {
  "toy.func"() ({
    %0 = "toy.constant"() {"value" = dense<[[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>} : () -> tensor<2x3xf64>
    %1 = "toy.constant"() {"value" = dense<[[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>} : () -> tensor<2x3xf64>
    %2 = "toy.add"(%0, %1) : (tensor<2x3xf64>, tensor<2x3xf64>) -> tensor<2x3xf64>
    "toy.print"(%2) : (tensor<2x3xf64>) -> ()
    "toy.return"() : () -> ()
  }) {"sym_name" = "main", "function_type" = () -> ()} : () -> ()
}

Now that we've done all the optimisations we could on this level of abstraction, let's
go one level lower towards RISC-V.