Skip to content

Commit

Permalink
ScalafmtRunner: always allow top-level terms
Browse files Browse the repository at this point in the history
The formatter's job is not to validate code but to format it.
  • Loading branch information
kitbellew committed Oct 31, 2022
1 parent e71daa6 commit 8266921
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 20 deletions.
Expand Up @@ -177,7 +177,7 @@ case class ScalafmtConfig(
)

def forSbt: ScalafmtConfig =
copy(runner = runner.forSbt, rewrite = rewrite.forSbt)
copy(rewrite = rewrite.forSbt)

private lazy val expandedFileOverride = Try {
val langPrefix = "lang:"
Expand Down Expand Up @@ -237,7 +237,7 @@ case class ScalafmtConfig(
private[scalafmt] lazy val docstringsWrapMaxColumn: Int =
docstrings.wrapMaxColumn.getOrElse(maxColumn)

@inline private[scalafmt] def dialect = runner.getDialect
@inline private[scalafmt] def dialect = runner.getDialectForParser

private[scalafmt] def getTrailingCommas = rewrite.trailingCommas.style

Expand Down
Expand Up @@ -28,17 +28,17 @@ case class ScalafmtRunner(
fatalWarnings: Boolean = false
) {
@inline private[scalafmt] def getDialect = dialect.dialect
private[scalafmt] lazy val getDialectForParser: Dialect =
getDialect.withAllowToplevelTerms(true).withAllowToplevelStatements(true)
@inline private[scalafmt] def dialectName = {
val name = dialect.name
if (dialectOverride.values.isEmpty) name else s"$name [with overrides]"
}
@inline private[scalafmt] def getParser = parser.parse

@inline private def topLevelDialect = dialect.copy(
dialect = getDialect
.withAllowToplevelTerms(true)
.withToplevelSeparator("")
)
@inline
def withDialect(dialect: sourcecode.Text[Dialect]): ScalafmtRunner =
withDialect(NamedDialect(dialect))

@inline
private[scalafmt] def withDialect(dialect: NamedDialect): ScalafmtRunner =
Expand All @@ -48,11 +48,8 @@ case class ScalafmtRunner(
private[scalafmt] def withParser(parser: ScalafmtParser): ScalafmtRunner =
copy(parser = parser)

def forSbt: ScalafmtRunner = withDialect(topLevelDialect)

private[scalafmt] def forCodeBlock: ScalafmtRunner = copy(
debug = false,
dialect = topLevelDialect,
eventCallback = null,
parser = ScalafmtParser.Source
)
Expand All @@ -64,7 +61,7 @@ case class ScalafmtRunner(
if (null != eventCallback) evts.foreach(eventCallback)

def parse(input: meta.inputs.Input): Parsed[_ <: Tree] =
getParser(input, getDialect)
getParser(input, getDialectForParser)

@inline def isDefaultDialect = dialect.name == NamedDialect.defaultName

Expand All @@ -85,7 +82,7 @@ object ScalafmtRunner {
maxStateVisits = 1000000
)

val sbt = default.forSbt
val sbt = default.withDialect(meta.dialects.Sbt)

implicit val encoder: ConfEncoder[ScalafmtRunner] =
generic.deriveEncoder
Expand Down
14 changes: 7 additions & 7 deletions scalafmt-tests/src/test/scala/org/scalafmt/util/ErrorTest.scala
Expand Up @@ -5,13 +5,13 @@ import org.scalafmt.Scalafmt
import munit.FunSuite

class ErrorTest extends FunSuite {
test("errors are caught") {
val nonSourceFile = Seq(
"class A {",
"val x = 1",
"println(1)"
)
nonSourceFile.foreach { original =>
private val nonSourceFile = Seq(
"class A {",
"val x + 1",
"println 1"
)
nonSourceFile.foreach { original =>
test(s"errors are caught: $original") {
Scalafmt.format(original, HasTests.unitTest40) match {
case _: Formatted.Success => fail("expected failure, got success")
case _ =>
Expand Down
Expand Up @@ -22,7 +22,7 @@ trait FormatAssertions {
): Unit =
assertFormatPreservesAst(filename, original, obtained)(
runner.getParser,
runner.getDialect
runner.getDialectForParser
)

def assertFormatPreservesAst[T <: Tree](
Expand Down

0 comments on commit 8266921

Please sign in to comment.