From d7e0870e4f81d9bcc0f71fee9e22f9c6cf9e07aa Mon Sep 17 00:00:00 2001 From: Marcus Hultman Date: Sun, 17 May 2020 19:59:01 +0200 Subject: [PATCH] implement import directive --- src/index.js | 265 ++++++++++++++++++----- test/integration/importdirective.spec.ts | 16 ++ test/protos/importdirective.proto | 11 + test/protos/imported.proto | 11 + 4 files changed, 253 insertions(+), 50 deletions(-) create mode 100644 test/integration/importdirective.spec.ts create mode 100644 test/protos/importdirective.proto create mode 100644 test/protos/imported.proto diff --git a/src/index.js b/src/index.js index 12159ef1..9d95ccb1 100644 --- a/src/index.js +++ b/src/index.js @@ -1,6 +1,7 @@ const plugin = require("google-protobuf/google/protobuf/compiler/plugin_pb"); const descriptorpb = require("google-protobuf/google/protobuf/descriptor_pb"); const fs = require("fs"); +const path = require("path"); const ts = require("typescript"); function createImport(identifier, moduleSpecifier) { @@ -12,7 +13,7 @@ function createImport(identifier, moduleSpecifier) { ); } -function createToObject(rootDescriptor, messageDescriptor) { +function createToObject(rootDescriptor, messageDescriptor, getNamedImport) { const properties = []; for (const fd of messageDescriptor.getFieldList()) { @@ -34,8 +35,10 @@ function createToObject(rootDescriptor, messageDescriptor) { "item", undefined, ts.createTypeReferenceNode( - ts.createIdentifier( - getTypeName(fd, rootDescriptor.getPackage()) + createTypeIdentifier( + fd, + rootDescriptor.getPackage(), + getNamedImport ), undefined ) @@ -106,7 +109,7 @@ function createNamespace(packageName, statements) { return statements; } -function createTypeLiteral(rootDescriptor, messageDescriptor) { +function createTypeLiteral(rootDescriptor, messageDescriptor, getNamedImport) { const members = []; for (const fieldDescriptor of messageDescriptor.getFieldList()) { @@ -117,7 +120,7 @@ function createTypeLiteral(rootDescriptor, messageDescriptor) { fieldDescriptor.getName(), ts.createToken(ts.SyntaxKind.QuestionToken), wrapRepeatedType( - getType(fieldDescriptor, rootDescriptor.getPackage()), + getType(fieldDescriptor, rootDescriptor.getPackage(), getNamedImport), fieldDescriptor ), undefined @@ -127,7 +130,12 @@ function createTypeLiteral(rootDescriptor, messageDescriptor) { return ts.createTypeLiteralNode(members); } -function createConstructor(rootDescriptor, messageDescriptor, pbIdentifier) { +function createConstructor( + rootDescriptor, + messageDescriptor, + pbIdentifier, + getNamedImport +) { const statements = []; const dataIdentifier = ts.createIdentifier("data"); @@ -135,7 +143,7 @@ function createConstructor(rootDescriptor, messageDescriptor, pbIdentifier) { ts.createArrayTypeNode( ts.createTypeReferenceNode(ts.createIdentifier("any"), undefined) ) /* any[] */, - createTypeLiteral(rootDescriptor, messageDescriptor), + createTypeLiteral(rootDescriptor, messageDescriptor, getNamedImport), ]); // Create super(); statement @@ -256,7 +264,7 @@ function wrapOptinalType(type, fieldDescriptor) { return type; } -function getType(fieldDescriptor, packageName) { +function getType(fieldDescriptor, packageName, getNamedImport) { switch (fieldDescriptor.getType()) { case descriptorpb.FieldDescriptorProto.Type.TYPE_DOUBLE: case descriptorpb.FieldDescriptorProto.Type.TYPE_FLOAT: @@ -280,18 +288,26 @@ function getType(fieldDescriptor, packageName) { return ts.createIdentifier("Uint8Array"); case descriptorpb.FieldDescriptorProto.Type.TYPE_MESSAGE: case descriptorpb.FieldDescriptorProto.Type.TYPE_ENUM: - return ts.createIdentifier(getTypeName(fieldDescriptor, packageName)); + return createTypeIdentifier(fieldDescriptor, packageName, getNamedImport); default: throw new Error("Unhandled type " + fieldDescriptor.getType()); } } -function getTypeName(fieldDescriptor, packageName) { +function createTypeIdentifier(fieldDescriptor, packageName, getNamedImport) { if (packageName == undefined) { throw new TypeError(); } - return normalizeTypeName(fieldDescriptor.getTypeName(), packageName); + const namedImport = getNamedImport(fieldDescriptor); + const normalized = normalizeTypeName( + fieldDescriptor.getTypeName(), + packageName + ); + + return namedImport + ? ts.createPropertyAccess(namedImport, normalized.replace(/^[^.]+./, "")) + : ts.createIdentifier(normalized); } function normalizeTypeName(name, packageName) { @@ -353,9 +369,14 @@ function toBinaryMethodName(fieldDescriptor, descriptor, isWriter = true) { } // Returns a get accessor for the field -function createGetter(rootDescriptor, fieldDescriptor, pbIdentifier) { +function createGetter( + rootDescriptor, + fieldDescriptor, + pbIdentifier, + getNamedImport +) { let type = wrapRepeatedType( - getType(fieldDescriptor, rootDescriptor.getPackage()), + getType(fieldDescriptor, rootDescriptor.getPackage(), getNamedImport), fieldDescriptor ); const getterType = wrapOptinalType(type, fieldDescriptor); @@ -372,7 +393,8 @@ function createGetter(rootDescriptor, fieldDescriptor, pbIdentifier) { fieldDescriptor, pbIdentifier, getterType, - rootDescriptor.getPackage() + rootDescriptor.getPackage(), + getNamedImport ) ), ], @@ -382,7 +404,13 @@ function createGetter(rootDescriptor, fieldDescriptor, pbIdentifier) { } // Returns the inner logic of the field accessor. -function createGetterCall(fieldDescriptor, pbIdentifier, type, packageName) { +function createGetterCall( + fieldDescriptor, + pbIdentifier, + type, + packageName, + getNamedImport +) { let calle = ts.createIdentifier("getFieldWithDefault"); let args = [ @@ -403,7 +431,7 @@ function createGetterCall(fieldDescriptor, pbIdentifier, type, packageName) { args.splice( 1, 0, - ts.createIdentifier(getTypeName(fieldDescriptor, packageName)) + createTypeIdentifier(fieldDescriptor, packageName, getNamedImport) ); args.pop(); } @@ -422,9 +450,14 @@ function createGetterCall(fieldDescriptor, pbIdentifier, type, packageName) { } // Returns a set accessor for the field -function createSetter(rootDescriptor, fieldDescriptor, pbIdentifier) { +function createSetter( + rootDescriptor, + fieldDescriptor, + pbIdentifier, + getNamedImport +) { let type = wrapRepeatedType( - getType(fieldDescriptor, rootDescriptor.getPackage()), + getType(fieldDescriptor, rootDescriptor.getPackage(), getNamedImport), fieldDescriptor ); const paramIdentifier = ts.createIdentifier("value"); @@ -472,7 +505,7 @@ function createSetter(rootDescriptor, fieldDescriptor, pbIdentifier) { * Returns the serialize method for the message class * TODO: Split this function into chunk functions */ -function createSerialize(rootDescriptor, fields, pbIdentifier) { +function createSerialize(rootDescriptor, fields, pbIdentifier, getNamedImport) { return ts.createMethod( undefined, undefined, @@ -546,11 +579,10 @@ function createSerialize(rootDescriptor, fields, pbIdentifier) { "item", undefined, ts.createTypeReferenceNode( - ts.createIdentifier( - getTypeName( - fieldDescriptor, - rootDescriptor.getPackage() - ) + createTypeIdentifier( + fieldDescriptor, + rootDescriptor.getPackage(), + getNamedImport ), undefined ) @@ -669,7 +701,12 @@ function createSerializeBinary(rootDescriptor, fields, pbIdentifier) { * Returns the deserialize method for the message class * TODO: Split this function into chunk functions */ -function createDeserialize(rootDescriptor, messageDescriptor, pbIdentifier) { +function createDeserialize( + rootDescriptor, + messageDescriptor, + pbIdentifier, + getNamedImport +) { return ts.createMethod( undefined, [ts.createModifier(ts.SyntaxKind.StaticKeyword)], @@ -802,8 +839,10 @@ function createDeserialize(rootDescriptor, messageDescriptor, pbIdentifier) { } else if (isMessage(fd)) { const readCall = ts.createCall( ts.createPropertyAccess( - ts.createIdentifier( - getTypeName(fd, rootDescriptor.getPackage()) + createTypeIdentifier( + fd, + rootDescriptor.getPackage(), + getNamedImport ), "deserialize" ), @@ -848,11 +887,10 @@ function createDeserialize(rootDescriptor, messageDescriptor, pbIdentifier) { fd.getNumber().toString() ), readCall, - ts.createIdentifier( - getTypeName( - fd, - rootDescriptor.getPackage() - ) + createTypeIdentifier( + fd, + rootDescriptor.getPackage(), + getNamedImport ), ] ) @@ -926,29 +964,56 @@ function createDeserialize(rootDescriptor, messageDescriptor, pbIdentifier) { } // Returns a class for the message descriptor -function createMessage(rootDescriptor, messageDescriptor, pbIdentifier) { +function createMessage( + rootDescriptor, + messageDescriptor, + pbIdentifier, + getNamedImport +) { const members = []; // Create constructor members.push( - createConstructor(rootDescriptor, messageDescriptor, pbIdentifier) + createConstructor( + rootDescriptor, + messageDescriptor, + pbIdentifier, + getNamedImport + ) ); // Create getter and setters for (const fieldDescriptor of messageDescriptor.getFieldList()) { - members.push(createGetter(rootDescriptor, fieldDescriptor, pbIdentifier)); - members.push(createSetter(rootDescriptor, fieldDescriptor, pbIdentifier)); + members.push( + createGetter( + rootDescriptor, + fieldDescriptor, + pbIdentifier, + getNamedImport + ) + ); + members.push( + createSetter( + rootDescriptor, + fieldDescriptor, + pbIdentifier, + getNamedImport + ) + ); } // Create toObject method - members.push(createToObject(rootDescriptor, messageDescriptor)); + members.push( + createToObject(rootDescriptor, messageDescriptor, getNamedImport) + ); // Create serialize method members.push( createSerialize( rootDescriptor, messageDescriptor.getFieldList(), - pbIdentifier + pbIdentifier, + getNamedImport ) ); // Create serializeBinary method @@ -962,7 +1027,12 @@ function createMessage(rootDescriptor, messageDescriptor, pbIdentifier) { // Create deserialize method members.push( - createDeserialize(rootDescriptor, messageDescriptor, pbIdentifier) + createDeserialize( + rootDescriptor, + messageDescriptor, + pbIdentifier, + getNamedImport + ) ); // Create message class @@ -1454,10 +1524,17 @@ function createServiceClient( ); } -function processMessageDescriptor(rootDescriptor, descriptor, pbIdentifier) { +function processMessageDescriptor( + rootDescriptor, + descriptor, + pbIdentifier, + getNamedImport +) { const statements = []; - statements.push(createMessage(rootDescriptor, descriptor, pbIdentifier)); + statements.push( + createMessage(rootDescriptor, descriptor, pbIdentifier, getNamedImport) + ); const namespacedStatements = []; @@ -1469,11 +1546,14 @@ function processMessageDescriptor(rootDescriptor, descriptor, pbIdentifier) { // Process nested messages if (descriptor.getNestedTypeList) { for (const nestedDescriptor of descriptor.getNestedTypeList()) { - namespacedStatements.push(...processMessageDescriptor( - rootDescriptor, - nestedDescriptor, - pbIdentifier - )); + namespacedStatements.push( + ...processMessageDescriptor( + rootDescriptor, + nestedDescriptor, + pbIdentifier, + getNamedImport + ) + ); } } @@ -1486,7 +1566,12 @@ function processMessageDescriptor(rootDescriptor, descriptor, pbIdentifier) { return statements; } -function processProtoDescriptor(rootDescriptor, descriptor, pbIdentifier) { +function processProtoDescriptor( + rootDescriptor, + descriptor, + pbIdentifier, + getNamedImport +) { const statements = []; // Process messages @@ -1496,7 +1581,8 @@ function processProtoDescriptor(rootDescriptor, descriptor, pbIdentifier) { ...processMessageDescriptor( rootDescriptor, messageDescriptor, - pbIdentifier + pbIdentifier, + getNamedImport ) ); } @@ -1505,6 +1591,33 @@ function processProtoDescriptor(rootDescriptor, descriptor, pbIdentifier) { return statements; } +function getExportPaths(prefix, descriptor) { + const exports = []; + if (descriptor.getMessageTypeList) { + for (const messageDescriptor of descriptor.getMessageTypeList()) { + const name = messageDescriptor.getName(); + exports.push( + [...prefix, name], + ...getExportPaths([...prefix, name], messageDescriptor) + ); + } + } + if (descriptor.getNestedTypeList) { + for (const nestedDescriptor of descriptor.getNestedTypeList()) { + const name = nestedDescriptor.getName(); + exports.push( + [...prefix, name], + ...getExportPaths([...prefix, name], nestedDescriptor) + ); + } + } + for (const enumDescriptor of descriptor.getEnumTypeList()) { + const name = enumDescriptor.getName(); + exports.push([...prefix, name]); + } + return exports; +} + function main() { const pbBuffer = fs.readFileSync(0); const pbVector = new Uint8Array(pbBuffer.length); @@ -1517,8 +1630,20 @@ function main() { const descriptors = codeGenRequest.getProtoFileList(); + // Exports (typeName -> { file, namedImport }) + const fileExports = {}; + for (const descriptor of descriptors) { - const name = descriptor.getName().replace(".proto", ".ts"); + const file = descriptor.getName(); + const packageName = descriptor.getPackage(); + for (const path of getExportPaths(packageName.split("."), descriptor)) { + fileExports["." + path.join(".")] = { file, namedImport: path[0] }; + } + } + + for (const descriptor of descriptors) { + const fileName = descriptor.getName(); + const name = fileName.replace(".proto", ".ts"); const codegenFile = new plugin.CodeGeneratorResponse.File(); const sf = ts.createSourceFile( @@ -1534,13 +1659,53 @@ function main() { const importStatements = []; + // Dependencies (file -> namedImport -> identifier) + const dependencies = {}; + // Create all messages recursively + // For imported types, assign a unique identifier to each typeName const statements = processProtoDescriptor( descriptor, descriptor, - pbIdentifier + pbIdentifier, + fieldDescriptor => { + const typeName = fieldDescriptor.getTypeName(); + if (!fileExports[typeName]) { + return; + } + const { file, namedImport } = fileExports[typeName]; + if (file === fileName) { + return; + } + if (!dependencies[file]) { + dependencies[file] = { }; + } + if (!dependencies[file][namedImport]) { + dependencies[file][namedImport] = ts.createUniqueName(namedImport); + } + return dependencies[file][namedImport]; + } ); + // Create all named imports from dependencies + for (const [file, namedImports] of Object.entries(dependencies)) { + const name = + "./" + + path.relative(path.dirname(fileName), file).replace(".proto", ""); + importStatements.push( + ts.createImportDeclaration( + undefined, + undefined, + ts.createNamedImports( + Object.entries(namedImports).map(([name, identifier]) => + ts.createImportSpecifier(ts.createIdentifier(name), identifier) + ) + ), + ts.createLiteral(name) + ) + ); + } + if (statements.length) { importStatements.push(createImport(pbIdentifier, "google-protobuf")); } diff --git a/test/integration/importdirective.spec.ts b/test/integration/importdirective.spec.ts new file mode 100644 index 00000000..9f4197aa --- /dev/null +++ b/test/integration/importdirective.spec.ts @@ -0,0 +1,16 @@ +import {importdirective as id1} from "../protos/imported"; +import {importdirective as id2} from "../protos/importdirective"; + +describe("Imported Proto", () => { + + it("should be serialized", () => { + const mymsg = new id2.a.Message({ + importedField: new id1.b.Imported({}) + }); + + const deserializedMessage = id2.a.Message.deserialize(mymsg.serialize()); + + expect(deserializedMessage.importedField instanceof id1.b.Imported).toBe(true); + }) + +}) diff --git a/test/protos/importdirective.proto b/test/protos/importdirective.proto new file mode 100644 index 00000000..36a5dd29 --- /dev/null +++ b/test/protos/importdirective.proto @@ -0,0 +1,11 @@ +syntax = "proto3"; + +package importdirective.a; + +import "test/protos/imported.proto"; + +message Message { + b.Imported importedField = 1; + b.Imported.SubMessage submessageField = 2; + b.Imported.SubMessage.MyEnum enumField = 3; +} \ No newline at end of file diff --git a/test/protos/imported.proto b/test/protos/imported.proto new file mode 100644 index 00000000..4b9546d2 --- /dev/null +++ b/test/protos/imported.proto @@ -0,0 +1,11 @@ +syntax = "proto3"; + +package importdirective.b; + +message Imported { + message SubMessage { + enum MyEnum { + VALUE = 0; + } + } +} \ No newline at end of file