From a4d2084d2753122139cbb0c144235068b9cfde5c Mon Sep 17 00:00:00 2001 From: alecerio Date: Wed, 24 Apr 2024 17:13:03 +0200 Subject: [PATCH 1/8] dialects: [dsp] add dsp dialect --- xdsl/dialects/dsp.py | 79 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 79 insertions(+) create mode 100644 xdsl/dialects/dsp.py diff --git a/xdsl/dialects/dsp.py b/xdsl/dialects/dsp.py new file mode 100644 index 0000000000..6e4f214c5c --- /dev/null +++ b/xdsl/dialects/dsp.py @@ -0,0 +1,79 @@ +from typing import Annotated + +from xdsl.dialects.builtin import ( + AnyFloat, + IntegerAttr, + IntegerType, + SSAValue, + TensorType, +) +from xdsl.ir import ( + Attribute, + Dialect, +) +from xdsl.irdl import ( + ConstraintVar, + IRDLOperation, + irdl_op_definition, + operand_def, + opt_attr_def, + result_def, +) + + +@irdl_op_definition +class STFT(IRDLOperation): + """ + The Short-Time Fourier Transform (STFT) is a technique used in digital signal processing to analyze the frequency content of a signal over time. + It provides a time-frequency representation of a signal by computing the Fourier Transform over short, overlapping windows of the signal. + This analysis is useful for tasks such as audio analysis, speech processing, and image processing. + + X[m,k]=∑ x[n]⋅w[n-m]⋅e^(-j*(2π/N)*nk) + + Where x[n] is the signal, w[n] is the window + """ + + name = "onnx.STFT" + T = Annotated[AnyFloat, ConstraintVar("T")] + T2 = Annotated[IntegerType, ConstraintVar("T")] + + frame = operand_def(TensorType[T]) + n_frame = operand_def(TensorType[T2]) + res = result_def(TensorType[T]) + + frame_size = opt_attr_def(IntegerAttr, attr_name="frame_size") + hop_size = opt_attr_def(IntegerAttr, attr_name="hop_size") + + assembly_format = ( + "`(` $operand`)` attr-dict `:` `(` type($operand) `)` `->` type($res)" + ) + + def __init__( + self, + frame: SSAValue, + n_frame: SSAValue, + frame_size: Attribute, + hop_size: Attribute, + ): + super().__init__( + attributes={ + "frame_size": frame_size, + "hop_size": hop_size, + }, + operands=[ + frame, + n_frame, + ], + result_types=[frame.type], + ) + + def verify_(self) -> None: + pass + + +DSP = Dialect( + "dsp", + [ + STFT, + ], +) From 37fcbd947e6f5d97c25e50b24b53d024cd0a3542 Mon Sep 17 00:00:00 2001 From: alecerio Date: Wed, 24 Apr 2024 17:21:54 +0200 Subject: [PATCH 2/8] dialects: [dsp] register dsp dialect --- xdsl/tools/command_line_tool.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/xdsl/tools/command_line_tool.py b/xdsl/tools/command_line_tool.py index 2e4cb1e6ff..4d8e7a7df0 100644 --- a/xdsl/tools/command_line_tool.py +++ b/xdsl/tools/command_line_tool.py @@ -69,6 +69,11 @@ def get_dmp(): return DMP + def get_dsp(): + from xdsl.dialects.dsp import DSP + + return DSP + def get_fir(): from xdsl.dialects.experimental.fir import FIR @@ -266,6 +271,7 @@ def get_x86(): "cmath": get_cmath, "comb": get_comb, "dmp": get_dmp, + "dsp": get_dsp, "fir": get_fir, "fsm": get_fsm, "func": get_func, From 6d9f541a9a22505acb2e5a1bd5bb66a497c9a3b7 Mon Sep 17 00:00:00 2001 From: alecerio Date: Wed, 24 Apr 2024 18:05:00 +0200 Subject: [PATCH 3/8] dialects: [dsp] add frame input dimensions test --- .../filecheck/dialects/onnx/onnx_invalid.mlir | 9 +++++++ xdsl/dialects/dsp.py | 26 ++++++++++++++++--- 2 files changed, 31 insertions(+), 4 deletions(-) diff --git a/tests/filecheck/dialects/onnx/onnx_invalid.mlir b/tests/filecheck/dialects/onnx/onnx_invalid.mlir index 9ad92beed0..cf89e1f302 100644 --- a/tests/filecheck/dialects/onnx/onnx_invalid.mlir +++ b/tests/filecheck/dialects/onnx/onnx_invalid.mlir @@ -551,3 +551,12 @@ builtin.module { // CHECK: Operation does not verify: incorrect output shape: output dimension #0 should be equal to 4 %res_transpose = "onnx.Transpose"(%t0) {onnx_node_name = "/Transpose", "perm" = [1 : i64, 0 : i64]}: (tensor<3x4xf32>) -> tensor<3x3xf32> } + +// ----- + +builtin.module { + %t0 = "test.op"() : () -> (tensor<2x256xf32>) + %t1 = "arith.constant"() {"value" = 1 : i64}: () -> i64 + // CHECK: Operation does not verify: frame number of dimensions must be 1. Actual number of dimensions: 2 + %res_stft = "dsp.STFT"(%t0, %t1) {"frame_size" = 256 : i64, "hop_size" = 128 : i64}: (tensor<2x256xf32>, i64) -> tensor<2x128xf32> +} \ No newline at end of file diff --git a/xdsl/dialects/dsp.py b/xdsl/dialects/dsp.py index 6e4f214c5c..92e6f97624 100644 --- a/xdsl/dialects/dsp.py +++ b/xdsl/dialects/dsp.py @@ -19,6 +19,7 @@ opt_attr_def, result_def, ) +from xdsl.utils.exceptions import VerifyException @irdl_op_definition @@ -35,11 +36,10 @@ class STFT(IRDLOperation): name = "onnx.STFT" T = Annotated[AnyFloat, ConstraintVar("T")] - T2 = Annotated[IntegerType, ConstraintVar("T")] frame = operand_def(TensorType[T]) - n_frame = operand_def(TensorType[T2]) - res = result_def(TensorType[T]) + n_frame = operand_def(IntegerType) + output = result_def(TensorType[T]) frame_size = opt_attr_def(IntegerAttr, attr_name="frame_size") hop_size = opt_attr_def(IntegerAttr, attr_name="hop_size") @@ -68,7 +68,25 @@ def __init__( ) def verify_(self) -> None: - pass + if ( + not isinstance(frame_type := self.frame.type, TensorType) + or not isinstance(n_frame := self.n_frame.type, IntegerType) + or not isinstance(output_type := self.output.type, TensorType) + ): + assert ( + False + ), "dsp stft operation operands must be TensorType and IntegerType, the result must be of type TensorType" + + frame_shape = frame_type.get_shape() + output_shape = output_type.get_shape() + print(output_shape) + print(n_frame) + + n_dimensions_frame = len(frame_shape) + if n_dimensions_frame != 1: + raise VerifyException( + f"frame number of dimensions must be 1. Actual number of dimensions: {n_dimensions_frame}" + ) DSP = Dialect( From ada4f8da3043d7128e999c68ddc463ad2775d226 Mon Sep 17 00:00:00 2001 From: alecerio Date: Wed, 24 Apr 2024 18:06:09 +0200 Subject: [PATCH 4/8] dialects: [dsp] add frame input dimensions test --- xdsl/dialects/dsp.py | 1 + 1 file changed, 1 insertion(+) diff --git a/xdsl/dialects/dsp.py b/xdsl/dialects/dsp.py index 92e6f97624..a9ac7fcbde 100644 --- a/xdsl/dialects/dsp.py +++ b/xdsl/dialects/dsp.py @@ -82,6 +82,7 @@ def verify_(self) -> None: print(output_shape) print(n_frame) + # n_dimensions_frame = len(frame_shape) if n_dimensions_frame != 1: raise VerifyException( From f438e0274ec0c98cfb3a9b431b2a64b5c52b2afe Mon Sep 17 00:00:00 2001 From: alecerio Date: Wed, 24 Apr 2024 18:14:31 +0200 Subject: [PATCH 5/8] dialects: [dsp] fix assembly format --- xdsl/dialects/dsp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xdsl/dialects/dsp.py b/xdsl/dialects/dsp.py index a9ac7fcbde..7018f808b2 100644 --- a/xdsl/dialects/dsp.py +++ b/xdsl/dialects/dsp.py @@ -45,7 +45,7 @@ class STFT(IRDLOperation): hop_size = opt_attr_def(IntegerAttr, attr_name="hop_size") assembly_format = ( - "`(` $operand`)` attr-dict `:` `(` type($operand) `)` `->` type($res)" + "`(` $operand `)` attr-dict `:` `(` type($operand) `)` `->` type($res)" ) def __init__( From 7677c5417d099849ade3a8fce589313ba25f57ad Mon Sep 17 00:00:00 2001 From: alecerio Date: Wed, 24 Apr 2024 18:19:25 +0200 Subject: [PATCH 6/8] dialects: [dsp] remove dsp test from onnx folder --- tests/filecheck/dialects/onnx/onnx_invalid.mlir | 9 --------- 1 file changed, 9 deletions(-) diff --git a/tests/filecheck/dialects/onnx/onnx_invalid.mlir b/tests/filecheck/dialects/onnx/onnx_invalid.mlir index cf89e1f302..1a6a68fff9 100644 --- a/tests/filecheck/dialects/onnx/onnx_invalid.mlir +++ b/tests/filecheck/dialects/onnx/onnx_invalid.mlir @@ -550,13 +550,4 @@ builtin.module { %t0 = "test.op"() : () -> (tensor<3x4xf32>) // CHECK: Operation does not verify: incorrect output shape: output dimension #0 should be equal to 4 %res_transpose = "onnx.Transpose"(%t0) {onnx_node_name = "/Transpose", "perm" = [1 : i64, 0 : i64]}: (tensor<3x4xf32>) -> tensor<3x3xf32> -} - -// ----- - -builtin.module { - %t0 = "test.op"() : () -> (tensor<2x256xf32>) - %t1 = "arith.constant"() {"value" = 1 : i64}: () -> i64 - // CHECK: Operation does not verify: frame number of dimensions must be 1. Actual number of dimensions: 2 - %res_stft = "dsp.STFT"(%t0, %t1) {"frame_size" = 256 : i64, "hop_size" = 128 : i64}: (tensor<2x256xf32>, i64) -> tensor<2x128xf32> } \ No newline at end of file From 58304defd2d09291313baff3fb0f7042a3f7e9b2 Mon Sep 17 00:00:00 2001 From: alecerio Date: Wed, 24 Apr 2024 18:26:05 +0200 Subject: [PATCH 7/8] dialects: [dsp] fix assembly format --- xdsl/dialects/dsp.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/xdsl/dialects/dsp.py b/xdsl/dialects/dsp.py index 7018f808b2..c927b95cbc 100644 --- a/xdsl/dialects/dsp.py +++ b/xdsl/dialects/dsp.py @@ -44,9 +44,7 @@ class STFT(IRDLOperation): frame_size = opt_attr_def(IntegerAttr, attr_name="frame_size") hop_size = opt_attr_def(IntegerAttr, attr_name="hop_size") - assembly_format = ( - "`(` $operand `)` attr-dict `:` `(` type($operand) `)` `->` type($res)" - ) + assembly_format = "`(` $frame `,` $n_frame `)` attr-dict `:` `(` type($frame) `,` type($n_frame) `)` `->` type($output)" def __init__( self, From 189d79c4af6a3538391564273527a3d622db1136 Mon Sep 17 00:00:00 2001 From: Sasha Lopoukhine Date: Thu, 2 May 2024 14:01:05 +0100 Subject: [PATCH 8/8] add back newline --- tests/filecheck/dialects/onnx/onnx_invalid.mlir | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/filecheck/dialects/onnx/onnx_invalid.mlir b/tests/filecheck/dialects/onnx/onnx_invalid.mlir index 1a6a68fff9..9ad92beed0 100644 --- a/tests/filecheck/dialects/onnx/onnx_invalid.mlir +++ b/tests/filecheck/dialects/onnx/onnx_invalid.mlir @@ -550,4 +550,4 @@ builtin.module { %t0 = "test.op"() : () -> (tensor<3x4xf32>) // CHECK: Operation does not verify: incorrect output shape: output dimension #0 should be equal to 4 %res_transpose = "onnx.Transpose"(%t0) {onnx_node_name = "/Transpose", "perm" = [1 : i64, 0 : i64]}: (tensor<3x4xf32>) -> tensor<3x3xf32> -} \ No newline at end of file +}