Skip to content

Commit

Permalink
Search for extensions recursively in imports
Browse files Browse the repository at this point in the history
For #38
  • Loading branch information
thesamet committed Jan 17, 2021
1 parent fbd8b59 commit 76642ff
Showing 1 changed file with 14 additions and 14 deletions.
Expand Up @@ -25,7 +25,7 @@ private[compiler] object ResolvedFieldTransformation {
def apply(
currentFile: String,
ft: FieldTransformation,
extensions: Seq[FieldDescriptor]
extensions: Set[FieldDescriptor]
): ResolvedFieldTransformation = {
ResolvedFieldTransformation(
FieldTransformations.fieldMap(
Expand Down Expand Up @@ -68,8 +68,8 @@ private[compiler] object FieldTransformations {
else
matchContains(
currentFile,
fieldMap(currentFile, u.asInstanceOf[Message], Seq.empty),
fieldMap(currentFile, v.asInstanceOf[Message], Seq.empty)
fieldMap(currentFile, u.asInstanceOf[Message], Set.empty),
fieldMap(currentFile, v.asInstanceOf[Message], Set.empty)
)
}
}
Expand All @@ -81,7 +81,7 @@ private[compiler] object FieldTransformations {
): Seq[AuxFieldOptions] =
if (transforms.isEmpty) Seq.empty
else {
val extensions: Seq[FieldDescriptor] = fieldExtensionsForFile(f)
val extensions: Set[FieldDescriptor] = fieldExtensionsForFile(f)
def processFile: Seq[AuxFieldOptions] =
f.getMessageTypes().asScala.flatMap(processMessage(_)).toSeq

Expand Down Expand Up @@ -128,23 +128,23 @@ private[compiler] object FieldTransformations {
}
}

def fieldExtensionsForFile(f: FileDescriptor): Seq[FieldDescriptor] = {
(f +: f.getDependencies().asScala.toSeq).flatMap {
_.getExtensions().asScala.filter(
def fieldExtensionsForFile(f: FileDescriptor): Set[FieldDescriptor] = {
(f.getExtensions()
.asScala
.filter(
// Comparing the descriptors references directly will not work. The google.protobuf.FieldOptions
// we will get from `getContainingType` is the one we get from parsing the code generation request
// inputs, which are disjoint from the compilerplugin's FieldOptions.
_.getContainingType.getFullName == FieldOptions.getDescriptor().getFullName()
)
}
) ++ f.getDependencies().asScala.flatMap(fieldExtensionsForFile(_))).toSet
}

// Like m.getAllFields(), but also resolves unknown fields from extensions available in the scope
// of the message.
def fieldMap(
currentFile: String,
m: Message,
extensions: Seq[FieldDescriptor]
extensions: Set[FieldDescriptor]
): Map[FieldDescriptor, Any] = {
val unknownFields = for {
number <- m.getUnknownFields().asMap().keySet().asScala
Expand Down Expand Up @@ -199,7 +199,7 @@ private[compiler] object FieldTransformations {
def fieldByPath(
message: Message,
path: String,
extensions: Seq[FieldDescriptor]
extensions: Set[FieldDescriptor]
): String =
if (path.isEmpty()) throw new GeneratorException("Got an empty path")
else
Expand All @@ -212,7 +212,7 @@ private[compiler] object FieldTransformations {
message: Message,
path: List[String],
allPath: String,
extensions: Seq[FieldDescriptor]
extensions: Set[FieldDescriptor]
): Either[String, String] = {
for {
fieldName <- path.headOption.toRight("Got an empty path")
Expand Down Expand Up @@ -249,7 +249,7 @@ private[compiler] object FieldTransformations {
private[compiler] def interpolateStrings[T <: Message](
msg: T,
data: Message,
extensions: Seq[FieldDescriptor]
extensions: Set[FieldDescriptor]
): T = {
val b = msg.toBuilder()
for {
Expand Down Expand Up @@ -284,7 +284,7 @@ private[compiler] object FieldTransformations {
private[compiler] def interpolate(
value: String,
data: Message,
extensions: Seq[FieldDescriptor]
extensions: Set[FieldDescriptor]
): String =
replaceAll(value, FieldPath, m => fieldByPath(data, m.group(1), extensions))

Expand Down

0 comments on commit 76642ff

Please sign in to comment.