From 231e76a599527e8ebdeb1f63029f7379106bcbe0 Mon Sep 17 00:00:00 2001 From: David Ndungu Date: Tue, 3 Mar 2026 14:02:24 -0800 Subject: [PATCH] feat: add Attribute_Tensor to protobuf schema Add tensor oneof variant to Attribute message, enabling storage of tensor constants as node attributes. Required by zerfoo model builder. --- zerfoo.pb.go | 59 ++++++++++++++++++++++++++++++++++------------------ zerfoo.proto | 3 ++- 2 files changed, 41 insertions(+), 21 deletions(-) diff --git a/zerfoo.pb.go b/zerfoo.pb.go index 7dca33e..76c69eb 100644 --- a/zerfoo.pb.go +++ b/zerfoo.pb.go @@ -104,7 +104,7 @@ func (Tensor_DataType) EnumDescriptor() ([]byte, []int) { return file_zerfoo_proto_rawDescGZIP(), []int{9, 0} } -// Attribute represents a named, non-tensor parameter for a node. +// Attribute represents a named parameter for a node. type Attribute struct { state protoimpl.MessageState `protogen:"open.v1"` // Types that are valid to be assigned to Value: @@ -116,6 +116,7 @@ type Attribute struct { // *Attribute_Ints // *Attribute_Strings // *Attribute_B + // *Attribute_Tensor Value isAttribute_Value `protobuf_oneof:"value"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache @@ -221,6 +222,15 @@ func (x *Attribute) GetB() bool { return false } +func (x *Attribute) GetTensor() *Tensor { + if x != nil { + if x, ok := x.Value.(*Attribute_Tensor); ok { + return x.Tensor + } + } + return nil +} + type isAttribute_Value interface { isAttribute_Value() } @@ -253,6 +263,10 @@ type Attribute_B struct { B bool `protobuf:"varint,7,opt,name=b,proto3,oneof"` // Added boolean support } +type Attribute_Tensor struct { + Tensor *Tensor `protobuf:"bytes,8,opt,name=tensor,proto3,oneof"` // Added tensor support for constants +} + func (*Attribute_F) isAttribute_Value() {} func (*Attribute_I) isAttribute_Value() {} @@ -267,6 +281,8 @@ func (*Attribute_Strings) isAttribute_Value() {} func (*Attribute_B) isAttribute_Value() {} +func (*Attribute_Tensor) isAttribute_Value() {} + // Floats is a wrapper for repeated float values in attributes. type Floats struct { state protoimpl.MessageState `protogen:"open.v1"` @@ -928,7 +944,7 @@ var File_zerfoo_proto protoreflect.FileDescriptor const file_zerfoo_proto_rawDesc = "" + "\n" + - "\fzerfoo.proto\x12\x03zmf\"\xc6\x01\n" + + "\fzerfoo.proto\x12\x03zmf\"\xed\x01\n" + "\tAttribute\x12\x0e\n" + "\x01f\x18\x01 \x01(\x02H\x00R\x01f\x12\x0e\n" + "\x01i\x18\x02 \x01(\x03H\x00R\x01i\x12\x0e\n" + @@ -936,7 +952,8 @@ const file_zerfoo_proto_rawDesc = "" + "\x06floats\x18\x04 \x01(\v2\v.zmf.FloatsH\x00R\x06floats\x12\x1f\n" + "\x04ints\x18\x05 \x01(\v2\t.zmf.IntsH\x00R\x04ints\x12(\n" + "\astrings\x18\x06 \x01(\v2\f.zmf.StringsH\x00R\astrings\x12\x0e\n" + - "\x01b\x18\a \x01(\bH\x00R\x01bB\a\n" + + "\x01b\x18\a \x01(\bH\x00R\x01b\x12%\n" + + "\x06tensor\x18\b \x01(\v2\v.zmf.TensorH\x00R\x06tensorB\a\n" + "\x05value\"\x1a\n" + "\x06Floats\x12\x10\n" + "\x03val\x18\x01 \x03(\x02R\x03val\"\x82\x02\n" + @@ -1054,23 +1071,24 @@ var file_zerfoo_proto_depIdxs = []int32{ 2, // 0: zmf.Attribute.floats:type_name -> zmf.Floats 4, // 1: zmf.Attribute.ints:type_name -> zmf.Ints 9, // 2: zmf.Attribute.strings:type_name -> zmf.Strings - 12, // 3: zmf.Graph.parameters:type_name -> zmf.Graph.ParametersEntry - 7, // 4: zmf.Graph.nodes:type_name -> zmf.Node - 11, // 5: zmf.Graph.inputs:type_name -> zmf.ValueInfo - 11, // 6: zmf.Graph.outputs:type_name -> zmf.ValueInfo - 3, // 7: zmf.Model.graph:type_name -> zmf.Graph - 5, // 8: zmf.Model.metadata:type_name -> zmf.Metadata - 13, // 9: zmf.Node.attributes:type_name -> zmf.Node.AttributesEntry - 0, // 10: zmf.Tensor.dtype:type_name -> zmf.Tensor.DataType - 8, // 11: zmf.Tensor.quant:type_name -> zmf.Quantization - 0, // 12: zmf.ValueInfo.dtype:type_name -> zmf.Tensor.DataType - 10, // 13: zmf.Graph.ParametersEntry.value:type_name -> zmf.Tensor - 1, // 14: zmf.Node.AttributesEntry.value:type_name -> zmf.Attribute - 15, // [15:15] is the sub-list for method output_type - 15, // [15:15] is the sub-list for method input_type - 15, // [15:15] is the sub-list for extension type_name - 15, // [15:15] is the sub-list for extension extendee - 0, // [0:15] is the sub-list for field type_name + 10, // 3: zmf.Attribute.tensor:type_name -> zmf.Tensor + 12, // 4: zmf.Graph.parameters:type_name -> zmf.Graph.ParametersEntry + 7, // 5: zmf.Graph.nodes:type_name -> zmf.Node + 11, // 6: zmf.Graph.inputs:type_name -> zmf.ValueInfo + 11, // 7: zmf.Graph.outputs:type_name -> zmf.ValueInfo + 3, // 8: zmf.Model.graph:type_name -> zmf.Graph + 5, // 9: zmf.Model.metadata:type_name -> zmf.Metadata + 13, // 10: zmf.Node.attributes:type_name -> zmf.Node.AttributesEntry + 0, // 11: zmf.Tensor.dtype:type_name -> zmf.Tensor.DataType + 8, // 12: zmf.Tensor.quant:type_name -> zmf.Quantization + 0, // 13: zmf.ValueInfo.dtype:type_name -> zmf.Tensor.DataType + 10, // 14: zmf.Graph.ParametersEntry.value:type_name -> zmf.Tensor + 1, // 15: zmf.Node.AttributesEntry.value:type_name -> zmf.Attribute + 16, // [16:16] is the sub-list for method output_type + 16, // [16:16] is the sub-list for method input_type + 16, // [16:16] is the sub-list for extension type_name + 16, // [16:16] is the sub-list for extension extendee + 0, // [0:16] is the sub-list for field type_name } func init() { file_zerfoo_proto_init() } @@ -1086,6 +1104,7 @@ func file_zerfoo_proto_init() { (*Attribute_Ints)(nil), (*Attribute_Strings)(nil), (*Attribute_B)(nil), + (*Attribute_Tensor)(nil), } file_zerfoo_proto_msgTypes[6].OneofWrappers = []any{} file_zerfoo_proto_msgTypes[7].OneofWrappers = []any{} diff --git a/zerfoo.proto b/zerfoo.proto index 468ef9b..436eae0 100644 --- a/zerfoo.proto +++ b/zerfoo.proto @@ -4,7 +4,7 @@ package zmf; option go_package = "github.com/zerfoo/zmf"; -// Attribute represents a named, non-tensor parameter for a node. +// Attribute represents a named parameter for a node. message Attribute { oneof value { float f = 1; @@ -14,6 +14,7 @@ message Attribute { Ints ints = 5; Strings strings = 6; bool b = 7; // Added boolean support + Tensor tensor = 8; // Added tensor support for constants } }