diff --git a/wonnx/src/compiler.rs b/wonnx/src/compiler.rs index bed2b865..d272534a 100644 --- a/wonnx/src/compiler.rs +++ b/wonnx/src/compiler.rs @@ -1308,6 +1308,7 @@ pub fn compile( let mode = node.get_attribute_value("mode", Some("constant".to_string()))?; match mode.as_str() { "constant" => {} + "reflect" => {} _ => { return Err(CompileError::UnimplementedVariant { op: String::from("Pad"), @@ -1315,6 +1316,7 @@ pub fn compile( }) } } + context.insert("mode", &mode); let pads: Vec = node.get_attribute_value("pads", None)?; if pads.len() != input_shapes[0].rank() * 2 { diff --git a/wonnx/templates/matrix/pad.wgsl b/wonnx/templates/matrix/pad.wgsl index e4c6e313..f8631746 100644 --- a/wonnx/templates/matrix/pad.wgsl +++ b/wonnx/templates/matrix/pad.wgsl @@ -26,19 +26,30 @@ fn main(@builtin(global_invocation_id) global_id: vec3) { var pad = false; {% for pad in pad_info %} - let id_{{ loop.index0 }} = d_{{ loop.index0 }} - - {{ pad.copy_start }}u; + var id_{{ loop.index0 }} = 0u; if (d_{{ loop.index0 }} < {{ pad.copy_start }}u) { - pad = true; - } - if (d_{{ loop.index0 }} > {{ pad.end_pad_start }}u) { - pad = true; + {% if mode == "reflect" %} + id_{{ loop.index0 }} = ({{ pad.copy_start }}u - d_{{ loop.index0 }}) % {{ i_shape[0][loop.index0] }}u; + {% else %} + id_{{ loop.index0 }} = d_{{ loop.index0 }} - {{ pad.copy_start }}u; + pad = true; + {% endif %} } + else if (d_{{ loop.index0 }} > {{ pad.end_pad_start }}u) { + {% if mode == "reflect" %} + id_{{ loop.index0 }} = 2u * {{ pad.end_pad_start }}u - d_{{ loop.index0 }}; + {% else %} + id_{{ loop.index0 }} = d_{{ loop.index0 }} - {{ pad.copy_start }}u; + pad = true; + {% endif %} + } else { + id_{{ loop.index0 }} = d_{{ loop.index0 }} - {{ pad.copy_start }}u; + } {% endfor %} if (pad) { - output_0.data[gidx] = {{ scalar_type }}({{ constant_value }}); + output_0.data[gidx] = {{ scalar_type }}({{ constant_value }}); } else { let index = {%- for chunk in i_chunks | first -%} diff --git a/wonnx/tests/matrix.rs b/wonnx/tests/matrix.rs index f7b4909d..e0cd2f5d 100644 --- a/wonnx/tests/matrix.rs +++ b/wonnx/tests/matrix.rs @@ -338,6 +338,78 @@ fn test_pad_complex() { assert_eq!(actual, &test_y); } +#[test] +fn test_pad_reflect() { + let mut input_data = HashMap::new(); + #[rustfmt::skip] + let data = [ + 1.0, 1.2, + 2.3, 3.4, + 4.5, 5.7, + ].to_vec(); + input_data.insert("X".to_string(), data.as_slice().into()); + + let model = model(graph( + vec![tensor("X", &[3, 2])], + vec![tensor("Y", &[3, 4])], + vec![], + vec![initializer_int64("pads", vec![0, 2, 0, 0], vec![4])], + vec![node(vec!["X", "pads"], vec!["Y"], "Pad", "Pad", vec![ + attribute("mode", "reflect"), + ])], + )); + + let session = + pollster::block_on(wonnx::Session::from_model(model)).expect("session did not create"); + let result = pollster::block_on(session.run(&input_data)).unwrap(); + + #[rustfmt::skip] + let test_y = vec![ + 1.0, 1.2, 1.0, 1.2, + 2.3, 3.4, 2.3, 3.4, + 4.5, 5.7, 4.5, 5.7, + ]; + let actual: &[_] = (&result["Y"]).try_into().unwrap(); + // No arithmetic is done, so `assert_eq!` can be used. + assert_eq!(actual, &test_y); +} + +#[test] +fn test_pad_reflect_complex() { + let mut input_data = HashMap::new(); + #[rustfmt::skip] + let data = [ + 1.0, 1.2, 1.3, + 2.3, 3.4, 4.5, + 4.5, 5.7, 6.8, + ].to_vec(); + input_data.insert("X".to_string(), data.as_slice().into()); + + let model = model(graph( + vec![tensor("X", &[3, 3])], + vec![tensor("Y", &[3, 7])], + vec![], + vec![initializer_int64("pads", vec![0, 2, 0, 2], vec![4])], + vec![node(vec!["X", "pads"], vec!["Y"], "Pad", "Pad", vec![ + attribute("mode", "reflect"), + ])], + )); + + let session = + pollster::block_on(wonnx::Session::from_model(model)).expect("session did not create"); + let result = pollster::block_on(session.run(&input_data)).unwrap(); + + #[rustfmt::skip] + let test_y = vec![ + 1.3, 1.2, 1.0, 1.2, 1.3, 1.2, 1.0, + 4.5, 3.4, 2.3, 3.4, 4.5, 3.4, 2.3, + 6.8, 5.7, 4.5, 5.7, 6.8, 5.7, 4.5, + ]; + let actual: &[_] = (&result["Y"]).try_into().unwrap(); + // No arithmetic is done, so `assert_eq!` can be used. + assert_eq!(actual, &test_y); +} + #[test] fn test_resize() { let _ = env_logger::builder().is_test(true).try_init();