diff --git a/cmd/generate-bindings/solana/anchor-go/generator/accounts.go b/cmd/generate-bindings/solana/anchor-go/generator/accounts.go index 1140d03f..e64668bc 100644 --- a/cmd/generate-bindings/solana/anchor-go/generator/accounts.go +++ b/cmd/generate-bindings/solana/anchor-go/generator/accounts.go @@ -154,7 +154,7 @@ func (g *Generator) gen_IDLTypeDefTyStruct( // TODO: optionality for complex enums is a nil interface. uniqueFieldName := uniqueFieldNames[field.Name] - fieldsGroup.Add(genFieldWithName(field, uniqueFieldName, optionality)). + fieldsGroup.Add(g.genFieldWithName(field, uniqueFieldName, optionality)). Add(func() Code { tagMap := map[string]string{} if IsOption(field.Ty) { @@ -180,7 +180,7 @@ func (g *Generator) gen_IDLTypeDefTyStruct( fieldsGroup.Line() optionality := IsOption(field) || IsCOption(field) - fieldsGroup.Add(genFieldNamed( + fieldsGroup.Add(g.genFieldNamed( FormatTupleItemName(fieldIndex), field, optionality, @@ -223,7 +223,7 @@ func (g *Generator) gen_IDLTypeDefTyStruct( // Declare MarshalWithEncoder: // TODO: code.Line().Line().Add( - gen_MarshalWithEncoder_struct( + g.gen_MarshalWithEncoder_struct( g.idl, withDiscriminator, exportedAccountName, @@ -234,7 +234,7 @@ func (g *Generator) gen_IDLTypeDefTyStruct( // Declare UnmarshalWithDecoder code.Line().Line().Add( - gen_UnmarshalWithDecoder_struct( + g.gen_UnmarshalWithDecoder_struct( g.idl, withDiscriminator, exportedAccountName, @@ -286,20 +286,19 @@ func generateUniqueFieldNames(fields []idl.IdlField) map[string]string { return fieldNameMap } -func genField(field idl.IdlField, pointer bool) Code { - return genFieldNamed(field.Name, field.Ty, pointer) +func (g *Generator) genField(field idl.IdlField, pointer bool) Code { + return g.genFieldNamed(field.Name, field.Ty, pointer) } -// genFieldWithName generates a field with a custom field name (for handling duplicates) -func genFieldWithName(field idl.IdlField, fieldName string, pointer bool) Code { - return genFieldNamed(fieldName, field.Ty, pointer) +func (g *Generator) genFieldWithName(field idl.IdlField, fieldName string, pointer bool) Code { + return g.genFieldNamed(fieldName, field.Ty, pointer) } -func genFieldNamed(name string, typ idltype.IdlType, pointer bool) Code { +func (g *Generator) genFieldNamed(name string, typ idltype.IdlType, pointer bool) Code { st := newStatement() st.Id(tools.ToCamelUpper(name)). Add(func() Code { - if isComplexEnum(typ) { + if g.isComplexEnum(typ) { return nil } if pointer { diff --git a/cmd/generate-bindings/solana/anchor-go/generator/collision_field_names_test.go b/cmd/generate-bindings/solana/anchor-go/generator/collision_field_names_test.go index d05d41a1..477cefa8 100644 --- a/cmd/generate-bindings/solana/anchor-go/generator/collision_field_names_test.go +++ b/cmd/generate-bindings/solana/anchor-go/generator/collision_field_names_test.go @@ -42,7 +42,13 @@ func TestMarshalUnmarshalCodegen_matchesUniqueStructFieldNames(t *testing.T) { fields := collidingNamedFields() receiver := "CollideAccount" - marshalCode := gen_MarshalWithEncoder_struct( + g := &Generator{ + idl: idlMinimal, + options: &GeneratorOptions{Package: "test"}, + complexEnumRegistry: make(map[string]struct{}), + } + + marshalCode := g.gen_MarshalWithEncoder_struct( idlMinimal, false, receiver, @@ -50,7 +56,7 @@ func TestMarshalUnmarshalCodegen_matchesUniqueStructFieldNames(t *testing.T) { fields, true, ) - unmarshalCode := gen_UnmarshalWithDecoder_struct( + unmarshalCode := g.gen_UnmarshalWithDecoder_struct( idlMinimal, false, receiver, diff --git a/cmd/generate-bindings/solana/anchor-go/generator/complex-enums.go b/cmd/generate-bindings/solana/anchor-go/generator/complex-enums.go index eb787fa3..62e66642 100644 --- a/cmd/generate-bindings/solana/anchor-go/generator/complex-enums.go +++ b/cmd/generate-bindings/solana/anchor-go/generator/complex-enums.go @@ -6,43 +6,43 @@ import ( "github.com/gagliardetto/anchor-go/idl/idltype" ) -// typeRegistryComplexEnum contains all types that are a complex enum (and thus implemented as an interface). -var typeRegistryComplexEnum = make(map[string]struct{}) - -func isComplexEnum(envel idltype.IdlType) bool { +func (g *Generator) isComplexEnum(envel idltype.IdlType) bool { switch vv := envel.(type) { case *idltype.Defined: - _, ok := typeRegistryComplexEnum[vv.Name] + _, ok := g.complexEnumRegistry[vv.Name] return ok } return false } -func isOptionalComplexEnum(ty idltype.IdlType) bool { +func (g *Generator) registerComplexEnumType(name string) { + if g.complexEnumRegistry == nil { + g.complexEnumRegistry = make(map[string]struct{}) + } + g.complexEnumRegistry[name] = struct{}{} +} + +func (g *Generator) isOptionalComplexEnum(ty idltype.IdlType) bool { switch v := ty.(type) { case *idltype.Option: - return isComplexEnum(v.Option) + return g.isComplexEnum(v.Option) case *idltype.COption: - return isComplexEnum(v.COption) + return g.isComplexEnum(v.COption) } return false } -func register_TypeName_as_ComplexEnum(name string) { - typeRegistryComplexEnum[name] = struct{}{} -} - -func registerComplexEnums(def idl.IdlTypeDef) { +func (g *Generator) registerComplexEnums(def idl.IdlTypeDef) { switch vv := def.Ty.(type) { case *idl.IdlTypeDefTyEnum: enumTypeName := def.Name if !vv.IsAllSimple() { - register_TypeName_as_ComplexEnum(enumTypeName) + g.registerComplexEnumType(enumTypeName) } case idl.IdlTypeDefTyEnum: enumTypeName := def.Name if !vv.IsAllSimple() { - register_TypeName_as_ComplexEnum(enumTypeName) + g.registerComplexEnumType(enumTypeName) } } } diff --git a/cmd/generate-bindings/solana/anchor-go/generator/complex_enums_test.go b/cmd/generate-bindings/solana/anchor-go/generator/complex_enums_test.go index 002a9f06..c8049cf1 100644 --- a/cmd/generate-bindings/solana/anchor-go/generator/complex_enums_test.go +++ b/cmd/generate-bindings/solana/anchor-go/generator/complex_enums_test.go @@ -15,23 +15,31 @@ import ( // and gen_unmarshal_DefinedFieldsNamed to decide whether a field is routed to // the specialized enum encoder/parser or falls through to the generic // Encode/Decode path. -func complexEnumGuard(ty idltype.IdlType) bool { - return isComplexEnum(ty) || - (IsArray(ty) && isComplexEnum(ty.(*idltype.Array).Type)) || - (IsVec(ty) && isComplexEnum(ty.(*idltype.Vec).Vec)) || - isOptionalComplexEnum(ty) +func complexEnumGuard(g *Generator, ty idltype.IdlType) bool { + return g.isComplexEnum(ty) || + (IsArray(ty) && g.isComplexEnum(ty.(*idltype.Array).Type)) || + (IsVec(ty) && g.isComplexEnum(ty.(*idltype.Vec).Vec)) || + g.isOptionalComplexEnum(ty) +} + +func newTestGenerator() *Generator { + return &Generator{ + idl: &idl.Idl{}, + options: &GeneratorOptions{Package: "test"}, + complexEnumRegistry: make(map[string]struct{}), + } } func TestComplexEnumGuard_handlesOptionAndCOption(t *testing.T) { const name = "Outcome" - register_TypeName_as_ComplexEnum(name) - t.Cleanup(func() { delete(typeRegistryComplexEnum, name) }) + g := newTestGenerator() + g.registerComplexEnumType(name) defined := &idltype.Defined{Name: name} - assert.True(t, complexEnumGuard(defined), "bare Defined") - assert.True(t, complexEnumGuard(&idltype.Option{Option: defined}), "Option") - assert.True(t, complexEnumGuard(&idltype.COption{COption: defined}), "COption") + assert.True(t, complexEnumGuard(g, defined), "bare Defined") + assert.True(t, complexEnumGuard(g, &idltype.Option{Option: defined}), "Option") + assert.True(t, complexEnumGuard(g, &idltype.COption{COption: defined}), "COption") } // TestComplexEnumGuard_rejectsNonComplexOptionals ensures the guard does NOT @@ -40,20 +48,20 @@ func TestComplexEnumGuard_handlesOptionAndCOption(t *testing.T) { // where .Option.(*idltype.Defined) would panic on a non-Defined inner type. func TestComplexEnumGuard_rejectsNonComplexOptionals(t *testing.T) { const complexName = "Outcome" - register_TypeName_as_ComplexEnum(complexName) - t.Cleanup(func() { delete(typeRegistryComplexEnum, complexName) }) + g := newTestGenerator() + g.registerComplexEnumType(complexName) nonComplex := &idltype.Defined{Name: "PlainStruct"} - assert.False(t, complexEnumGuard(&idltype.Option{Option: nonComplex}), + assert.False(t, complexEnumGuard(g, &idltype.Option{Option: nonComplex}), "Option must not trigger the complex-enum path") - assert.False(t, complexEnumGuard(&idltype.COption{COption: nonComplex}), + assert.False(t, complexEnumGuard(g, &idltype.COption{COption: nonComplex}), "COption must not trigger the complex-enum path") - assert.False(t, complexEnumGuard(&idltype.Option{Option: &idltype.U64{}}), + assert.False(t, complexEnumGuard(g, &idltype.Option{Option: &idltype.U64{}}), "Option must not trigger the complex-enum path") - assert.False(t, complexEnumGuard(&idltype.COption{COption: &idltype.U8{}}), + assert.False(t, complexEnumGuard(g, &idltype.COption{COption: &idltype.U8{}}), "COption must not trigger the complex-enum path") - assert.False(t, complexEnumGuard(&idltype.Option{Option: &idltype.Vec{Vec: &idltype.Defined{Name: complexName}}}), + assert.False(t, complexEnumGuard(g, &idltype.Option{Option: &idltype.Vec{Vec: &idltype.Defined{Name: complexName}}}), "Option> — nested containers not supported, must not match") } @@ -63,8 +71,8 @@ func TestComplexEnumGuard_rejectsNonComplexOptionals(t *testing.T) { // instead of the generic Encode/Decode. func TestComplexEnumCodegen_optionalComplexEnum(t *testing.T) { const enumName = "Outcome" - register_TypeName_as_ComplexEnum(enumName) - t.Cleanup(func() { delete(typeRegistryComplexEnum, enumName) }) + g := newTestGenerator() + g.registerComplexEnumType(enumName) fields := idl.IdlDefinedFieldsNamed{ {Name: "id", Ty: &idltype.U64{}}, @@ -73,10 +81,10 @@ func TestComplexEnumCodegen_optionalComplexEnum(t *testing.T) { {Name: "checksum", Ty: &idltype.U64{}}, } - marshalCode := gen_MarshalWithEncoder_struct( + marshalCode := g.gen_MarshalWithEncoder_struct( &idl.Idl{}, false, "Report", "", fields, true, ) - unmarshalCode := gen_UnmarshalWithDecoder_struct( + unmarshalCode := g.gen_UnmarshalWithDecoder_struct( &idl.Idl{}, false, "Report", "", fields, ) diff --git a/cmd/generate-bindings/solana/anchor-go/generator/generator.go b/cmd/generate-bindings/solana/anchor-go/generator/generator.go index 4eddb47d..b95e1bc5 100644 --- a/cmd/generate-bindings/solana/anchor-go/generator/generator.go +++ b/cmd/generate-bindings/solana/anchor-go/generator/generator.go @@ -12,8 +12,9 @@ import ( var Debug = false // Set to true to enable debug logging. type Generator struct { - options *GeneratorOptions - idl *idl.Idl + options *GeneratorOptions + idl *idl.Idl + complexEnumRegistry map[string]struct{} } type GeneratorOptions struct { @@ -62,13 +63,14 @@ func (g *Generator) Generate() (*Output, error) { Files: make([]*OutputFile, 0), } + g.complexEnumRegistry = make(map[string]struct{}) + { // Register complex enums. { - // register complex enums: // TODO: .types is the only place where we can find complex enums? (or enums in general?) for _, typ := range g.idl.Types { - registerComplexEnums(typ) + g.registerComplexEnums(typ) } } if len(g.idl.Docs) > 0 { diff --git a/cmd/generate-bindings/solana/anchor-go/generator/instructions.go b/cmd/generate-bindings/solana/anchor-go/generator/instructions.go index 4febce22..24a3a404 100644 --- a/cmd/generate-bindings/solana/anchor-go/generator/instructions.go +++ b/cmd/generate-bindings/solana/anchor-go/generator/instructions.go @@ -135,9 +135,9 @@ func (g *Generator) gen_instructions() (*OutputFile, error) { // ) // } checkNil := true - body.BlockFunc(func(g *Group) { - gen_marshal_DefinedFieldsNamed( - g, + body.BlockFunc(func(grp *Group) { + g.gen_marshal_DefinedFieldsNamed( + grp, instruction.Args, checkNil, func(param idl.IdlField) *Statement { diff --git a/cmd/generate-bindings/solana/anchor-go/generator/marshal.go b/cmd/generate-bindings/solana/anchor-go/generator/marshal.go index be4ebae7..cc024ac6 100644 --- a/cmd/generate-bindings/solana/anchor-go/generator/marshal.go +++ b/cmd/generate-bindings/solana/anchor-go/generator/marshal.go @@ -9,7 +9,7 @@ import ( "github.com/gagliardetto/anchor-go/idl/idltype" ) -func gen_MarshalWithEncoder_struct( +func (g *Generator) gen_MarshalWithEncoder_struct( idl_ *idl.Idl, withDiscriminator bool, receiverTypeName string, @@ -45,7 +45,7 @@ func gen_MarshalWithEncoder_struct( switch fields := fields.(type) { case idl.IdlDefinedFieldsNamed: uniqueFieldNames := generateUniqueFieldNames(fields) - gen_marshal_DefinedFieldsNamed( + g.gen_marshal_DefinedFieldsNamed( body, fields, checkNil, @@ -61,7 +61,7 @@ func gen_MarshalWithEncoder_struct( case idl.IdlDefinedFieldsTuple: convertedFields := tupleToFieldsNamed(fields) uniqueFieldNames := generateUniqueFieldNames(convertedFields) - gen_marshal_DefinedFieldsNamed( + g.gen_marshal_DefinedFieldsNamed( body, convertedFields, checkNil, @@ -136,7 +136,7 @@ func gen_MarshalWithEncoder_struct( return code } -func gen_marshal_DefinedFieldsNamed( +func (g *Generator) gen_marshal_DefinedFieldsNamed( body *Group, fields idl.IdlDefinedFieldsNamed, checkNil bool, @@ -153,7 +153,7 @@ func gen_marshal_DefinedFieldsNamed( body.Commentf("Serialize `%s`:", exportedArgName) } - if isComplexEnum(field.Ty) || (IsArray(field.Ty) && isComplexEnum(field.Ty.(*idltype.Array).Type)) || (IsVec(field.Ty) && isComplexEnum(field.Ty.(*idltype.Vec).Vec)) || isOptionalComplexEnum(field.Ty) { + if g.isComplexEnum(field.Ty) || (IsArray(field.Ty) && g.isComplexEnum(field.Ty.(*idltype.Array).Type)) || (IsVec(field.Ty) && g.isComplexEnum(field.Ty.(*idltype.Vec).Vec)) || g.isOptionalComplexEnum(field.Ty) { switch field.Ty.(type) { case *idltype.Defined: enumTypeName := field.Ty.(*idltype.Defined).Name diff --git a/cmd/generate-bindings/solana/anchor-go/generator/types.go b/cmd/generate-bindings/solana/anchor-go/generator/types.go index 23e5386d..489d6bec 100644 --- a/cmd/generate-bindings/solana/anchor-go/generator/types.go +++ b/cmd/generate-bindings/solana/anchor-go/generator/types.go @@ -117,7 +117,7 @@ func (g *Generator) gen_complexEnum(name string, docs []string, typ idl.IdlTypeD // Add comments for the enum type: addComments(code, docs) { - register_TypeName_as_ComplexEnum(name) + g.registerComplexEnumType(name) containerName := formatEnumContainerName(enumTypeName) interfaceMethodName := formatInterfaceMethodName(enumTypeName) @@ -264,7 +264,7 @@ func (g *Generator) gen_complexEnum(name string, docs []string, typ idl.IdlTypeD case idl.IdlDefinedFieldsNamed: for _, variantField := range fields { optionality := IsOption(variantField.Ty) || IsCOption(variantField.Ty) - structGroup.Add(genField(variantField, optionality)). + structGroup.Add(g.genField(variantField, optionality)). Add(func() Code { tagMap := map[string]string{} if IsOption(variantField.Ty) { @@ -287,7 +287,7 @@ func (g *Generator) gen_complexEnum(name string, docs []string, typ idl.IdlTypeD for itemIndex, tupleItem := range fields { optionality := IsOption(tupleItem) || IsCOption(tupleItem) tupleItemName := FormatTupleItemName(itemIndex) - structGroup.Add(genFieldNamed(tupleItemName, tupleItem, optionality)). + structGroup.Add(g.genFieldNamed(tupleItemName, tupleItem, optionality)). Add(func() Code { tagMap := map[string]string{} if IsOption(tupleItem) { @@ -356,7 +356,7 @@ func (g *Generator) gen_complexEnum(name string, docs []string, typ idl.IdlTypeD case idl.IdlDefinedFieldsNamed: // Declare MarshalWithEncoder: code.Line().Line().Add( - gen_MarshalWithEncoder_struct( + g.gen_MarshalWithEncoder_struct( g.idl, false, variantTypeNameComplex, @@ -367,7 +367,7 @@ func (g *Generator) gen_complexEnum(name string, docs []string, typ idl.IdlTypeD // Declare UnmarshalWithDecoder code.Line().Line().Add( - gen_UnmarshalWithDecoder_struct( + g.gen_UnmarshalWithDecoder_struct( g.idl, false, variantTypeNameComplex, @@ -379,7 +379,7 @@ func (g *Generator) gen_complexEnum(name string, docs []string, typ idl.IdlTypeD // TODO: handle tuples // Declare MarshalWithEncoder: code.Line().Line().Add( - gen_MarshalWithEncoder_struct( + g.gen_MarshalWithEncoder_struct( g.idl, false, variantTypeNameComplex, @@ -390,7 +390,7 @@ func (g *Generator) gen_complexEnum(name string, docs []string, typ idl.IdlTypeD // Declare UnmarshalWithDecoder code.Line().Line().Add( - gen_UnmarshalWithDecoder_struct( + g.gen_UnmarshalWithDecoder_struct( g.idl, false, variantTypeNameComplex, diff --git a/cmd/generate-bindings/solana/anchor-go/generator/types_test.go b/cmd/generate-bindings/solana/anchor-go/generator/types_test.go index fbd4083d..72712eb8 100644 --- a/cmd/generate-bindings/solana/anchor-go/generator/types_test.go +++ b/cmd/generate-bindings/solana/anchor-go/generator/types_test.go @@ -40,13 +40,14 @@ func TestGenComplexEnum_ConsecutiveUppercase(t *testing.T) { // won't find the original "HTTPStatus" entry and returns nil. idlData := makeComplexEnumIDL("HTTPStatus") gen := &Generator{ - idl: idlData, - options: &GeneratorOptions{Package: "test"}, + idl: idlData, + options: &GeneratorOptions{Package: "test"}, + complexEnumRegistry: make(map[string]struct{}), } // Register the complex enum as the generator normally would. for _, typ := range gen.idl.Types { - registerComplexEnums(typ) + gen.registerComplexEnums(typ) } outputFile, err := gen.genfile_types() @@ -62,12 +63,13 @@ func TestGenComplexEnum_SnakeCaseName(t *testing.T) { // "MyStatus", so ByName("MyStatus") won't find "my_status". idlData := makeComplexEnumIDL("my_status") gen := &Generator{ - idl: idlData, - options: &GeneratorOptions{Package: "test"}, + idl: idlData, + options: &GeneratorOptions{Package: "test"}, + complexEnumRegistry: make(map[string]struct{}), } for _, typ := range gen.idl.Types { - registerComplexEnums(typ) + gen.registerComplexEnums(typ) } outputFile, err := gen.genfile_types() @@ -83,12 +85,13 @@ func TestGenComplexEnum_AlreadyCamelCase(t *testing.T) { // so ByName should find it. This should always work. idlData := makeComplexEnumIDL("MyStatus") gen := &Generator{ - idl: idlData, - options: &GeneratorOptions{Package: "test"}, + idl: idlData, + options: &GeneratorOptions{Package: "test"}, + complexEnumRegistry: make(map[string]struct{}), } for _, typ := range gen.idl.Types { - registerComplexEnums(typ) + gen.registerComplexEnums(typ) } outputFile, err := gen.genfile_types() diff --git a/cmd/generate-bindings/solana/anchor-go/generator/unmarshal.go b/cmd/generate-bindings/solana/anchor-go/generator/unmarshal.go index e358a161..035acf85 100644 --- a/cmd/generate-bindings/solana/anchor-go/generator/unmarshal.go +++ b/cmd/generate-bindings/solana/anchor-go/generator/unmarshal.go @@ -72,7 +72,7 @@ func formatEnumEncoderName(enumTypeName string) string { return "Encode" + enumTypeName } -func gen_UnmarshalWithDecoder_struct( +func (g *Generator) gen_UnmarshalWithDecoder_struct( idl_ *idl.Idl, withDiscriminator bool, receiverTypeName string, @@ -118,10 +118,10 @@ func gen_UnmarshalWithDecoder_struct( switch fields := fields.(type) { case idl.IdlDefinedFieldsNamed: - gen_unmarshal_DefinedFieldsNamed(body, fields, generateUniqueFieldNames(fields)) + g.gen_unmarshal_DefinedFieldsNamed(body, fields, generateUniqueFieldNames(fields)) case idl.IdlDefinedFieldsTuple: convertedFields := tupleToFieldsNamed(fields) - gen_unmarshal_DefinedFieldsNamed(body, convertedFields, generateUniqueFieldNames(convertedFields)) + g.gen_unmarshal_DefinedFieldsNamed(body, convertedFields, generateUniqueFieldNames(convertedFields)) case nil: // No fields, just an empty struct. // TODO: should we panic here? @@ -226,7 +226,7 @@ func tupleToFieldsNamed( return fields } -func gen_unmarshal_DefinedFieldsNamed( +func (g *Generator) gen_unmarshal_DefinedFieldsNamed( body *Group, fields idl.IdlDefinedFieldsNamed, uniqueFieldNames map[string]string, @@ -240,7 +240,7 @@ func gen_unmarshal_DefinedFieldsNamed( body.Commentf("Deserialize `%s`:", exportedArgName) } - if isComplexEnum(field.Ty) || (IsArray(field.Ty) && isComplexEnum(field.Ty.(*idltype.Array).Type)) || (IsVec(field.Ty) && isComplexEnum(field.Ty.(*idltype.Vec).Vec)) || isOptionalComplexEnum(field.Ty) { + if g.isComplexEnum(field.Ty) || (IsArray(field.Ty) && g.isComplexEnum(field.Ty.(*idltype.Array).Type)) || (IsVec(field.Ty) && g.isComplexEnum(field.Ty.(*idltype.Vec).Vec)) || g.isOptionalComplexEnum(field.Ty) { switch field.Ty.(type) { case *idltype.Defined: enumName := field.Ty.(*idltype.Defined).Name diff --git a/cmd/generate-bindings/solana/bindings_test.go b/cmd/generate-bindings/solana/bindings_test.go index bff2ad1f..4f5ce0cb 100644 --- a/cmd/generate-bindings/solana/bindings_test.go +++ b/cmd/generate-bindings/solana/bindings_test.go @@ -112,7 +112,7 @@ func TestWriteReportMethods(t *testing.T) { reply := ds.WriteReportFromUserData(runtime, datastorage.UserData{ Key: "testKey", Value: "testValue", - }, nil) + }, nil, nil) require.NoError(t, err, "WriteReportDataStorageUserData should not return an error") response, err := reply.Await() require.NoError(t, err, "Awaiting WriteReportDataStorageUserData reply should not return an error")