Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import software.amazon.smithy.model.shapes.ShapeType
import software.amazon.smithy.model.shapes.StructureShape
import software.amazon.smithy.model.shapes.UnionShape
import software.amazon.smithy.model.traits.EnumTrait
import software.amazon.smithy.utils.StringUtils.lowerCase

/**
* Generates a shape type declaration based on the parameters provided.
Expand Down Expand Up @@ -252,7 +253,7 @@ class ShapeValueGenerator(
CodegenException("unknown member ${currShape.id}.${keyNode.value}")
}
memberShape = generator.model.expectShape(member.target)
writer.writeInline("\$L(", keyNode.value)
writer.writeInline("\$L(", lowerCase(keyNode.value))
generator.writeShapeValueInline(writer, memberShape, valueNode)
}
else -> throw CodegenException("unexpected shape type " + currShape.type)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ import software.amazon.smithy.swift.codegen.SwiftSettings.Companion.reservedKeyw
import software.amazon.smithy.swift.codegen.model.hasTrait
import software.amazon.smithy.swift.codegen.utils.toPascalCase
import software.amazon.smithy.utils.StringUtils
import software.amazon.smithy.utils.StringUtils.lowerCase
import java.util.logging.Logger

// PropertyBag keys
Expand Down Expand Up @@ -164,6 +165,11 @@ class SymbolVisitor(private val model: Model, swiftSettings: SwiftSettings) :
}

override fun toMemberName(shape: MemberShape): String {
val containingShape = model.expectShape(shape.container)
if (containingShape is UnionShape) {
val name = escaper.escapeMemberName(shape.memberName)
return if (!name.equals("sdkUnknown")) lowerCase(name) else name
}
return escaper.escapeMemberName(shape.memberName.decapitalize())
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,9 @@ abstract class MemberShapeEncodeXMLGenerator(
memberTarget: CollectionShape,
containerName: String
) {
val originalMemberName = member.memberName
val memberName = ctx.symbolProvider.toMemberName(member)
val resolvedMemberName = XMLNameTraitGenerator.construct(member, memberName)
val resolvedMemberName = XMLNameTraitGenerator.construct(member, originalMemberName)
val nestedContainer = "${memberName}Container"
writer.openBlock("if let $memberName = $memberName {", "}") {
if (member.hasTrait(XmlFlattenedTrait::class.java)) {
Expand Down Expand Up @@ -157,8 +158,9 @@ abstract class MemberShapeEncodeXMLGenerator(
}

fun renderMapMember(member: MemberShape, memberTarget: MapShape, containerName: String) {
val originalMemberName = member.memberName
val memberName = ctx.symbolProvider.toMemberName(member)
val resolvedMemberName = XMLNameTraitGenerator.construct(member, memberName)
val resolvedMemberName = XMLNameTraitGenerator.construct(member, originalMemberName)

writer.openBlock("if let $memberName = $memberName {", "}") {
if (member.hasTrait(XmlFlattenedTrait::class.java)) {
Expand Down Expand Up @@ -310,8 +312,9 @@ abstract class MemberShapeEncodeXMLGenerator(

fun renderTimestampMember(member: MemberShape, memberTarget: TimestampShape, containerName: String) {
val symbol = ctx.symbolProvider.toSymbol(memberTarget)
val originalMemberName = member.memberName
val memberName = ctx.symbolProvider.toMemberName(member)
val resolvedMemberName = XMLNameTraitGenerator.construct(member, memberName)
val resolvedMemberName = XMLNameTraitGenerator.construct(member, originalMemberName)
val format = determineTimestampFormat(member, memberTarget, defaultTimestampFormat)
val isBoxed = symbol.isBoxed()
val encodeLine = "try $containerName.encode(TimestampWrapper($memberName, format: .$format), forKey: Key(\"$resolvedMemberName\"))"
Expand All @@ -326,8 +329,9 @@ abstract class MemberShapeEncodeXMLGenerator(

fun renderScalarMember(member: MemberShape, memberTarget: Shape, containerName: String) {
val symbol = ctx.symbolProvider.toSymbol(memberTarget)
val originalMemberName = member.memberName
val memberName = ctx.symbolProvider.toMemberName(member)
val resolvedMemberName = XMLNameTraitGenerator.construct(member, memberName).toString()
val resolvedMemberName = XMLNameTraitGenerator.construct(member, originalMemberName).toString()
val isBoxed = symbol.isBoxed()
if (isBoxed) {
writer.openBlock("if let $memberName = $memberName {", "}") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -489,7 +489,7 @@ class HttpProtocolUnitTestRequestGeneratorTests {
decoder.nonConformingFloatDecodingStrategy = .convertFromString(positiveInfinity: "Infinity", negativeInfinity: "-Infinity", nan: "NaN")

let input = JsonUnionsInput(
contents: MyUnion.stringValue("foo")
contents: MyUnion.stringvalue("foo")

)
let encoder = JSONEncoder()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ open class HttpProtocolUnitTestResponseGeneratorTests {
let actual = try JsonUnionsOutputResponse(httpResponse: httpResponse, decoder: decoder)
let expected = JsonUnionsOutputResponse(
contents: MyUnion.stringValue("foo")
contents: MyUnion.stringvalue("foo")
)
Expand Down
162 changes: 81 additions & 81 deletions smithy-swift-codegen/src/test/kotlin/UnionDecodeGeneratorTests.kt
Original file line number Diff line number Diff line change
Expand Up @@ -66,62 +66,62 @@ class UnionDecodeGeneratorTests {
"""
extension MyUnion: Codable, Reflection {
enum CodingKeys: String, CodingKey {
case blobValue
case booleanValue
case enumValue
case listValue
case mapValue
case numberValue
case blobvalue = "blobValue"
case booleanvalue = "booleanValue"
case enumvalue = "enumValue"
case listvalue = "listValue"
case mapvalue = "mapValue"
case numbervalue = "numberValue"
case sdkUnknown
case stringValue
case structureValue
case timestampValue
case stringvalue = "stringValue"
case structurevalue = "structureValue"
case timestampvalue = "timestampValue"
}

public func encode(to encoder: Encoder) throws {
var container = encoder.container(keyedBy: CodingKeys.self)
switch self {
case let .blobValue(blobValue):
if let blobValue = blobValue {
try container.encode(blobValue.base64EncodedString(), forKey: .blobValue)
case let .blobvalue(blobvalue):
if let blobvalue = blobvalue {
try container.encode(blobvalue.base64EncodedString(), forKey: .blobvalue)
}
case let .booleanValue(booleanValue):
if let booleanValue = booleanValue {
try container.encode(booleanValue, forKey: .booleanValue)
case let .booleanvalue(booleanvalue):
if let booleanvalue = booleanvalue {
try container.encode(booleanvalue, forKey: .booleanvalue)
}
case let .enumValue(enumValue):
if let enumValue = enumValue {
try container.encode(enumValue.rawValue, forKey: .enumValue)
case let .enumvalue(enumvalue):
if let enumvalue = enumvalue {
try container.encode(enumvalue.rawValue, forKey: .enumvalue)
}
case let .listValue(listValue):
if let listValue = listValue {
var listValueContainer = container.nestedUnkeyedContainer(forKey: .listValue)
for stringlist0 in listValue {
try listValueContainer.encode(stringlist0)
case let .listvalue(listvalue):
if let listvalue = listvalue {
var listvalueContainer = container.nestedUnkeyedContainer(forKey: .listvalue)
for stringlist0 in listvalue {
try listvalueContainer.encode(stringlist0)
}
}
case let .mapValue(mapValue):
if let mapValue = mapValue {
var mapValueContainer = container.nestedContainer(keyedBy: Key.self, forKey: .mapValue)
for (dictKey0, stringmap0) in mapValue {
try mapValueContainer.encode(stringmap0, forKey: Key(stringValue: dictKey0))
case let .mapvalue(mapvalue):
if let mapvalue = mapvalue {
var mapvalueContainer = container.nestedContainer(keyedBy: Key.self, forKey: .mapvalue)
for (dictKey0, stringmap0) in mapvalue {
try mapvalueContainer.encode(stringmap0, forKey: Key(stringValue: dictKey0))
}
}
case let .numberValue(numberValue):
if let numberValue = numberValue {
try container.encode(numberValue, forKey: .numberValue)
case let .numbervalue(numbervalue):
if let numbervalue = numbervalue {
try container.encode(numbervalue, forKey: .numbervalue)
}
case let .stringValue(stringValue):
if let stringValue = stringValue {
try container.encode(stringValue, forKey: .stringValue)
case let .stringvalue(stringvalue):
if let stringvalue = stringvalue {
try container.encode(stringvalue, forKey: .stringvalue)
}
case let .structureValue(structureValue):
if let structureValue = structureValue {
try container.encode(structureValue, forKey: .structureValue)
case let .structurevalue(structurevalue):
if let structurevalue = structurevalue {
try container.encode(structurevalue, forKey: .structurevalue)
}
case let .timestampValue(timestampValue):
if let timestampValue = timestampValue {
try container.encode(timestampValue.iso8601WithoutFractionalSeconds(), forKey: .timestampValue)
case let .timestampvalue(timestampvalue):
if let timestampvalue = timestampvalue {
try container.encode(timestampvalue.iso8601WithoutFractionalSeconds(), forKey: .timestampvalue)
}
case let .sdkUnknown(sdkUnknown):
try container.encode(sdkUnknown, forKey: .sdkUnknown)
Expand All @@ -130,72 +130,72 @@ class UnionDecodeGeneratorTests {

public init (from decoder: Decoder) throws {
let values = try decoder.container(keyedBy: CodingKeys.self)
let stringValueDecoded = try values.decodeIfPresent(String.self, forKey: .stringValue)
if let stringValue = stringValueDecoded {
self = .stringValue(stringValue)
let stringvalueDecoded = try values.decodeIfPresent(String.self, forKey: .stringvalue)
if let stringvalue = stringvalueDecoded {
self = .stringvalue(stringvalue)
return
}
let booleanValueDecoded = try values.decodeIfPresent(Bool.self, forKey: .booleanValue)
if let booleanValue = booleanValueDecoded {
self = .booleanValue(booleanValue)
let booleanvalueDecoded = try values.decodeIfPresent(Bool.self, forKey: .booleanvalue)
if let booleanvalue = booleanvalueDecoded {
self = .booleanvalue(booleanvalue)
return
}
let numberValueDecoded = try values.decodeIfPresent(Int.self, forKey: .numberValue)
if let numberValue = numberValueDecoded {
self = .numberValue(numberValue)
let numbervalueDecoded = try values.decodeIfPresent(Int.self, forKey: .numbervalue)
if let numbervalue = numbervalueDecoded {
self = .numbervalue(numbervalue)
return
}
let blobValueDecoded = try values.decodeIfPresent(Data.self, forKey: .blobValue)
if let blobValue = blobValueDecoded {
self = .blobValue(blobValue)
let blobvalueDecoded = try values.decodeIfPresent(Data.self, forKey: .blobvalue)
if let blobvalue = blobvalueDecoded {
self = .blobvalue(blobvalue)
return
}
let timestampValueDateString = try values.decodeIfPresent(String.self, forKey: .timestampValue)
var timestampValueDecoded: Date? = nil
if let timestampValueDateString = timestampValueDateString {
let timestampValueFormatter = DateFormatter.iso8601DateFormatterWithoutFractionalSeconds
timestampValueDecoded = timestampValueFormatter.date(from: timestampValueDateString)
let timestampvalueDateString = try values.decodeIfPresent(String.self, forKey: .timestampvalue)
var timestampvalueDecoded: Date? = nil
if let timestampvalueDateString = timestampvalueDateString {
let timestampvalueFormatter = DateFormatter.iso8601DateFormatterWithoutFractionalSeconds
timestampvalueDecoded = timestampvalueFormatter.date(from: timestampvalueDateString)
}
if let timestampValue = timestampValueDecoded {
self = .timestampValue(timestampValue)
if let timestampvalue = timestampvalueDecoded {
self = .timestampvalue(timestampvalue)
return
}
let enumValueDecoded = try values.decodeIfPresent(FooEnum.self, forKey: .enumValue)
if let enumValue = enumValueDecoded {
self = .enumValue(enumValue)
let enumvalueDecoded = try values.decodeIfPresent(FooEnum.self, forKey: .enumvalue)
if let enumvalue = enumvalueDecoded {
self = .enumvalue(enumvalue)
return
}
let listValueContainer = try values.decodeIfPresent([String?].self, forKey: .listValue)
var listValueDecoded0:[String]? = nil
if let listValueContainer = listValueContainer {
listValueDecoded0 = [String]()
for string0 in listValueContainer {
let listvalueContainer = try values.decodeIfPresent([String?].self, forKey: .listvalue)
var listvalueDecoded0:[String]? = nil
if let listvalueContainer = listvalueContainer {
listvalueDecoded0 = [String]()
for string0 in listvalueContainer {
if let string0 = string0 {
listValueDecoded0?.append(string0)
listvalueDecoded0?.append(string0)
}
}
}
if let listValue = listValueDecoded0 {
self = .listValue(listValue)
if let listvalue = listvalueDecoded0 {
self = .listvalue(listvalue)
return
}
let mapValueContainer = try values.decodeIfPresent([String: String?].self, forKey: .mapValue)
var mapValueDecoded0: [String:String]? = nil
if let mapValueContainer = mapValueContainer {
mapValueDecoded0 = [String:String]()
for (key0, string0) in mapValueContainer {
let mapvalueContainer = try values.decodeIfPresent([String: String?].self, forKey: .mapvalue)
var mapvalueDecoded0: [String:String]? = nil
if let mapvalueContainer = mapvalueContainer {
mapvalueDecoded0 = [String:String]()
for (key0, string0) in mapvalueContainer {
if let string0 = string0 {
mapValueDecoded0?[key0] = string0
mapvalueDecoded0?[key0] = string0
}
}
}
if let mapValue = mapValueDecoded0 {
self = .mapValue(mapValue)
if let mapvalue = mapvalueDecoded0 {
self = .mapvalue(mapvalue)
return
}
let structureValueDecoded = try values.decodeIfPresent(GreetingWithErrorsOutput.self, forKey: .structureValue)
if let structureValue = structureValueDecoded {
self = .structureValue(structureValue)
let structurevalueDecoded = try values.decodeIfPresent(GreetingWithErrorsOutput.self, forKey: .structurevalue)
if let structurevalue = structurevalueDecoded {
self = .structurevalue(structurevalue)
return
}
self = .sdkUnknown("")
Expand Down
Loading