Skip to content

Commit

Permalink
java-files: Added files support for Spring and Micronaut services and…
Browse files Browse the repository at this point in the history
… OkHttp client
  • Loading branch information
nislovskaya committed Feb 29, 2024
1 parent 062e587 commit 83474c4
Show file tree
Hide file tree
Showing 9 changed files with 89 additions and 39 deletions.
4 changes: 2 additions & 2 deletions codegen/java/client/generator.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,11 @@ func NewGenerator(jsonlib string, client string, packages *Packages) *Generator
var clientGenerator ClientGenerator = nil
switch client {
case OkHttp:
types = models.NewTypes(jsonlib, "byte[]", "Reader")
types = models.NewTypes(jsonlib, "byte[]", "Reader", "byte[]", "Reader")
clientGenerator = NewOkHttpGenerator(types, modelsGenerator, packages)
break
case Micronaut:
types = models.NewTypes(jsonlib, "byte[]", "Reader")
types = models.NewTypes(jsonlib, "byte[]", "Reader", "byte[]", "Reader")
clientGenerator = NewMicronautGenerator(types, modelsGenerator, packages)
break
default:
Expand Down
5 changes: 4 additions & 1 deletion codegen/java/client/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,10 @@ func operationParameters(operation *spec.NamedOperation, types *types.Types) []s

func appendParams(types *types.Types, params []string, namedParams []spec.NamedParam) []string {
for _, param := range namedParams {
params = append(params, fmt.Sprintf("%s %s", types.Java(&param.Type.Definition), param.Name.CamelCase()))
if param.Type.Definition.String() == spec.TypeFile {
params = append(params, "String fileName")
}
params = append(params, fmt.Sprintf("%s %s", types.ParamJavaType(&param), param.Name.CamelCase()))
}
return params
}
25 changes: 19 additions & 6 deletions codegen/java/client/okhttp.go
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,11 @@ func (g *OkHttpGenerator) createRequest(w *writer.Writer, operation *spec.NamedO
if operation.Body.IsBodyFormData() {
w.Line(`var body = new MultipartBodyBuilder(MultipartBody.FORM);`)
for _, param := range operation.Body.FormData {
w.Line(`body.addFormDataPart("%s", %s);`, param.Name.SnakeCase(), param.Name.CamelCase())
if param.Type.Definition.String() == spec.TypeFile {
w.Line(`body.addFormDataPart("%s", fileName, %s);`, param.Name.Source, param.Name.CamelCase())
} else {
w.Line(`body.addFormDataPart("%s", %s);`, param.Name.SnakeCase(), param.Name.CamelCase())
}
}
requestBody = "body.build()"
w.EmptyLine()
Expand Down Expand Up @@ -203,6 +207,9 @@ func (g *OkHttpGenerator) successResponse(response *spec.OperationResponse) stri
if response.Body.IsBinary() {
return responseCreate(response, "response.body().charStream()")
}
if response.Body.IsFile() {
return responseCreate(response, "response.body().charStream()")
}
return responseCreate(response, "")
}

Expand All @@ -217,6 +224,9 @@ func (g *OkHttpGenerator) errorResponse(response *spec.Response) string {
if response.Body.IsBinary() {
responseBody = "response.body().charStream()"
}
if response.Body.IsFile() {
responseBody = "response.body().charStream()"
}
return fmt.Sprintf(`throw new %s(%s);`, errorExceptionClassName(response), responseBody)
}

Expand Down Expand Up @@ -309,15 +319,14 @@ func (g *OkHttpGenerator) multipartBodyBuilder() *generator.CodeFile {
w := writer.New(g.Packages.Utils, `MultipartBodyBuilder`)
w.Lines(`
import java.io.File;
import java.net.URLConnection;
import java.util.List;
import okhttp3.*;
public class [[.ClassName]] {
private final MediaType contentType;
private final MultipartBody.Builder multipartBodyBuilder;
public MultipartBodyBuilder(MediaType contentType) {
this.contentType = contentType;
public [[.ClassName]](MediaType contentType) {
this.multipartBodyBuilder = new MultipartBody.Builder().setType(contentType);
}
Expand All @@ -334,11 +343,15 @@ public class [[.ClassName]] {
}
public void addFormDataPart(String fieldName, File file) {
this.multipartBodyBuilder.addFormDataPart(fieldName, file.getName(), RequestBody.create(file, this.contentType));
this.multipartBodyBuilder.addFormDataPart(fieldName, file.getName(), RequestBody.create(file, getFileContentType(file.getName())));
}
public void addFormDataPart(String fieldName, String fileName, byte[] file) {
this.multipartBodyBuilder.addFormDataPart(fieldName, fileName, RequestBody.create(file, this.contentType));
this.multipartBodyBuilder.addFormDataPart(fieldName, fileName, RequestBody.create(file, getFileContentType(fileName)));
}
private MediaType getFileContentType(String fileName) {
return MediaType.parse(URLConnection.getFileNameMap().getContentTypeFor(fileName));
}
public MultipartBody build() {
Expand Down
10 changes: 5 additions & 5 deletions codegen/java/models/generator.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ type Generator interface {
}

func NewGenerator(jsonlib string, packages *Packages) Generator {
types := NewTypes(jsonlib, "", "")
types := NewTypes(jsonlib, "", "", "", "")
if jsonlib == Jackson {
return NewJacksonGenerator(types, packages)
}
Expand All @@ -30,12 +30,12 @@ func NewGenerator(jsonlib string, packages *Packages) Generator {
panic(fmt.Sprintf(`Unsupported jsonlib: %s`, jsonlib))
}

func NewTypes(jsonlib, requestBinaryType, responseBinaryType string) *types.Types {
func NewTypes(jsonlib, requestBinaryType, responseBinaryType, requestFileType, responseFileType string) *types.Types {
if jsonlib == Jackson {
return types.NewTypes("JsonNode", requestBinaryType, responseBinaryType)
return types.NewTypes("JsonNode", requestBinaryType, responseBinaryType, requestFileType, responseFileType)
}
if jsonlib == Moshi {
return types.NewTypes("Map<String, Object>", requestBinaryType, responseBinaryType)
return types.NewTypes("Map<String, Object>", requestBinaryType, responseBinaryType, requestFileType, responseFileType)
}
panic(fmt.Sprintf(`Unsupported binary types: %s, %s`, requestBinaryType, responseBinaryType))
panic(fmt.Sprintf(`Unsupported binary types: %s, %s, %s, %s`, requestBinaryType, responseBinaryType, requestFileType, responseFileType))
}
4 changes: 2 additions & 2 deletions codegen/java/service/generator.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,11 @@ func NewGenerator(jsonlib, server string, packages *Packages) *Generator {
var serverGenerator ServerGenerator = nil
switch server {
case Spring:
types = models.NewTypes(jsonlib, "Resource", "Resource")
types = models.NewTypes(jsonlib, "Resource", "Resource", "MultipartFile", "Resource")
serverGenerator = NewSpringGenerator(types, modelsGenerator, packages)
break
case Micronaut:
types = models.NewTypes(jsonlib, "byte[]", "byte[]")
types = models.NewTypes(jsonlib, "byte[]", "byte[]", "CompletedFileUpload", "StreamedFile")
serverGenerator = NewMicronautGenerator(types, modelsGenerator, packages)
break
default:
Expand Down
5 changes: 2 additions & 3 deletions codegen/java/service/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,11 @@ package service

import (
"fmt"
"strings"

"generator"
"java/types"
"java/writer"
"spec"
"strings"
)

func (g *Generator) ServicesInterfaces(version *spec.Version) []generator.CodeFile {
Expand Down Expand Up @@ -67,7 +66,7 @@ func operationParameters(operation *spec.NamedOperation, types *types.Types) []s

func appendParams(types *types.Types, params []string, namedParams []spec.NamedParam) []string {
for _, param := range namedParams {
params = append(params, fmt.Sprintf("%s %s", types.Java(&param.Type.Definition), param.Name.CamelCase()))
params = append(params, fmt.Sprintf("%s %s", types.ParamJavaType(&param), param.Name.CamelCase()))
}
return params
}
20 changes: 14 additions & 6 deletions codegen/java/service/micronaut.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,14 @@ package service

import (
"fmt"
"strings"

"generator"

"github.com/pinzolo/casee"
"java/models"
"java/packages"
"java/types"
"java/writer"
"spec"
"strings"
)

var Micronaut = "micronaut"
Expand Down Expand Up @@ -40,7 +38,10 @@ func (g *MicronautGenerator) ServicesControllers(version *spec.Version) []genera
}

func (g *MicronautGenerator) FilesImports() []string {
return nil
return []string{
`io.micronaut.http.multipart.CompletedFileUpload`,
`io.micronaut.http.server.types.files.StreamedFile`,
}
}

func (g *MicronautGenerator) ServiceImports() []string {
Expand Down Expand Up @@ -93,6 +94,7 @@ func (g *MicronautGenerator) errorHandler(w *writer.Writer, errors spec.ErrorRes
func (g *MicronautGenerator) serviceController(api *spec.Api) *generator.CodeFile {
w := writer.New(g.Packages.Controllers(api.InHttp.InVersion), controllerName(api))
w.Imports.Add(g.ServiceImports()...)
w.Imports.Add(g.FilesImports()...)
w.Imports.Add(`io.micronaut.core.annotation.Nullable`)
w.Imports.Star(g.Packages.ContentType)
w.Imports.Star(g.Packages.Json)
Expand Down Expand Up @@ -187,6 +189,8 @@ func (g *MicronautGenerator) responseContentType(response *spec.Response) string
return `MediaType.APPLICATION_JSON`
case spec.BodyBinary:
return `MediaType.APPLICATION_OCTET_STREAM`
case spec.BodyFile:
return ""
default:
panic(fmt.Sprintf("Unknown Content Type"))
}
Expand Down Expand Up @@ -229,7 +233,11 @@ func (g *MicronautGenerator) processResponse(w *writer.Writer, response *spec.Re
bodyVar = "bodyJson"
}
w.Line(`logger.info("Completed request with status code: HttpStatus.%s");`, response.Name.UpperCase())
w.Line(`return HttpResponse.status(HttpStatus.%s).body(%s).contentType(%s);`, response.Name.UpperCase(), bodyVar, g.responseContentType(response))
if response.Body.IsFile() {
w.Line(`return HttpResponse.status(HttpStatus.%s).body(%s);`, response.Name.UpperCase(), bodyVar)
} else {
w.Line(`return HttpResponse.status(HttpStatus.%s).body(%s).contentType(%s);`, response.Name.UpperCase(), bodyVar, g.responseContentType(response))
}
}
}

Expand Down Expand Up @@ -405,7 +413,7 @@ func generateMicronautMethodParam(namedParams []spec.NamedParam, paramAnnotation

if namedParams != nil && len(namedParams) > 0 {
for _, param := range namedParams {
paramType := fmt.Sprintf(`%s %s`, types.Java(&param.Type.Definition), param.Name.CamelCase())
paramType := fmt.Sprintf(`%s %s`, types.ParamJavaType(&param), param.Name.CamelCase())
if param.Type.Definition.IsNullable() || (!isSupportDefaulted && param.DefinitionDefault.Default != nil) {
paramType = fmt.Sprintf(`@Nullable %s`, paramType)
}
Expand Down
21 changes: 14 additions & 7 deletions codegen/java/service/spring.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ func (g *SpringGenerator) FilesImports() []string {
return []string{
`org.springframework.core.io.Resource`,
`org.springframework.core.io.InputStreamResource`,
`org.springframework.web.multipart.MultipartFile`,
`java.net.URLConnection`,
}
}

Expand Down Expand Up @@ -93,8 +95,7 @@ func (g *SpringGenerator) errorHandler(w *writer.Writer, errors spec.ErrorRespon
func (g *SpringGenerator) serviceController(api *spec.Api) *generator.CodeFile {
w := writer.New(g.Packages.Controllers(api.InHttp.InVersion), controllerName(api))
w.Imports.Add(g.ServiceImports()...)
w.Imports.Add(`org.springframework.core.io.Resource`)
w.Imports.Add(`org.springframework.core.io.InputStreamResource`)
w.Imports.Add(g.FilesImports()...)
w.Imports.Add(`javax.servlet.http.HttpServletRequest`)
w.Imports.Star(g.Packages.ContentType)
w.Imports.Star(g.Packages.Json)
Expand Down Expand Up @@ -137,7 +138,7 @@ func (g *SpringGenerator) controllerMethod(w *writer.Writer, operation *spec.Nam

func responseEntityType(operation *spec.NamedOperation) string {
for _, response := range operation.Responses {
if response.Body.IsBinary() {
if response.Body.IsBinary() || response.Body.IsFile() {
return "Resource"
}
}
Expand Down Expand Up @@ -190,6 +191,8 @@ func (g *SpringGenerator) responseContentType(response *spec.Response) string {
return `MediaType.APPLICATION_JSON_VALUE`
case spec.BodyBinary:
return `MediaType.APPLICATION_OCTET_STREAM_VALUE`
case spec.BodyFile:
return `URLConnection.getFileNameMap().getContentTypeFor(fileName)`
default:
panic(fmt.Sprintf("Unknown Content Type"))
}
Expand Down Expand Up @@ -232,6 +235,10 @@ func (g *SpringGenerator) processResponse(w *writer.Writer, response *spec.Respo
bodyVar = "bodyJson"
}
w.Line(`HttpHeaders headers = new HttpHeaders();`)
if response.Body.IsFile() {
w.Line(`String fileName = %s.getFilename();`, bodyVar)
w.Line(`headers.add(CONTENT_DISPOSITION, "attachment; filename=" + fileName);`)
}
w.Line(`headers.add(CONTENT_TYPE, %s);`, g.responseContentType(response))
w.Line(`logger.info("Completed request with status code: {}", HttpStatus.%s);`, response.Name.UpperCase())
w.Line(`return new ResponseEntity<>(%s, headers, HttpStatus.%s);`, bodyVar, response.Name.UpperCase())
Expand Down Expand Up @@ -265,9 +272,9 @@ import javax.servlet.http.HttpServletRequest;
public class ContentType {
public static void check(HttpServletRequest request, MediaType expectedContentType) {
var requestContentType = request.getHeader("Content-Type");
if (requestContentType == null || !requestContentType.contains(expectedContentType.toString())) {
throw new ContentTypeMismatchException(expectedContentType.toString(), requestContentType);
var contentType = request.getHeader("Content-Type");
if (contentType == null || !contentType.contains(expectedContentType.toString())) {
throw new ContentTypeMismatchException(expectedContentType.toString(), contentType);
}
}
}
Expand Down Expand Up @@ -367,7 +374,7 @@ func generateSpringMethodParam(namedParams []spec.NamedParam, paramAnnotationNam
if namedParams != nil && len(namedParams) > 0 {
for _, param := range namedParams {
paramAnnotation := getSpringParameterAnnotation(paramAnnotationName, &param)
paramType := fmt.Sprintf(`%s %s`, types.Java(&param.Type.Definition), param.Name.CamelCase())
paramType := fmt.Sprintf(`%s %s`, types.ParamJavaType(&param), param.Name.CamelCase())
dateFormatAnnotation := dateFormatSpringAnnotation(&param.Type.Definition)
if dateFormatAnnotation != "" {
params = append(params, fmt.Sprintf(`%s %s %s`, paramAnnotation, dateFormatAnnotation, paramType))
Expand Down
34 changes: 27 additions & 7 deletions codegen/java/types/type.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,25 +11,35 @@ const EmptyType = `void`
type Types struct {
RawJsonType string
BinaryType BinaryType
FileType FileType
}

type BinaryType struct {
RequestType string
ResponseType string
}

func NewTypes(rawJsonType, requestBinaryType, responseBinaryType string) *Types {
return &Types{RawJsonType: rawJsonType, BinaryType: BinaryType{RequestType: requestBinaryType, ResponseType: responseBinaryType}}
type FileType struct {
RequestType string
ResponseType string
}

func NewTypes(rawJsonType, requestBinaryType, responseBinaryType, requestFileType, responseFileType string) *Types {
return &Types{
RawJsonType: rawJsonType,
BinaryType: BinaryType{RequestType: requestBinaryType, ResponseType: responseBinaryType},
FileType: FileType{RequestType: requestFileType, ResponseType: responseFileType},
}
}

func (types *Types) RequestBodyJavaType(body *spec.RequestBody) string {
switch body.Kind() {
case spec.BodyBinary:
return types.BinaryType.RequestType
case spec.BodyText:
return TextType
case spec.BodyEmpty:
return EmptyType
case spec.BodyBinary:
return types.BinaryType.RequestType
case spec.BodyJson:
return types.Java(&body.Type.Definition)
default:
Expand All @@ -39,19 +49,29 @@ func (types *Types) RequestBodyJavaType(body *spec.RequestBody) string {

func (types *Types) ResponseBodyJavaType(body *spec.ResponseBody) string {
switch body.Kind() {
case spec.BodyBinary:
return types.BinaryType.ResponseType
case spec.BodyText:
return TextType
case spec.BodyEmpty:
return EmptyType
case spec.BodyBinary:
return types.BinaryType.ResponseType
case spec.BodyFile:
return types.FileType.ResponseType
case spec.BodyJson:
return types.Java(&body.Type.Definition)
default:
panic(fmt.Sprintf("Unknown response body kind: %v", body.Kind()))
}
}

func (types *Types) ParamJavaType(param *spec.NamedParam) string {
if param.Type.Definition.String() == spec.TypeFile {
return types.FileType.RequestType
} else {
return types.Java(&param.Type.Definition)
}
}

func (t *Types) Java(typ *spec.TypeDef) string {
javaType, _ := t.javaType(typ, false)
return javaType
Expand Down Expand Up @@ -127,7 +147,7 @@ func (t *Types) plainJavaType(typ string, referenceTypesOnly bool) (string, bool
case spec.TypeJson:
return t.RawJsonType, true
case spec.TypeEmpty:
return "void", false
return EmptyType, false
default:
return typ, true
}
Expand Down

0 comments on commit 83474c4

Please sign in to comment.