Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 10 additions & 11 deletions cmd/generate-bindings/solana/anchor-go/generator/accounts.go
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ func (g *Generator) gen_IDLTypeDefTyStruct(

// TODO: optionality for complex enums is a nil interface.
uniqueFieldName := uniqueFieldNames[field.Name]
fieldsGroup.Add(genFieldWithName(field, uniqueFieldName, optionality)).
fieldsGroup.Add(g.genFieldWithName(field, uniqueFieldName, optionality)).
Add(func() Code {
tagMap := map[string]string{}
if IsOption(field.Ty) {
Expand All @@ -180,7 +180,7 @@ func (g *Generator) gen_IDLTypeDefTyStruct(
fieldsGroup.Line()
optionality := IsOption(field) || IsCOption(field)

fieldsGroup.Add(genFieldNamed(
fieldsGroup.Add(g.genFieldNamed(
FormatTupleItemName(fieldIndex),
field,
optionality,
Expand Down Expand Up @@ -223,7 +223,7 @@ func (g *Generator) gen_IDLTypeDefTyStruct(
// Declare MarshalWithEncoder:
// TODO:
code.Line().Line().Add(
gen_MarshalWithEncoder_struct(
g.gen_MarshalWithEncoder_struct(
g.idl,
withDiscriminator,
exportedAccountName,
Expand All @@ -234,7 +234,7 @@ func (g *Generator) gen_IDLTypeDefTyStruct(

// Declare UnmarshalWithDecoder
code.Line().Line().Add(
gen_UnmarshalWithDecoder_struct(
g.gen_UnmarshalWithDecoder_struct(
g.idl,
withDiscriminator,
exportedAccountName,
Expand Down Expand Up @@ -286,20 +286,19 @@ func generateUniqueFieldNames(fields []idl.IdlField) map[string]string {
return fieldNameMap
}

func genField(field idl.IdlField, pointer bool) Code {
return genFieldNamed(field.Name, field.Ty, pointer)
func (g *Generator) genField(field idl.IdlField, pointer bool) Code {
return g.genFieldNamed(field.Name, field.Ty, pointer)
}

// genFieldWithName generates a field with a custom field name (for handling duplicates)
func genFieldWithName(field idl.IdlField, fieldName string, pointer bool) Code {
return genFieldNamed(fieldName, field.Ty, pointer)
func (g *Generator) genFieldWithName(field idl.IdlField, fieldName string, pointer bool) Code {
return g.genFieldNamed(fieldName, field.Ty, pointer)
}

func genFieldNamed(name string, typ idltype.IdlType, pointer bool) Code {
func (g *Generator) genFieldNamed(name string, typ idltype.IdlType, pointer bool) Code {
st := newStatement()
st.Id(tools.ToCamelUpper(name)).
Add(func() Code {
if isComplexEnum(typ) {
if g.isComplexEnum(typ) {
return nil
}
if pointer {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,15 +42,21 @@ func TestMarshalUnmarshalCodegen_matchesUniqueStructFieldNames(t *testing.T) {
fields := collidingNamedFields()
receiver := "CollideAccount"

marshalCode := gen_MarshalWithEncoder_struct(
g := &Generator{
idl: idlMinimal,
options: &GeneratorOptions{Package: "test"},
complexEnumRegistry: make(map[string]struct{}),
}

marshalCode := g.gen_MarshalWithEncoder_struct(
idlMinimal,
false,
receiver,
"",
fields,
true,
)
unmarshalCode := gen_UnmarshalWithDecoder_struct(
unmarshalCode := g.gen_UnmarshalWithDecoder_struct(
idlMinimal,
false,
receiver,
Expand Down
30 changes: 15 additions & 15 deletions cmd/generate-bindings/solana/anchor-go/generator/complex-enums.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,43 +6,43 @@ import (
"github.com/gagliardetto/anchor-go/idl/idltype"
)

// typeRegistryComplexEnum contains all types that are a complex enum (and thus implemented as an interface).
var typeRegistryComplexEnum = make(map[string]struct{})

func isComplexEnum(envel idltype.IdlType) bool {
func (g *Generator) isComplexEnum(envel idltype.IdlType) bool {
switch vv := envel.(type) {
case *idltype.Defined:
_, ok := typeRegistryComplexEnum[vv.Name]
_, ok := g.complexEnumRegistry[vv.Name]
return ok
}
return false
}

func isOptionalComplexEnum(ty idltype.IdlType) bool {
func (g *Generator) registerComplexEnumType(name string) {
if g.complexEnumRegistry == nil {
g.complexEnumRegistry = make(map[string]struct{})
}
g.complexEnumRegistry[name] = struct{}{}
}

func (g *Generator) isOptionalComplexEnum(ty idltype.IdlType) bool {
switch v := ty.(type) {
case *idltype.Option:
return isComplexEnum(v.Option)
return g.isComplexEnum(v.Option)
case *idltype.COption:
return isComplexEnum(v.COption)
return g.isComplexEnum(v.COption)
}
return false
}

func register_TypeName_as_ComplexEnum(name string) {
typeRegistryComplexEnum[name] = struct{}{}
}

func registerComplexEnums(def idl.IdlTypeDef) {
func (g *Generator) registerComplexEnums(def idl.IdlTypeDef) {
switch vv := def.Ty.(type) {
case *idl.IdlTypeDefTyEnum:
enumTypeName := def.Name
if !vv.IsAllSimple() {
register_TypeName_as_ComplexEnum(enumTypeName)
g.registerComplexEnumType(enumTypeName)
}
case idl.IdlTypeDefTyEnum:
enumTypeName := def.Name
if !vv.IsAllSimple() {
register_TypeName_as_ComplexEnum(enumTypeName)
g.registerComplexEnumType(enumTypeName)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,23 +15,31 @@ import (
// and gen_unmarshal_DefinedFieldsNamed to decide whether a field is routed to
// the specialized enum encoder/parser or falls through to the generic
// Encode/Decode path.
func complexEnumGuard(ty idltype.IdlType) bool {
return isComplexEnum(ty) ||
(IsArray(ty) && isComplexEnum(ty.(*idltype.Array).Type)) ||
(IsVec(ty) && isComplexEnum(ty.(*idltype.Vec).Vec)) ||
isOptionalComplexEnum(ty)
func complexEnumGuard(g *Generator, ty idltype.IdlType) bool {
return g.isComplexEnum(ty) ||
(IsArray(ty) && g.isComplexEnum(ty.(*idltype.Array).Type)) ||
(IsVec(ty) && g.isComplexEnum(ty.(*idltype.Vec).Vec)) ||
g.isOptionalComplexEnum(ty)
}

func newTestGenerator() *Generator {
return &Generator{
idl: &idl.Idl{},
options: &GeneratorOptions{Package: "test"},
complexEnumRegistry: make(map[string]struct{}),
}
}

func TestComplexEnumGuard_handlesOptionAndCOption(t *testing.T) {
const name = "Outcome"
register_TypeName_as_ComplexEnum(name)
t.Cleanup(func() { delete(typeRegistryComplexEnum, name) })
g := newTestGenerator()
g.registerComplexEnumType(name)

defined := &idltype.Defined{Name: name}

assert.True(t, complexEnumGuard(defined), "bare Defined")
assert.True(t, complexEnumGuard(&idltype.Option{Option: defined}), "Option<ComplexEnum>")
assert.True(t, complexEnumGuard(&idltype.COption{COption: defined}), "COption<ComplexEnum>")
assert.True(t, complexEnumGuard(g, defined), "bare Defined")
assert.True(t, complexEnumGuard(g, &idltype.Option{Option: defined}), "Option<ComplexEnum>")
assert.True(t, complexEnumGuard(g, &idltype.COption{COption: defined}), "COption<ComplexEnum>")
}

// TestComplexEnumGuard_rejectsNonComplexOptionals ensures the guard does NOT
Expand All @@ -40,20 +48,20 @@ func TestComplexEnumGuard_handlesOptionAndCOption(t *testing.T) {
// where .Option.(*idltype.Defined) would panic on a non-Defined inner type.
func TestComplexEnumGuard_rejectsNonComplexOptionals(t *testing.T) {
const complexName = "Outcome"
register_TypeName_as_ComplexEnum(complexName)
t.Cleanup(func() { delete(typeRegistryComplexEnum, complexName) })
g := newTestGenerator()
g.registerComplexEnumType(complexName)

nonComplex := &idltype.Defined{Name: "PlainStruct"}

assert.False(t, complexEnumGuard(&idltype.Option{Option: nonComplex}),
assert.False(t, complexEnumGuard(g, &idltype.Option{Option: nonComplex}),
"Option<NonComplexDefined> must not trigger the complex-enum path")
assert.False(t, complexEnumGuard(&idltype.COption{COption: nonComplex}),
assert.False(t, complexEnumGuard(g, &idltype.COption{COption: nonComplex}),
"COption<NonComplexDefined> must not trigger the complex-enum path")
assert.False(t, complexEnumGuard(&idltype.Option{Option: &idltype.U64{}}),
assert.False(t, complexEnumGuard(g, &idltype.Option{Option: &idltype.U64{}}),
"Option<U64> must not trigger the complex-enum path")
assert.False(t, complexEnumGuard(&idltype.COption{COption: &idltype.U8{}}),
assert.False(t, complexEnumGuard(g, &idltype.COption{COption: &idltype.U8{}}),
"COption<U8> must not trigger the complex-enum path")
assert.False(t, complexEnumGuard(&idltype.Option{Option: &idltype.Vec{Vec: &idltype.Defined{Name: complexName}}}),
assert.False(t, complexEnumGuard(g, &idltype.Option{Option: &idltype.Vec{Vec: &idltype.Defined{Name: complexName}}}),
"Option<Vec<ComplexEnum>> — nested containers not supported, must not match")
}

Expand All @@ -63,8 +71,8 @@ func TestComplexEnumGuard_rejectsNonComplexOptionals(t *testing.T) {
// instead of the generic Encode/Decode.
func TestComplexEnumCodegen_optionalComplexEnum(t *testing.T) {
const enumName = "Outcome"
register_TypeName_as_ComplexEnum(enumName)
t.Cleanup(func() { delete(typeRegistryComplexEnum, enumName) })
g := newTestGenerator()
g.registerComplexEnumType(enumName)

fields := idl.IdlDefinedFieldsNamed{
{Name: "id", Ty: &idltype.U64{}},
Expand All @@ -73,10 +81,10 @@ func TestComplexEnumCodegen_optionalComplexEnum(t *testing.T) {
{Name: "checksum", Ty: &idltype.U64{}},
}

marshalCode := gen_MarshalWithEncoder_struct(
marshalCode := g.gen_MarshalWithEncoder_struct(
&idl.Idl{}, false, "Report", "", fields, true,
)
unmarshalCode := gen_UnmarshalWithDecoder_struct(
unmarshalCode := g.gen_UnmarshalWithDecoder_struct(
&idl.Idl{}, false, "Report", "", fields,
)

Expand Down
10 changes: 6 additions & 4 deletions cmd/generate-bindings/solana/anchor-go/generator/generator.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@ import (
var Debug = false // Set to true to enable debug logging.

type Generator struct {
options *GeneratorOptions
idl *idl.Idl
options *GeneratorOptions
idl *idl.Idl
complexEnumRegistry map[string]struct{}
}

type GeneratorOptions struct {
Expand Down Expand Up @@ -62,13 +63,14 @@ func (g *Generator) Generate() (*Output, error) {
Files: make([]*OutputFile, 0),
}

g.complexEnumRegistry = make(map[string]struct{})

{
// Register complex enums.
{
// register complex enums:
// TODO: .types is the only place where we can find complex enums? (or enums in general?)
for _, typ := range g.idl.Types {
registerComplexEnums(typ)
g.registerComplexEnums(typ)
}
}
if len(g.idl.Docs) > 0 {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -135,9 +135,9 @@ func (g *Generator) gen_instructions() (*OutputFile, error) {
// )
// }
checkNil := true
body.BlockFunc(func(g *Group) {
gen_marshal_DefinedFieldsNamed(
g,
body.BlockFunc(func(grp *Group) {
g.gen_marshal_DefinedFieldsNamed(
grp,
instruction.Args,
checkNil,
func(param idl.IdlField) *Statement {
Expand Down
10 changes: 5 additions & 5 deletions cmd/generate-bindings/solana/anchor-go/generator/marshal.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import (
"github.com/gagliardetto/anchor-go/idl/idltype"
)

func gen_MarshalWithEncoder_struct(
func (g *Generator) gen_MarshalWithEncoder_struct(
idl_ *idl.Idl,
withDiscriminator bool,
receiverTypeName string,
Expand Down Expand Up @@ -45,7 +45,7 @@ func gen_MarshalWithEncoder_struct(
switch fields := fields.(type) {
case idl.IdlDefinedFieldsNamed:
uniqueFieldNames := generateUniqueFieldNames(fields)
gen_marshal_DefinedFieldsNamed(
g.gen_marshal_DefinedFieldsNamed(
body,
fields,
checkNil,
Expand All @@ -61,7 +61,7 @@ func gen_MarshalWithEncoder_struct(
case idl.IdlDefinedFieldsTuple:
convertedFields := tupleToFieldsNamed(fields)
uniqueFieldNames := generateUniqueFieldNames(convertedFields)
gen_marshal_DefinedFieldsNamed(
g.gen_marshal_DefinedFieldsNamed(
body,
convertedFields,
checkNil,
Expand Down Expand Up @@ -136,7 +136,7 @@ func gen_MarshalWithEncoder_struct(
return code
}

func gen_marshal_DefinedFieldsNamed(
func (g *Generator) gen_marshal_DefinedFieldsNamed(
body *Group,
fields idl.IdlDefinedFieldsNamed,
checkNil bool,
Expand All @@ -153,7 +153,7 @@ func gen_marshal_DefinedFieldsNamed(
body.Commentf("Serialize `%s`:", exportedArgName)
}

if isComplexEnum(field.Ty) || (IsArray(field.Ty) && isComplexEnum(field.Ty.(*idltype.Array).Type)) || (IsVec(field.Ty) && isComplexEnum(field.Ty.(*idltype.Vec).Vec)) || isOptionalComplexEnum(field.Ty) {
if g.isComplexEnum(field.Ty) || (IsArray(field.Ty) && g.isComplexEnum(field.Ty.(*idltype.Array).Type)) || (IsVec(field.Ty) && g.isComplexEnum(field.Ty.(*idltype.Vec).Vec)) || g.isOptionalComplexEnum(field.Ty) {
switch field.Ty.(type) {
case *idltype.Defined:
enumTypeName := field.Ty.(*idltype.Defined).Name
Expand Down
14 changes: 7 additions & 7 deletions cmd/generate-bindings/solana/anchor-go/generator/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ func (g *Generator) gen_complexEnum(name string, docs []string, typ idl.IdlTypeD
// Add comments for the enum type:
addComments(code, docs)
{
register_TypeName_as_ComplexEnum(name)
g.registerComplexEnumType(name)
containerName := formatEnumContainerName(enumTypeName)
interfaceMethodName := formatInterfaceMethodName(enumTypeName)

Expand Down Expand Up @@ -264,7 +264,7 @@ func (g *Generator) gen_complexEnum(name string, docs []string, typ idl.IdlTypeD
case idl.IdlDefinedFieldsNamed:
for _, variantField := range fields {
optionality := IsOption(variantField.Ty) || IsCOption(variantField.Ty)
structGroup.Add(genField(variantField, optionality)).
structGroup.Add(g.genField(variantField, optionality)).
Add(func() Code {
tagMap := map[string]string{}
if IsOption(variantField.Ty) {
Expand All @@ -287,7 +287,7 @@ func (g *Generator) gen_complexEnum(name string, docs []string, typ idl.IdlTypeD
for itemIndex, tupleItem := range fields {
optionality := IsOption(tupleItem) || IsCOption(tupleItem)
tupleItemName := FormatTupleItemName(itemIndex)
structGroup.Add(genFieldNamed(tupleItemName, tupleItem, optionality)).
structGroup.Add(g.genFieldNamed(tupleItemName, tupleItem, optionality)).
Add(func() Code {
tagMap := map[string]string{}
if IsOption(tupleItem) {
Expand Down Expand Up @@ -356,7 +356,7 @@ func (g *Generator) gen_complexEnum(name string, docs []string, typ idl.IdlTypeD
case idl.IdlDefinedFieldsNamed:
// Declare MarshalWithEncoder:
code.Line().Line().Add(
gen_MarshalWithEncoder_struct(
g.gen_MarshalWithEncoder_struct(
g.idl,
false,
variantTypeNameComplex,
Expand All @@ -367,7 +367,7 @@ func (g *Generator) gen_complexEnum(name string, docs []string, typ idl.IdlTypeD

// Declare UnmarshalWithDecoder
code.Line().Line().Add(
gen_UnmarshalWithDecoder_struct(
g.gen_UnmarshalWithDecoder_struct(
g.idl,
false,
variantTypeNameComplex,
Expand All @@ -379,7 +379,7 @@ func (g *Generator) gen_complexEnum(name string, docs []string, typ idl.IdlTypeD
// TODO: handle tuples
// Declare MarshalWithEncoder:
code.Line().Line().Add(
gen_MarshalWithEncoder_struct(
g.gen_MarshalWithEncoder_struct(
g.idl,
false,
variantTypeNameComplex,
Expand All @@ -390,7 +390,7 @@ func (g *Generator) gen_complexEnum(name string, docs []string, typ idl.IdlTypeD

// Declare UnmarshalWithDecoder
code.Line().Line().Add(
gen_UnmarshalWithDecoder_struct(
g.gen_UnmarshalWithDecoder_struct(
g.idl,
false,
variantTypeNameComplex,
Expand Down
Loading