# Toy

This is an xDSL version of the Toy compiler, as described in the [MLIR tutorial](https://mlir.llvm.org/docs/Tutorials/Toy/)

In [4]:
import sys
# To import local version of xdsl
sys.path.insert(0, '../src')

## Chapter 1: Toy Language and AST

This tutorial will be illustrated with a toy language that we’ll call “Toy” (naming is hard…). Toy is a tensor-based language that allows you to define functions, perform some math computation, and print results.

In [5]:
example_0 = '''
# User defined generic function that operates on unknown shaped arguments.
def multiply_transpose(a, b) {
  return transpose(a) * transpose(b);
}

def main() {
  # Define a variable `a` with shape <2, 3>, initialized with the literal value.
  var a = [[1, 2, 3], [4, 5, 6]];
  var b<2, 3> = [1, 2, 3, 4, 5, 6];

  # This call will specialize `multiply_transpose` with <2, 3> for both
  # arguments and deduce a return type of <3, 2> in initialization of `c`.
  var c = multiply_transpose(a, b);

  # A second call to `multiply_transpose` with <2, 3> for both arguments will
  # reuse the previously specialized and inferred version and return <3, 2>.
  var d = multiply_transpose(b, a);

  # A new call with <3, 2> (instead of <2, 3>) for both dimensions will
  # trigger another specialization of `multiply_transpose`.
  var e = multiply_transpose(b, c);

  # Finally, calling into `multiply_transpose` with incompatible shape will
  # trigger a shape inference error.
  var f = multiply_transpose(transpose(a), c);
}
'''

In [6]:
from toy.Tokenizer import tokenize, Token

[f'{t.line},{t.col}: {t.text} {t.__class__.__name__}' for t in tokenize('example.toy', example_0)]
''

''

In [7]:
from toy.AST import *

In [8]:
print(VarDeclExprAST(None, 'a', VarType(shape=[1,2]), NumberExprAST(None, 1)).dump())

VarDeclExprAST
  name: a
  type: VarType(shape=[1, 2])
  expr: NumberExprAST
    val: 1


In [9]:
from toy.Parser import Parser


parser = Parser('example_0', example_0)

print(parser.parseModule().dump())

ModuleAST
  functions: [
    FunctionAST
      proto: PrototypeAST
        name: multiply_transpose
        args: [
          VariableExprAST
            name: a
          VariableExprAST
            name: b
        ]
      body: [
        ReturnExprAST
          expr: BinaryExprAST
            op: *
            lhs: CallExprAST
              callee: transpose
              args: [
                VariableExprAST
                  name: a
              ]
            rhs: CallExprAST
              callee: transpose
              args: [
                VariableExprAST
                  name: b
              ]
      ]
    FunctionAST
      proto: PrototypeAST
        name: main
        args: [
        ]
      body: [
        VarDeclExprAST
          name: a
          type: VarType(shape=[])
          expr: LiteralExprAST
            values: [
              LiteralExprAST
                values: [
                  NumberExprAST
                    val: 1.0
                  NumberExprAST

In [10]:
isinstance(1, int)

True

Unlike the MLIR tutorial, we'll leverage one of the parsers in the rich Python ecosystem to parse the syntax of the language for us, as opposed to writing a custom parser.

In [8]:
# To define the grammar
from arpeggio import Optional, ZeroOrMore, OneOrMore, EOF
from arpeggio import RegExMatch as _
from arpeggio import StrMatch

# To define the parser
from arpeggio import ParserPython

# To define the AST traversal
from arpeggio import PTNodeVisitor

In [9]:
def return_stmt(): return ['return', ZeroOrMore(expr)]
def number():
    return _(r'\d*\.\d*|\d+')
def tensor_literal():
    return ['[', literal_list, ']'], number
def literal_list():
    return tensor_literal, [tensor_literal, ',', literal_list]
def paren_expr():
    return ['(', expr, ')']
def print_expr():
    return ['print', '(', expr, OneOrMore(expr), ')']
def identifier_expr():
    return [identifier, Optional(['(', expr, ')'])]
def primary():
    return print_expr, identifier_expr, number, paren_expr, tensor_literal
def op():
    return '+', '-', '*'
def bin_op_rhs():
    return OneOrMore([op, expr])
def expr():       
    return [primary, Optional(bin_op_rhs)]
def _type():
    return ['<', number, OneOrMore([',', number]), '>']
def decl():
    return ['var', identifier, Optional(_type), '=', expr]

def block():
    return ['{', ZeroOrMore(expr, ';'), '}']
def block_expr():
    return decl, 'return', expr
def prototype():
    return ['def', identifier, '(', ZeroOrMore(identifier, ','), ')']
def function():
    return [prototype, block]

def number_expression():
    pass
def identifier():
    return _(r'[\w]+')

def module():
    return [ZeroOrMore(function), EOF]

# def factor():     return Optional(["+","-"]), [number, ("(", expression, ")")]
# def term():       return factor, ZeroOrMore(["*","/"], factor)
# def expression(): return term, ZeroOrMore(["+", "-"], term)
# def calc():       return OneOrMore(expression), EOF

In [10]:
parser = ParserPython(module, debug=True)

New rule: identifier -> RegExMatch
Rule identifier founded in cache.
New rule: prototype -> OrderedChoice
Rule expr founded in cache.
CrossRef usage: expr
Rule expr founded in cache.
CrossRef usage: expr
New rule: print_expr -> OrderedChoice
Rule identifier founded in cache.
Rule expr founded in cache.
CrossRef usage: expr
New rule: identifier_expr -> OrderedChoice
New rule: number -> RegExMatch
Rule expr founded in cache.
CrossRef usage: expr
New rule: paren_expr -> OrderedChoice
Rule tensor_literal founded in cache.
CrossRef usage: tensor_literal
Rule tensor_literal founded in cache.
CrossRef usage: tensor_literal
Rule literal_list founded in cache.
CrossRef usage: literal_list
New rule: literal_list -> Sequence
Rule number founded in cache.
New rule: tensor_literal -> Sequence
New rule: primary -> Sequence
New rule: op -> Sequence
Rule expr founded in cache.
CrossRef usage: expr
New rule: bin_op_rhs -> OneOrMore
New rule: expr -> OrderedChoice
New rule: block -> OrderedChoice
New ru

In [288]:
bla = parser.parse(example_0)

AttributeError: 'Parser' object has no attribute 'parse'

In [39]:
bla

[  'return' [0], expression 'hello' [7] ]

In [4]:

class ToyVisitor(PTNodeVisitor):
    def visit_return(self, node, children):
        