diff --git a/cmd/generate-fix/internal/template_helpers.go b/cmd/generate-fix/internal/template_helpers.go index 36f6a3e29..4aa0d2122 100644 --- a/cmd/generate-fix/internal/template_helpers.go +++ b/cmd/generate-fix/internal/template_helpers.go @@ -6,6 +6,22 @@ import ( "github.com/quickfixgo/quickfix/datadictionary" ) +func checkIfDecimalImportRequiredForFields(fTypes []*datadictionary.FieldType) (ok bool, err error) { + var t string + for _, fType := range fTypes { + t, err = quickfixType(fType) + if err != nil { + return + } + + if t == "FIXDecimal" { + return true, nil + } + } + + return +} + func checkFieldDecimalRequired(f *datadictionary.FieldDef) (required bool, err error) { var globalType *datadictionary.FieldType if globalType, err = getGlobalFieldType(f); err != nil { @@ -92,8 +108,6 @@ func collectExtraImports(m *datadictionary.MessageDef) (imports []string, err er return } -func useFloatType() bool { return *useFloat } - func quickfixValueType(quickfixType string) (goType string, err error) { switch quickfixType { case "FIXString": diff --git a/cmd/generate-fix/internal/templates.go b/cmd/generate-fix/internal/templates.go index fe89f2d82..eb4ef4c11 100644 --- a/cmd/generate-fix/internal/templates.go +++ b/cmd/generate-fix/internal/templates.go @@ -16,16 +16,16 @@ var ( func init() { tmplFuncs := template.FuncMap{ - "toLower": strings.ToLower, - "requiredFields": requiredFields, - "beginString": beginString, - "routerBeginString": routerBeginString, - "importRootPath": getImportPathRoot, - "quickfixType": quickfixType, - "quickfixValueType": quickfixValueType, - "getGlobalFieldType": getGlobalFieldType, - "collectExtraImports": collectExtraImports, - "useFloatType": useFloatType, + "toLower": strings.ToLower, + "requiredFields": requiredFields, + "beginString": beginString, + "routerBeginString": routerBeginString, + "importRootPath": getImportPathRoot, + "quickfixType": quickfixType, + "quickfixValueType": quickfixValueType, + "getGlobalFieldType": getGlobalFieldType, + "collectExtraImports": collectExtraImports, + "checkIfDecimalImportRequiredForFields": checkIfDecimalImportRequiredForFields, } baseTemplate := template.Must(template.New("Base").Funcs(tmplFuncs).Parse(` @@ -284,7 +284,7 @@ package field import( "github.com/quickfixgo/quickfix" "{{ importRootPath }}/tag" -{{ if eq useFloatType false}} "github.com/shopspring/decimal" {{ end }} +{{ if checkIfDecimalImportRequiredForFields . }} "github.com/shopspring/decimal" {{ end }} "time" )