From 019b97237ae498b7599620ce6f40bfe0f7e929c3 Mon Sep 17 00:00:00 2001 From: Xavier Dupre Date: Tue, 9 Jan 2024 11:59:39 +0100 Subject: [PATCH] Fix default values --- onnx_array_api/light_api/_op_vars.py | 163 +++++++++++++-------------- 1 file changed, 81 insertions(+), 82 deletions(-) diff --git a/onnx_array_api/light_api/_op_vars.py b/onnx_array_api/light_api/_op_vars.py index 64d0d2d..4f30dbe 100644 --- a/onnx_array_api/light_api/_op_vars.py +++ b/onnx_array_api/light_api/_op_vars.py @@ -10,8 +10,10 @@ def BitShift(self, direction: str = "") -> "Var": return self.make_node("BitShift", *self.vars_, direction=direction) def CenterCropPad(self, axes: Optional[List[int]] = None) -> "Var": - axes = axes or [] - return self.make_node("CenterCropPad", *self.vars_, axes=axes) + kwargs = {} + if axes is not None: + kwargs["axes"] = axes + return self.make_node("CenterCropPad", *self.vars_, **kwargs) def Clip( self, @@ -27,12 +29,14 @@ def Col2Im( pads: Optional[List[int]] = None, strides: Optional[List[int]] = None, ) -> "Var": - dilations = dilations or [] - pads = pads or [] - strides = strides or [] - return self.make_node( - "Col2Im", *self.vars_, dilations=dilations, pads=pads, strides=strides - ) + kwargs = {} + if dilations is not None: + kwargs["dilations"] = dilations + if pads is not None: + kwargs["pads"] = pads + if strides is not None: + kwargs["strides"] = strides + return self.make_node("Col2Im", *self.vars_, **kwargs) def Compress(self, axis: int = 0) -> "Var": return self.make_node("Compress", *self.vars_, axis=axis) @@ -71,19 +75,17 @@ def ConvInteger( pads: Optional[List[int]] = None, strides: Optional[List[int]] = None, ) -> "Var": - dilations = dilations or [] - kernel_shape = kernel_shape or [] - pads = pads or [] - strides = strides or [] + kwargs = {} + if dilations is not None: + kwargs["dilations"] = dilations + if kernel_shape is not None: + kwargs["kernel_shape"] = kernel_shape + if pads is not None: + kwargs["pads"] = pads + if strides is not None: + kwargs["strides"] = strides return self.make_node( - "ConvInteger", - *self.vars_, - auto_pad=auto_pad, - dilations=dilations, - group=group, - kernel_shape=kernel_shape, - pads=pads, - strides=strides, + "ConvInteger", *self.vars_, auto_pad=auto_pad, group=group, **kwargs ) def ConvTranspose( @@ -97,23 +99,21 @@ def ConvTranspose( pads: Optional[List[int]] = None, strides: Optional[List[int]] = None, ) -> "Var": - dilations = dilations or [] - kernel_shape = kernel_shape or [] - output_padding = output_padding or [] - output_shape = output_shape or [] - pads = pads or [] - strides = strides or [] - return self.make_node( - "ConvTranspose", - *self.vars_, - auto_pad=auto_pad, - dilations=dilations, - group=group, - kernel_shape=kernel_shape, - output_padding=output_padding, - output_shape=output_shape, - pads=pads, - strides=strides, + kwargs = {} + if dilations is not None: + kwargs["dilations"] = dilations + if kernel_shape is not None: + kwargs["kernel_shape"] = kernel_shape + if pads is not None: + kwargs["pads"] = pads + if strides is not None: + kwargs["strides"] = strides + if output_padding is not None: + kwargs["output_padding"] = output_padding + if output_shape is not None: + kwargs["output_shape"] = output_shape + return self.make_node( + "ConvTranspose", *self.vars_, auto_pad=auto_pad, group=group, **kwargs ) def CumSum(self, exclusive: int = 0, reverse: int = 0) -> "Var": @@ -135,19 +135,17 @@ def DeformConv( pads: Optional[List[int]] = None, strides: Optional[List[int]] = None, ) -> "Var": - dilations = dilations or [] - kernel_shape = kernel_shape or [] - pads = pads or [] - strides = strides or [] + kwargs = {} + if dilations is not None: + kwargs["dilations"] = dilations + if kernel_shape is not None: + kwargs["kernel_shape"] = kernel_shape + if pads is not None: + kwargs["pads"] = pads + if strides is not None: + kwargs["strides"] = strides return self.make_node( - "DeformConv", - *self.vars_, - dilations=dilations, - group=group, - kernel_shape=kernel_shape, - offset_group=offset_group, - pads=pads, - strides=strides, + "DeformConv", *self.vars_, group=group, offset_group=offset_group, **kwargs ) def DequantizeLinear(self, axis: int = 1) -> "Var": @@ -204,12 +202,11 @@ def MatMulInteger( def MaxRoiPool( self, pooled_shape: Optional[List[int]] = None, spatial_scale: float = 1.0 ) -> "Var": - pooled_shape = pooled_shape or [] + kwargs = {} + if pooled_shape is not None: + kwargs["pooled_shape"] = pooled_shape return self.make_node( - "MaxRoiPool", - *self.vars_, - pooled_shape=pooled_shape, - spatial_scale=spatial_scale, + "MaxRoiPool", *self.vars_, spatial_scale=spatial_scale, **kwargs ) def MaxUnpool( @@ -218,16 +215,14 @@ def MaxUnpool( pads: Optional[List[int]] = None, strides: Optional[List[int]] = None, ) -> "Var": - kernel_shape = kernel_shape or [] - pads = pads or [] - strides = strides or [] - return self.make_node( - "MaxUnpool", - *self.vars_, - kernel_shape=kernel_shape, - pads=pads, - strides=strides, - ) + kwargs = {} + if kernel_shape is not None: + kwargs["kernel_shape"] = kernel_shape + if pads is not None: + kwargs["pads"] = pads + if strides is not None: + kwargs["strides"] = strides + return self.make_node("MaxUnpool", *self.vars_, **kwargs) def MelWeightMatrix(self, output_datatype: int = 1) -> "Var": return self.make_node( @@ -267,19 +262,17 @@ def QLinearConv( pads: Optional[List[int]] = None, strides: Optional[List[int]] = None, ) -> "Var": - dilations = dilations or [] - kernel_shape = kernel_shape or [] - pads = pads or [] - strides = strides or [] + kwargs = {} + if kernel_shape is not None: + kwargs["kernel_shape"] = kernel_shape + if pads is not None: + kwargs["pads"] = pads + if strides is not None: + kwargs["strides"] = strides + if dilations is not None: + kwargs["dilations"] = dilations return self.make_node( - "QLinearConv", - *self.vars_, - auto_pad=auto_pad, - dilations=dilations, - group=group, - kernel_shape=kernel_shape, - pads=pads, - strides=strides, + "QLinearConv", *self.vars_, auto_pad=auto_pad, group=group, **kwargs ) def QLinearMatMul( @@ -303,7 +296,9 @@ def RandomNormal( seed: float = 0.0, shape: Optional[List[int]] = None, ) -> "Var": - shape = shape or [] + kwargs = {} + if shape is not None: + kwargs["shape"] = shape return self.make_node( "RandomNormal", *self.vars_, @@ -311,7 +306,7 @@ def RandomNormal( mean=mean, scale=scale, seed=seed, - shape=shape, + **kwargs, ) def RandomUniform( @@ -322,7 +317,9 @@ def RandomUniform( seed: float = 0.0, shape: Optional[List[int]] = None, ) -> "Var": - shape = shape or [] + kwargs = {} + if shape is not None: + kwargs["shape"] = shape return self.make_node( "RandomUniform", *self.vars_, @@ -330,7 +327,7 @@ def RandomUniform( high=high, low=low, seed=seed, - shape=shape, + **kwargs, ) def Range( @@ -437,12 +434,13 @@ def Resize( mode: str = "nearest", nearest_mode: str = "round_prefer_floor", ) -> "Var": - axes = axes or [] + kwargs = {} + if axes is not None: + kwargs["axes"] = axes return self.make_node( "Resize", *self.vars_, antialias=antialias, - axes=axes, coordinate_transformation_mode=coordinate_transformation_mode, cubic_coeff_a=cubic_coeff_a, exclude_outside=exclude_outside, @@ -450,6 +448,7 @@ def Resize( keep_aspect_ratio_policy=keep_aspect_ratio_policy, mode=mode, nearest_mode=nearest_mode, + **kwargs, ) def RoiAlign(