Skip to content

Commit

Permalink
allow multiple Swagger documents (#1022)
Browse files Browse the repository at this point in the history
* allow multiple Swagger documents (no breaking changes)
  • Loading branch information
h44z committed Oct 11, 2021
1 parent 9fb19d0 commit 18b2bd1
Show file tree
Hide file tree
Showing 5 changed files with 118 additions and 10 deletions.
7 changes: 7 additions & 0 deletions cmd/swag/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ const (
parseInternalFlag = "parseInternal"
generatedTimeFlag = "generatedTime"
parseDepthFlag = "parseDepth"
instanceNameFlag = "instanceName"
)

var initFlags = []cli.Flag{
Expand Down Expand Up @@ -87,6 +88,11 @@ var initFlags = []cli.Flag{
Value: 100,
Usage: "Dependency parse depth",
},
&cli.StringFlag{
Name: instanceNameFlag,
Value: "",
Usage: "This parameter can be used to name different swagger document instances. It is optional.",
},
}

func initAction(c *cli.Context) error {
Expand All @@ -111,6 +117,7 @@ func initAction(c *cli.Context) error {
GeneratedTime: c.Bool(generatedTimeFlag),
CodeExampleFilesDir: c.String(codeExampleFilesFlag),
ParseDepth: c.Int(parseDepthFlag),
InstanceName: c.String(instanceNameFlag),
})
}

Expand Down
12 changes: 11 additions & 1 deletion gen/gen.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,10 +74,18 @@ type Config struct {

// ParseDepth dependency parse depth
ParseDepth int

// InstanceName is used to get distinct names for different swagger documents in the
// same project. The default value is "swagger".
InstanceName string
}

// Build builds swagger json file for given searchDir and mainAPIFile. Returns json
func (g *Gen) Build(config *Config) error {
if config.InstanceName == "" {
config.InstanceName = swag.Name
}

searchDirs := strings.Split(config.SearchDir, ",")
for _, searchDir := range searchDirs {
if _, err := os.Stat(searchDir); os.IsNotExist(err) {
Expand Down Expand Up @@ -233,6 +241,7 @@ func (g *Gen) writeGoDoc(packageName string, output io.Writer, swagger *spec.Swa
Title string
Description string
Version string
InstanceName string
}{
Timestamp: time.Now(),
GeneratedTime: config.GeneratedTime,
Expand All @@ -244,6 +253,7 @@ func (g *Gen) writeGoDoc(packageName string, output io.Writer, swagger *spec.Swa
Title: swagger.Info.Title,
Description: swagger.Info.Description,
Version: swagger.Info.Version,
InstanceName: config.InstanceName,
})
if err != nil {
return err
Expand Down Expand Up @@ -323,6 +333,6 @@ func (s *s) ReadDoc() string {
}
func init() {
swag.Register(swag.Name, &s{})
swag.Register({{ printf "%q" .InstanceName }}, &s{})
}
`
48 changes: 48 additions & 0 deletions gen/gen_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"os/exec"
"path/filepath"
"plugin"
"strings"
"testing"

"github.com/go-openapi/spec"
Expand Down Expand Up @@ -39,6 +40,53 @@ func TestGen_Build(t *testing.T) {
}
}

func TestGen_BuildInstanceName(t *testing.T) {
searchDir := "../testdata/simple"

config := &Config{
SearchDir: searchDir,
MainAPIFile: "./main.go",
OutputDir: "../testdata/simple/docs",
PropNamingStrategy: "",
}
assert.NoError(t, New().Build(config))

goSourceFile := filepath.Join(config.OutputDir, "docs.go")

// Validate default registration name
expectedCode, err := ioutil.ReadFile(goSourceFile)
if err != nil {
t.Fatal(err)
}
if !strings.Contains(string(expectedCode), "swag.Register(\"swagger\", &s{})") {
t.Fatal(errors.New("generated go code does not contain the correct default registration sequence"))
}

// Custom name
config.InstanceName = "custom"
assert.NoError(t, New().Build(config))
expectedCode, err = ioutil.ReadFile(goSourceFile)
if err != nil {
t.Fatal(err)
}
if !strings.Contains(string(expectedCode), "swag.Register(\"custom\", &s{})") {
t.Fatal(errors.New("generated go code does not contain the correct registration sequence"))
}

// cleanup
expectedFiles := []string{
filepath.Join(config.OutputDir, "docs.go"),
filepath.Join(config.OutputDir, "swagger.json"),
filepath.Join(config.OutputDir, "swagger.yaml"),
}
for _, expectedFile := range expectedFiles {
if _, err := os.Stat(expectedFile); os.IsNotExist(err) {
t.Fatal(err)
}
_ = os.Remove(expectedFile)
}
}

func TestGen_BuildSnakecase(t *testing.T) {
searchDir := "../testdata/simple2"
config := &Config{
Expand Down
35 changes: 27 additions & 8 deletions swagger.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package swag

import (
"errors"
"fmt"
"sync"
)

Expand All @@ -10,7 +11,7 @@ const Name = "swagger"

var (
swaggerMu sync.RWMutex
swag Swagger
swags map[string]Swagger
)

// Swagger is a interface to read swagger document.
Expand All @@ -26,17 +27,35 @@ func Register(name string, swagger Swagger) {
panic("swagger is nil")
}

if swag != nil {
if swags == nil {
swags = make(map[string]Swagger)
}

if _, ok := swags[name]; ok {
panic("Register called twice for swag: " + name)
}
swag = swagger
swags[name] = swagger
}

// ReadDoc reads swagger document.
func ReadDoc() (string, error) {
if swag != nil {
return swag.ReadDoc(), nil
// ReadDoc reads swagger document. An optional name parameter can be passed to read a specific document.
// The default name is "swagger".
func ReadDoc(optionalName ...string) (string, error) {
swaggerMu.RLock()
defer swaggerMu.RUnlock()

if swags == nil {
return "", errors.New("no swag has yet been registered")
}

name := Name
if len(optionalName) != 0 && optionalName[0] != "" {
name = optionalName[0]
}

swag, ok := swags[name]
if !ok {
return "", fmt.Errorf("no swag named \"%s\" was registered", name)
}

return "", errors.New("not yet registered swag")
return swag.ReadDoc(), nil
}
26 changes: 25 additions & 1 deletion swagger_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -162,12 +162,36 @@ func TestRegister(t *testing.T) {
assert.Equal(t, doc, d)
}

func TestRegisterByName(t *testing.T) {
setup()
Register("another_name", &s{})
d, _ := ReadDoc("another_name")
assert.Equal(t, doc, d)
}

func TestRegisterMultiple(t *testing.T) {
setup()
Register(Name, &s{})
Register("another_name", &s{})
d1, _ := ReadDoc(Name)
d2, _ := ReadDoc("another_name")
assert.Equal(t, doc, d1)
assert.Equal(t, doc, d2)
}

func TestReadDocBeforeRegistered(t *testing.T) {
setup()
_, err := ReadDoc()
assert.Error(t, err)
}

func TestReadDocWithInvalidName(t *testing.T) {
setup()
Register(Name, &s{})
_, err := ReadDoc("invalid")
assert.Error(t, err)
}

func TestNilRegister(t *testing.T) {
setup()
var swagger Swagger
Expand All @@ -185,5 +209,5 @@ func TestCalledTwicelRegister(t *testing.T) {
}

func setup() {
swag = nil
swags = nil
}

0 comments on commit 18b2bd1

Please sign in to comment.