Skip to content

Commit

Permalink
Merge pull request #287 from smarter/fix/main-class-detection
Browse files Browse the repository at this point in the history
Fix #102: Better main class detection
  • Loading branch information
eed3si9n committed May 28, 2017
2 parents 970e18e + f10c53c commit d253375
Show file tree
Hide file tree
Showing 17 changed files with 140 additions and 18 deletions.
5 changes: 5 additions & 0 deletions internal/compiler-bridge/src/main/scala/xsbt/API.scala
Expand Up @@ -46,8 +46,10 @@ final class API(val global: CallbackGlobal) extends Compat with GlobalHelpers {
extractUsedNames.extractAndReport(unit)

val classApis = traverser.allNonLocalClasses
val mainClasses = traverser.mainClasses

classApis.foreach(callback.api(sourceFile, _))
mainClasses.foreach(callback.mainClass(sourceFile, _))
}
}

Expand All @@ -56,6 +58,9 @@ final class API(val global: CallbackGlobal) extends Compat with GlobalHelpers {
def allNonLocalClasses: Set[ClassLike] = {
extractApi.allExtractedNonLocalClasses
}

def mainClasses: Set[String] = extractApi.mainClasses

def `class`(c: Symbol): Unit = {
extractApi.extractAllClassesOf(c.owner, c)
}
Expand Down
12 changes: 11 additions & 1 deletion internal/compiler-bridge/src/main/scala/xsbt/ExtractAPI.scala
Expand Up @@ -10,7 +10,7 @@ package xsbt
import java.io.File
import java.util.{ Arrays, Comparator }
import scala.tools.nsc.symtab.Flags
import scala.collection.mutable.{ HashMap, HashSet }
import scala.collection.mutable.{ HashMap, HashSet, ListBuffer }
import xsbti.api._

import scala.tools.nsc.Global
Expand Down Expand Up @@ -71,6 +71,7 @@ class ExtractAPI[GlobalType <: Global](
private[this] val emptyStringArray = new Array[String](0)

private[this] val allNonLocalClassesInSrc = new HashSet[xsbti.api.ClassLike]
private[this] val _mainClasses = new HashSet[String]

/**
* Implements a work-around for https://github.com/sbt/sbt/issues/823
Expand Down Expand Up @@ -600,6 +601,11 @@ class ExtractAPI[GlobalType <: Global](
allNonLocalClassesInSrc.toSet
}

def mainClasses: Set[String] = {
forceStructures()
_mainClasses.toSet
}

private def classLike(in: Symbol, c: Symbol): ClassLikeDef =
classLikeCache.getOrElseUpdate((in, c), mkClassLike(in, c))
private def mkClassLike(in: Symbol, c: Symbol): ClassLikeDef = {
Expand Down Expand Up @@ -641,6 +647,10 @@ class ExtractAPI[GlobalType <: Global](

allNonLocalClassesInSrc += classWithMembers

if (sym.isStatic && defType == DefinitionType.Module && definitions.hasJavaMainMethod(sym)) {
_mainClasses += name
}

val classDef = new xsbti.api.ClassLikeDef(
name,
acc,
Expand Down
Expand Up @@ -110,6 +110,20 @@ void generatedNonLocalClass(File source,
*/
void api(File sourceFile, xsbti.api.ClassLike classApi);

/**
* Register a class containing an entry point coming from a given source file.
*
* A class is an entry point if its bytecode contains a method with the
* following signature:
* <pre>
* public static void main(String[] args);
* </pre>
*
* @param sourceFile Source file where <code>className</code> is defined.
* @param className A class containing an entry point.
*/
void mainClass(File sourceFile, String className);

/**
* Register the use of a <code>name</code> from a given source class name.
*
Expand Down Expand Up @@ -158,4 +172,4 @@ void problem(String what,
* phase defined by <code>xsbt-analyzer</code> should be added.
*/
boolean enabled();
}
}
Expand Up @@ -30,5 +30,12 @@ public interface SourceInfo {
* @return The compiler reported problems.
*/
public Problem[] getUnreportedProblems();

/**
* Returns the main classes found in this compilation unit.
*
* @return The full name of the main classes, like "foo.bar.Main"
*/
public String[] getMainClasses();
}

Expand Up @@ -64,6 +64,8 @@ class TestCallback extends AnalysisCallback {
()
}

def mainClass(source: File, className: String): Unit = ()

override def enabled(): Boolean = true

def problem(category: String,
Expand Down
Expand Up @@ -22,14 +22,16 @@ object ClassToAPI {
def apply(c: Seq[Class[_]]): Seq[api.ClassLike] = process(c)._1

// (api, public inherited classes)
def process(classes: Seq[Class[_]]): (Seq[api.ClassLike], Set[(Class[_], Class[_])]) = {
def process(
classes: Seq[Class[_]]): (Seq[api.ClassLike], Seq[String], Set[(Class[_], Class[_])]) = {
val cmap = emptyClassMap
classes.foreach(toDefinitions(cmap)) // force recording of class definitions
cmap.lz.foreach(_.get()) // force thunks to ensure all inherited dependencies are recorded
val classApis = cmap.allNonLocalClasses.toSeq
val mainClasses = cmap.mainClasses.toSeq
val inDeps = cmap.inherited.toSet
cmap.clear()
(classApis, inDeps)
(classApis, mainClasses, inDeps)
}

// Avoiding implicit allocation.
Expand All @@ -55,7 +57,8 @@ object ClassToAPI {
private[sbt] val memo: mutable.Map[String, Seq[api.ClassLikeDef]],
private[sbt] val inherited: mutable.Set[(Class[_], Class[_])],
private[sbt] val lz: mutable.Buffer[xsbti.api.Lazy[_]],
private[sbt] val allNonLocalClasses: mutable.Set[api.ClassLike]
private[sbt] val allNonLocalClasses: mutable.Set[api.ClassLike],
private[sbt] val mainClasses: mutable.Set[String]
) {
def clear(): Unit = {
memo.clear()
Expand All @@ -67,6 +70,7 @@ object ClassToAPI {
new ClassMap(new mutable.HashMap,
new mutable.HashSet,
new mutable.ListBuffer,
new mutable.HashSet,
new mutable.HashSet)

def classCanonicalName(c: Class[_]): String =
Expand Down Expand Up @@ -115,6 +119,17 @@ object ClassToAPI {
val defsEmptyMembers = clsDef :: statDef :: Nil
cmap.memo(name) = defsEmptyMembers
cmap.allNonLocalClasses ++= defs

if (c.getMethods.exists(
meth =>
meth.getName == "main" &&
Modifier.isStatic(meth.getModifiers) &&
meth.getParameterTypes.length == 1 &&
meth.getParameterTypes.head == classOf[Array[String]] &&
meth.getReturnType == java.lang.Void.TYPE)) {
cmap.mainClasses += name
}

defsEmptyMembers
}

Expand Down
Expand Up @@ -80,8 +80,9 @@ class ClassToAPISpecification extends UnitSpec {
def readAPI(callback: AnalysisCallback,
source: File,
classes: Seq[Class[_]]): Set[(String, String)] = {
val (apis, inherits) = ClassToAPI.process(classes)
val (apis, mainClasses, inherits) = ClassToAPI.process(classes)
apis.foreach(callback.api(source, _))
mainClasses.foreach(callback.mainClass(source, _))
inherits.map {
case (from, to) => (from.getName, to.getName)
}
Expand Down
Expand Up @@ -149,6 +149,7 @@ private final class AnalysisCallback(
private[this] val usedNames = new HashMap[String, Set[UsedName]]
private[this] val unreporteds = new HashMap[File, ListBuffer[Problem]]
private[this] val reporteds = new HashMap[File, ListBuffer[Problem]]
private[this] val mainClasses = new HashMap[File, ListBuffer[String]]
private[this] val binaryDeps = new HashMap[File, Set[File]]
// source file to set of generated (class file, binary class name); only non local classes are stored here
private[this] val nonLocalClasses = new HashMap[File, Set[(File, String)]]
Expand Down Expand Up @@ -285,6 +286,11 @@ private final class AnalysisCallback(
}
}

def mainClass(sourceFile: File, className: String): Unit = {
mainClasses.getOrElseUpdate(sourceFile, ListBuffer.empty) += className
()
}

def usedName(className: String, name: String, useScopes: util.EnumSet[UseScope]) =
add(usedNames, className, UsedName(name, useScopes))

Expand Down Expand Up @@ -346,7 +352,9 @@ private final class AnalysisCallback(
val stamp = stampReader.source(src)
val classesInSrc = classNames.getOrElse(src, Set.empty).map(_._1)
val analyzedApis = classesInSrc.map(analyzeClass)
val info = SourceInfos.makeInfo(getOrNil(reporteds, src), getOrNil(unreporteds, src))
val info = SourceInfos.makeInfo(getOrNil(reporteds, src),
getOrNil(unreporteds, src),
getOrNil(mainClasses, src))
val binaries = binaryDeps.getOrElse(src, Nil: Iterable[File])
val localProds = localClasses.getOrElse(src, Nil: Iterable[File]) map { classFile =>
val classFileStamp = stampReader.product(classFile)
Expand Down
Expand Up @@ -26,9 +26,11 @@ object SourceInfos {
def empty: SourceInfos = make(Map.empty)
def make(m: Map[File, SourceInfo]): SourceInfos = new MSourceInfos(m)

val emptyInfo: SourceInfo = makeInfo(Nil, Nil)
def makeInfo(reported: Seq[Problem], unreported: Seq[Problem]): SourceInfo =
new UnderlyingSourceInfo(reported, unreported)
val emptyInfo: SourceInfo = makeInfo(Nil, Nil, Nil)
def makeInfo(reported: Seq[Problem],
unreported: Seq[Problem],
mainClasses: Seq[String]): SourceInfo =
new UnderlyingSourceInfo(reported, unreported, mainClasses)
def merge(infos: Traversable[SourceInfos]): SourceInfos = (SourceInfos.empty /: infos)(_ ++ _)
}

Expand All @@ -48,8 +50,10 @@ private final class MSourceInfos(val allInfos: Map[File, SourceInfo]) extends So
}

private final class UnderlyingSourceInfo(val reportedProblems: Seq[Problem],
val unreportedProblems: Seq[Problem])
val unreportedProblems: Seq[Problem],
val mainClasses: Seq[String])
extends SourceInfo {
override def getReportedProblems: Array[Problem] = reportedProblems.toArray
override def getUnreportedProblems: Array[Problem] = unreportedProblems.toArray
override def getMainClasses: Array[String] = mainClasses.toArray
}
Expand Up @@ -60,9 +60,9 @@ class TextAnalysisFormat(override val mappers: AnalysisMappers)
private implicit val analyzedClassFormat: Format[AnalyzedClass] =
AnalyzedClassFormats.analyzedClassFormat
private implicit def infoFormat: Format[SourceInfo] =
wrap[SourceInfo, (Seq[Problem], Seq[Problem])](
si => (si.getReportedProblems, si.getUnreportedProblems), {
case (a, b) => SourceInfos.makeInfo(a, b)
wrap[SourceInfo, (Seq[Problem], Seq[Problem], Seq[String])](
si => (si.getReportedProblems, si.getUnreportedProblems, si.getMainClasses), {
case (a, b, c) => SourceInfos.makeInfo(a, b, c)
})
private implicit def fileHashFormat: Format[FileHash] =
asProduct2((file: File, hash: Int) => new FileHash(file, hash))(h => (h.file, h.hash))
Expand Down
Expand Up @@ -71,7 +71,7 @@ trait BaseTextAnalysisFormatTest { self: Properties =>
val aClass = genClass("A").sample.get
val cClass = genClass("C").sample.get
val absent = EmptyStamp
val sourceInfos = SourceInfos.makeInfo(Nil, Nil)
val sourceInfos = SourceInfos.makeInfo(Nil, Nil, Nil)

var analysis = Analysis.empty
val products = NonLocalProduct("A", "A", f("A.class"), absent) ::
Expand Down Expand Up @@ -106,10 +106,14 @@ trait BaseTextAnalysisFormatTest { self: Properties =>
("Whole Analysis" |: left =? right)
}

private def mapInfos(a: SourceInfos): Map[File, (Seq[Problem], Seq[Problem])] =
private def mapInfos(a: SourceInfos): Map[File, (Seq[Problem], Seq[Problem], Seq[String])] =
a.allInfos.map {
case (f, infos) =>
f -> (infos.getReportedProblems.toList -> infos.getUnreportedProblems.toList)
f -> ((
infos.getReportedProblems.toList,
infos.getUnreportedProblems.toList,
infos.getMainClasses.toList
))
}

private def compareOutputs(left: Output, right: Output): Prop = {
Expand Down
Expand Up @@ -160,6 +160,12 @@ final class IncHandler(directory: File, scriptedLog: ManagedLogger)
p.checkClasses(i, srcFile, products)
case (p, other, _) => p.unrecognizedArguments("checkClasses", other)
},
"checkMainClasses" -> {
case (p, src :: products, i) =>
val srcFile = if (src endsWith ":") src dropRight 1 else src
p.checkMainClasses(i, srcFile, products)
case (p, other, _) => p.unrecognizedArguments("checkMainClasses", other)
},
"checkProducts" -> {
case (p, src :: products, i) =>
val srcFile = if (src endsWith ":") src dropRight 1 else src
Expand Down Expand Up @@ -336,6 +342,17 @@ case class ProjectStructure(
()
}

def checkMainClasses(i: IncInstance, src: String, expected: List[String]): Unit = {
val analysis = compile(i)
def mainClasses(src: String): Set[String] =
analysis.infos.get(baseDirectory / src).getMainClasses.toSet
def assertClasses(expected: Set[String], actual: Set[String]) =
assert(expected == actual, s"Expected $expected classes, got $actual")

assertClasses(expected.toSet, mainClasses(src))
()
}

def checkProducts(i: IncInstance, src: String, expected: List[String]): Unit = {
val analysis = compile(i)
def relativeClassDir(f: File): File = f.relativeTo(classesDir) getOrElse f
Expand Down
Expand Up @@ -139,8 +139,9 @@ final class AnalyzingJavaCompiler private[sbt] (

/** Read the API information from [[Class]] to analyze dependencies. */
def readAPI(source: File, classes: Seq[Class[_]]): Set[(String, String)] = {
val (apis, inherits) = ClassToAPI.process(classes)
val (apis, mainClasses, inherits) = ClassToAPI.process(classes)
apis.foreach(callback.api(source, _))
mainClasses.foreach(callback.mainClass(source, _))
inherits.map {
case (from, to) => (from.getName, to.getName)
}
Expand Down
@@ -0,0 +1,11 @@
package runjava;

class MainJava {
public static void main(String args[]) {
}

static public class StaticInner {
public static void main(String args[]) {
}
}
}
@@ -0,0 +1,6 @@
package runjava;

class NoMainJava {
public void main(String args[]) {
}
}
@@ -0,0 +1,13 @@
package runscala

object MainScala {
def main(args: Array[String]) {}

object StaticInner {
def main(args: Array[String]) {}
}
}

class NoMainScala {
def main(args: Array[String]) {}
}
4 changes: 4 additions & 0 deletions zinc/src/sbt-test/apiinfo/main-discovery/test
@@ -0,0 +1,4 @@
> compile
> checkMainClasses src/main/java/runjava/MainJava.java: runjava.MainJava runjava.MainJava.StaticInner
> checkMainClasses src/main/java/runjava/oMainJava.java:
> checkMainClasses src/main/scala/Hello.scala: runscala.MainScala runscala.MainScala.StaticInner

0 comments on commit d253375

Please sign in to comment.