Skip to content

Commit

Permalink
generateModule: Simplify implementation
Browse files Browse the repository at this point in the history
With support for only one file, generateModule can be simplified.
It no longer needs to instantiate a map (which receives only one entry).
Instead it can return the intended file path and contents of the file in
one go and let the caller handle merging it into the main files map.
  • Loading branch information
abhinav committed May 29, 2019
1 parent da6e7ce commit 3da0099
Showing 1 changed file with 42 additions and 42 deletions.
84 changes: 42 additions & 42 deletions gen/generate.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,13 +105,15 @@ func Generate(m *compile.Module, o *Options) error {
genBuilder := newGenerateServiceBuilder(importer)

generate := func(m *compile.Module) error {
moduleFiles, err := generateModule(m, importer, genBuilder, o)
path, contents, err := generateModule(m, importer, genBuilder, o)
if err != nil {
return generateError{Name: m.ThriftPath, Reason: err}
}
if err := mergeFiles(files, moduleFiles); err != nil {

if err := addFile(files, path, contents); err != nil {
return generateError{Name: m.ThriftPath, Reason: err}
}

return nil
}

Expand Down Expand Up @@ -200,72 +202,83 @@ func (i thriftPackageImporter) Package(file string) (string, error) {
}

func mergeFiles(dest, src map[string][]byte) error {
var errors []error
var err error
for path, contents := range src {
if _, ok := dest[path]; ok {
errors = append(errors, fmt.Errorf("file generation conflict: "+
"multiple sources are trying to write to %q", path))
}
dest[path] = contents
err = multierr.Append(err, addFile(dest, path, contents))
}
return multierr.Combine(errors...)
return err
}

// generateModule returns a mapping from filename to file contents of files that
// should be generated relative to o.OutputDir.
func generateModule(m *compile.Module, i thriftPackageImporter, builder *generateServiceBuilder, o *Options) (map[string][]byte, error) {
func addFile(dest map[string][]byte, path string, contents []byte) error {
if _, ok := dest[path]; ok {
return fmt.Errorf("file generation conflict: "+
"multiple sources are trying to write to %q", path)
}
dest[path] = contents
return nil
}

// generateModule generates the code for the given Thrift file and returns the
// path to the output file relative to OutputDir and the contents of the file.
func generateModule(
m *compile.Module,
i thriftPackageImporter,
builder *generateServiceBuilder,
o *Options,
) (outputFilepath string, contents []byte, err error) {
// packageRelPath is the path relative to outputDir into which we'll be
// writing the package for this Thrift file. For $thriftRoot/foo/bar.thrift,
// packageRelPath is foo/bar, and packageDir is $outputDir/foo/bar. All
// files for bar.thrift will be written to the $outputDir/foo/bar/ tree. The
// package will be importable via $importPrefix/foo/bar.
packageRelPath, err := i.RelativePackage(m.ThriftPath)
if err != nil {
return nil, err
return "", nil, err
}

// TODO(abg): Prefer top-level package name from `namespace go` directive.
packageName := filepath.Base(packageRelPath)

// Output file name defaults to the package name.
outputFilename := packageName + ".go"
if len(o.OutputFile) > 0 {
outputFilename = o.OutputFile
}
outputFilepath = filepath.Join(packageRelPath, outputFilename)

// importPath is the full import path for the top-level package generated
// for this Thrift file.
importPath, err := i.Package(m.ThriftPath)
if err != nil {
return nil, err
return "", nil, err
}

// Mapping of file names relative to packageRelPath to their contents.
// Note that we need to return a mapping relative to o.OutputDir so we
// will prepend $packageRelPath/ to all these paths.
files := make(map[string][]byte)

gopts := &GeneratorOptions{
g := NewGenerator(&GeneratorOptions{
Importer: i,
ImportPath: importPath,
PackageName: packageName,
NoZap: o.NoZap,
}
g := NewGenerator(gopts)
})

if len(m.Constants) > 0 {
for _, constantName := range sortStringKeys(m.Constants) {
if err := Constant(g, m.Constants[constantName]); err != nil {
return nil, err
return "", nil, err
}
}
}

if len(m.Types) > 0 {
for _, typeName := range sortStringKeys(m.Types) {
if err := TypeDefinition(g, m.Types[typeName]); err != nil {
return nil, err
return "", nil, err
}
}
}

if !o.NoEmbedIDL {
if err := embedIDL(g, i, m); err != nil {
return nil, err
return "", nil, err
}
}

Expand All @@ -283,32 +296,19 @@ func generateModule(m *compile.Module, i thriftPackageImporter, builder *generat
// root services, even though they have information about the
// whole service tree.
if _, err := builder.AddRootService(service); err != nil {
return nil, err
return "", nil, err
}
}

if err = Services(g, m.Services); err != nil {
return nil, fmt.Errorf(
"could not generate code for services %v", err)
return "", nil, fmt.Errorf("could not generate code for services %v", err)
}
}

// OutputFile name defaults to the package name.
name := packageName + ".go"
if len(o.OutputFile) > 0 {
name = o.OutputFile
}

buff := new(bytes.Buffer)
if err := g.Write(buff, nil); err != nil {
return nil, fmt.Errorf(
"could not write file %q: %v", name, err)
return "", nil, fmt.Errorf("could not write output for file %q: %v", outputFilename, err)
}
files[name] = buff.Bytes()

newFiles := make(map[string][]byte, len(files))
for path, contents := range files {
newFiles[filepath.Join(packageRelPath, path)] = contents
}
return newFiles, nil
return outputFilepath, buff.Bytes(), nil
}

0 comments on commit 3da0099

Please sign in to comment.