diff --git a/scalafmt-core/shared/src/main/scala/org/scalafmt/internal/FormatOps.scala b/scalafmt-core/shared/src/main/scala/org/scalafmt/internal/FormatOps.scala index f42cd5ad35..9a1d1c4c94 100644 --- a/scalafmt-core/shared/src/main/scala/org/scalafmt/internal/FormatOps.scala +++ b/scalafmt-core/shared/src/main/scala/org/scalafmt/internal/FormatOps.scala @@ -121,7 +121,7 @@ class FormatOps( val arguments = mutable.Map.empty[TokenHash, Tree] val optional = Set.newBuilder[TokenHash] def getHeadHash(tree: Tree): Option[TokenHash] = - tree.tokens.headOption.map { x => hash(tokens.after(x).left) } + tokens.getHeadOpt(tree).map(x => hash(x.left)) def add(tree: Tree): Unit = getHeadHash(tree).foreach { x => if (!arguments.contains(x)) arguments += x -> tree @@ -229,14 +229,14 @@ class FormatOps( tree.parent.exists { case InfixApp(ia) => (ia.op eq tree) || ia.rhs.headOption.forall { arg => - (arg eq tree) && arg.tokens.headOption.contains(ft.right) + (arg eq tree) && tokens.tokenJustBeforeOpt(arg).contains(ft) } case _ => false } } final def startsNewBlockOnRight(ft: FormatToken): Boolean = - ft.meta.rightOwner.tokens.headOption.contains(ft.right) + tokens.tokenBeforeOpt(ft.meta.rightOwner).contains(ft) /** js.native is very special in Scala.js. * @@ -741,10 +741,8 @@ class FormatOps( case Some(ia) => Some(findLeftInfix(ia).op) case _ => findNextInfixInParent(app.all, fullInfix.all) } - val endOfNextOp = nextOp.map(_.tokens.last) - val breakAfterClose = endOfNextOp.flatMap { tok => - breakAfterComment(tokens(tok)) - } + val endOfNextOp = nextOp.map(tokens.getLast) + val breakAfterClose = endOfNextOp.flatMap(breakAfterComment) val nlSplit = Split(nlMod, 0) .andPolicyOpt(breakAfterClose) @@ -752,7 +750,7 @@ class FormatOps( .withPolicy(nlPolicy) val singleLineSplit = Split(Space, 0) .notIf(noSingleLine) - .withSingleLine(endOfNextOp.getOrElse(close)) + .withSingleLine(endOfNextOp.fold(close)(_.left)) .andPolicyOpt(breakAfterClose) .andPolicy(getSingleLineInfixPolicy(close)) Seq(singleLineSplit, nlSplit) @@ -1105,7 +1103,9 @@ class FormatOps( if app.args.length >= runner.optimizer.forceConfigStyleMinArgCount && distance(left, matching(left)) > maxDistance => forces += app - app.args.foreach { arg => clearQueues += hash(arg.tokens.head) } + app.args.foreach { arg => + clearQueues += hash(tokens.getHead(arg).left) + } case _ => } (forces.result(), clearQueues.result()) @@ -1352,33 +1352,28 @@ class FormatOps( // look for arrow before body, if any, else after params def getFuncArrow(term: Term.FunctionTerm): Option[FormatToken] = - term.body.tokens.headOption - .map(tokenBefore) - .orElse { - val lastParam = term.params.lastOption - lastParam.flatMap(_.tokens.lastOption).map { x => - val maybeArrow = tokens(nextNonComment(tokens(x)), 1) - if (maybeArrow.left.is[T.RightArrow]) maybeArrow - else tokens(nextNonComment(maybeArrow), 1) - } - } + tokens + .tokenBeforeOpt(term.body) + .orElse(tokens.tokenAfterOpt(term.params).map(getArrowAfter)) .orElse { - val headToken = tokens.after(term.tokens.head) - findFirst(headToken, term.tokens.last)(_.left.is[T.RightArrow]) + findFirst(tokens.getHead(term), term.pos.end)(_.left.is[T.RightArrow]) } // look for arrow before body, if any, else after cond/pat def getCaseArrow(term: Case): FormatToken = - term.body.tokens.headOption.fold { - val endOfPat = tokens.getLast(term.cond.getOrElse(term.pat)) - val maybeArrow = tokens(nextNonComment(endOfPat), 1) - if (maybeArrow.left.is[T.RightArrow]) maybeArrow - else tokens(nextNonComment(maybeArrow), 1) - }(tokenBefore) + tokens.tokenBeforeOpt(term.body).getOrElse { + getArrowAfter(tokens.tokenAfter(term.cond.getOrElse(term.pat))) + } // look for arrow before body, if any, else after cond/pat def getCaseArrow(term: TypeCase): FormatToken = - tokens(nextNonComment(tokens.getLast(term.pat)), 1) + next(tokens.tokenAfter(term.pat)) + + private def getArrowAfter(ft: FormatToken): FormatToken = { + val maybeArrow = next(ft) + if (maybeArrow.left.is[T.RightArrow]) maybeArrow + else next(nextNonComment(maybeArrow)) + } def getApplyArgs( ft: FormatToken, @@ -1510,7 +1505,7 @@ class FormatOps( ok && (thisTree.parent match { case `nextSelect` => style.includeNoParensInSelectChains case Some(Term.Apply(fun, List(_))) - if nextNonComment(tokens.getLast(fun)).right.is[T.LeftBrace] => + if tokens.tokenAfter(fun).right.is[T.LeftBrace] => style.includeCurlyBraceInSelectChains && !nextSelect.contains(lastApply) // exclude short curly case Some(SplitCallIntoParts(`thisTree`, _)) => true @@ -1591,7 +1586,7 @@ class FormatOps( val nextFt = next(nextNonComment(next(openFt))) getOpenNLByArgs(nextFt, argss.tail, penalty, policies) } else { - val endPolicy = args.head.tokens.head match { + val endPolicy = tokens.getHead(args.head).left match { case t: T.LeftBrace => Policy.End.After(t) case t => Policy.End.On(t) } @@ -1647,7 +1642,7 @@ class FormatOps( nlSplitFunc: Int => Split, spaceIndents: Seq[Indent] = Seq.empty )(implicit style: ScalafmtConfig): Seq[Split] = { - val bhead = body.tokens.head + def bheadFT = tokens.getHead(body) val blastFT = tokens.getLastNonTrivial(body) val blast = blastFT.left val expire = nextNonCommentSameLine(blastFT).left @@ -1695,7 +1690,7 @@ class FormatOps( def hasStateColumn = spaceIndents.exists(_.hasStateColumn) val (spaceSplit, nlSplit) = body match { case t: Term.If if ifWithoutElse(t) || hasStateColumn => - val thenBeg = tokens.after(t.thenp.tokens.head) + val thenBeg = tokens.getHead(t.thenp) val thenHasLB = thenBeg.left.is[T.LeftBrace] val end = if (thenHasLB) thenBeg else prevNonComment(prev(thenBeg)) getSplits(getSlbSplit(end.left)) @@ -1708,7 +1703,7 @@ class FormatOps( if (!tokens.hasMatching(blast)) getSlbSplits() else getSplits(getSpaceSplit(1)) case Term.ForYield(_, b) => - nextNonComment(tokens(bhead)).right match { + nextNonComment(bheadFT).right match { // skipping `for` case x @ LeftParenOrBrace() => val exclude = TokenRanges(TokenRange(x, matching(x))) if (b.is[Term.Block]) @@ -1895,7 +1890,7 @@ class FormatOps( def getMatchDot(tree: Term.Match): Option[FormatToken] = if (dialect.allowMatchAsOperator) { - val ft = nextNonComment(tokens.getLast(tree.expr)) + val ft = tokens.tokenAfter(tree.expr) if (ft.right.is[T.Dot]) Some(ft) else None } else None @@ -2306,7 +2301,7 @@ class FormatOps( style: ScalafmtConfig ): Option[OptionalBracesRegion] = { def result(tree: Tree, cases: Seq[Tree]): Option[Seq[Split]] = { - val ok = cases.headOption.exists(_.tokens.head eq nft.right) + val ok = tokens.tokenJustBeforeOpt(cases).contains(nft) if (ok) Some(getSplits(ft, tree, true)) else None } ft.meta.leftOwner match { @@ -2407,7 +2402,7 @@ class FormatOps( fileLine: FileLine, style: ScalafmtConfig ): Option[Seq[Split]] = - if (head.tokens.headOption.contains(nft.right)) Some { + if (tokens.tokenJustBeforeOpt(head).contains(nft)) Some { val forceNL = nlOnly || shouldBreakInOptionalBraces(ft) getSplits(ft, tail.lastOption.getOrElse(head), forceNL) } @@ -2474,10 +2469,10 @@ class FormatOps( isTreeMultiStatBlock(tree) && !tokenBefore(tree).left.is[T.LeftBrace] private def isBlockStart(tree: Term.Block, ft: FormatToken): Boolean = - tree.stats.headOption.exists(_.tokens.headOption.contains(ft.right)) + tokens.tokenJustBeforeOpt(tree.stats).contains(ft) @inline private def treeLast(tree: Tree): Option[T] = - tree.tokens.lastOption.map(tokens(_).left) + tokens.getLastOpt(tree).map(_.left) @inline private def blockLast(tree: Tree): Option[T] = if (isTreeMultiStatBlock(tree)) treeLast(tree) else None @inline private def blockLast(tree: Term.Block): Option[T] = @@ -2551,7 +2546,8 @@ class FormatOps( private object BlockImpl extends Factory { def getBlocks(ft: FormatToken, nft: FormatToken, all: Boolean): Result = { - def ok(stat: Tree): Boolean = stat.tokens.headOption.contains(nft.right) + def ok(stat: Tree): Boolean = + tokens.tokenJustBeforeOpt(stat).contains(nft) val leftOwner = ft.meta.leftOwner findTreeWithParentSimple(nft.meta.rightOwner)(_ eq leftOwner) match { case Some(t: Term.Block) => diff --git a/scalafmt-core/shared/src/main/scala/org/scalafmt/internal/FormatTokens.scala b/scalafmt-core/shared/src/main/scala/org/scalafmt/internal/FormatTokens.scala index 43fc842eee..a390f9ec5d 100644 --- a/scalafmt-core/shared/src/main/scala/org/scalafmt/internal/FormatTokens.scala +++ b/scalafmt-core/shared/src/main/scala/org/scalafmt/internal/FormatTokens.scala @@ -125,15 +125,18 @@ class FormatTokens(leftTok2tok: Map[TokenOps.TokenHash, Int])( final def prevNonComment(curr: FormatToken): FormatToken = findToken(curr, prev)(!_.left.is[Token.Comment]).fold(identity, identity) + def getHead(tree: Tree): FormatToken = + after(tree.tokens.head) + def getHeadOpt(tree: Tree): Option[FormatToken] = + tree.tokens.headOption.map(after) + def getLast(tree: Tree): FormatToken = apply(TokenOps.findLastVisibleToken(tree.tokens)) - def getLastOpt(tree: Tree): Option[FormatToken] = TokenOps.findLastVisibleTokenOpt(tree.tokens).map(apply) def getLastNonTrivial(tree: Tree): FormatToken = apply(TokenOps.findLastNonTrivialToken(tree.tokens)) - def getLastNonTrivialOpt(tree: Tree): Option[FormatToken] = TokenOps.findLastNonTrivialTokenOpt(tree.tokens).map(apply) @@ -146,6 +149,11 @@ class FormatTokens(leftTok2tok: Map[TokenOps.TokenHash, Int])( @inline def tokenAfter(trees: Seq[Tree]): FormatToken = tokenAfter(trees.last) + def tokenAfterOpt(tree: Tree): Option[FormatToken] = + getLastOpt(tree).map(nextNonComment) + def tokenAfterOpt(trees: Seq[Tree]): Option[FormatToken] = + trees.lastOption.flatMap(tokenAfterOpt) + /* the following methods return the last format token such that * its `left` is before the parameter */ @inline @@ -153,6 +161,11 @@ class FormatTokens(leftTok2tok: Map[TokenOps.TokenHash, Int])( @inline def tokenJustBefore(tree: Tree): FormatToken = justBefore(tree.tokens.head) + def tokenJustBeforeOpt(tree: Tree): Option[FormatToken] = + tree.tokens.headOption.map(justBefore) + def tokenJustBeforeOpt(trees: Seq[Tree]): Option[FormatToken] = + trees.headOption.flatMap(tokenJustBeforeOpt) + /* the following methods return the last format token such that * its `left` is before the parameter and is not a comment */ @inline @@ -162,6 +175,11 @@ class FormatTokens(leftTok2tok: Map[TokenOps.TokenHash, Int])( @inline def tokenBefore(trees: Seq[Tree]): FormatToken = tokenBefore(trees.head) + def tokenBeforeOpt(tree: Tree): Option[FormatToken] = + tokenJustBeforeOpt(tree).map(prevNonComment) + def tokenBeforeOpt(trees: Seq[Tree]): Option[FormatToken] = + trees.headOption.flatMap(tokenBeforeOpt) + @inline def isBreakAfterRight(ft: FormatToken): Boolean = next(ft).hasBreakOrEOF diff --git a/scalafmt-core/shared/src/main/scala/org/scalafmt/internal/FormatWriter.scala b/scalafmt-core/shared/src/main/scala/org/scalafmt/internal/FormatWriter.scala index ac5d7c0b81..7e64d1ded8 100644 --- a/scalafmt-core/shared/src/main/scala/org/scalafmt/internal/FormatWriter.scala +++ b/scalafmt-core/shared/src/main/scala/org/scalafmt/internal/FormatWriter.scala @@ -573,7 +573,7 @@ class FormatWriter(formatOps: FormatOps) { case _ => 2 } val tiState = - locations(tokens(ti.tokens.head).meta.idx).state.prev + locations(tokens.getHead(ti).meta.idx).state.prev val indent = if (style.align.stripMargin) tiState.column + alignPipeOffset else tiState.indentation + offset @@ -1168,7 +1168,7 @@ class FormatWriter(formatOps: FormatOps) { } private def isEarlierLine(t: Tree)(implicit fl: FormatLocation): Boolean = { - val idx = tokens.after(t.tokens.head).meta.idx + 1 + val idx = tokens.getHead(t).meta.idx + 1 idx <= fl.formatToken.meta.idx && // e.g., leading comments locations(idx).leftLineId != fl.leftLineId } @@ -1221,8 +1221,9 @@ class FormatWriter(formatOps: FormatOps) { val beg = mods.lastOption.fold(tokens.after(ptokens.head)) { m => tokens.next(tokens.tokenAfter(m)) } - val end = b.tokens.headOption - .fold(tokens.before(ptokens.last))(tokens.justBefore) + val end = tokens + .tokenJustBeforeOpt(b) + .getOrElse(tokens.before(ptokens.last)) getLineDiff(locations, beg, end) == 0 } if (keepGoing) getAlignContainerParent(p) else p diff --git a/scalafmt-core/shared/src/main/scala/org/scalafmt/internal/Router.scala b/scalafmt-core/shared/src/main/scala/org/scalafmt/internal/Router.scala index 752c217ad7..a64747816a 100644 --- a/scalafmt-core/shared/src/main/scala/org/scalafmt/internal/Router.scala +++ b/scalafmt-core/shared/src/main/scala/org/scalafmt/internal/Router.scala @@ -618,7 +618,7 @@ class Router(formatOps: FormatOps) { } val defn = isDefnSite(rightOwner) val defRhs = if (defn) defDefBody(rightOwner) else None - val beforeDefRhs = defRhs.flatMap(_.tokens.headOption.map(tokenBefore)) + val beforeDefRhs = defRhs.flatMap(tokens.tokenJustBeforeOpt) def getSplitsBeforeOpenParen( src: Newlines.SourceHints, indentLen: Int @@ -906,8 +906,8 @@ class Router(formatOps: FormatOps) { else getAssignAtSingleArgCallSite(leftOwner).map { assign => val assignToken = assign.rhs match { - case b: Term.Block => b.tokens.head - case _ => assign.tokens.find(_.is[T.Equals]).get + case b: Term.Block => tokens.getHead(b) + case _ => tokens(assign.tokens.find(_.is[T.Equals]).get) } val breakToken = getOptimalTokenFor(assignToken) val newlineAfterAssignDecision = @@ -937,7 +937,7 @@ class Router(formatOps: FormatOps) { val excludeBlocks = if (isBracket) { - val excludeBeg = if (align) tokens(args.last.tokens.head) else tok + val excludeBeg = if (align) tokens.getHead(args.last) else tok insideBlock[T.LeftBracket](excludeBeg, close) } else if ( multipleArgs || @@ -951,7 +951,7 @@ class Router(formatOps: FormatOps) { singleArgument && isExcludedTree(args(0)) } ) - parensTuple(args(0).tokens.last) + parensTuple(tokens.getLast(args(0)).left) else insideBracesBlock(tok, close) def singleLine( @@ -1319,7 +1319,7 @@ class Router(formatOps: FormatOps) { if style.newlines.avoidInResultType => val expire = returnType match { case Type.Refine(_, headStat :: _) => - tokens(headStat.tokens.head, -1).left + tokens.tokenJustBefore(headStat).left case t => getLastNonTrivialToken(t) } Seq(Split(Space, 0).withPolicy(SingleLineBlock(expire, okSLC = true))) diff --git a/scalafmt-core/shared/src/main/scala/org/scalafmt/internal/State.scala b/scalafmt-core/shared/src/main/scala/org/scalafmt/internal/State.scala index 23a9c817a1..355be172c4 100644 --- a/scalafmt-core/shared/src/main/scala/org/scalafmt/internal/State.scala +++ b/scalafmt-core/shared/src/main/scala/org/scalafmt/internal/State.scala @@ -222,7 +222,7 @@ final case class State( if (ok) prev.getLineStartOwner(isComment) else None } else { def startsWithLeft(tree: meta.Tree): Boolean = - tree.tokens.headOption.contains(ft.left) + tokens.getHeadOpt(tree).contains(ft) val ro = ft.meta.rightOwner val owner = if (startsWithLeft(ro)) Some(ro) diff --git a/scalafmt-core/shared/src/main/scala/org/scalafmt/rewrite/RedundantBraces.scala b/scalafmt-core/shared/src/main/scala/org/scalafmt/rewrite/RedundantBraces.scala index bd7476d0c8..69d81de9b3 100644 --- a/scalafmt-core/shared/src/main/scala/org/scalafmt/rewrite/RedundantBraces.scala +++ b/scalafmt-core/shared/src/main/scala/org/scalafmt/rewrite/RedundantBraces.scala @@ -222,7 +222,7 @@ class RedundantBraces(ftoks: FormatTokens) extends FormatTokensRewrite.Rule { f.parent.flatMap(okToReplaceFunctionInSingleArgApply).exists(_._2 eq f) private def getOpeningParen(t: Term.Apply): Option[Token.LeftParen] = - ftoks.nextNonComment(ftoks(t.fun.tokens.last)).right match { + ftoks.tokenAfter(t.fun).right match { case lp: Token.LeftParen => Some(lp) case _ => None } diff --git a/scalafmt-core/shared/src/main/scala/org/scalafmt/rewrite/RedundantParens.scala b/scalafmt-core/shared/src/main/scala/org/scalafmt/rewrite/RedundantParens.scala index d5c1432e1b..d268f56fef 100644 --- a/scalafmt-core/shared/src/main/scala/org/scalafmt/rewrite/RedundantParens.scala +++ b/scalafmt-core/shared/src/main/scala/org/scalafmt/rewrite/RedundantParens.scala @@ -111,7 +111,7 @@ class RedundantParens(ftoks: FormatTokens) extends FormatTokensRewrite.Rule { } private def breaksBeforeOp(ia: InfixApp): Boolean = { - val beforeOp = ftoks(ia.op.tokens.head, -1) + val beforeOp = ftoks.tokenJustBefore(ia.op) ftoks.prevNonCommentSameLine(beforeOp).hasBreak || (ia.lhs match { case InfixApp(lhsApp) if breaksBeforeOpAndNotEnclosed(lhsApp) => true case _ => diff --git a/scalafmt-core/shared/src/main/scala/org/scalafmt/rewrite/RemoveScala3OptionalBraces.scala b/scalafmt-core/shared/src/main/scala/org/scalafmt/rewrite/RemoveScala3OptionalBraces.scala index 5368bfca55..5e63a039cf 100644 --- a/scalafmt-core/shared/src/main/scala/org/scalafmt/rewrite/RemoveScala3OptionalBraces.scala +++ b/scalafmt-core/shared/src/main/scala/org/scalafmt/rewrite/RemoveScala3OptionalBraces.scala @@ -97,10 +97,8 @@ private class RemoveScala3OptionalBraces(ftoks: FormatTokens) case _: Token.KwIf => true case _: Token.KwThen => true case _: Token.KwElse => - !TreeOps.isTreeMultiStatBlock(t.elsep) || { - val endOfCond = ftoks(t.cond.tokens.last) - ftoks.nextNonComment(endOfCond).right.is[Token.KwThen] - } + !TreeOps.isTreeMultiStatBlock(t.elsep) || + ftoks.tokenAfter(t.cond).right.is[Token.KwThen] case _: Token.RightParen => allowOldSyntax case _ => false }