Skip to content

Commit

Permalink
FormatTokens: use to access tree head/last tokens
Browse files Browse the repository at this point in the history
  • Loading branch information
kitbellew committed Dec 30, 2021
1 parent 331c78a commit f96e495
Show file tree
Hide file tree
Showing 8 changed files with 71 additions and 58 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ class FormatOps(
val arguments = mutable.Map.empty[TokenHash, Tree]
val optional = Set.newBuilder[TokenHash]
def getHeadHash(tree: Tree): Option[TokenHash] =
tree.tokens.headOption.map { x => hash(tokens.after(x).left) }
tokens.getHeadOpt(tree).map(x => hash(x.left))
def add(tree: Tree): Unit =
getHeadHash(tree).foreach { x =>
if (!arguments.contains(x)) arguments += x -> tree
Expand Down Expand Up @@ -229,14 +229,14 @@ class FormatOps(
tree.parent.exists {
case InfixApp(ia) =>
(ia.op eq tree) || ia.rhs.headOption.forall { arg =>
(arg eq tree) && arg.tokens.headOption.contains(ft.right)
(arg eq tree) && tokens.tokenJustBeforeOpt(arg).contains(ft)
}
case _ => false
}
}

final def startsNewBlockOnRight(ft: FormatToken): Boolean =
ft.meta.rightOwner.tokens.headOption.contains(ft.right)
tokens.tokenBeforeOpt(ft.meta.rightOwner).contains(ft)

/** js.native is very special in Scala.js.
*
Expand Down Expand Up @@ -741,18 +741,16 @@ class FormatOps(
case Some(ia) => Some(findLeftInfix(ia).op)
case _ => findNextInfixInParent(app.all, fullInfix.all)
}
val endOfNextOp = nextOp.map(_.tokens.last)
val breakAfterClose = endOfNextOp.flatMap { tok =>
breakAfterComment(tokens(tok))
}
val endOfNextOp = nextOp.map(tokens.getLast)
val breakAfterClose = endOfNextOp.flatMap(breakAfterComment)

val nlSplit = Split(nlMod, 0)
.andPolicyOpt(breakAfterClose)
.withIndent(nlIndent)
.withPolicy(nlPolicy)
val singleLineSplit = Split(Space, 0)
.notIf(noSingleLine)
.withSingleLine(endOfNextOp.getOrElse(close))
.withSingleLine(endOfNextOp.fold(close)(_.left))
.andPolicyOpt(breakAfterClose)
.andPolicy(getSingleLineInfixPolicy(close))
Seq(singleLineSplit, nlSplit)
Expand Down Expand Up @@ -1105,7 +1103,9 @@ class FormatOps(
if app.args.length >= runner.optimizer.forceConfigStyleMinArgCount &&
distance(left, matching(left)) > maxDistance =>
forces += app
app.args.foreach { arg => clearQueues += hash(arg.tokens.head) }
app.args.foreach { arg =>
clearQueues += hash(tokens.getHead(arg).left)
}
case _ =>
}
(forces.result(), clearQueues.result())
Expand Down Expand Up @@ -1352,33 +1352,28 @@ class FormatOps(

// look for arrow before body, if any, else after params
def getFuncArrow(term: Term.FunctionTerm): Option[FormatToken] =
term.body.tokens.headOption
.map(tokenBefore)
.orElse {
val lastParam = term.params.lastOption
lastParam.flatMap(_.tokens.lastOption).map { x =>
val maybeArrow = tokens(nextNonComment(tokens(x)), 1)
if (maybeArrow.left.is[T.RightArrow]) maybeArrow
else tokens(nextNonComment(maybeArrow), 1)
}
}
tokens
.tokenBeforeOpt(term.body)
.orElse(tokens.tokenAfterOpt(term.params).map(getArrowAfter))
.orElse {
val headToken = tokens.after(term.tokens.head)
findFirst(headToken, term.tokens.last)(_.left.is[T.RightArrow])
findFirst(tokens.getHead(term), term.pos.end)(_.left.is[T.RightArrow])
}

// look for arrow before body, if any, else after cond/pat
def getCaseArrow(term: Case): FormatToken =
term.body.tokens.headOption.fold {
val endOfPat = tokens.getLast(term.cond.getOrElse(term.pat))
val maybeArrow = tokens(nextNonComment(endOfPat), 1)
if (maybeArrow.left.is[T.RightArrow]) maybeArrow
else tokens(nextNonComment(maybeArrow), 1)
}(tokenBefore)
tokens.tokenBeforeOpt(term.body).getOrElse {
getArrowAfter(tokens.tokenAfter(term.cond.getOrElse(term.pat)))
}

// look for arrow before body, if any, else after cond/pat
def getCaseArrow(term: TypeCase): FormatToken =
tokens(nextNonComment(tokens.getLast(term.pat)), 1)
next(tokens.tokenAfter(term.pat))

private def getArrowAfter(ft: FormatToken): FormatToken = {
val maybeArrow = next(ft)
if (maybeArrow.left.is[T.RightArrow]) maybeArrow
else next(nextNonComment(maybeArrow))
}

def getApplyArgs(
ft: FormatToken,
Expand Down Expand Up @@ -1510,7 +1505,7 @@ class FormatOps(
ok && (thisTree.parent match {
case `nextSelect` => style.includeNoParensInSelectChains
case Some(Term.Apply(fun, List(_)))
if nextNonComment(tokens.getLast(fun)).right.is[T.LeftBrace] =>
if tokens.tokenAfter(fun).right.is[T.LeftBrace] =>
style.includeCurlyBraceInSelectChains &&
!nextSelect.contains(lastApply) // exclude short curly
case Some(SplitCallIntoParts(`thisTree`, _)) => true
Expand Down Expand Up @@ -1591,7 +1586,7 @@ class FormatOps(
val nextFt = next(nextNonComment(next(openFt)))
getOpenNLByArgs(nextFt, argss.tail, penalty, policies)
} else {
val endPolicy = args.head.tokens.head match {
val endPolicy = tokens.getHead(args.head).left match {
case t: T.LeftBrace => Policy.End.After(t)
case t => Policy.End.On(t)
}
Expand Down Expand Up @@ -1647,7 +1642,7 @@ class FormatOps(
nlSplitFunc: Int => Split,
spaceIndents: Seq[Indent] = Seq.empty
)(implicit style: ScalafmtConfig): Seq[Split] = {
val bhead = body.tokens.head
def bheadFT = tokens.getHead(body)
val blastFT = tokens.getLastNonTrivial(body)
val blast = blastFT.left
val expire = nextNonCommentSameLine(blastFT).left
Expand Down Expand Up @@ -1695,7 +1690,7 @@ class FormatOps(
def hasStateColumn = spaceIndents.exists(_.hasStateColumn)
val (spaceSplit, nlSplit) = body match {
case t: Term.If if ifWithoutElse(t) || hasStateColumn =>
val thenBeg = tokens.after(t.thenp.tokens.head)
val thenBeg = tokens.getHead(t.thenp)
val thenHasLB = thenBeg.left.is[T.LeftBrace]
val end = if (thenHasLB) thenBeg else prevNonComment(prev(thenBeg))
getSplits(getSlbSplit(end.left))
Expand All @@ -1708,7 +1703,7 @@ class FormatOps(
if (!tokens.hasMatching(blast)) getSlbSplits()
else getSplits(getSpaceSplit(1))
case Term.ForYield(_, b) =>
nextNonComment(tokens(bhead)).right match {
nextNonComment(bheadFT).right match { // skipping `for`
case x @ LeftParenOrBrace() =>
val exclude = TokenRanges(TokenRange(x, matching(x)))
if (b.is[Term.Block])
Expand Down Expand Up @@ -1895,7 +1890,7 @@ class FormatOps(

def getMatchDot(tree: Term.Match): Option[FormatToken] =
if (dialect.allowMatchAsOperator) {
val ft = nextNonComment(tokens.getLast(tree.expr))
val ft = tokens.tokenAfter(tree.expr)
if (ft.right.is[T.Dot]) Some(ft) else None
} else None

Expand Down Expand Up @@ -2306,7 +2301,7 @@ class FormatOps(
style: ScalafmtConfig
): Option[OptionalBracesRegion] = {
def result(tree: Tree, cases: Seq[Tree]): Option[Seq[Split]] = {
val ok = cases.headOption.exists(_.tokens.head eq nft.right)
val ok = tokens.tokenJustBeforeOpt(cases).contains(nft)
if (ok) Some(getSplits(ft, tree, true)) else None
}
ft.meta.leftOwner match {
Expand Down Expand Up @@ -2407,7 +2402,7 @@ class FormatOps(
fileLine: FileLine,
style: ScalafmtConfig
): Option[Seq[Split]] =
if (head.tokens.headOption.contains(nft.right)) Some {
if (tokens.tokenJustBeforeOpt(head).contains(nft)) Some {
val forceNL = nlOnly || shouldBreakInOptionalBraces(ft)
getSplits(ft, tail.lastOption.getOrElse(head), forceNL)
}
Expand Down Expand Up @@ -2474,10 +2469,10 @@ class FormatOps(
isTreeMultiStatBlock(tree) && !tokenBefore(tree).left.is[T.LeftBrace]

private def isBlockStart(tree: Term.Block, ft: FormatToken): Boolean =
tree.stats.headOption.exists(_.tokens.headOption.contains(ft.right))
tokens.tokenJustBeforeOpt(tree.stats).contains(ft)

@inline private def treeLast(tree: Tree): Option[T] =
tree.tokens.lastOption.map(tokens(_).left)
tokens.getLastOpt(tree).map(_.left)
@inline private def blockLast(tree: Tree): Option[T] =
if (isTreeMultiStatBlock(tree)) treeLast(tree) else None
@inline private def blockLast(tree: Term.Block): Option[T] =
Expand Down Expand Up @@ -2551,7 +2546,8 @@ class FormatOps(

private object BlockImpl extends Factory {
def getBlocks(ft: FormatToken, nft: FormatToken, all: Boolean): Result = {
def ok(stat: Tree): Boolean = stat.tokens.headOption.contains(nft.right)
def ok(stat: Tree): Boolean =
tokens.tokenJustBeforeOpt(stat).contains(nft)
val leftOwner = ft.meta.leftOwner
findTreeWithParentSimple(nft.meta.rightOwner)(_ eq leftOwner) match {
case Some(t: Term.Block) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,15 +125,18 @@ class FormatTokens(leftTok2tok: Map[TokenOps.TokenHash, Int])(
final def prevNonComment(curr: FormatToken): FormatToken =
findToken(curr, prev)(!_.left.is[Token.Comment]).fold(identity, identity)

def getHead(tree: Tree): FormatToken =
after(tree.tokens.head)
def getHeadOpt(tree: Tree): Option[FormatToken] =
tree.tokens.headOption.map(after)

def getLast(tree: Tree): FormatToken =
apply(TokenOps.findLastVisibleToken(tree.tokens))

def getLastOpt(tree: Tree): Option[FormatToken] =
TokenOps.findLastVisibleTokenOpt(tree.tokens).map(apply)

def getLastNonTrivial(tree: Tree): FormatToken =
apply(TokenOps.findLastNonTrivialToken(tree.tokens))

def getLastNonTrivialOpt(tree: Tree): Option[FormatToken] =
TokenOps.findLastNonTrivialTokenOpt(tree.tokens).map(apply)

Expand All @@ -146,13 +149,23 @@ class FormatTokens(leftTok2tok: Map[TokenOps.TokenHash, Int])(
@inline
def tokenAfter(trees: Seq[Tree]): FormatToken = tokenAfter(trees.last)

def tokenAfterOpt(tree: Tree): Option[FormatToken] =
getLastOpt(tree).map(nextNonComment)
def tokenAfterOpt(trees: Seq[Tree]): Option[FormatToken] =
trees.lastOption.flatMap(tokenAfterOpt)

/* the following methods return the last format token such that
* its `left` is before the parameter */
@inline
def justBefore(token: Token): FormatToken = apply(token, -1)
@inline
def tokenJustBefore(tree: Tree): FormatToken = justBefore(tree.tokens.head)

def tokenJustBeforeOpt(tree: Tree): Option[FormatToken] =
tree.tokens.headOption.map(justBefore)
def tokenJustBeforeOpt(trees: Seq[Tree]): Option[FormatToken] =
trees.headOption.flatMap(tokenJustBeforeOpt)

/* the following methods return the last format token such that
* its `left` is before the parameter and is not a comment */
@inline
Expand All @@ -162,6 +175,11 @@ class FormatTokens(leftTok2tok: Map[TokenOps.TokenHash, Int])(
@inline
def tokenBefore(trees: Seq[Tree]): FormatToken = tokenBefore(trees.head)

def tokenBeforeOpt(tree: Tree): Option[FormatToken] =
tokenJustBeforeOpt(tree).map(prevNonComment)
def tokenBeforeOpt(trees: Seq[Tree]): Option[FormatToken] =
trees.headOption.flatMap(tokenBeforeOpt)

@inline
def isBreakAfterRight(ft: FormatToken): Boolean =
next(ft).hasBreakOrEOF
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -573,7 +573,7 @@ class FormatWriter(formatOps: FormatOps) {
case _ => 2
}
val tiState =
locations(tokens(ti.tokens.head).meta.idx).state.prev
locations(tokens.getHead(ti).meta.idx).state.prev
val indent =
if (style.align.stripMargin) tiState.column + alignPipeOffset
else tiState.indentation + offset
Expand Down Expand Up @@ -1168,7 +1168,7 @@ class FormatWriter(formatOps: FormatOps) {
}

private def isEarlierLine(t: Tree)(implicit fl: FormatLocation): Boolean = {
val idx = tokens.after(t.tokens.head).meta.idx + 1
val idx = tokens.getHead(t).meta.idx + 1
idx <= fl.formatToken.meta.idx && // e.g., leading comments
locations(idx).leftLineId != fl.leftLineId
}
Expand Down Expand Up @@ -1221,8 +1221,9 @@ class FormatWriter(formatOps: FormatOps) {
val beg = mods.lastOption.fold(tokens.after(ptokens.head)) { m =>
tokens.next(tokens.tokenAfter(m))
}
val end = b.tokens.headOption
.fold(tokens.before(ptokens.last))(tokens.justBefore)
val end = tokens
.tokenJustBeforeOpt(b)
.getOrElse(tokens.before(ptokens.last))
getLineDiff(locations, beg, end) == 0
}
if (keepGoing) getAlignContainerParent(p) else p
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -618,7 +618,7 @@ class Router(formatOps: FormatOps) {
}
val defn = isDefnSite(rightOwner)
val defRhs = if (defn) defDefBody(rightOwner) else None
val beforeDefRhs = defRhs.flatMap(_.tokens.headOption.map(tokenBefore))
val beforeDefRhs = defRhs.flatMap(tokens.tokenJustBeforeOpt)
def getSplitsBeforeOpenParen(
src: Newlines.SourceHints,
indentLen: Int
Expand Down Expand Up @@ -906,8 +906,8 @@ class Router(formatOps: FormatOps) {
else
getAssignAtSingleArgCallSite(leftOwner).map { assign =>
val assignToken = assign.rhs match {
case b: Term.Block => b.tokens.head
case _ => assign.tokens.find(_.is[T.Equals]).get
case b: Term.Block => tokens.getHead(b)
case _ => tokens(assign.tokens.find(_.is[T.Equals]).get)
}
val breakToken = getOptimalTokenFor(assignToken)
val newlineAfterAssignDecision =
Expand Down Expand Up @@ -937,7 +937,7 @@ class Router(formatOps: FormatOps) {

val excludeBlocks =
if (isBracket) {
val excludeBeg = if (align) tokens(args.last.tokens.head) else tok
val excludeBeg = if (align) tokens.getHead(args.last) else tok
insideBlock[T.LeftBracket](excludeBeg, close)
} else if (
multipleArgs ||
Expand All @@ -951,7 +951,7 @@ class Router(formatOps: FormatOps) {
singleArgument && isExcludedTree(args(0))
}
)
parensTuple(args(0).tokens.last)
parensTuple(tokens.getLast(args(0)).left)
else insideBracesBlock(tok, close)

def singleLine(
Expand Down Expand Up @@ -1319,7 +1319,7 @@ class Router(formatOps: FormatOps) {
if style.newlines.avoidInResultType =>
val expire = returnType match {
case Type.Refine(_, headStat :: _) =>
tokens(headStat.tokens.head, -1).left
tokens.tokenJustBefore(headStat).left
case t => getLastNonTrivialToken(t)
}
Seq(Split(Space, 0).withPolicy(SingleLineBlock(expire, okSLC = true)))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ final case class State(
if (ok) prev.getLineStartOwner(isComment) else None
} else {
def startsWithLeft(tree: meta.Tree): Boolean =
tree.tokens.headOption.contains(ft.left)
tokens.getHeadOpt(tree).contains(ft)
val ro = ft.meta.rightOwner
val owner =
if (startsWithLeft(ro)) Some(ro)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ class RedundantBraces(ftoks: FormatTokens) extends FormatTokensRewrite.Rule {
f.parent.flatMap(okToReplaceFunctionInSingleArgApply).exists(_._2 eq f)

private def getOpeningParen(t: Term.Apply): Option[Token.LeftParen] =
ftoks.nextNonComment(ftoks(t.fun.tokens.last)).right match {
ftoks.tokenAfter(t.fun).right match {
case lp: Token.LeftParen => Some(lp)
case _ => None
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ class RedundantParens(ftoks: FormatTokens) extends FormatTokensRewrite.Rule {
}

private def breaksBeforeOp(ia: InfixApp): Boolean = {
val beforeOp = ftoks(ia.op.tokens.head, -1)
val beforeOp = ftoks.tokenJustBefore(ia.op)
ftoks.prevNonCommentSameLine(beforeOp).hasBreak || (ia.lhs match {
case InfixApp(lhsApp) if breaksBeforeOpAndNotEnclosed(lhsApp) => true
case _ =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,10 +97,8 @@ private class RemoveScala3OptionalBraces(ftoks: FormatTokens)
case _: Token.KwIf => true
case _: Token.KwThen => true
case _: Token.KwElse =>
!TreeOps.isTreeMultiStatBlock(t.elsep) || {
val endOfCond = ftoks(t.cond.tokens.last)
ftoks.nextNonComment(endOfCond).right.is[Token.KwThen]
}
!TreeOps.isTreeMultiStatBlock(t.elsep) ||
ftoks.tokenAfter(t.cond).right.is[Token.KwThen]
case _: Token.RightParen => allowOldSyntax
case _ => false
}
Expand Down

0 comments on commit f96e495

Please sign in to comment.