Skip to content

Commit

Permalink
Merge pull request #299 from Yoitsumi/topic/argument-attributes
Browse files Browse the repository at this point in the history
Introduce function argument attributes
  • Loading branch information
densh committed Oct 8, 2016
2 parents 7c787e8 + 5e51aaa commit 74a80b7
Show file tree
Hide file tree
Showing 16 changed files with 165 additions and 45 deletions.
3 changes: 3 additions & 0 deletions nir/src/main/scala/scala/scalanative/nir/Arg.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
package scala.scalanative.nir

case class Arg(ty: Type, passConvention: Option[PassConv] = None)
7 changes: 7 additions & 0 deletions nir/src/main/scala/scala/scalanative/nir/PassConv.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
package scala.scalanative.nir

sealed abstract class PassConv
object PassConv {
final case class Byval(ty: Type) extends PassConv
final case class Sret(ty: Type) extends PassConv
}
10 changes: 10 additions & 0 deletions nir/src/main/scala/scala/scalanative/nir/Shows.scala
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,16 @@ object Shows {
case Type.Module(name) => sh"module $name"
}

implicit val showArg: Show[Arg] = Show {
case Arg(ty, None) => sh"$ty"
case Arg(ty, Some(passConv)) => sh"$passConv $ty"
}

implicit val showPassConvention: Show[PassConv] = Show {
case PassConv.Byval(ty) => sh"byval[$ty]"
case PassConv.Sret(ty) => sh"sret[$ty]"
}

implicit val showGlobal: Show[Global] = Show {
case Global.None => unreachable
case Global.Top(id) => sh"@$id"
Expand Down
8 changes: 8 additions & 0 deletions nir/src/main/scala/scala/scalanative/nir/Tags.scala
Original file line number Diff line number Diff line change
Expand Up @@ -199,4 +199,12 @@ object Tags {
final val UnitVal = 1 + GlobalVal
final val ConstVal = 1 + UnitVal
final val StringVal = 1 + ConstVal

// Argument Passing Conventions

final val PassConv = Val + 32

final val Byval = PassConv + 1
final val Sret = Byval + 1

}
2 changes: 1 addition & 1 deletion nir/src/main/scala/scala/scalanative/nir/Types.scala
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ object Type {
final object F64 extends F(64)

final case class Array(ty: Type, n: Int) extends Type
final case class Function(args: Seq[Type], ret: Type) extends Type
final case class Function(args: Seq[Arg], ret: Type) extends Type
final case class Struct(name: Global, tys: Seq[Type]) extends Type with Named

// high-level types
Expand Down
4 changes: 2 additions & 2 deletions nir/src/main/scala/scala/scalanative/nir/Versions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ object Versions {
* when 1.3-based release happens all of the code needs to recompiled with
* new version of the toolchain.
*/
final val compat: Int = 4
final val revision: Int = 5
final val compat: Int = 5
final val revision: Int = 6

/* Current public release version of Scala Native. */
final val current: String = "0.1-SNAPSHOT"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ final class BinaryDeserializer(_buffer: => ByteBuffer) {
case T.F32Type => Type.F32
case T.F64Type => Type.F64
case T.ArrayType => Type.Array(getType, getInt)
case T.FunctionType => Type.Function(getTypes, getType)
case T.FunctionType => Type.Function(getArgs, getType)
case T.StructType => Type.Struct(getGlobal, getTypes)

case T.UnitType => Type.Unit
Expand All @@ -269,6 +269,13 @@ final class BinaryDeserializer(_buffer: => ByteBuffer) {
case T.ModuleType => Type.Module(getGlobal)
}

private def getArgs(): Seq[Arg] = getSeq(getArg)
private def getArg(): Arg = Arg(getType, getOpt(getPassConvention))
private def getPassConvention(): PassConv = getInt match {
case T.Byval => PassConv.Byval(getType)
case T.Sret => PassConv.Sret(getType)
}

private def getVals(): Seq[Val] = getSeq(getVal)
private def getVal(): Val = getInt match {
case T.NoneVal => Val.None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -376,7 +376,7 @@ final class BinarySerializer(buffer: ByteBuffer) {
case Type.F64 => putInt(T.F64Type)
case Type.Array(ty, n) => putInt(T.ArrayType); putType(ty); putInt(n)
case Type.Function(args, ret) =>
putInt(T.FunctionType); putTypes(args); putType(ret)
putInt(T.FunctionType); putArgs(args); putType(ret)
case Type.Struct(n, tys) =>
putInt(T.StructType); putGlobal(n); putTypes(tys)

Expand All @@ -387,6 +387,16 @@ final class BinarySerializer(buffer: ByteBuffer) {
case Type.Module(n) => putInt(T.ModuleType); putGlobal(n)
}

private def putArgs(args: Seq[Arg]): Unit = putSeq(args)(putArg)
private def putArg(arg: Arg): Unit = {
putType(arg.ty)
putOpt(arg.passConvention)(putPassConvention)
}
private def putPassConvention(attr: PassConv): Unit = attr match {
case PassConv.Byval(ty) => putInt(T.Byval); putType(ty)
case PassConv.Sret(ty) => putInt(T.Sret); putType(ty)
}

private def putVals(values: Seq[Val]): Unit = putSeq(values)(putVal)
private def putVal(value: Val): Unit = value match {
case Val.None => putInt(T.NoneVal)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -377,7 +377,7 @@ abstract class NirCodeGen
val selfty = genType(sym.owner.tpe)
val retty = genType(res, retty = true)

Type.Function(Seq(selfty), retty)
Type.Function(Seq(Arg(selfty)), retty)

case sym: MethodSymbol =>
val params = sym.paramLists.flatten
Expand All @@ -390,7 +390,7 @@ abstract class NirCodeGen
if (sym.isClassConstructor) Type.Unit
else genType(sym.tpe.resultType, retty = true)

Type.Function(selfty ++: paramtys, retty)
Type.Function((selfty ++: paramtys).map(Arg(_)), retty)
}

def genParams(
Expand Down Expand Up @@ -952,7 +952,7 @@ abstract class NirCodeGen
lazy val jlClass = nir.Type.Class(jlClassName)
lazy val jlClassCtorName = jlClassName member "init_ptr"
lazy val jlClassCtorSig =
nir.Type.Function(Seq(jlClass, Type.Ptr), nir.Type.Unit)
nir.Type.Function(Seq(Arg(jlClass), Arg(Type.Ptr)), nir.Type.Unit)
lazy val jlClassCtor = nir.Val.Global(jlClassCtorName, nir.Type.Ptr)

def genBoxClass(type_ : Val, focus: Focus) = {
Expand Down Expand Up @@ -1316,7 +1316,7 @@ abstract class NirCodeGen
val (argsp, ctsp) = allargsp.splitAt(arity)
val ctsyms = ctsp.map(extractClassFromImplicitClassTag)
val cttys = ctsyms.map(ctsym => genType(ctsym.info))
val sig = Type.Function(cttys.init, cttys.last)
val sig = Type.Function(cttys.init.map(Arg(_)), cttys.last)

val args = mutable.UnrolledBuffer.empty[nir.Val]
var last = fun
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,20 +68,21 @@ trait NirNameEncoding { self: NirCodeGen =>

private def mangledTypeInternal(ty: nir.Type): String = {
implicit lazy val showMangledType: Show[nir.Type] = Show {
case nir.Type.None => ""
case nir.Type.Void => "void"
case nir.Type.Vararg => "..."
case nir.Type.Ptr => "ptr"
case nir.Type.Bool => "bool"
case nir.Type.I8 => "i8"
case nir.Type.I16 => "i16"
case nir.Type.I32 => "i32"
case nir.Type.I64 => "i64"
case nir.Type.F32 => "f32"
case nir.Type.F64 => "f64"
case nir.Type.Array(ty, n) => sh"arr.$ty.$n"
case nir.Type.Function(args, ret) => sh"fun.${r(args :+ ret, sep = ".")}"
case nir.Type.Struct(name, _) => sh"struct.$name"
case nir.Type.None => ""
case nir.Type.Void => "void"
case nir.Type.Vararg => "..."
case nir.Type.Ptr => "ptr"
case nir.Type.Bool => "bool"
case nir.Type.I8 => "i8"
case nir.Type.I16 => "i16"
case nir.Type.I32 => "i32"
case nir.Type.I64 => "i64"
case nir.Type.F32 => "f32"
case nir.Type.F64 => "f64"
case nir.Type.Array(ty, n) => sh"arr.$ty.$n"
case nir.Type.Function(args, ret) =>
sh"fun.${r(args.map(_.ty) :+ ret, sep = ".")}"
case nir.Type.Struct(name, _) => sh"struct.$name"

case nir.Type.Nothing => "nothing"
case nir.Type.Unit => "unit"
Expand Down
9 changes: 5 additions & 4 deletions tools/src/main/scala/scala/scalanative/compiler/Pass.scala
Original file line number Diff line number Diff line change
Expand Up @@ -189,10 +189,11 @@ trait Pass extends (Seq[Defn] => Seq[Defn]) {
private def txType(ty: Type): Type = {
val pre = hook(preType, ty, ty)
val post = pre match {
case Type.Array(ty, n) => Type.Array(txType(ty), n)
case Type.Function(tys, ty) => Type.Function(tys.map(txType), txType(ty))
case Type.Struct(n, tys) => Type.Struct(n, tys.map(txType))
case _ => pre
case Type.Array(ty, n) => Type.Array(txType(ty), n)
case Type.Function(args, ty) =>
Type.Function(args.map(a => a.copy(ty = txType(a.ty))), txType(ty))
case Type.Struct(n, tys) => Type.Struct(n, tys.map(txType))
case _ => pre
}

hook(postType, post, post)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@ package compiler
package codegen

import java.{lang => jl}

import scala.collection.mutable
import util.{unsupported, unreachable, sh, Show}
import util.Show.{Sequence => s, Indent => i, Unindent => ui, Repeat => r, Newline => nl}
import util.{Show, sh, unreachable, unsupported}
import util.Show.{Indent => i, Newline => nl, Repeat => r, Sequence => s, Unindent => ui}
import compiler.analysis.ControlFlow
import nir.Shows.brace
import nir._
Expand Down Expand Up @@ -84,9 +85,38 @@ class GenTextualLLVM(assembly: Seq[Defn]) extends GenShow(assembly) {

val isDecl = blocks.isEmpty
val keyword = if (isDecl) "declare" else "define"
val params =
if (isDecl) r(argtys, sep = ", ")
else r(blocks.head.params: Seq[Val], sep = ", ")

def showDefnArg(arg: Arg,
value: Val.Local): (Show.Result, Seq[Show.Result]) =
arg match {
case Arg(_, None) => (sh"${value: Val}", Seq.empty)
case Arg(Type.Ptr, Some(PassConv.Byval(pointee))) =>
val pointer = fresh()
(sh"$pointee* byval %$pointer",
Seq(sh"%${value.name} = bitcast $pointee* %$pointer to i8*"))
case Arg(Type.Ptr, Some(PassConv.Sret(pointee))) =>
val pointer = fresh()
(sh"$pointee* sret %$pointer",
Seq(sh"%${value.name} = bitcast $pointee* %$pointer to i8*"))
case x => unsupported(x)
}

def showDeclArg(arg: Arg): Show.Result = arg match {
case Arg(ty, None) => sh"$ty"
case Arg(Type.Ptr, Some(PassConv.Byval(pointee))) =>
sh"$pointee* byval"
case Arg(Type.Ptr, Some(PassConv.Sret(pointee))) =>
sh"$pointee* sret"
case x => unsupported(x)
}

val (params, preInstrs) =
if (isDecl) (r(argtys.map(showDeclArg), sep = ", "), Seq())
else {
val results =
(argtys zip blocks.head.params).map((showDefnArg _).tupled)
(r(results.map(_._1), sep = ", "), results.flatMap(_._2))
}
val postattrs: Seq[Attr] =
if (attrs.inline != Attr.MayInline) Seq(attrs.inline) else Seq()
val personality = if (attrs.isExtern || isDecl) s() else gxxpersonality
Expand All @@ -95,15 +125,22 @@ class GenTextualLLVM(assembly: Seq[Defn]) extends GenShow(assembly) {
else {
implicit val cfg = ControlFlow(blocks)
val showblocks = cfg.map { node =>
showBlock(node.block, node.pred, isEntry = node eq cfg.entry)
val isEntry = node eq cfg.entry
showBlock(node.block,
node.pred,
isEntry = isEntry,
if (isEntry) r(preInstrs.map(nl)) else s())
}
s(" ", brace(r(showblocks)))
}

sh"$keyword $retty @$name($params)$postattrs$personality$body"
}

def showBlock(block: Block, pred: Seq[ControlFlow.Edge], isEntry: Boolean)(
def showBlock(block: Block,
pred: Seq[ControlFlow.Edge],
isEntry: Boolean,
preInstructions: Show.Result)(
implicit cfg: ControlFlow.Graph): Show.Result = {
val Block(name, params, insts, cf) = block

Expand Down Expand Up @@ -150,7 +187,7 @@ class GenTextualLLVM(assembly: Seq[Defn]) extends GenShow(assembly) {
r(shows.map(s(_)))
}

sh"${nl("")}$label$prologue$body"
sh"${nl("")}$label$prologue$preInstructions${nl("")}$body"
}

implicit val showType: Show[Type] = Show {
Expand All @@ -171,6 +208,13 @@ class GenTextualLLVM(assembly: Seq[Defn]) extends GenShow(assembly) {
case ty => unsupported(ty)
}

implicit val showArg: Show[Arg] = Show {
case Arg(ty, None) => sh"$ty"
case Arg(Type.Ptr, Some(PassConv.Byval(pointee))) => sh"$pointee*"
case Arg(Type.Ptr, Some(PassConv.Sret(pointee))) => sh"$pointee*"
case arg => unsupported(arg)
}

def justVal(v: Val): Show.Result = v match {
case Val.True => "true"
case Val.False => "false"
Expand Down Expand Up @@ -225,6 +269,23 @@ class GenTextualLLVM(assembly: Seq[Defn]) extends GenShow(assembly) {
def isVoid(ty: Type): Boolean =
ty == Type.Void || ty == Type.Unit || ty == Type.Nothing

def showCallArgs(args: Seq[Arg],
vals: Seq[Val]): (Seq[Show.Result], Seq[Show.Result]) = {
val res = (args zip vals) map {
case (Arg(Type.Ptr, Some(PassConv.Byval(pointee))), v) =>
val bitcasted = fresh()
(Seq(sh"%$bitcasted = bitcast $v to $pointee*"),
sh"$pointee* %$bitcasted")
case (Arg(Type.Ptr, Some(PassConv.Sret(pointee))), v) =>
val bitcasted = fresh()
(Seq(sh"%$bitcasted = bitcast $v to $pointee*"),
sh"$pointee* %$bitcasted")
case (Arg(_, None), v) => (Seq(), sh"$v")
case _ => unsupported()
}
(res.flatMap(_._1), res.map(_._2))
}

insts.foreach { inst =>
val op = inst.op
val name = inst.name
Expand All @@ -234,14 +295,24 @@ class GenTextualLLVM(assembly: Seq[Defn]) extends GenShow(assembly) {
case Op.Call(ty, Val.Global(pointee, _), args) =>
val bind = if (isVoid(op.resty)) s() else sh"%$name = "

buf += sh"${bind}call $ty @$pointee(${r(args, sep = ", ")})"
val Type.Function(argtys, _) = ty

val (preinsts, argshows) = showCallArgs(argtys, args)

buf ++= preinsts
buf += sh"${bind}call ${ty: Type} @$pointee(${r(argshows, sep = ", ")})"

case Op.Call(ty, ptr, args) =>
val pointee = fresh()
val bind = if (isVoid(op.resty)) s() else sh"%$name = "

val Type.Function(argtys, _) = ty

val (preinsts, argshows) = showCallArgs(argtys, args)

buf ++= preinsts
buf += sh"%$pointee = bitcast $ptr to $ty*"
buf += sh"${bind}call $ty %$pointee(${r(args, sep = ", ")})"
buf += sh"${bind}call ${ty: Type} %$pointee(${r(argshows, sep = ", ")})"

case Op.Load(ty, ptr) =>
val pointee = fresh()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ object ClassLowering extends PassCompanion {
def apply(ctx: Ctx) = new ClassLowering()(ctx.top, ctx.fresh)

val allocName = Global.Top("scalanative_alloc")
val allocSig = Type.Function(Seq(Type.Ptr, Type.I64), Type.Ptr)
val allocSig = Type.Function(Seq(Arg(Type.Ptr), Arg(Type.I64)), Type.Ptr)
val alloc = Val.Global(allocName, allocSig)

override val injects = Seq(Defn.Declare(Attrs.None, allocName, allocSig))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@ class MainInjection(entry: Global)(implicit fresh: Fresh) extends Pass {

override def preAssembly = {
case defns =>
val mainTy =
Type.Function(Seq(Type.Module(entry.top), ObjectArray), Type.Void)
val mainTy = Type.Function(
Seq(Arg(Type.Module(entry.top)), Arg(ObjectArray)),
Type.Void)
val main = Val.Global(entry, Type.Ptr)
val argc = Val.Local(fresh(), Type.I32)
val argv = Val.Local(fresh(), Type.Ptr)
Expand Down Expand Up @@ -47,11 +48,12 @@ object MainInjection extends PassCompanion {

val Rt = Type.Module(Global.Top("scala.scalanative.runtime.package$"))
val initName = Rt.name member "init_i32_ptr_class.ssnr.ObjectArray"
val initSig = Type.Function(Seq(Rt, Type.I32, Type.Ptr), ObjectArray)
val init = Val.Global(initName, initSig)
val initSig =
Type.Function(Seq(Arg(Rt), Arg(Type.I32), Arg(Type.Ptr)), ObjectArray)
val init = Val.Global(initName, initSig)

val mainName = Global.Top("main")
val mainSig = Type.Function(Seq(Type.I32, Type.Ptr), Type.I32)
val mainSig = Type.Function(Seq(Arg(Type.I32), Arg(Type.Ptr)), Type.I32)

override val depends = Seq(ObjectArray.name, Rt.name, init.name)
}
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ class ModuleLowering(implicit top: Top, fresh: Fresh) extends Pass {
val initCall =
if (isStaticModule(name)) Seq()
else {
val initSig = Type.Function(Seq(Type.Class(name)), Type.Void)
val initSig = Type.Function(Seq(Arg(Type.Class(name))), Type.Void)
val init = Val.Global(name member "init", Type.Ptr)

Seq(Inst(Op.Call(initSig, init, Seq(alloc))))
Expand Down

0 comments on commit 74a80b7

Please sign in to comment.