Skip to content

Commit

Permalink
Some more type checking.
Browse files Browse the repository at this point in the history
  • Loading branch information
zzorn committed Mar 17, 2012
1 parent 4d21210 commit ac74a58
Show file tree
Hide file tree
Showing 17 changed files with 141 additions and 63 deletions.
6 changes: 3 additions & 3 deletions common/assets/testpackage/skycastle/Tree.funlang
Expand Up @@ -4,17 +4,17 @@ module Tree {
import skycastle.utils.Branch
import skycastle.utils.MathUtils.*

fun tree(height = 1): Model = {
fun tree(height = 1): List[Any] = {
module TreeType {
val Birch = 1
val Fir = 2
}

fun leafCalculator(x: Num): Leaf = foo(45*x)
fun leafCalculator(x: Num) = foo(45*x)
val rootHeight = height / 2
val funnyFunc = (x: Any = 3): Any => x^x
val referenceToFunnyFunc = funnyFunc
fun testRef(): Num = skycastle.utils.MathUtils.foo(rootHeight)
fun testRef() = skycastle.utils.MathUtils.foo(rootHeight)
return [
many(referenceToFunnyFunc)
many(x: Any => Branch.branch())
Expand Down
2 changes: 1 addition & 1 deletion common/assets/testpackage/skycastle/utils/Branch.funlang
@@ -1,6 +1,6 @@
module Branch {

fun branch(): Model = {
fun branch(): Num = {
return 3 //Tube()
}
}
9 changes: 5 additions & 4 deletions common/assets/testpackage/skycastle/utils/MathUtils.funlang
Expand Up @@ -2,10 +2,11 @@
module MathUtils {

fun lerp(a, b, t) = a * (1 - t) + b * t
fun lerpVec2(a: Vec2, b: Vec2, t): Vec2 = a * (1 - t) + b * t
fun lerpVec3(a: Vec3, b: Vec3, t): Vec3 = a * (1 - t) + b * t
fun lerpVec2(a: Vec2, b: Vec2, t) = a * (1 - t) + b * t
fun lerpVec3(a: Vec3, b: Vec3, t) = a * (1 - t) + b * t

fun foo(p: Any) = false
fun foo(p: Any): Bool = false
fun boolFunc(): Bool = true

val x = y
val y = 3
Expand All @@ -14,7 +15,7 @@ module MathUtils {
val z = x
}

fun testFun(p) = 3 + 4
fun testFun(p = 3): Any = 3 + 4
val testVal = 2

fun many(f: Any -> Any) {
Expand Down
Expand Up @@ -12,6 +12,7 @@ class BeanFactory {
def typeFor(name: String): TypeDef = {
name match {
case NumType.name => NumType
case BoolType.name => BoolType
case AnyType.name => AnyType
case NothingType.name => NothingType
case _ => SimpleType(Symbol(name), null)
Expand Down
48 changes: 40 additions & 8 deletions common/src/main/scala/org/skycastle/parser/ModuleLoader.scala
@@ -1,10 +1,10 @@
package org.skycastle.parser

import model._
import expressions.{FunExpr, Expr}
import model.defs.{ValDef, Parameter, FunDef, Def}
import model.expressions.Expr
import model.module.{Import, Module}
import model.refs.{Arg, Call, Ref}
import model.{TypeDef, Callable, FunType, SyntaxNode}
import org.skycastle.utils.StringUtils
import java.io.{FileFilter, File, FilenameFilter}
import org.skycastle.utils.LangUtils._
Expand Down Expand Up @@ -88,8 +88,9 @@ class ModuleLoader {
// Also checks for missing references
root.visitClasses(classOf[Ref]) { (context, ref) =>
context.getDef(ref.path.path) match {
case Some(definition) => ref.definition = definition
case Some(definition: ValueTyped) => ref.definition = definition
case None => addError("Could not resolve reference "+ref.path+" ", ref)
case x => addError("Cannot refer to a "+x.get.name.name+".", ref)
}
}

Expand All @@ -107,12 +108,19 @@ class ModuleLoader {


// Check for missing types and cyclic references
root.visitClasses(classOf[Expr]) { (context, exp) =>
root.visitClasses(classOf[ValueTyped]) { (context, exp) =>
if (exp.valueType == null) addError("Could not determine the type of the expression '" + exp + "'", exp)
}

// Check that function call parameter types and named parameters match the function definition they are calling
// Check for missing return types
root.visitClasses(classOf[FunDef]) { (context, exp: ReturnTyped) =>
if (exp.returnType == null) addError("Could not determine the return type of the function '" + exp + "'", exp)
}
root.visitClasses(classOf[FunExpr]) { (context, exp: ReturnTyped) =>
if (exp.returnType == null) addError("Could not determine the return type of the function expression '" + exp + "'", exp)
}

// Check that function call parameter types and named parameters match the function definition they are calling
// Check that reference is a function or func expr.
root.visitClasses(classOf[Call]) { (context: ResolverContext, call: Call) =>
call.functionDef match {
Expand Down Expand Up @@ -164,7 +172,7 @@ class ModuleLoader {
}
}

case otherDef: Def =>
case otherDef: Def with ValueTyped =>
// E.g. parameter of value with function object that has no named parameters
if (!otherDef.valueType.isInstanceOf[FunType]) {
addError("Can not call non-function type value '"+otherDef.valueType+"'", call)
Expand Down Expand Up @@ -192,12 +200,36 @@ class ModuleLoader {
}

case _ =>
addError("Can not invoke a function call on an expression ("+call.functionDef+") of type '"+call.functionDef.valueType+"' ", call)
addError("Can not invoke a function call on an expression ("+call.functionDef+") of type '"+(if (call.functionDef != null) call.functionDef.valueType else "[UnknownType]")+"' ", call)
}

}

errors

// Check that value expressions are of correct types
def checkTypes[T <: ReturnTyped](kind: Class[T], msg: String, actualCalc: T => TypeDef) {
root.visitClasses(kind) { (context, returnTyped: T) =>
if (returnTyped.declaredReturnType.isDefined) {
val expected: TypeDef = returnTyped.returnType
val actual: TypeDef = actualCalc(returnTyped)
if (!expected.isAssignableFrom(actual)) {
addError(msg + " does not correcpond to declared type, expected '"+expected+"', " +
"but got '"+actual+"'", returnTyped)
}
}
}
}

checkTypes[FunDef](classOf[FunDef], "Type of function expression", ref => ref.expression.valueType)
checkTypes[ValDef](classOf[ValDef], "Type of val expression", ref => ref.value.valueType)
checkTypes[FunExpr](classOf[FunExpr], "Type of function expression value", ref => ref.expression.valueType)
checkTypes[Parameter](classOf[Parameter], "Type of parameter default value", ref => ref.defaultValue.map(p => p.valueType).getOrElse(ref.returnType))



errors.reverse
}



}
12 changes: 6 additions & 6 deletions common/src/main/scala/org/skycastle/parser/ModuleParser.scala
Expand Up @@ -89,17 +89,17 @@ class ModuleParser(beanFactory: BeanFactory) extends LanguageParser[Module] {
private lazy val parameterWithDefault: PackratParser[Parameter] =
ident ~ ("=" ~> expression) ^^
{case name ~ defaultExp =>
new Parameter(Symbol(name), defaultExp.valueType, Some(defaultExp))}
new Parameter(Symbol(name), None, Some(defaultExp))}

private lazy val parameterWithType: PackratParser[Parameter] =
ident ~ typeTag ~ opt ("=" ~> expression) ^^
{case name ~ t ~ defaultExp =>
new Parameter(Symbol(name), t, defaultExp)}
new Parameter(Symbol(name), Some(t), defaultExp)}

private lazy val numParameter: PackratParser[Parameter] =
ident ^^
{case name =>
new Parameter(Symbol(name), NumType, None)}
new Parameter(Symbol(name), Some(NumType), None)}



Expand Down Expand Up @@ -137,8 +137,8 @@ class ModuleParser(beanFactory: BeanFactory) extends LanguageParser[Module] {
| list
| quotedString ^^ (x => StringExpr(x))
| "null" ^^ (x => NullExpr)
| "true" ^^ (x => TrueExpr)
| "false" ^^ (x => FalseExpr)
| TRUE ^^ (x => TrueExpr)
| FALSE ^^ (x => FalseExpr)
)


Expand Down Expand Up @@ -168,7 +168,7 @@ class ModuleParser(beanFactory: BeanFactory) extends LanguageParser[Module] {

private lazy val funExprParameter: PackratParser[Parameter] = parameterWithDefault | parameterWithType | funExprParameterWithoutTypeOrDefault
private lazy val funExprParameterWithoutTypeOrDefault: PackratParser[Parameter] =
ident ^^ {case name => new Parameter(Symbol(name), null, None)}
ident ^^ {case name => new Parameter(Symbol(name), None, None)}



Expand Down
Expand Up @@ -6,13 +6,18 @@ package org.skycastle.parser.model
*/
trait ReturnTyped extends ValueTyped {

def declaredReturnType: Option[TypeDef]

def returnType: TypeDef = returnType(Set(this))

def returnType(visited: Set[SyntaxNode]): TypeDef = {
val vt: TypeDef = valueType(visited)
if (vt == null) null
else if (!vt.isInstanceOf[FunType]) null
else vt.asInstanceOf[FunType].returnType
if (declaredReturnType.isDefined) declaredReturnType.get
else {
val vt: TypeDef = valueType(visited)
if (vt == null) null
else if (!vt.isInstanceOf[FunType]) null
else vt.asInstanceOf[FunType].returnType
}
}

}
48 changes: 36 additions & 12 deletions common/src/main/scala/org/skycastle/parser/model/TypeDef.scala
Expand Up @@ -25,8 +25,8 @@ case class SimpleType(typeName: Symbol, kind: Class[_]) extends TypeDef {
else if (other == this) this
else if (other.isInstanceOf[SimpleType]) {
val otherST = other.asInstanceOf[SimpleType]
if (otherST.kind.isAssignableFrom(kind)) otherST
else if (kind.isAssignableFrom(otherST.kind)) this
if (otherST.kind != null && otherST.kind.isAssignableFrom(kind)) otherST
else if (kind != null && kind.isAssignableFrom(otherST.kind)) this
else AnyType
}
else AnyType
Expand All @@ -37,8 +37,8 @@ case class SimpleType(typeName: Symbol, kind: Class[_]) extends TypeDef {
else if (other == this) this
else if (other.isInstanceOf[SimpleType]) {
val otherST = other.asInstanceOf[SimpleType]
if (otherST.kind.isAssignableFrom(kind)) this
else if (kind.isAssignableFrom(otherST.kind)) other
if (otherST.kind != null && otherST.kind.isAssignableFrom(kind)) this
else if (kind != null && kind.isAssignableFrom(otherST.kind)) other
else NothingType
}
else AnyType
Expand All @@ -58,29 +58,37 @@ case class FunType(parameterTypes: List[TypeDef], returnType: TypeDef) extends T

def mostSpecificCommonType(other: TypeDef): TypeDef = {
if (other == null) null
else if (returnType == null) null
else if (other == this) this
else if (other.isInstanceOf[FunType]) {
val otherFT = other.asInstanceOf[FunType]
if (otherFT.parameterTypes.size != parameterTypes.size) AnyType
if (otherFT.returnType == null) null
else {
val params = parameterTypes.zip(otherFT.parameterTypes).map(zipped => zipped._1.mostGeneralCommonSubType(zipped._2))
val ret = returnType.mostSpecificCommonType(otherFT.returnType)
FunType(params, ret)
if (otherFT.parameterTypes.size != parameterTypes.size) AnyType
else {
val params = parameterTypes.zip(otherFT.parameterTypes).map(zipped => zipped._1.mostGeneralCommonSubType(zipped._2))
val ret = returnType.mostSpecificCommonType(otherFT.returnType)
FunType(params, ret)
}
}
}
else AnyType
}

def mostGeneralCommonSubType(other: TypeDef): TypeDef = {
if (other == null) null
else if (returnType == null) null
else if (other == this) this
else if (other.isInstanceOf[FunType]) {
val otherFT = other.asInstanceOf[FunType]
if (otherFT.parameterTypes.size != parameterTypes.size) NothingType
if (otherFT.returnType == null) null
else {
val params = parameterTypes.zip(otherFT.parameterTypes).map(zipped => zipped._1.mostSpecificCommonType(zipped._2))
val ret = returnType.mostGeneralCommonSubType(otherFT.returnType)
FunType(params, ret)
if (otherFT.parameterTypes.size != parameterTypes.size) NothingType
else {
val params = parameterTypes.zip(otherFT.parameterTypes).map(zipped => zipped._1.mostSpecificCommonType(zipped._2))
val ret = returnType.mostGeneralCommonSubType(otherFT.returnType)
FunType(params, ret)
}
}
}
else NothingType
Expand Down Expand Up @@ -147,6 +155,22 @@ case object NumType extends SpecialType("Num") {
}


case object BoolType extends SpecialType("Bool") {

def mostSpecificCommonType(other: TypeDef): TypeDef = {
if (other == null) null
else if (other == this) this
else AnyType
}

def mostGeneralCommonSubType(other: TypeDef): TypeDef = {
if (other == null) null
else if (other == this) this
else NothingType
}
}


case object AnyType extends SpecialType("Any") {
def mostSpecificCommonType(other: TypeDef) = AnyType
def mostGeneralCommonSubType(other: TypeDef) = other
Expand Down
Expand Up @@ -6,7 +6,7 @@ import org.skycastle.parser.model._
/**
* Some kind of definitions
*/
trait Def extends ValueTyped with ReturnTyped {
trait Def extends SyntaxNode {

def name: Symbol

Expand Down
Expand Up @@ -9,8 +9,8 @@ import expressions.Expr
*/
case class FunDef(name: Symbol,
parameters: List[Parameter],
initialResultTypeDef: Option[TypeDef],
expression: Expr) extends Def with Callable {
declaredReturnType: Option[TypeDef],
expression: Expr) extends Def with ValueTyped with ReturnTyped with Callable {

private val paramsByName: Map[Symbol, Parameter] = parameters.map(p => p.name -> p).toMap

Expand All @@ -32,7 +32,7 @@ case class FunDef(name: Symbol,
expression.output(s, indent + 1)
}

override def subNodes = parameters.iterator ++ initialResultTypeDef.iterator ++ singleIt(expression)
override def subNodes = parameters.iterator ++ declaredReturnType.iterator ++ singleIt(expression)

def getMember(name: Symbol) = None

Expand All @@ -42,8 +42,8 @@ case class FunDef(name: Symbol,
def nameAndSignature = name.name + "("+parameters.mkString(", ")+"): " + (if (returnType == null) "[UnknownType]" else returnType.toString)

protected def determineValueType(visitedNodes: Set[SyntaxNode]): TypeDef = {
if (initialResultTypeDef.isDefined) {
FunType(parameters.map(p => p.valueType(visitedNodes)), initialResultTypeDef.get)
if (declaredReturnType.isDefined) {
FunType(parameters.map(p => p.valueType(visitedNodes)), declaredReturnType.get)
}
else {
val retType = expression.valueType(visitedNodes)
Expand Down
@@ -1,12 +1,12 @@
package org.skycastle.parser.model.defs

import org.skycastle.parser.model.expressions.Expr
import org.skycastle.parser.model.{Callable, SyntaxNode, Outputable, TypeDef}
import org.skycastle.parser.model._

/**
*
*/
case class Parameter(name: Symbol, initialTypeDef: TypeDef, defaultValue: Option[Expr]) extends Def {
case class Parameter(name: Symbol, declaredReturnType: Option[TypeDef], defaultValue: Option[Expr]) extends Def with ValueTyped with ReturnTyped {

def output(s: StringBuilder, indent: Int) {
s.append(name.name)
Expand All @@ -23,13 +23,13 @@ case class Parameter(name: Symbol, initialTypeDef: TypeDef, defaultValue: Option
}

protected def determineValueType(visitedNodes: Set[SyntaxNode]): TypeDef = {
if (initialTypeDef != null) initialTypeDef
if (declaredReturnType.isDefined) declaredReturnType.get
else if (defaultValue.isDefined) defaultValue.get.valueType(visitedNodes)
else null
}

def getMember(name: Symbol) = None

override def subNodes = singleIt(valueType) ++ defaultValue.iterator
override def subNodes = singleIt(valueType) ++ declaredReturnType.iterator ++ defaultValue.iterator

}

0 comments on commit ac74a58

Please sign in to comment.