Skip to content

Commit

Permalink
Encode main class name to match outputs of the compiler (#2955)
Browse files Browse the repository at this point in the history
* Reproduce issue #2790

* Encode main class to match the outputs of the compiler

* Use scala.reflect.NameTransformer directly instead of reimplementing its logic

* Include package name encoding with exception of dots

* Make sure to create mirror class even if module is not a candidate for forwarders (match JVM codegen in scala compiler)

* Fix failing tools test
  • Loading branch information
WojciechMazur committed Oct 30, 2022
1 parent 76bec79 commit 02f38ff
Show file tree
Hide file tree
Showing 8 changed files with 109 additions and 72 deletions.
Expand Up @@ -88,7 +88,7 @@ abstract class NirGenPhase[G <: Global with Singleton](override val global: G)
(path, reflectiveInstBuf.toSeq)
}.toMap

val allRegularDefns = if (generatedStaticForwarderClasses.isEmpty) {
val allRegularDefns = if (generatedMirrorClasses.isEmpty) {
/* Fast path, applicable under -Xno-forwarders, as well as when all
* the `object`s of a compilation unit have a companion class.
*/
Expand Down Expand Up @@ -121,9 +121,9 @@ abstract class NirGenPhase[G <: Global with Singleton](override val global: G)
}.toSet

val staticForwarderDefns: List[nir.Defn] =
generatedStaticForwarderClasses
generatedMirrorClasses
.collect {
case (site, StaticForwarderClass(classDef, forwarders)) =>
case (site, MirrorClass(classDef, forwarders)) =>
val name = caseInsensitiveNameOf(classDef)
if (!generatedCaseInsensitiveNames.contains(name)) {
classDef +: forwarders
Expand Down Expand Up @@ -164,7 +164,7 @@ abstract class NirGenPhase[G <: Global with Singleton](override val global: G)
.parallel()
.forEach(generateIRFile)
} finally {
generatedStaticForwarderClasses.clear()
generatedMirrorClasses.clear()
}
}

Expand Down
Expand Up @@ -21,10 +21,9 @@ trait NirGenStat[G <: nsc.Global with Singleton] { self: NirGenPhase[G] =>
val reflectiveInstantiationInfo =
mutable.UnrolledBuffer.empty[ReflectiveInstantiationBuffer]

protected val generatedStaticForwarderClasses =
mutable.Map.empty[Symbol, StaticForwarderClass]
protected val generatedMirrorClasses = mutable.Map.empty[Symbol, MirrorClass]

protected case class StaticForwarderClass(
protected case class MirrorClass(
defn: nir.Defn.Class,
forwarders: Seq[nir.Defn.Define]
)
Expand Down Expand Up @@ -124,6 +123,7 @@ trait NirGenStat[G <: nsc.Global with Singleton] { self: NirGenPhase[G] =>
genReflectiveInstantiation(cd)
genClassFields(cd)
genMethods(cd)
genMirrorClass(cd)

buf += {
if (sym.isScalaModule) {
Expand Down Expand Up @@ -557,7 +557,6 @@ trait NirGenStat[G <: nsc.Global with Singleton] { self: NirGenPhase[G] =>
val methods = cd.impl.body.flatMap {
case dd: DefDef => genMethod(dd)
case _ => Nil

}
val forwarders = genStaticMethodForwarders(cd, methods)
buf ++= methods
Expand Down Expand Up @@ -1074,23 +1073,34 @@ trait NirGenStat[G <: nsc.Global with Singleton] { self: NirGenPhase[G] =>
): Seq[Defn] = {
val sym = td.symbol
if (!isCandidateForForwarders(sym)) Nil
else if (sym.isModuleClass) {
if (!sym.linkedClassOfClass.exists) {
val forwarders = genStaticForwardersFromModuleClass(Nil, sym)
if (forwarders.nonEmpty) {
val classDefn = Defn.Class(
attrs = Attrs.None,
name = Global.Top(genTypeName(sym).id.stripSuffix("$")),
parent = Some(Rt.Object.name),
traits = Nil
)(td.pos)
val forwarderClass = StaticForwarderClass(classDefn, forwarders)
generatedStaticForwarderClasses += sym -> forwarderClass
}
}
Nil
} else {
genStaticForwardersForClassOrInterface(existingMethods, sym)
else if (sym.isModuleClass) Nil
else genStaticForwardersForClassOrInterface(existingMethods, sym)
}

/** Create a mirror class for top level module that has no defined companion
* class. A mirror class is a class containing only static methods that
* forward to the corresponding method on the MODULE instance of the given
* Scala object. It will only be generated if there is no companion class: if
* there is, an attempt will instead be made to add the forwarder methods to
* the companion class.
*/
private def genMirrorClass(cd: ClassDef) = {
val sym = cd.symbol
// phase travel to pickler required for isNestedClass (looks at owner)
val isTopLevelModuleClass = exitingPickler {
sym.isModuleClass && !sym.isNestedClass
}
if (isTopLevelModuleClass && sym.companionClass == NoSymbol) {
val classDefn = Defn.Class(
attrs = Attrs.None,
name = Global.Top(genTypeName(sym).id.stripSuffix("$")),
parent = Some(Rt.Object.name),
traits = Nil
)(cd.pos)
generatedMirrorClasses += sym -> MirrorClass(
classDefn,
genStaticForwardersFromModuleClass(Nil, sym)
)
}
}

Expand Down
Expand Up @@ -57,7 +57,7 @@ class NirCodeGen(val settings: GenNIR.Settings)(using ctx: Context)
genCompilationUnit(ctx.compilationUnit)
} finally {
generatedDefns.clear()
generatedStaticForwarderClasses.clear()
generatedMirrorClasses.clear()
reflectiveInstantiationBuffers.clear()
}
}
Expand All @@ -84,7 +84,7 @@ class NirCodeGen(val settings: GenNIR.Settings)(using ctx: Context)
.groupMapReduce(buf => getFileFor(cunit, buf.name.top))(_.toSeq)(_ ++ _)
.foreach(genIRFile(_, _))

if (generatedStaticForwarderClasses.nonEmpty) {
if (generatedMirrorClasses.nonEmpty) {
// Ported from Scala.js
/* #4148 Add generated static forwarder classes, except those that
* would collide with regular classes on case insensitive file systems.
Expand All @@ -108,8 +108,8 @@ class NirCodeGen(val settings: GenNIR.Settings)(using ctx: Context)
case cls: Defn.Class => caseInsensitiveNameOf(cls)
}.toSet

for ((site, staticCls) <- generatedStaticForwarderClasses) {
val StaticForwarderClass(classDef, forwarders) = staticCls
for ((site, staticCls) <- generatedMirrorClasses) {
val MirrorClass(classDef, forwarders) = staticCls
val caseInsensitiveName = caseInsensitiveNameOf(classDef)
if (!generatedCaseInsensitiveNames.contains(caseInsensitiveName)) {
val file = getFileFor(cunit, classDef.name)
Expand Down
Expand Up @@ -26,10 +26,10 @@ trait NirGenStat(using Context) {
import positionsConversions.fromSpan

protected val generatedDefns = mutable.UnrolledBuffer.empty[nir.Defn]
protected val generatedStaticForwarderClasses =
mutable.Map.empty[Symbol, StaticForwarderClass]
protected val generatedMirrorClasses =
mutable.Map.empty[Symbol, MirrorClass]

protected case class StaticForwarderClass(
protected case class MirrorClass(
defn: nir.Defn.Class,
forwarders: Seq[nir.Defn.Define]
)
Expand Down Expand Up @@ -66,6 +66,7 @@ trait NirGenStat(using Context) {
genClassFields(td)
genMethods(td)
genReflectiveInstantiation(td)
genMirrorClass(td)
}

private def genClassAttrs(td: TypeDef): nir.Attrs = {
Expand Down Expand Up @@ -583,23 +584,35 @@ trait NirGenStat(using Context) {
): Seq[Defn] = {
val sym = td.symbol
if !isCandidateForForwarders(sym) then Nil
else if sym.isStaticModule then {
if !sym.linkedClass.exists then {
val forwarders = genStaticForwardersFromModuleClass(Nil, sym)
if (forwarders.nonEmpty) {
given pos: nir.Position = td.span
val classDefn = Defn.Class(
attrs = Attrs.None,
name = Global.Top(genTypeName(sym).id.stripSuffix("$")),
parent = Some(Rt.Object.name),
traits = Nil
)
val forwarderClass = StaticForwarderClass(classDefn, forwarders)
generatedStaticForwarderClasses += sym -> forwarderClass
}
}
Nil
} else genStaticForwardersForClassOrInterface(existingMethods, sym)
else if sym.isStaticModule then Nil
else genStaticForwardersForClassOrInterface(existingMethods, sym)
}

/** Create a mirror class for top level module that has no defined companion
* class. A mirror class is a class containing only static methods that
* forward to the corresponding method on the MODULE instance of the given
* Scala object. It will only be generated if there is no companion class: if
* there is, an attempt will instead be made to add the forwarder methods to
* the companion class.
*/
private def genMirrorClass(td: TypeDef): Unit = {
given pos: nir.Position = td.span
val sym = td.symbol
val isTopLevelModuleClass = sym.is(ModuleClass) &&
atPhase(flattenPhase) {
toDenot(sym).owner.is(PackageClass)
}
if isTopLevelModuleClass && sym.companionClass == NoSymbol then {
val classDefn = Defn.Class(
attrs = Attrs.None,
name = Global.Top(genTypeName(sym).id.stripSuffix("$")),
parent = Some(Rt.Object.name),
traits = Nil
)
generatedMirrorClasses += sym -> MirrorClass(
classDefn,
genStaticForwardersFromModuleClass(Nil, sym)
)
}
}
}
Expand Up @@ -17,8 +17,7 @@ private[scalanative] object ScalaNative {
/** Compute all globals that must be reachable based on given configuration.
*/
def entries(config: Config): Seq[Global] = {
val mainClass = Global.Top(config.mainClass)
val entry = mainClass.member(Rt.ScalaMainSig)
val entry = encodedMainClass(config).member(Rt.ScalaMainSig)
entry +: CodeGen.depends
}

Expand Down Expand Up @@ -176,4 +175,11 @@ private[scalanative] object ScalaNative {
}
}
}

private[scalanative] def encodedMainClass(config: Config): Global.Top = {
import scala.reflect.NameTransformer.encode
val encoded = config.mainClass.split('.').map(encode).mkString(".")
Global.Top(encoded)
}

}
4 changes: 2 additions & 2 deletions tools/src/main/scala/scala/scalanative/codegen/CodeGen.scala
Expand Up @@ -5,7 +5,7 @@ import java.io.File
import java.nio.file.{Path, Paths}
import scala.collection.mutable
import scala.scalanative.build.{Config, IncCompilationContext}
import scala.scalanative.build.core.ScalaNative.dumpDefns
import scala.scalanative.build.core.ScalaNative.{dumpDefns, encodedMainClass}
import scala.scalanative.io.VirtualDirectory
import scala.scalanative.nir._
import scala.scalanative.util.{Scope, partitionBy, procs}
Expand All @@ -23,7 +23,7 @@ object CodeGen {
implicit val meta: Metadata =
new Metadata(linked, proxies, config.compilerConfig.is32BitPlatform)

val generated = Generate(Global.Top(config.mainClass), defns ++ proxies)
val generated = Generate(encodedMainClass(config), defns ++ proxies)
val embedded = ResourceEmbedder(config)
val lowered = lower(generated ++ embedded)
dumpDefns(config, "lowered", lowered)
Expand Down
16 changes: 2 additions & 14 deletions tools/src/test/scala/scala/scalanative/NIRCompilerTest.scala
Expand Up @@ -33,20 +33,8 @@ class NIRCompilerTest extends AnyFlatSpec with Matchers with Inspectors {
val nirFiles =
compiler.compile(sourcesDir) filter (Files
.isRegularFile(_)) map (_.getFileName.toString)
val expectedNames =
Seq(
"A.class",
"A.nir",
"B.class",
"B.nir",
"C.class",
"C.nir",
"D.class",
"D.nir",
"E$.class",
"E$.nir",
"E.class"
)
val expectedNames = Seq("A", "B", "C", "D", "E", "E$")
.flatMap(name => Seq(s"$name.class", s"$name.nir"))
nirFiles should contain theSameElementsAs expectedNames
}
}
Expand Down
30 changes: 25 additions & 5 deletions tools/src/test/scala/scala/scalanative/linker/IssuesSpec.scala
Expand Up @@ -6,20 +6,40 @@ import scala.scalanative.LinkerSpec
import org.scalatest.matchers.should._

class IssuesSpec extends LinkerSpec with Matchers {
private val mainClass = "Test$"
private val mainClass = "Test"
private val sourceFile = "Test.scala"

private def testLinked(source: String)(fn: Result => Unit): Unit =
link("Test", sources = Map("Test.scala" -> source)) {
private def testLinked(source: String, mainClass: String = mainClass)(
fn: Result => Unit
): Unit =
link(mainClass, sources = Map("Test.scala" -> source)) {
case (_, result) => fn(result)
}

private def checkNoLinkageErrors(source: String) =
testLinked(source.stripMargin) { result =>
private def checkNoLinkageErrors(
source: String,
mainClass: String = mainClass
) =
testLinked(source.stripMargin, mainClass) { result =>
val erros = Check(result)
erros shouldBe empty
}

"Issue #2790" should "link main classes using encoded characters" in {
// All encoded character and an example of unciode encode character ';'
val packageName = "foo.`b~a-r`.`b;a;z`"
val mainClass = raw"Test-native~=<>!#%^&|*/+-:'?@;sc"
val fqcn = s"$packageName.$mainClass".replace("`", "")
checkNoLinkageErrors(
mainClass = fqcn,
source = s"""package $packageName
|object `$mainClass`{
| def main(args: Array[String]) = ()
|}
|""".stripMargin
)
}

"Issue #2880" should "handle lambas correctly" in checkNoLinkageErrors {
"""
|object Test {
Expand Down

0 comments on commit 02f38ff

Please sign in to comment.