Skip to content

Commit

Permalink
Fail the build on import cycles.
Browse files Browse the repository at this point in the history
I really dislike this limitation but protoc enforces it, and not enforcing
it makes interop with protoc.
  • Loading branch information
swankjesse committed Jul 1, 2020
1 parent aa8264f commit 0e69fae
Show file tree
Hide file tree
Showing 4 changed files with 400 additions and 0 deletions.
Expand Up @@ -22,6 +22,7 @@ import com.google.common.jimfs.Jimfs
import com.squareup.wire.testing.add
import org.assertj.core.api.Assertions.assertThat
import org.junit.Test
import kotlin.test.assertFailsWith

class LinkerTest {
private val fs = Jimfs.newFileSystem(Configuration.unix())
Expand Down Expand Up @@ -189,6 +190,106 @@ class LinkerTest {
assertThat(enumValueDeprecated!!.encodeMode).isNotNull()
}

@Test
fun singleFileImportCycle() {
fs.add("source-path/ouroboros.proto", """
|syntax = "proto2";
|import "ouroboros.proto";
|message Snake {
|}
""".trimMargin())
fs.add("proto-path/unused.proto", "")

val exception = assertFailsWith<SchemaException> {
loadAndLinkSchema()
}
assertThat(exception).hasMessage("""
|imports form a cycle:
| ouroboros.proto:
| import "ouroboros.proto";
""".trimMargin())
}

@Test
fun threeFileImportCycle() {
fs.add("source-path/paper.proto", """
|syntax = "proto2";
|import "rock.proto";
|message Paper {
|}
""".trimMargin())
fs.add("source-path/rock.proto", """
|syntax = "proto2";
|import "scissors.proto";
|message Rock {
|}
""".trimMargin())
fs.add("source-path/scissors.proto", """
|syntax = "proto2";
|import "paper.proto";
|message Scissors {
|}
""".trimMargin())
fs.add("proto-path/unused.proto", "")

val exception = assertFailsWith<SchemaException> {
loadAndLinkSchema()
}
assertThat(exception).hasMessage("""
|imports form a cycle:
| paper.proto:
| import "rock.proto";
| rock.proto:
| import "scissors.proto";
| scissors.proto:
| import "paper.proto";
""".trimMargin())
}

@Test
fun multipleCycleImportProblem() {
fs.add("source-path/a.proto", """
|syntax = "proto2";
|import "b.proto";
|import "d.proto";
|message A {
|}
""".trimMargin())
fs.add("source-path/b.proto", """
|syntax = "proto2";
|import "c.proto";
|message B {
|}
""".trimMargin())
fs.add("source-path/c.proto", """
|syntax = "proto2";
|import "a.proto";
|import "b.proto";
|message C {
|}
""".trimMargin())
fs.add("source-path/d.proto", """
|syntax = "proto2";
|message D {
|}
""".trimMargin())
fs.add("proto-path/unused.proto", "")

val exception = assertFailsWith<SchemaException> {
loadAndLinkSchema()
}
assertThat(exception).hasMessage("""
|imports form a cycle:
| a.proto:
| import "b.proto";
| b.proto:
| import "c.proto";
| c.proto:
| import "a.proto";
| import "b.proto";
""".trimMargin())
}

private fun loadAndLinkSchema(): Schema {
NewSchemaLoader(fs).use { loader ->
loader.initRoots(
Expand Down
Expand Up @@ -16,6 +16,7 @@
package com.squareup.wire.schema

import com.squareup.wire.schema.ProtoType.Companion.get
import com.squareup.wire.schema.internal.DagChecker
import com.squareup.wire.schema.internal.MutableQueue
import com.squareup.wire.schema.internal.isValidTag
import com.squareup.wire.schema.internal.mutableQueueOf
Expand Down Expand Up @@ -109,6 +110,8 @@ class Linker {
fileLinker.validate(syntaxRules)
}

checkForImportCycles()

if (errors.isNotEmpty()) {
throw SchemaException(errors)
}
Expand All @@ -133,6 +136,17 @@ class Linker {
return Schema(result)
}

private fun checkForImportCycles() {
val dagChecker = DagChecker(fileLinkers.keys) {
val fileLinker = fileLinkers[it] ?: return@DagChecker listOf<String>()
fileLinker.protoFile.imports + fileLinker.protoFile.publicImports
}
val cycles = dagChecker.check()
for (cycle in cycles) {
addError(importCycleMessageError(cycle))
}
}

/** Returns the type name for the scalar, relative or fully-qualified name [name]. */
fun resolveType(name: String): ProtoType {
return resolveType(name, false)
Expand Down Expand Up @@ -436,6 +450,29 @@ class Linker {
}
}

/** Returns an error message that describes cyclic imports in [files]. */
private fun importCycleMessageError(files: List<String>): String {
return buildString {
append("imports form a cycle:")

for (file in files) {
val fileLinker = fileLinkers[file] ?: continue

append("\n $file:")
for (import in fileLinker.protoFile.imports) {
if (import in files) {
append("\n import \"$import\";")
}
}
for (import in fileLinker.protoFile.publicImports) {
if (import in files) {
append("\n import public \"$import\";")
}
}
}
}
}

/** Returns a new linker that uses [context] to resolve type names and report errors. */
fun withContext(context: Any): Linker {
return Linker(this, context)
Expand Down
@@ -0,0 +1,116 @@
/*
* Copyright (C) 2020 Square, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.squareup.wire.schema.internal

/**
* Checks whether a graph is a directed acyclic graph using [Tarjan's algorithm][tarjan].
*
* Note that all cycles are strongly connected components, but a strongly connected component is not
* strictly a cycle. In particular it may contain nodes that are mutually reachable from each other
* through multiple paths.
*
* [tarjan]: https://en.wikipedia.org/wiki/Tarjan%27s_strongly_connected_components_algorithm
*/
class DagChecker<N>(
private val nodes: Iterable<N>,
private val edges: (N) -> Iterable<N>
) {
private var nextDiscoveryId = 0
private val tags = nodes.associateWith { Tag(it) }
private val stack = mutableListOf<Tag<N>>()
private val result = mutableSetOf<List<N>>()

private val N.tag: Tag<N>?
get() = tags[this]

/**
* Returns a set of strongly connected components. Each strongly connected component is a list of
* nodes that are mutually reachable to each other.
*
* If the graph contains nodes that have self edges but are not strongly connected to any other
* nodes, those nodes will be single-element lists in the result.
*
* If the result is empty the graph is acyclic.
*/
fun check(): Set<List<N>> {
check(nextDiscoveryId == 0)

for (node in nodes) {
val tag = node.tag!!
if (tag.discoveryId == -1) {
tag.discoverDepthFirst()
}
}

return result
}

/**
* Traverse this node and all of the nodes it can reach. This returns the lowest discovery ID of
* the set of nodes strongly connected to this node.
*/
private fun Tag<N>.discoverDepthFirst(): Int {
discoveryId = nextDiscoveryId
lowestConnectedDiscoveryId = nextDiscoveryId
nextDiscoveryId++

val stackIndex = stack.size
stack += this
onStack = true

for (target in edges(node)) {
val t = target.tag ?: error("edge target not in graph: $node$target")

if (t.discoveryId == -1) {
// Traverse a new node. If in the process it received a lower discovery ID, it must be
// strongly connected to this node! Take that lower discovery ID.
lowestConnectedDiscoveryId = minOf(lowestConnectedDiscoveryId, t.discoverDepthFirst())
} else if (t.onStack) {
// Node a new node, but one we're in a cycle with. Take its discover ID if it's lower.
if (t == this) selfEdge = true
lowestConnectedDiscoveryId = minOf(lowestConnectedDiscoveryId, t.discoveryId)
}
}

// We've traversed all the edges. If our discovery ID is the lowest then we're the root of our
// strongly connected component. Include it in the result.
if (discoveryId == lowestConnectedDiscoveryId) {
val slice = stack.subList(stackIndex, stack.size)
val component = slice.toList()
slice.clear()

for (tag in component) {
tag.onStack = false
}

if (component.size > 1 || component.single().selfEdge) {
result += component.map { it.node }
}
}

return lowestConnectedDiscoveryId
}

private class Tag<N>(
var node: N
) {
var discoveryId = -1
var lowestConnectedDiscoveryId = -1
var onStack = false
var selfEdge = false
}
}

0 comments on commit 0e69fae

Please sign in to comment.