Skip to content

Commit

Permalink
Restrict argument type of concat to be a concrete type for type saf…
Browse files Browse the repository at this point in the history
…ety.
  • Loading branch information
tarao committed Nov 27, 2023
1 parent 67ed1ee commit da9dbf3
Show file tree
Hide file tree
Showing 5 changed files with 72 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,7 @@ object ArrayRecord extends ArrayRecord.Extensible[EmptyTuple] {
other: R2,
)(using c: Concat[R, R2]): c.Out =
withPotentialTypingError {
summon[typing.Concrete[R2]]
val vec = record
.__fields
.toVector
Expand Down
22 changes: 22 additions & 0 deletions modules/core/src/main/scala/com/github/tarao/record4s/Macros.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package com.github.tarao.record4s
import scala.annotation.nowarn

import Record.newMapRecord
import typing.Concrete
import typing.Record.{Concat, Lookup, Select, Unselect}

@nowarn("msg=unused local")
Expand Down Expand Up @@ -215,6 +216,27 @@ object Macros {
}
}

def derivedTypingConcreteImple[T: Type](using
Quotes,
): Expr[Concrete[T]] = withInternal {
import quotes.reflect.*
import internal.*

if (TypeRepr.of[T].dealias.typeSymbol.isTypeParam)
errorAndAbort(
Seq(
s"A concrete type expected but type variable ${Type.show[T]} is given.",
"Did you forget to make the method inline?",
).mkString("\n"),
)
else
'{
Concrete
.instance
.asInstanceOf[Concrete[T]]
}
}

private def typeNameOfImpl[T: Type](using Quotes): Expr[String] = {
import quotes.reflect.*

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ object Record {
inline def concat[R2: RecordLike, RR <: %](
other: R2,
)(using Concat.Aux[R, R2, RR]): RR = withPotentialTypingError {
summon[typing.Concrete[R2]]
newMapRecord[RR](
record
.__iterable
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,14 @@ trait MaybeError {
type Msg <: String
}

final class Concrete[T] private {}
object Concrete {
private[record4s] val instance = new Concrete[Nothing]

transparent inline given [T]: Concrete[T] =
${ Macros.derivedTypingConcreteImple }
}

private inline def showTypingError(using err: typing.MaybeError): Unit = {
import scala.compiletime.{constValue, erasedValue, error}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,22 @@ class TypeErrorSpec extends helper.UnitSpec {
typing.Record.Concat.Aux[R, % { val name: String }, RR],
): RR = record ++ %(email = email)
""" shouldNot typeCheck

def checkErrors(errs: List[Error]): Unit = {
errs should not be empty
errs.head.kind shouldBe ErrorKind.Typer
val _ = errs.exists(
_.message.startsWith(
"A concrete type expected",
),
) shouldBe true
}

checkErrors(typeCheckErrors("""
def concat[R1 <: %, R2 <: %, RR <: %](r1: R1, r2: R2)(using
typing.Record.Concat.Aux[R1, R2, RR],
): RR = r1 ++ r2
"""))
}
}

Expand Down Expand Up @@ -149,16 +165,38 @@ class TypeErrorSpec extends helper.UnitSpec {

it("should detect wrong usage") {
"""
def addEmail[R, RR <: %](record: ArrayRecord[R], email: String)(using
def addEmail[R, RR <: ProductRecord](record: ArrayRecord[R], email: String)(using
typing.ArrayRecord.Concat.Aux[R, Nothing, RR],
): RR = record ++ ArrayRecord(email = email)
""" shouldNot typeCheck

"""
def addEmail[R, RR <: %](record: ArrayRecord[R], email: String)(using
def addEmail[R, RR <: ProductRecord](record: ArrayRecord[R], email: String)(using
typing.ArrayRecord.Concat.Aux[R, ArrayRecord[("name", String) *: EmptyTuple], RR],
): RR = record ++ ArrayRecord(email = email)
""" shouldNot typeCheck

"""
def concat[R1, R2, RR <: ProductRecord](r1: ArrayRecord[R1], r2: ArrayRecord[R2])(using
typing.ArrayRecord.Concat.Aux[R1, R2, RR],
): RR = r1 ++ r2
""" shouldNot typeCheck

def checkErrors(errs: List[Error]): Unit = {
errs should not be empty
errs.head.kind shouldBe ErrorKind.Typer
val _ = errs.exists(
_.message.startsWith(
"A concrete type expected",
),
) shouldBe true
}

checkErrors(typeCheckErrors("""
def concat[R1, R2 <: %, RR <: ProductRecord](r1: ArrayRecord[R1], r2: R2)(using
typing.ArrayRecord.Concat.Aux[R1, R2, RR],
): RR = r1 ++ r2
"""))
}
}

Expand Down

0 comments on commit da9dbf3

Please sign in to comment.