Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fail the build on import cycles. #1612

Merged
merged 1 commit into from Jul 2, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Expand Up @@ -22,6 +22,7 @@ import com.google.common.jimfs.Jimfs
import com.squareup.wire.testing.add
import org.assertj.core.api.Assertions.assertThat
import org.junit.Test
import kotlin.test.assertFailsWith

class LinkerTest {
private val fs = Jimfs.newFileSystem(Configuration.unix())
Expand Down Expand Up @@ -189,6 +190,106 @@ class LinkerTest {
assertThat(enumValueDeprecated!!.encodeMode).isNotNull()
}

@Test
fun singleFileImportCycle() {
fs.add("source-path/ouroboros.proto", """
|syntax = "proto2";
|import "ouroboros.proto";
|message Snake {
|}
""".trimMargin())
fs.add("proto-path/unused.proto", "")

val exception = assertFailsWith<SchemaException> {
loadAndLinkSchema()
}
assertThat(exception).hasMessage("""
|imports form a cycle:
| ouroboros.proto:
| import "ouroboros.proto";
""".trimMargin())
}

@Test
fun threeFileImportCycle() {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I love this one

fs.add("source-path/paper.proto", """
|syntax = "proto2";
|import "rock.proto";
|message Paper {
|}
""".trimMargin())
fs.add("source-path/rock.proto", """
|syntax = "proto2";
|import "scissors.proto";
|message Rock {
|}
""".trimMargin())
fs.add("source-path/scissors.proto", """
|syntax = "proto2";
|import "paper.proto";
|message Scissors {
|}
""".trimMargin())
fs.add("proto-path/unused.proto", "")

val exception = assertFailsWith<SchemaException> {
loadAndLinkSchema()
}
assertThat(exception).hasMessage("""
|imports form a cycle:
| paper.proto:
| import "rock.proto";
| rock.proto:
| import "scissors.proto";
| scissors.proto:
| import "paper.proto";
""".trimMargin())
}

@Test
fun multipleCycleImportProblem() {
fs.add("source-path/a.proto", """
|syntax = "proto2";
|import "b.proto";
|import "d.proto";
|message A {
|}
""".trimMargin())
fs.add("source-path/b.proto", """
|syntax = "proto2";
|import "c.proto";
|message B {
|}
""".trimMargin())
fs.add("source-path/c.proto", """
|syntax = "proto2";
|import "a.proto";
|import "b.proto";
|message C {
|}
""".trimMargin())
fs.add("source-path/d.proto", """
|syntax = "proto2";
|message D {
|}
""".trimMargin())
fs.add("proto-path/unused.proto", "")

val exception = assertFailsWith<SchemaException> {
loadAndLinkSchema()
}
assertThat(exception).hasMessage("""
|imports form a cycle:
| a.proto:
| import "b.proto";
| b.proto:
| import "c.proto";
| c.proto:
| import "a.proto";
| import "b.proto";
""".trimMargin())
}

private fun loadAndLinkSchema(): Schema {
NewSchemaLoader(fs).use { loader ->
loader.initRoots(
Expand Down
Expand Up @@ -16,6 +16,7 @@
package com.squareup.wire.schema

import com.squareup.wire.schema.ProtoType.Companion.get
import com.squareup.wire.schema.internal.DagChecker
import com.squareup.wire.schema.internal.MutableQueue
import com.squareup.wire.schema.internal.isValidTag
import com.squareup.wire.schema.internal.mutableQueueOf
Expand Down Expand Up @@ -109,6 +110,8 @@ class Linker {
fileLinker.validate(syntaxRules)
}

checkForImportCycles()

if (errors.isNotEmpty()) {
throw SchemaException(errors)
}
Expand All @@ -133,6 +136,17 @@ class Linker {
return Schema(result)
}

private fun checkForImportCycles() {
val dagChecker = DagChecker(fileLinkers.keys) {
val fileLinker = fileLinkers[it] ?: return@DagChecker listOf<String>()
fileLinker.protoFile.imports + fileLinker.protoFile.publicImports
}
val cycles = dagChecker.check()
for (cycle in cycles) {
addError(importCycleMessageError(cycle))
}
}

/** Returns the type name for the scalar, relative or fully-qualified name [name]. */
fun resolveType(name: String): ProtoType {
return resolveType(name, false)
Expand Down Expand Up @@ -436,6 +450,29 @@ class Linker {
}
}

/** Returns an error message that describes cyclic imports in [files]. */
private fun importCycleMessageError(files: List<String>): String {
return buildString {
append("imports form a cycle:")

for (file in files) {
val fileLinker = fileLinkers[file] ?: continue

append("\n $file:")
for (import in fileLinker.protoFile.imports) {
if (import in files) {
append("\n import \"$import\";")
}
}
for (import in fileLinker.protoFile.publicImports) {
if (import in files) {
append("\n import public \"$import\";")
}
}
}
}
}

/** Returns a new linker that uses [context] to resolve type names and report errors. */
fun withContext(context: Any): Linker {
return Linker(this, context)
Expand Down
@@ -0,0 +1,116 @@
/*
* Copyright (C) 2020 Square, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.squareup.wire.schema.internal

/**
* Checks whether a graph is a directed acyclic graph using [Tarjan's algorithm][tarjan].
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ideas for interview question

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I definitely needed a lot of time to get this right, including reading the code sample from Wikipedia repeatedly. I would fail an interview that asks me to produce this!

*
* Note that all cycles are strongly connected components, but a strongly connected component is not
* strictly a cycle. In particular it may contain nodes that are mutually reachable from each other
* through multiple paths.
*
* [tarjan]: https://en.wikipedia.org/wiki/Tarjan%27s_strongly_connected_components_algorithm
*/
class DagChecker<N>(
private val nodes: Iterable<N>,
private val edges: (N) -> Iterable<N>
) {
private var nextDiscoveryId = 0
private val tags = nodes.associateWith { Tag(it) }
private val stack = mutableListOf<Tag<N>>()
private val result = mutableSetOf<List<N>>()

private val N.tag: Tag<N>?
get() = tags[this]

/**
* Returns a set of strongly connected components. Each strongly connected component is a list of
* nodes that are mutually reachable to each other.
*
* If the graph contains nodes that have self edges but are not strongly connected to any other
* nodes, those nodes will be single-element lists in the result.
*
* If the result is empty the graph is acyclic.
*/
fun check(): Set<List<N>> {
check(nextDiscoveryId == 0)

for (node in nodes) {
val tag = node.tag!!
if (tag.discoveryId == -1) {
tag.discoverDepthFirst()
}
}

return result
}

/**
* Traverse this node and all of the nodes it can reach. This returns the lowest discovery ID of
* the set of nodes strongly connected to this node.
*/
private fun Tag<N>.discoverDepthFirst(): Int {
discoveryId = nextDiscoveryId
lowestConnectedDiscoveryId = nextDiscoveryId
nextDiscoveryId++

val stackIndex = stack.size
stack += this
onStack = true

for (target in edges(node)) {
val t = target.tag ?: continue

if (t.discoveryId == -1) {
// Traverse a new node. If in the process it received a lower discovery ID, it must be
// strongly connected to this node! Take that lower discovery ID.
lowestConnectedDiscoveryId = minOf(lowestConnectedDiscoveryId, t.discoverDepthFirst())
} else if (t.onStack) {
// Node a new node, but one we're in a cycle with. Take its discover ID if it's lower.
if (t == this) selfEdge = true
lowestConnectedDiscoveryId = minOf(lowestConnectedDiscoveryId, t.discoveryId)
}
}

// We've traversed all the edges. If our discovery ID is the lowest then we're the root of our
// strongly connected component. Include it in the result.
if (discoveryId == lowestConnectedDiscoveryId) {
val slice = stack.subList(stackIndex, stack.size)
val component = slice.toList()
slice.clear()

for (tag in component) {
tag.onStack = false
}

if (component.size > 1 || component.single().selfEdge) {
result += component.map { it.node }
}
}

return lowestConnectedDiscoveryId
}

private class Tag<N>(
var node: N
) {
var discoveryId = -1
var lowestConnectedDiscoveryId = -1
var onStack = false
var selfEdge = false
}
}