diff --git a/core/src/fr/hammons/slinc/SetSizeArray.scala b/core/src/fr/hammons/slinc/SetSizeArray.scala index 91c6b9a..b19f967 100644 --- a/core/src/fr/hammons/slinc/SetSizeArray.scala +++ b/core/src/fr/hammons/slinc/SetSizeArray.scala @@ -44,6 +44,10 @@ class SetSizeArray[A, B <: Int] private[slinc] (private val array: Array[A]) value: A )(using 0 <= C =:= true, C < B =:= true): Unit = array(constValue[C]) = value + def zip[C](oArr: SetSizeArray[C, B]): SetSizeArray[(A, C), B] = + new SetSizeArray[(A, C), B](array.zip(oArr.array)) + def foreach(fn: A => Unit) = array.foreach(fn) + object SetSizeArray: class SetSizeArrayBuilderUnsafe[B <: Int]: def apply[A](array: Array[A]): SetSizeArray[A, B] = new SetSizeArray(array) diff --git a/core/src/fr/hammons/slinc/TypeDescriptor.scala b/core/src/fr/hammons/slinc/TypeDescriptor.scala index 50291da..38ab3e1 100644 --- a/core/src/fr/hammons/slinc/TypeDescriptor.scala +++ b/core/src/fr/hammons/slinc/TypeDescriptor.scala @@ -239,9 +239,20 @@ case class SetSizeArrayDescriptor( override val argumentTransition : (TransitionModule, ReadWriteModule, Allocator) ?=> ArgumentTransition[ Inner - ] = ??? + ] = arg => + val mem = summon[Allocator].allocate(this, 1) + summon[ReadWriteModule].write( + mem, + Bytes(0), + this, + arg + ) + mem.asAddress override val returnTransition - : (TransitionModule, ReadWriteModule) ?=> ReturnTransition[Inner] = ??? + : (TransitionModule, ReadWriteModule) ?=> ReturnTransition[Inner] = + obj => + val mem = summon[TransitionModule].addressReturn(obj) + summon[ReadWriteModule].read(mem, Bytes(0), this) type Inner = SetSizeArray[contained.Inner, ?] diff --git a/core/test/resources/native/test.c b/core/test/resources/native/test.c index 9721d26..8a7543a 100644 --- a/core/test/resources/native/test.c +++ b/core/test/resources/native/test.c @@ -139,3 +139,12 @@ EXPORTED struct_issue_175 i175_test(struct_issue_175 a, char left) { } return a; } + +EXPORTED int* i180_test(int my_array[5]) { + int i = 0; + while(i < 5) { + my_array[i] = my_array[i] * 2; + i++; + } + return my_array; +} \ No newline at end of file diff --git a/core/test/src/fr/hammons/slinc/BindingSpec.scala b/core/test/src/fr/hammons/slinc/BindingSpec.scala index ef8a049..45df3fe 100644 --- a/core/test/src/fr/hammons/slinc/BindingSpec.scala +++ b/core/test/src/fr/hammons/slinc/BindingSpec.scala @@ -59,6 +59,10 @@ trait BindingSpec(val slinc: Slinc) extends ScalaCheckSuite: left: CChar ): I175_Struct + def i180_test( + input: SetSizeArray[CInt, 5] + ): SetSizeArray[CInt, 5] + test("int_identity") { val test = FSet.instance[TestLib] @@ -186,3 +190,11 @@ trait BindingSpec(val slinc: Slinc) extends ScalaCheckSuite: union.set(double) val res = test.i175_test(I175_Struct(union), 0) assertEquals(res.union.get[CDouble], double / 2) + + property("issue 180 - can send and receive set size arrays to C functions"): + val test = FSet.instance[TestLib] + forAll(Gen.listOfN(5, Arbitrary.arbitrary[CInt])): list => + val arr = SetSizeArray.fromArrayUnsafe[5](list.toArray) + val retArr = test.i180_test(arr) + + retArr.zip(arr.map(_ * 2)).foreach(assertEquals(_, _)) diff --git a/core/test/src/fr/hammons/slinc/TransferSpec.scala b/core/test/src/fr/hammons/slinc/TransferSpec.scala index e8ed0be..1bf855a 100644 --- a/core/test/src/fr/hammons/slinc/TransferSpec.scala +++ b/core/test/src/fr/hammons/slinc/TransferSpec.scala @@ -10,6 +10,7 @@ import scala.concurrent.Await import scala.concurrent.duration.Duration import scala.concurrent.ExecutionContext.Implicits.global import scala.reflect.ClassTag +import scala.util.chaining.* trait TransferSpec[ThreadException <: Throwable](val slinc: Slinc)(using ClassTag[ThreadException] @@ -28,7 +29,7 @@ trait TransferSpec[ThreadException <: Throwable](val slinc: Slinc)(using case class F(u: CUnion[(CInt, CFloat)]) derives Struct - case class G(arr: SetSizeArray[CLong, 2]) derives Struct + case class G(long: CLong, arr: SetSizeArray[CLong, 2]) derives Struct test("can read and write jvm ints") { Scope.global { @@ -162,9 +163,8 @@ trait TransferSpec[ThreadException <: Throwable](val slinc: Slinc)(using } } - test("varargs can be sent and retrieved"): + test("varargs can receive primitive types"): Scope.confined { - val vaListForVaList = VarArgsBuilder(4).build val vaList = VarArgsBuilder( 4.toByte, 5.toShort, @@ -172,10 +172,7 @@ trait TransferSpec[ThreadException <: Throwable](val slinc: Slinc)(using 7.toLong, 2f, 3d, - Null[Int], - A(1, 2), - CLong(4: Byte), - vaListForVaList + Null[Int] ).build assertEquals(vaList.get[Byte], 4.toByte, "byte assert") @@ -185,24 +182,118 @@ trait TransferSpec[ThreadException <: Throwable](val slinc: Slinc)(using assertEquals(vaList.get[Float], 2f, "float assert") assertEquals(vaList.get[Double], 3d, "double assert") assertEquals( - vaList.get[Ptr[Int]].mem.asAddress, - Null[Int].mem.asAddress, + vaList.get[Ptr[Int]], + Null[Int], "ptr assert" ) + } + + test("varargs can receive complex types".ignore): + Scope.confined { + val vaListForVaList = VarArgsBuilder(4).build + val vaList = VarArgsBuilder( + A(1, 2), + CLong(4), + A(3, 4), + SetSizeArray(1, 2, 3, 4), + vaListForVaList, + CUnion[(CInt, CFloat)].tap(_.set(5)), + // Null[Int], + A(3, 4) + ).build + assertEquals(vaList.get[A], A(1, 2), "struct assert") assertEquals(vaList.get[CLong], CLong(4: Byte), "alias assert") - assertEquals(vaList.get[VarArgs].get[CInt], 4) + assertEquals(vaList.get[A], A(3, 4)) + assertEquals( + vaList.get[SetSizeArray[CInt, 4]].toSeq, + Seq(1, 2, 3, 4), + "set size array assert" + ) + assertEquals( + vaListForVaList.get[VarArgs].get[Int], + 4 + ) + assertEquals( + vaList.get[CUnion[(CLongLong, CFloat)]].get[CLongLong], + 5L, + "cunion assert" + ) + // assertEquals( + // vaList.get[Ptr[Int]], + // Null[Int] + // ) + assertEquals( + vaList.get[A], + A(3, 4), + "struct assert 2" + ) } - test("varargs can be skipped"): + test("varargs can skip primitive types"): Scope.confined { val vaList = VarArgsBuilder( - 4.toByte, - 2f + 4: Byte, + 5: Short, + 6, + 7L, + 2f, + 3d, + Null[Int] ).build + val vaList2 = vaList.copy() + vaList.skip[Byte] - assertEquals(vaList.get[Float], 2f) + assertEquals(vaList.get[Short], 5: Short) + vaList.skip[Int] + assertEquals(vaList.get[Long], 7L) + vaList.skip[Float] + assertEquals(vaList.get[Double], 3d) + vaList.skip[Ptr[Int]] + + assertEquals(vaList2.get[Byte], 4: Byte) + vaList2.skip[Short] + assertEquals(vaList2.get[Int], 6) + vaList2.skip[Long] + assertEquals(vaList2.get[Float], 2f) + vaList2.skip[Double] + assertEquals(vaList2.get[Ptr[Int]], Null[Int]) + } + + test("varargs can skip complex types".ignore): + Scope.confined { + val vaListForVaList = VarArgsBuilder(4, 5, 6).build + val vaList = VarArgsBuilder( + A(1, 2), + CLong(4), + vaListForVaList, + CUnion[(CInt, CFloat)].tap(_.set(5)), + SetSizeArray(1, 2, 3, 4) + ).build + + val vaList2 = vaList.copy() + + assertEquals(vaList.get[A], A(1, 2), "struct assert") + vaList.skip[CLong] + val vaList3 = vaList.get[VarArgs] + assertEquals( + List(vaList3.get[Int], vaList3.get[Int], vaList3.get[Int]), + List(4, 5, 6), + "varargs assert" + ) + vaList.skip[CUnion[(CInt, CFloat)]] + assertEquals( + vaList.get[SetSizeArray[Int, 4]].toSeq, + Seq(1, 2, 3, 4), + "set size array assert" + ) + + vaList2.skip[A] + assertEquals(vaList2.get[CLong], CLong(4)) + vaList2.skip[VarArgs] + assertEquals(vaList2.get[CUnion[(CInt, CFloat)]].get[Int], 5) + vaList2.skip[SetSizeArray[Int, 4]] } test("varargs can be copied and reread"): @@ -373,7 +464,7 @@ trait TransferSpec[ThreadException <: Throwable](val slinc: Slinc)(using } test("can copy G to native memory and back"): - val g = G(SetSizeArray(CLong(1), CLong(2))) + val g = G(CLong(5), SetSizeArray(CLong(1), CLong(2))) Scope.confined { val ptr = Ptr.copy(g) diff --git a/j17/src/fr/hammons/slinc/Allocator17.scala b/j17/src/fr/hammons/slinc/Allocator17.scala index 9d93faa..868800f 100644 --- a/j17/src/fr/hammons/slinc/Allocator17.scala +++ b/j17/src/fr/hammons/slinc/Allocator17.scala @@ -122,6 +122,15 @@ object Allocator17: case ms: MemorySegment => ms case _ => throw Error("base of mem was not J17 MemorySegment!!") ) + case (ssad: SetSizeArrayDescriptor, s: SetSizeArray[?, ?]) => + LinkageModule17.tempScope(alloc ?=> + builder.vargFromAddress( + C_POINTER, + transitionModule17 + .methodArgument(ssad, s, alloc) + .asInstanceOf[Addressable] + ) + ) case (a, d) => throw Error( s"Unsupported type descriptor/data pairing for VarArgs: $a - $d" diff --git a/j17/src/fr/hammons/slinc/VarArgs17.scala b/j17/src/fr/hammons/slinc/VarArgs17.scala index f34182c..ebefb84 100644 --- a/j17/src/fr/hammons/slinc/VarArgs17.scala +++ b/j17/src/fr/hammons/slinc/VarArgs17.scala @@ -3,6 +3,7 @@ package fr.hammons.slinc import jdk.incubator.foreign.CLinker.VaList import jdk.incubator.foreign.CLinker.{C_INT, C_LONG_LONG, C_DOUBLE, C_POINTER} import jdk.incubator.foreign.SegmentAllocator +import jdk.incubator.foreign.GroupLayout import fr.hammons.slinc.modules.{ LinkageModule17, descriptorModule17, @@ -20,7 +21,8 @@ class VarArgs17(args: VaList) extends VarArgs: case LongDescriptor => Long.box(args.vargAsLong(C_LONG_LONG)) case FloatDescriptor => Float.box(args.vargAsDouble(C_DOUBLE).toFloat) case DoubleDescriptor => Double.box(args.vargAsDouble(C_DOUBLE)) - case PtrDescriptor => args.vargAsAddress(C_POINTER).nn + case PtrDescriptor | _: SetSizeArrayDescriptor | VaListDescriptor => + args.vargAsAddress(C_POINTER).nn case sd: StructDescriptor => LinkageModule17.tempScope(alloc ?=> args @@ -30,26 +32,34 @@ class VarArgs17(args: VaList) extends VarArgs: ) .nn ) - case AliasDescriptor(real) => get(real) - case VaListDescriptor => args.vargAsAddress(C_POINTER).nn - case CUnionDescriptor(possibleTypes) => get(possibleTypes.maxBy(_.size)) + case AliasDescriptor(real) => get(real) + case cud: CUnionDescriptor => + LinkageModule17.tempScope(alloc ?=> + args + .vargAsSegment( + descriptorModule17.toMemoryLayout(cud).asInstanceOf[GroupLayout], + alloc.base.asInstanceOf[SegmentAllocator] + ) + .nn + ) def get[A](using d: DescriptorOf[A]): A = transitionModule17.methodReturn[A](d.descriptor, get(d.descriptor)) private def skip(td: TypeDescriptor): Unit = td match - case ByteDescriptor => args.skip(C_INT) - case ShortDescriptor => args.skip(C_INT) - case IntDescriptor => args.skip(C_INT) - case LongDescriptor => args.skip(C_LONG_LONG) - case FloatDescriptor => args.skip(C_DOUBLE) - case DoubleDescriptor => args.skip(C_DOUBLE) - case PtrDescriptor => args.skip(C_POINTER) + case ByteDescriptor => args.skip(C_INT) + case ShortDescriptor => args.skip(C_INT) + case IntDescriptor => args.skip(C_INT) + case LongDescriptor => args.skip(C_LONG_LONG) + case FloatDescriptor => args.skip(C_DOUBLE) + case DoubleDescriptor => args.skip(C_DOUBLE) + case PtrDescriptor | _: SetSizeArrayDescriptor => args.skip(C_POINTER) case sd: StructDescriptor => args.skip(descriptorModule17.toGroupLayout(sd)) - case AliasDescriptor(real) => skip(real) - case VaListDescriptor => args.skip(C_POINTER) - case CUnionDescriptor(possibleTypes) => skip(possibleTypes.maxBy(_.size)) + case AliasDescriptor(real) => skip(real) + case VaListDescriptor => args.skip(C_POINTER) + case cud: CUnionDescriptor => + args.skip(descriptorModule17.toMemoryLayout(cud)) def skip[A](using dO: DescriptorOf[A]): Unit = skip(dO.descriptor) diff --git a/j17/src/fr/hammons/slinc/modules/DescriptorModule17.scala b/j17/src/fr/hammons/slinc/modules/DescriptorModule17.scala index ac298b5..7ab52be 100644 --- a/j17/src/fr/hammons/slinc/modules/DescriptorModule17.scala +++ b/j17/src/fr/hammons/slinc/modules/DescriptorModule17.scala @@ -8,7 +8,8 @@ import jdk.incubator.foreign.{ MemorySegment, GroupLayout, CLinker, - ValueLayout + ValueLayout, + SequenceLayout }, CLinker.C_POINTER import scala.collection.concurrent.TrieMap import fr.hammons.slinc.types.{arch, os, OS, Arch} @@ -26,11 +27,10 @@ given descriptorModule17: DescriptorModule with case FloatDescriptor => classOf[Float] case DoubleDescriptor => classOf[Double] case PtrDescriptor => classOf[MemoryAddress] - case _: StructDescriptor | _: CUnionDescriptor | - _: SetSizeArrayDescriptor => + case _: StructDescriptor | _: CUnionDescriptor => classOf[MemorySegment] - case VaListDescriptor => classOf[MemoryAddress] - case ad: AliasDescriptor[?] => toCarrierType(ad.real) + case VaListDescriptor | _: SetSizeArrayDescriptor => classOf[MemoryAddress] + case ad: AliasDescriptor[?] => toCarrierType(ad.real) def genLayoutList( layouts: Seq[MemoryLayout], @@ -123,6 +123,11 @@ given descriptorModule17: DescriptorModule with case CUnionDescriptor(possibleTypes) => MemoryLayout.unionLayout(possibleTypes.map(toMemoryLayout).toSeq*).nn + def toDowncallLayout(td: TypeDescriptor): MemoryLayout = toMemoryLayout( + td + ) match + case _: SequenceLayout => C_POINTER.nn + case o => o def toMemoryLayout(smd: StructMemberDescriptor): MemoryLayout = toMemoryLayout(smd.descriptor).withName(smd.name).nn diff --git a/j17/src/fr/hammons/slinc/modules/LinkageModule17.scala b/j17/src/fr/hammons/slinc/modules/LinkageModule17.scala index dc524cf..5ed0d1b 100644 --- a/j17/src/fr/hammons/slinc/modules/LinkageModule17.scala +++ b/j17/src/fr/hammons/slinc/modules/LinkageModule17.scala @@ -26,11 +26,11 @@ object LinkageModule17 extends LinkageModule: varArgs.view.map(_.use[DescriptorOf](d ?=> _ => d.descriptor)) val fdConstructor = descriptor.returnDescriptor match case None => FunctionDescriptor.ofVoid(_*) - case Some(value) => FunctionDescriptor.of(toMemoryLayout(value), _*) + case Some(value) => FunctionDescriptor.of(toDowncallLayout(value), _*) val fd = fdConstructor( descriptor.inputDescriptors.view - .map(toMemoryLayout) + .map(toDowncallLayout) .concat(variadicDescriptors.map(toMemoryLayout).map(CLinker.asVarArg)) .toSeq ) diff --git a/j19/src/fr/hammons/slinc/Allocator19.scala b/j19/src/fr/hammons/slinc/Allocator19.scala index 71575cc..c46dc9b 100644 --- a/j19/src/fr/hammons/slinc/Allocator19.scala +++ b/j19/src/fr/hammons/slinc/Allocator19.scala @@ -112,6 +112,16 @@ class Allocator19( case ms: MemorySegment => ms case _ => throw Error("Illegal datatype") ) + + case (ssad: SetSizeArrayDescriptor, s: SetSizeArray[?, ?]) => + LinkageModule19.tempScope(alloc ?=> + builder.addVarg( + ValueLayout.ADDRESS, + transitionModule19 + .methodArgument(ssad, s, alloc) + .asInstanceOf[Addressable] + ) + ) case (td, d) => throw Error(s"Unsupported datatype for $td - $d") diff --git a/j19/src/fr/hammons/slinc/VarArgs19.scala b/j19/src/fr/hammons/slinc/VarArgs19.scala index a161a49..dceb224 100644 --- a/j19/src/fr/hammons/slinc/VarArgs19.scala +++ b/j19/src/fr/hammons/slinc/VarArgs19.scala @@ -9,6 +9,7 @@ import fr.hammons.slinc.modules.{ } import scala.util.chaining.* import java.lang.foreign.MemorySegment +import java.lang.foreign.GroupLayout private class VarArgs19(vaList: VaList) extends VarArgs: @@ -25,7 +26,8 @@ private class VarArgs19(vaList: VaList) extends VarArgs: case LongDescriptor => vaList.skip(ValueLayout.JAVA_LONG) case FloatDescriptor | DoubleDescriptor => vaList.skip(ValueLayout.JAVA_DOUBLE) - case PtrDescriptor | VaListDescriptor => vaList.skip(ValueLayout.ADDRESS) + case PtrDescriptor | VaListDescriptor | _: SetSizeArrayDescriptor => + vaList.skip(ValueLayout.ADDRESS) case sd: StructDescriptor => vaList.skip(descriptorModule19.toGroupLayout(sd)) case cd: CUnionDescriptor => @@ -50,7 +52,7 @@ private class VarArgs19(vaList: VaList) extends VarArgs: Float.box(vaList.nextVarg(ValueLayout.JAVA_DOUBLE).toFloat) case DoubleDescriptor => Double.box(vaList.nextVarg(ValueLayout.JAVA_DOUBLE).toDouble) - case PtrDescriptor => + case PtrDescriptor | VaListDescriptor | _: SetSizeArrayDescriptor => vaList.nextVarg(ValueLayout.ADDRESS).nn case sd: StructDescriptor => LinkageModule19.tempScope(alloc ?=> @@ -62,9 +64,17 @@ private class VarArgs19(vaList: VaList) extends VarArgs: .nn ) case AliasDescriptor(real) => as(real) - case VaListDescriptor => vaList.nextVarg(ValueLayout.ADDRESS).nn - case CUnionDescriptor(possibleTypes) => - as(possibleTypes.maxBy(_.size)) + case cud: CUnionDescriptor => + val desc = + descriptorModule19.toMemoryLayout(cud).asInstanceOf[GroupLayout] + LinkageModule19.tempScope(alloc ?=> + vaList + .nextVarg( + descriptorModule19.toMemoryLayout(cud).asInstanceOf[GroupLayout], + alloc.asInstanceOf[Allocator19].segmentAllocator + ) + .nn + ) override def get[A](using d: DescriptorOf[A]): A = transitionModule19.methodReturn[A](d.descriptor, as(d.descriptor)) diff --git a/j19/src/fr/hammons/slinc/modules/DescriptorModule19.scala b/j19/src/fr/hammons/slinc/modules/DescriptorModule19.scala index 614ccd8..bc8250d 100644 --- a/j19/src/fr/hammons/slinc/modules/DescriptorModule19.scala +++ b/j19/src/fr/hammons/slinc/modules/DescriptorModule19.scala @@ -7,6 +7,7 @@ import java.lang.foreign.MemoryLayout import java.lang.foreign.MemoryAddress import java.lang.foreign.MemorySegment import java.lang.foreign.GroupLayout +import java.lang.foreign.SequenceLayout given descriptorModule19: DescriptorModule with private val sdt = TrieMap.empty[StructDescriptor, GroupLayout] @@ -79,6 +80,11 @@ given descriptorModule19: DescriptorModule with else Seq.empty ) + def toDowncallLayout(td: TypeDescriptor): MemoryLayout = + toMemoryLayout(td) match + case _: SequenceLayout => ValueLayout.ADDRESS.nn + case o => o + def toCarrierType(td: TypeDescriptor): Class[?] = td match case ByteDescriptor => classOf[Byte] case ShortDescriptor => classOf[Short] @@ -86,10 +92,9 @@ given descriptorModule19: DescriptorModule with case LongDescriptor => classOf[Long] case FloatDescriptor => classOf[Float] case DoubleDescriptor => classOf[Double] - case VaListDescriptor => classOf[MemoryAddress] - case PtrDescriptor => classOf[MemoryAddress] - case _: StructDescriptor | _: CUnionDescriptor | - _: SetSizeArrayDescriptor => + case VaListDescriptor | _: SetSizeArrayDescriptor | PtrDescriptor => + classOf[MemoryAddress] + case _: StructDescriptor | _: CUnionDescriptor => classOf[MemorySegment] case ad: AliasDescriptor[?] => toCarrierType(ad.real) diff --git a/j19/src/fr/hammons/slinc/modules/LinkageModule19.scala b/j19/src/fr/hammons/slinc/modules/LinkageModule19.scala index 4c6d335..507c388 100644 --- a/j19/src/fr/hammons/slinc/modules/LinkageModule19.scala +++ b/j19/src/fr/hammons/slinc/modules/LinkageModule19.scala @@ -39,7 +39,7 @@ object LinkageModule19 extends LinkageModule: case Some(value) => JFunctionDescriptor .of( - toMemoryLayout(value), + toDowncallLayout(value), argDescriptors* ) .nn @@ -47,7 +47,7 @@ object LinkageModule19 extends LinkageModule: val fd = fdGen( descriptor.inputDescriptors - .map(toMemoryLayout), + .map(toDowncallLayout), varargs.view .map(_.use[DescriptorOf](dc ?=> _ => dc.descriptor)) .map(toMemoryLayout)