Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fuse emitting and printing of trees in the backend #4917

Merged
merged 3 commits into from
Jan 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ final class ClosureLinkerBackend(config: LinkerBackendImpl.Config)
.withTrackAllGlobalRefs(true)
.withInternalModulePattern(m => OutputPatternsImpl.moduleName(config.outputPatterns, m.id))

new Emitter(emitterConfig)
new Emitter(emitterConfig, ClosureLinkerBackend.PostTransformer)
}

val symbolRequirements: SymbolRequirement = emitter.symbolRequirements
Expand Down Expand Up @@ -106,7 +106,8 @@ final class ClosureLinkerBackend(config: LinkerBackendImpl.Config)
sjsModule <- moduleSet.modules.headOption
} yield {
val closureChunk = logger.time("Closure: Create trees)") {
buildChunk(emitterResult.body(sjsModule.id))
val (trees, _) = emitterResult.body(sjsModule.id)
buildChunk(trees)
}

logger.time("Closure: Compiler pass") {
Expand Down Expand Up @@ -295,4 +296,11 @@ private object ClosureLinkerBackend {
Function.prototype.apply;
var NaN = 0.0/0.0, Infinity = 1.0/0.0, undefined = void 0;
"""

private object PostTransformer extends Emitter.PostTransformer[js.Tree] {
// Do not apply ClosureAstTransformer eagerly:
// The ASTs used by closure are highly mutable, so re-using them is non-trivial.
// Since closure is slow anyways, we haven't built the optimization.
def transformStats(trees: List[js.Tree], indent: Int): List[js.Tree] = trees
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ import scala.concurrent._
import java.nio.ByteBuffer
import java.nio.charset.StandardCharsets

import java.util.concurrent.atomic.AtomicInteger

import org.scalajs.logging.Logger

import org.scalajs.linker.interface.{IRFile, OutputDirectory, Report}
Expand All @@ -36,12 +38,19 @@ final class BasicLinkerBackend(config: LinkerBackendImpl.Config)

import BasicLinkerBackend._

private[this] var totalModules = 0
private[this] val rewrittenModules = new AtomicInteger(0)

private[this] val emitter = {
val emitterConfig = Emitter.Config(config.commonConfig.coreSpec)
.withJSHeader(config.jsHeader)
.withInternalModulePattern(m => OutputPatternsImpl.moduleName(config.outputPatterns, m.id))

new Emitter(emitterConfig)
val postTransformer =
if (config.sourceMap) PostTransformerWithSourceMap
else PostTransformerWithoutSourceMap

new Emitter(emitterConfig, postTransformer)
}

val symbolRequirements: SymbolRequirement = emitter.symbolRequirements
Expand All @@ -61,31 +70,35 @@ final class BasicLinkerBackend(config: LinkerBackendImpl.Config)
implicit ec: ExecutionContext): Future[Report] = {
verifyModuleSet(moduleSet)

// Reset stats.

totalModules = moduleSet.modules.size
rewrittenModules.set(0)

val emitterResult = logger.time("Emitter") {
emitter.emit(moduleSet, logger)
}

val skipContentCheck = !isFirstRun
isFirstRun = false

printedModuleSetCache.startRun(moduleSet)
val allChanged =
printedModuleSetCache.updateGlobal(emitterResult.header, emitterResult.footer)

val writer = new OutputWriter(output, config, skipContentCheck) {
protected def writeModuleWithoutSourceMap(moduleID: ModuleID, force: Boolean): Option[ByteBuffer] = {
val cache = printedModuleSetCache.getModuleCache(moduleID)
val changed = cache.update(emitterResult.body(moduleID))
val (printedTrees, changed) = emitterResult.body(moduleID)

if (force || changed || allChanged) {
printedModuleSetCache.incRewrittenModules()
rewrittenModules.incrementAndGet()

val jsFileWriter = new ByteArrayWriter(sizeHintFor(cache.getPreviousFinalJSFileSize()))

jsFileWriter.write(printedModuleSetCache.headerBytes)
jsFileWriter.writeASCIIString("'use strict';\n")

for (printedTree <- cache.printedTrees)
for (printedTree <- printedTrees)
jsFileWriter.write(printedTree.jsCode)

jsFileWriter.write(printedModuleSetCache.footerBytes)
Expand All @@ -99,10 +112,10 @@ final class BasicLinkerBackend(config: LinkerBackendImpl.Config)

protected def writeModuleWithSourceMap(moduleID: ModuleID, force: Boolean): Option[(ByteBuffer, ByteBuffer)] = {
val cache = printedModuleSetCache.getModuleCache(moduleID)
val changed = cache.update(emitterResult.body(moduleID))
val (printedTrees, changed) = emitterResult.body(moduleID)

if (force || changed || allChanged) {
printedModuleSetCache.incRewrittenModules()
rewrittenModules.incrementAndGet()

val jsFileWriter = new ByteArrayWriter(sizeHintFor(cache.getPreviousFinalJSFileSize()))
val sourceMapWriter = new ByteArrayWriter(sizeHintFor(cache.getPreviousFinalSourceMapSize()))
Expand All @@ -120,7 +133,7 @@ final class BasicLinkerBackend(config: LinkerBackendImpl.Config)
jsFileWriter.writeASCIIString("'use strict';\n")
smWriter.nextLine()

for (printedTree <- cache.printedTrees) {
for (printedTree <- printedTrees) {
jsFileWriter.write(printedTree.jsCode)
smWriter.insertFragment(printedTree.sourceMapFragment)
}
Expand All @@ -145,9 +158,15 @@ final class BasicLinkerBackend(config: LinkerBackendImpl.Config)
writer.write(moduleSet)
}.andThen { case _ =>
printedModuleSetCache.cleanAfterRun()
printedModuleSetCache.logStats(logger)
logStats(logger)
}
}

private def logStats(logger: Logger): Unit = {
// Message extracted in BasicLinkerBackendTest
logger.debug(
s"BasicBackend: total modules: $totalModules; re-written: ${rewrittenModules.get()}")
}
}

private object BasicLinkerBackend {
Expand All @@ -161,20 +180,6 @@ private object BasicLinkerBackend {

private val modules = new java.util.concurrent.ConcurrentHashMap[ModuleID, PrintedModuleCache]

private var totalModules = 0
private val rewrittenModules = new java.util.concurrent.atomic.AtomicInteger(0)

private var totalTopLevelTrees = 0
private var recomputedTopLevelTrees = 0

def startRun(moduleSet: ModuleSet): Unit = {
totalModules = moduleSet.modules.size
rewrittenModules.set(0)

totalTopLevelTrees = 0
recomputedTopLevelTrees = 0
}

def updateGlobal(header: String, footer: String): Boolean = {
if (header == lastHeader && footer == lastFooter) {
false
Expand All @@ -193,61 +198,30 @@ private object BasicLinkerBackend {
def headerNewLineCount: Int = _headerNewLineCountCache

def getModuleCache(moduleID: ModuleID): PrintedModuleCache = {
val result = modules.computeIfAbsent(moduleID, { _ =>
if (withSourceMaps) new PrintedModuleCacheWithSourceMaps
else new PrintedModuleCache
})

val result = modules.computeIfAbsent(moduleID, _ => new PrintedModuleCache)
result.startRun()
result
}

def incRewrittenModules(): Unit =
rewrittenModules.incrementAndGet()

def cleanAfterRun(): Unit = {
val iter = modules.entrySet().iterator()
while (iter.hasNext()) {
val moduleCache = iter.next().getValue()
if (moduleCache.cleanAfterRun()) {
totalTopLevelTrees += moduleCache.getTotalTopLevelTrees
recomputedTopLevelTrees += moduleCache.getRecomputedTopLevelTrees
} else {
if (!moduleCache.cleanAfterRun()) {
iter.remove()
}
}
}

def logStats(logger: Logger): Unit = {
/* These messages are extracted in BasicLinkerBackendTest to assert that
* we do not invalidate anything in a no-op second run.
*/
logger.debug(
s"BasicBackend: total top-level trees: $totalTopLevelTrees; re-computed: $recomputedTopLevelTrees")
logger.debug(
s"BasicBackend: total modules: $totalModules; re-written: ${rewrittenModules.get()}")
}
}

private final class PrintedTree(val jsCode: Array[Byte], val sourceMapFragment: SourceMapWriter.Fragment) {
var cachedUsed: Boolean = false
}

private sealed class PrintedModuleCache {
private var cacheUsed = false
private var changed = false
private var lastJSTrees: List[js.Tree] = Nil
private var printedTreesCache: List[PrintedTree] = Nil
private val cache = new java.util.IdentityHashMap[js.Tree, PrintedTree]

private var previousFinalJSFileSize: Int = 0
private var previousFinalSourceMapSize: Int = 0

private var recomputedTopLevelTrees = 0

def startRun(): Unit = {
cacheUsed = true
recomputedTopLevelTrees = 0
}

def getPreviousFinalJSFileSize(): Int = previousFinalJSFileSize
Expand All @@ -259,72 +233,42 @@ private object BasicLinkerBackend {
previousFinalSourceMapSize = finalSourceMapSize
}

def update(newJSTrees: List[js.Tree]): Boolean = {
val changed = !newJSTrees.corresponds(lastJSTrees)(_ eq _)
this.changed = changed
if (changed) {
printedTreesCache = newJSTrees.map(getOrComputePrintedTree(_))
lastJSTrees = newJSTrees
}
changed
}

private def getOrComputePrintedTree(tree: js.Tree): PrintedTree = {
val result = cache.computeIfAbsent(tree, { (tree: js.Tree) =>
recomputedTopLevelTrees += 1
computePrintedTree(tree)
})

result.cachedUsed = true
result
}

protected def computePrintedTree(tree: js.Tree): PrintedTree = {
val jsCodeWriter = new ByteArrayWriter()
val printer = new Printers.JSTreePrinter(jsCodeWriter)

printer.printStat(tree)

new PrintedTree(jsCodeWriter.toByteArray(), SourceMapWriter.Fragment.Empty)
def cleanAfterRun(): Boolean = {
val wasUsed = cacheUsed
cacheUsed = false
wasUsed
}
}

def printedTrees: List[PrintedTree] = printedTreesCache
private object PostTransformerWithoutSourceMap extends Emitter.PostTransformer[js.PrintedTree] {
def transformStats(trees: List[js.Tree], indent: Int): List[js.PrintedTree] = {
if (trees.isEmpty) {
Nil // Fast path
} else {
val jsCodeWriter = new ByteArrayWriter()
val printer = new Printers.JSTreePrinter(jsCodeWriter, indent)

def cleanAfterRun(): Boolean = {
if (cacheUsed) {
cacheUsed = false

if (changed) {
val iter = cache.entrySet().iterator()
while (iter.hasNext()) {
val printedTree = iter.next().getValue()
if (printedTree.cachedUsed)
printedTree.cachedUsed = false
else
iter.remove()
}
}
trees.map(printer.printStat(_))

true
} else {
false
js.PrintedTree(jsCodeWriter.toByteArray(), SourceMapWriter.Fragment.Empty) :: Nil
}
}

def getTotalTopLevelTrees: Int = lastJSTrees.size
def getRecomputedTopLevelTrees: Int = recomputedTopLevelTrees
}

private final class PrintedModuleCacheWithSourceMaps extends PrintedModuleCache {
override protected def computePrintedTree(tree: js.Tree): PrintedTree = {
val jsCodeWriter = new ByteArrayWriter()
val smFragmentBuilder = new SourceMapWriter.FragmentBuilder()
val printer = new Printers.JSTreePrinterWithSourceMap(jsCodeWriter, smFragmentBuilder)
private object PostTransformerWithSourceMap extends Emitter.PostTransformer[js.PrintedTree] {
def transformStats(trees: List[js.Tree], indent: Int): List[js.PrintedTree] = {
if (trees.isEmpty) {
Nil // Fast path
} else {
val jsCodeWriter = new ByteArrayWriter()
val smFragmentBuilder = new SourceMapWriter.FragmentBuilder()
val printer = new Printers.JSTreePrinterWithSourceMap(jsCodeWriter, smFragmentBuilder, indent)

printer.printStat(tree)
smFragmentBuilder.complete()
trees.map(printer.printStat(_))
smFragmentBuilder.complete()

new PrintedTree(jsCodeWriter.toByteArray(), smFragmentBuilder.result())
js.PrintedTree(jsCodeWriter.toByteArray(), smFragmentBuilder.result()) :: Nil
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ private[emitter] final class ClassEmitter(sjsGen: SJSGen) {

def buildClass(className: ClassName, isJSClass: Boolean, jsClassCaptures: Option[List[ParamDef]],
hasClassInitializer: Boolean,
superClass: Option[ClassIdent], storeJSSuperClass: Option[js.Tree], useESClass: Boolean,
superClass: Option[ClassIdent], storeJSSuperClass: List[js.Tree], useESClass: Boolean,
members: List[js.Tree])(
implicit moduleContext: ModuleContext,
globalKnowledge: GlobalKnowledge, pos: Position): WithGlobals[List[js.Tree]] = {
Expand Down Expand Up @@ -75,7 +75,7 @@ private[emitter] final class ClassEmitter(sjsGen: SJSGen) {
val createClassValueVar = genEmptyMutableLet(classValueIdent)

val entireClassDefWithGlobals = if (useESClass) {
genJSSuperCtor(superClass, storeJSSuperClass.isDefined).map { jsSuperClass =>
genJSSuperCtor(superClass, storeJSSuperClass.nonEmpty).map { jsSuperClass =>
List(classValueVar := js.ClassDef(Some(classValueIdent), Some(jsSuperClass), members))
}
} else {
Expand All @@ -86,7 +86,7 @@ private[emitter] final class ClassEmitter(sjsGen: SJSGen) {
entireClassDef <- entireClassDefWithGlobals
createStaticFields <- genCreateStaticFieldsOfJSClass(className)
} yield {
storeJSSuperClass.toList ::: entireClassDef ::: createStaticFields
storeJSSuperClass ::: entireClassDef ::: createStaticFields
}

jsClassCaptures.fold {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,9 @@ import PolyfillableBuiltin._

private[emitter] object CoreJSLib {

def build(sjsGen: SJSGen, moduleContext: ModuleContext,
globalKnowledge: GlobalKnowledge): WithGlobals[Lib] = {
new CoreJSLibBuilder(sjsGen)(moduleContext, globalKnowledge).build()
def build[E](sjsGen: SJSGen, postTransform: List[Tree] => E, moduleContext: ModuleContext,
globalKnowledge: GlobalKnowledge): WithGlobals[Lib[E]] = {
new CoreJSLibBuilder(sjsGen)(moduleContext, globalKnowledge).build(postTransform)
}

/** A fully built CoreJSLib
Expand All @@ -52,10 +52,10 @@ private[emitter] object CoreJSLib {
* @param initialization Things that depend on Scala.js generated classes.
* These must have class definitions (but not static fields) available.
*/
final class Lib private[CoreJSLib] (
val preObjectDefinitions: List[Tree],
val postObjectDefinitions: List[Tree],
val initialization: List[Tree])
final class Lib[E] private[CoreJSLib] (
val preObjectDefinitions: E,
val postObjectDefinitions: E,
val initialization: E)

private class CoreJSLibBuilder(sjsGen: SJSGen)(
implicit moduleContext: ModuleContext, globalKnowledge: GlobalKnowledge) {
Expand Down Expand Up @@ -115,9 +115,11 @@ private[emitter] object CoreJSLib {
private val specializedArrayTypeRefs: List[NonArrayTypeRef] =
ClassRef(ObjectClass) :: orderedPrimRefsWithoutVoid

def build(): WithGlobals[Lib] = {
val lib = new Lib(buildPreObjectDefinitions(),
buildPostObjectDefinitions(), buildInitializations())
def build[E](postTransform: List[Tree] => E): WithGlobals[Lib[E]] = {
val lib = new Lib(
postTransform(buildPreObjectDefinitions()),
postTransform(buildPostObjectDefinitions()),
postTransform(buildInitializations()))
WithGlobals(lib, trackedGlobalRefs)
}

Expand Down