In [None]:
import SExprLibrary._

sealed abstract class Expr {
  override def toString() : String = Printer.print(this)
}
case class Var(id:String) extends Expr
case class Num(i:Int) extends Expr
case class Assgn(id:String,e:Expr) extends Expr
case class While(c:Expr,e:Expr) extends Expr
case class If(c:Expr,t:Expr,e:Expr) extends Expr
case class Write(e:Expr) extends Expr
case class Block(es:List[Expr]) extends Expr
case class Add(l:Expr,r:Expr) extends Expr
case class Sub(l:Expr,r:Expr) extends Expr
case class Mul(l:Expr,r:Expr) extends Expr
case class Div(l:Expr,r:Expr) extends Expr
case class Rem(l:Expr,r:Expr) extends Expr
case class Le(l:Expr,r:Expr) extends Expr
case class For(x:String, e1:Expr, e2:Expr, e3:Expr) extends Expr


case class ParseException(string: String) extends RuntimeException

object Parser {
  def parse(str:String,debug:Int = 0): Expr = {
    try {
      val a = parseE(SExprReader.read(str))
      if (debug > 0) println("Parsed expression: " + a) 
      a
    } catch {
      case ex:ReadException => throw ParseException(ex.string)
    }
  }
  
  def parseE(sexpr: SExpr) : Expr = sexpr match {
    case SNum(n) => Num(n)
    case SSym(id) => Var(id)
    case SList(SSym(":=") :: SSym(id) :: e :: Nil) => Assgn(id,parseE(e))
    case SList(SSym("while") :: c :: e :: Nil) => While(parseE(c),parseE(e))
    case SList(SSym("if") :: c :: t :: e :: Nil) => If(parseE(c),parseE(t),parseE(e))
    case SList(SSym("write") :: e :: Nil) => Write(parseE(e))
    case SList(SSym("block") :: es) => Block(parseEs(es))
    case SList(SSym("+") :: l :: r :: Nil) => Add(parseE(l),parseE(r))
    case SList(SSym("-") :: l :: r :: Nil) => Sub(parseE(l),parseE(r))
    case SList(SSym("*") :: l :: r :: Nil) => Mul(parseE(l),parseE(r))
    case SList(SSym("/") :: l :: r :: Nil) => Div(parseE(l),parseE(r))
    case SList(SSym("%") :: l :: r :: Nil) => Rem(parseE(l),parseE(r))
    case SList(SSym("<=") :: l :: r :: Nil) => Le(parseE(l),parseE(r))
    case SList(SSym("for") :: SSym(x) :: e1 :: e2 :: e3 :: Nil) => For(x, parseE(e1), parseE(e2), parseE(e3))
  case _ => throw ParseException("Cannot parse expression:" + sexpr)
  }
  
  // Note: Later on, we'll see that this would be easier to write using a `map` expression
  def parseEs(sexprs : List[SExpr]) : List[Expr] = sexprs match {
    case Nil => Nil
    case (e :: es) => parseE(e) :: parseEs(es)
  }

}

object Printer {
  def print(expr: Expr) : String = unparse(expr).toString()

  def unparse(expr: Expr) : SExpr = expr match {
    case Num(n) => SNum(n)
    case Var(x) => SSym(x)
    case Assgn(x,e) => SList(SSym(":=") :: SSym(x) :: unparse(e) :: Nil)
    case While(c,e) => SList(SSym("while") :: unparse(c) :: unparse(e) :: Nil)
    case If(c,t,e) => SList(SSym("if") :: unparse(c) :: unparse(t) :: unparse(e) :: Nil)
    case Write(e) => SList(SSym("write") :: unparse(e) :: Nil)
    case Block(es) => SList(SSym("block") :: unparseEs(es))
    case Add(l,r) => SList(SSym("+") :: unparse(l) :: unparse(r) :: Nil)
    case Sub(l,r) => SList(SSym("-") :: unparse(l) :: unparse(r) :: Nil)
    case Mul(l,r) => SList(SSym("*") :: unparse(l) :: unparse(r) :: Nil)
    case Div(l,r) => SList(SSym("/") :: unparse(l) :: unparse(r) :: Nil)
    case Rem(l,r) => SList(SSym("%") :: unparse(l) :: unparse(r) :: Nil)
    case Le(l,r) => SList(SSym("<=") :: unparse(l) :: unparse(r) :: Nil)
    case For(x, e1, e2, e3) => SList(SSym("for") :: SSym(x) :: unparse(e1) :: unparse(e2) :: unparse(e3) :: Nil)
  }
  
  def unparseEs(exprs: List[Expr]) : List[SExpr] = exprs match {
    case Nil => Nil
    case (e::es) => unparse(e)::unparseEs(es)
  }
  
}    

case class InterpException(string: String) extends RuntimeException

// Stack Machine
object Machine {

  sealed abstract class Instr
  case class Const(n:Int) extends Instr
  case object Plus extends Instr
  case object Times extends Instr
  case object Divrem extends Instr
  case object Lessequ extends Instr
  case object Pop extends Instr
  case object Dup extends Instr
  case object Swap extends Instr
  case class Load(x:String) extends Instr
  case class Store(x:String) extends Instr
  case object Print extends Instr
  case class Label(l:Int) extends Instr
  case class Branch(l:Int) extends Instr
  case class Branchz(l:Int) extends Instr

  type Program = List[Instr]

  type Stack = collection.mutable.Stack[Int]

  // see http://docs.scala-lang.org/overviews/collections/maps.html for details of Map class
  type VarStore = collection.mutable.Map[String,Int]

  def exec(prog:Program,debug: Int = 0) : Int = {
    val stk : Stack = collection.mutable.Stack[Int]()
    val store : VarStore = collection.mutable.Map[String,Int]()  
    var pc = 0

    def step () : Int = 
      prog(pc) match {
        case Const(i) => {
          stk.push(i)
          pc+1
        }
        case Plus => {
          val v2 = stk.pop()
          val v1 = stk.pop()
          stk.push(v1 + v2)
          pc+1
        }
        case Times => {
          val v2 = stk.pop()
          val v1 = stk.pop()
          stk.push(v1 * v2)
          pc+1
        }
        case Divrem => {
          val v2 = stk.pop()
          val v1 = stk.pop()
          if (v2 == 0) 
            throw InterpException("division by zero")
          else {
            stk.push(v1/v2)
            stk.push(v1 % v2)
          }
          pc+1
        }
        case Lessequ => {
          val v2 = stk.pop()
          val v1 = stk.pop()
          stk.push(if (v1 <=v2) 1 else 0)
          pc+1
        }
        case Pop => {
          stk.pop()
          pc+1
        }
        case Dup => {
          val v = stk.pop()
          stk.push(v)
          stk.push(v)
          pc+1
        }
        case Swap => {
          val v2 = stk.pop()
          val v1 = stk.pop()
          stk.push(v2)
          stk.push(v1)
          pc+1
        }
        case Load(x) => {
          store get x match {
            case Some(v) => stk.push(v)
            case None => stk.push(0)
          }
          pc+1
        }
        case Store(x) => {
          val v = stk.pop()
          store(x) = v
          pc+1
        }
        case Print => {
          val v = stk.pop()
          println(v)
          pc+1
        }
        case Label(l) =>
          pc+1
        case Branch(l) =>
          findLabel(prog,l)
        case Branchz(l) => {
          val v = stk.pop()
          if (v == 0)
            findLabel(prog,l)
          else
            pc+1
        }
      }
    
    def findLabel(prog:Program,l:Int) = {
      def f(n:Int,prog:Program) : Int = prog match {
        case Nil => throw InterpException("missing label " + l)
        case Label(l1)::rest if l == l1 => n
        case _::rest => f(n+1,rest)
      }
      f(0,prog)
    }

    if (debug > 0) println("Machine code:" + prog)

    while (pc < prog.length) {
      if (debug > 1) print("" + pc + "*" + prog(pc))
      pc = step()
      if (debug > 1) println (":" + stk.mkString(" "))
    }
    val r = stk.pop()
    
    if (debug > 0) println("Result:" + r)
    r
  }        
}

object Compile {
  import Machine._
  var nextLabel: Int = 0
  def comp (e:Expr) : Machine.Program = e match {
    case Var(x) => Load(x)::Nil
    case Num(i) => Const(i)::Nil
    case Assgn(x,e) => comp(e) ::: Dup::Store(x)::Nil
    case While(c,b) => {
      val topLab = newLabel()
      val bottomLab = newLabel()
      Label(topLab) :: comp(c) :::
      Branchz(bottomLab) :: comp(b) :::  Pop ::  // throw away value of body
      Branch(topLab) :: Label(bottomLab) :: Const(0) :: Nil // overall expression evaluates to 0
    }
    case If(c,t,f) => {
      val falseLab = newLabel()
      val joinLab = newLabel()
      comp(c) ::: Branchz(falseLab) :: comp(t) :::
      Branch(joinLab) :: Label(falseLab) :: comp(f) :::
      Label(joinLab) :: Nil
    }
    case Write(e) => comp(e) ::: Dup :: Print :: Nil
    case Block(es) => {
      def c(es:List[Expr]) : List[Instr] = es match {
        case Nil => Const(0)::Nil
        case e::Nil => comp(e)
        case e::es => comp(e) ::: Pop :: c(es)
      }
      c(es)
    }
    case Add(e1,e2) => comp(e2) ::: comp(e1) ::: Plus::Nil
    case Sub(e1,e2) => comp(e2) ::: Const(-1) :: Times :: comp(e1) ::: Plus::Nil
    case Mul(e1,e2) => comp(e2) ::: comp(e1) ::: Times::Nil
    case Div(e1,e2) => comp(e2) ::: comp(e1) ::: Swap::Divrem::Pop::Nil
    case Rem(e1,e2) => comp(e2) ::: comp(e1) ::: Swap::Divrem::Swap::Pop::Nil
    case Le(e1,e2)  => comp(e2) ::: comp(e1) ::: Swap::Lessequ::Nil
    case For(x, e1, e2, e3) => {
      comp(e3) ::: comp(e2) ::: comp(e1) ::: store(x) ::: Swap:::Lessequ
    }
  }
  def newLabel() = {
    val next = nextLabel
    nextLabel = nextLabel + 1
    next
  }
  def compile(e:Expr) = {
    nextLabel = 0
    comp(e)
  }
}


object Process {
  def process (s:String,debug:Int = 0) : Int = {
    try {
      val e : Expr = Parser.parse(s,debug)
      val p : Machine.Program = Compile.compile(e)
      Machine.exec(p,debug)
    } catch {
      case ex: InterpException => { println("Interp Error:" + ex.string) ; throw ex }
      case ex: ParseException => { println("Parser Error:" + ex.string) ; throw ex }
    }
  }
}


// The following code may be useful for stand-alone development and
// testing from the command line. (It is not useful when developing
// or testing within WebLab.)
object Imperative {
  import scala.io.Source
  def main (argv: Array[String]) = {
    val s = Source.fromFile(argv(0)).getLines.mkString("\n")
    val d = if (argv.length > 1) argv(1).toInt else 0
    Process.process(s,d)
    ()
  }
}
//
