diff --git a/gen/gen.go b/gen/gen.go index 39ed0489b..9adc0967e 100644 --- a/gen/gen.go +++ b/gen/gen.go @@ -405,68 +405,23 @@ var packageTemplate = `// Package {{.PackageName}} GENERATED BY THE COMMAND ABOV // {{ .Timestamp }}{{ end }} package {{.PackageName}} -import ( - "bytes" - "encoding/json" - "strings" - "text/template" +import "github.com/swaggo/swag" - "github.com/swaggo/swag" -) - -var doc = ` + "`{{ printDoc .Doc}}`" + ` +const docTemplate_{{ .InstanceName }} = ` + "`{{ printDoc .Doc}}`" + ` -type swaggerInfo struct { - Version string - Host string - BasePath string - Schemes []string - Title string - Description string -} - -// SwaggerInfo holds exported Swagger Info so clients can modify it -var SwaggerInfo = swaggerInfo{ +// SwaggerInfo_{{ .InstanceName }} holds exported Swagger Info so clients can modify it +var SwaggerInfo_{{ .InstanceName }} = &swag.Spec{ Version: {{ printf "%q" .Version}}, Host: {{ printf "%q" .Host}}, BasePath: {{ printf "%q" .BasePath}}, Schemes: []string{ {{ range $index, $schema := .Schemes}}{{if gt $index 0}},{{end}}{{printf "%q" $schema}}{{end}} }, Title: {{ printf "%q" .Title}}, Description: {{ printf "%q" .Description}}, -} - -type s struct{} - -func (s *s) ReadDoc() string { - sInfo := SwaggerInfo - sInfo.Description = strings.Replace(sInfo.Description, "\n", "\\n", -1) - - t, err := template.New("swagger_info").Funcs(template.FuncMap{ - "marshal": func(v interface{}) string { - a, _ := json.Marshal(v) - return string(a) - }, - "escape": func(v interface{}) string { - // escape tabs - str := strings.Replace(v.(string), "\t", "\\t", -1) - // replace " with \", and if that results in \\", replace that with \\\" - str = strings.Replace(str, "\"", "\\\"", -1) - return strings.Replace(str, "\\\\\"", "\\\\\\\"", -1) - }, - }).Parse(doc) - if err != nil { - return doc - } - - var tpl bytes.Buffer - if err := t.Execute(&tpl, sInfo); err != nil { - return doc - } - - return tpl.String() + InfoInstanceName: {{ printf "%q" .InstanceName }}, + SwaggerTemplate: docTemplate_{{ .InstanceName }}, } func init() { - swag.Register({{ printf "%q" .InstanceName }}, &s{}) + swag.Register(SwaggerInfo_{{ .InstanceName }}.InstanceName(), SwaggerInfo_{{ .InstanceName }}) } ` diff --git a/gen/gen_test.go b/gen/gen_test.go index 0447ec9f5..05d0ee833 100644 --- a/gen/gen_test.go +++ b/gen/gen_test.go @@ -95,7 +95,10 @@ func TestGen_BuildInstanceName(t *testing.T) { if err != nil { require.NoError(t, err) } - if !strings.Contains(string(expectedCode), "swag.Register(\"swagger\", &s{})") { + if !strings.Contains( + string(expectedCode), + "swag.Register(SwaggerInfo_swagger.InstanceName(), SwaggerInfo_swagger)", + ) { t.Fatal(errors.New("generated go code does not contain the correct default registration sequence")) } @@ -107,7 +110,10 @@ func TestGen_BuildInstanceName(t *testing.T) { if err != nil { require.NoError(t, err) } - if !strings.Contains(string(expectedCode), "swag.Register(\"custom\", &s{})") { + if !strings.Contains( + string(expectedCode), + "swag.Register(SwaggerInfo_custom.InstanceName(), SwaggerInfo_custom)", + ) { t.Fatal(errors.New("generated go code does not contain the correct registration sequence")) } diff --git a/spec.go b/spec.go new file mode 100644 index 000000000..9e0ec1ad0 --- /dev/null +++ b/spec.go @@ -0,0 +1,54 @@ +package swag + +import ( + "bytes" + "encoding/json" + "strings" + "text/template" +) + +// Spec holds exported Swagger Info so clients can modify it. +type Spec struct { + Version string + Host string + BasePath string + Schemes []string + Title string + Description string + InfoInstanceName string + SwaggerTemplate string +} + +// ReadDoc parses SwaggerTemplate into swagger document. +func (i *Spec) ReadDoc() string { + i.Description = strings.Replace(i.Description, "\n", "\\n", -1) + + t, err := template.New("swagger_info").Funcs(template.FuncMap{ + "marshal": func(v interface{}) string { + a, _ := json.Marshal(v) + return string(a) + }, + "escape": func(v interface{}) string { + // escape tabs + str := strings.Replace(v.(string), "\t", "\\t", -1) + // replace " with \", and if that results in \\", replace that with \\\" + str = strings.Replace(str, "\"", "\\\"", -1) + return strings.Replace(str, "\\\\\"", "\\\\\\\"", -1) + }, + }).Parse(i.SwaggerTemplate) + if err != nil { + return i.SwaggerTemplate + } + + var tpl bytes.Buffer + if err = t.Execute(&tpl, i); err != nil { + return i.SwaggerTemplate + } + + return tpl.String() +} + +// InstanceName returns Spec instance name. +func (i *Spec) InstanceName() string { + return i.InfoInstanceName +} diff --git a/spec_test.go b/spec_test.go new file mode 100644 index 000000000..00ab37007 --- /dev/null +++ b/spec_test.go @@ -0,0 +1,146 @@ +package swag + +import ( + "github.com/stretchr/testify/assert" + "testing" +) + +func TestSpec_InstanceName(t *testing.T) { + type fields struct { + Version string + Host string + BasePath string + Schemes []string + Title string + Description string + InfoInstanceName string + SwaggerTemplate string + } + tests := []struct { + name string + fields fields + want string + }{ + { + name: "TestInstanceNameCorrect", + fields: fields{ + Version: "1.0", + Host: "localhost:8080", + BasePath: "/", + InfoInstanceName: "TestInstanceName1", + }, + want: "TestInstanceName1", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + i := &Spec{ + Version: tt.fields.Version, + Host: tt.fields.Host, + BasePath: tt.fields.BasePath, + Schemes: tt.fields.Schemes, + Title: tt.fields.Title, + Description: tt.fields.Description, + InfoInstanceName: tt.fields.InfoInstanceName, + SwaggerTemplate: tt.fields.SwaggerTemplate, + } + assert.Equal(t, tt.want, i.InstanceName()) + }) + } +} + +func TestSpec_ReadDoc(t *testing.T) { + type fields struct { + Version string + Host string + BasePath string + Schemes []string + Title string + Description string + InfoInstanceName string + SwaggerTemplate string + } + tests := []struct { + name string + fields fields + want string + }{ + { + name: "TestReadDocCorrect", + fields: fields{ + Version: "1.0", + Host: "localhost:8080", + BasePath: "/", + InfoInstanceName: "TestInstanceName", + SwaggerTemplate: `{ + "swagger": "2.0", + "info": { + "description": "{{escape .Description}}", + "title": "{{.Title}}", + "version": "{{.Version}}" + }, + "host": "{{.Host}}", + "basePath": "{{.BasePath}}", + }`, + }, + want: "{" + + "\n\t\t\t\"swagger\": \"2.0\"," + + "\n\t\t\t\"info\": {" + + "\n\t\t\t\t\"description\": \"\",\n\t\t\t\t\"" + + "title\": \"\"," + + "\n\t\t\t\t\"version\": \"1.0\"" + + "\n\t\t\t}," + + "\n\t\t\t\"host\": \"localhost:8080\"," + + "\n\t\t\t\"basePath\": \"/\"," + + "\n\t\t}", + }, + { + name: "TestReadDocMarshalTrigger", + fields: fields{ + Version: "1.0", + Host: "localhost:8080", + BasePath: "/", + InfoInstanceName: "TestInstanceName", + SwaggerTemplate: "{{ marshal .Version }}", + }, + want: "\"1.0\"", + }, + { + name: "TestReadDocParseError", + fields: fields{ + Version: "1.0", + Host: "localhost:8080", + BasePath: "/", + InfoInstanceName: "TestInstanceName", + SwaggerTemplate: "{{ ..Version }}", + }, + want: "{{ ..Version }}", + }, + { + name: "TestReadDocExecuteError", + fields: fields{ + Version: "1.0", + Host: "localhost:8080", + BasePath: "/", + InfoInstanceName: "TestInstanceName", + SwaggerTemplate: "{{ .Schemesa }}", + }, + want: "{{ .Schemesa }}", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + i := &Spec{ + Version: tt.fields.Version, + Host: tt.fields.Host, + BasePath: tt.fields.BasePath, + Schemes: tt.fields.Schemes, + Title: tt.fields.Title, + Description: tt.fields.Description, + InfoInstanceName: tt.fields.InfoInstanceName, + SwaggerTemplate: tt.fields.SwaggerTemplate, + } + assert.Equal(t, tt.want, i.ReadDoc()) + }) + } +}