From da1ad37f46a99d864c1ba6a5c79585932d081332 Mon Sep 17 00:00:00 2001 From: Josh Fromm Date: Mon, 27 Feb 2023 08:38:17 -0800 Subject: [PATCH] Various fixes for end to end stable diffusion (#29) Various fixes for end to end stable diffusion. --- python/tvm/relax/frontend/onnx_frontend.py | 133 ++++++++++++------ src/relax/op/tensor/manipulate.cc | 42 ++++-- .../relax/frontend/test_onnx_frontend.py | 81 +++++++---- tests/python/relax/test_op_manipulate.py | 8 +- 4 files changed, 182 insertions(+), 82 deletions(-) diff --git a/python/tvm/relax/frontend/onnx_frontend.py b/python/tvm/relax/frontend/onnx_frontend.py index 5f950a89076f..81058a95b7c3 100644 --- a/python/tvm/relax/frontend/onnx_frontend.py +++ b/python/tvm/relax/frontend/onnx_frontend.py @@ -256,9 +256,21 @@ class Gather(OnnxOpConverter): @classmethod def _impl_v13(cls, bb, inputs, attr): - # TODO This assumes positive only indices. + # Unpack inputs + data = inputs[0] + indices = inputs[1] + # Indices must be rank 1, if we're given a scalar, expand it. + scalar_indices = False + if len(indices.struct_info.shape) == 0: + scalar_indices = True + indices = bb.normalize(relax.op.expand_dims(indices, axis=0)) + axis = attr.get("axis", 0) - return bb.emit_te(topi.take, inputs[0], inputs[1], axis) + out = relax.op.take(data, indices, axis) + # If indices were scalar, output dimension needs to be reduced. + if scalar_indices: + out = relax.op.squeeze(out, axis) + return out class Gemm(OnnxOpConverter): @@ -296,30 +308,10 @@ class Reshape(OnnxOpConverter): @classmethod def _impl_v13(cls, bb, inputs, attr): data = inputs[0] - # TODO We assume new_shape is a constant, need to enable tensor input to reshape - # for full support. - if not isinstance(inputs[1], relax.Constant): - return inputs[0] - new_shape = inputs[1].data.numpy().tolist() - - # Convert -1 dims in new_shape into positive equivalent. - if -1 in new_shape: - if new_shape.count(-1) != 1: - raise ValueError("Reshape with multiple -1 is not supported.") - - data_shape = [dim.value for dim in data.struct_info.shape.values] - total_elements = _np.prod(data_shape) - new_product = 1 - for dim in new_shape: - if dim > 0: - new_product *= dim - - # Replace -1 with positive equivalent - for i, dim in enumerate(new_shape): - if dim == -1: - new_shape[i] = int(total_elements / new_product) - - return bb.emit_te(topi.reshape, data, new_shape) + new_shape = inputs[1] + if isinstance(inputs[1], relax.Constant): + new_shape = inputs[1].data.numpy().tolist() + return relax.op.reshape(data, new_shape) class Gelu(OnnxOpConverter): @@ -379,6 +371,12 @@ class Shape(OnnxOpConverter): @classmethod def _impl_v13(cls, bb, inputs, attr): + # See if we can extract a constant shape. + if all(["int" in i.dtype for i in inputs[0].struct_info.shape]): + # If so, return the shape as a constant. + data_shape = [i.value for i in inputs[0].struct_info.shape] + return relax.const(data_shape, "int64") + # Otherwise compute it dynamically. return relax.op.shape_of(inputs[0]) @@ -427,20 +425,44 @@ class Conv(OnnxOpConverter): @classmethod def _impl_v11(cls, bb, inputs, attr): - conv_out = bb.normalize( - relax.op.nn.conv2d( - data=inputs[0], - weight=inputs[1], - strides=attr.get("strides", 1), - padding=attr.get("pads", 0), - dilation=attr.get("dilation", 1), - groups=attr.get("group", 1), - data_layout="NCHW", - kernel_layout="OIHW", + ndim = len(inputs[0].struct_info.shape) + if ndim == 3: + conv_out = bb.emit_te( + topi.nn.conv1d, + inputs[0], + inputs[1], + attr.get("strides", 1), + attr.get("pads", 0), + attr.get("dilation", 1), + "NCHW", + "OIHW", ) - ) + elif ndim == 4: + conv_out = bb.normalize( + relax.op.nn.conv2d( + data=inputs[0], + weight=inputs[1], + strides=attr.get("strides", 1), + padding=attr.get("pads", 0), + dilation=attr.get("dilation", 1), + groups=attr.get("group", 1), + data_layout="NCHW", + kernel_layout="OIHW", + ) + ) + else: + raise NotImplementedError("Only 1d and 2d conv currently supported.") + if inputs[2] is not None: - conv_out = relax.op.add(conv_out, inputs[2]) + bias = relax.op.reshape( + inputs[2], + [1, -1] + + [ + 1, + ] + * (ndim - 2), + ) + conv_out = relax.op.add(conv_out, bias) return conv_out @@ -1021,6 +1043,36 @@ def _impl_v12(cls, bb, inputs, attr): return bb.emit_te(topi.arange, start, limit, step) +class InstanceNormalization(OnnxOpConverter): + """Converts an onnx InstanceNormalization node into an equivalent Relax expression.""" + + @classmethod + def _impl_v6(cls, bb, inputs, attr): + data = inputs[0] + scale = inputs[1] + B = inputs[2] + epsilon = attr.get("epsilon", 1e-05) + epsilon = relax.const(epsilon, dtype=data.struct_info.dtype) + + ndim = len(data.struct_info.shape) + redux_axes = list(range(2, ndim)) + + mean = relax.op.mean(data, axis=redux_axes, keepdims=True) + var = relax.op.variance(data, axis=redux_axes, keepdims=True) + sqrt = relax.op.sqrt(var + epsilon) + out = relax.op.divide(relax.op.subtract(data, mean), sqrt) + broadcast_shape = [-1] + [ + 1, + ] * (ndim - 2) + if scale is not None: + scale = relax.op.reshape(scale, broadcast_shape) + out = relax.op.multiply(out, scale) + if B is not None: + B = relax.op.reshape(B, broadcast_shape) + out = relax.op.add(out, B) + return out + + def _get_convert_map(): return { "MatMul": relay.frontend.onnx.MatMul, @@ -1046,7 +1098,7 @@ def _get_convert_map(): "Tanh": Tanh, "Sqrt": Sqrt, "Relu": Relu, - "Conv": relay.frontend.onnx.Conv, + "Conv": Conv, "Pow": Pow, "Erf": Erf, "CumSum": CumSum, @@ -1065,7 +1117,7 @@ def _get_convert_map(): "LayerNormalization": relay.frontend.onnx.LayerNormalization, "SkipLayerNormalization": relay.frontend.onnx.SkipLayerNormalization, "EmbedLayerNormalization": relay.frontend.onnx.EmbedLayerNormalization, - "InstanceNormalization": relay.frontend.onnx.InstanceNorm, + "InstanceNormalization": InstanceNormalization, # defs/reduction "ReduceMax": relay.frontend.onnx.ReduceMax, "ReduceMin": relay.frontend.onnx.ReduceMin, @@ -1273,6 +1325,7 @@ def _construct_nodes(self, graph: onnx.onnx_ml_pb2.GraphProto): attr["tvm_custom"]["name"] = i_name attr["tvm_custom"]["num_outputs"] = len(outputs) + print(op_name, node.name) op = self._convert_operator(op_name, inputs, attr, self.opset) # Create struct information for the new operator. op = self.bb.normalize(op) diff --git a/src/relax/op/tensor/manipulate.cc b/src/relax/op/tensor/manipulate.cc index 9016580daa1f..40a47411ee8d 100644 --- a/src/relax/op/tensor/manipulate.cc +++ b/src/relax/op/tensor/manipulate.cc @@ -522,8 +522,9 @@ Expr ConvertNewShapeToExpr(const Expr& data, const ObjectRef& shape) { CHECK(array != nullptr) << "Reshape only expects the input new shape to be either an Expr or an " "Array of PrimExprs. However, the given new shape is " << shape; + // Identify which, if any dimensions are special values that must be computed. int dim_to_infer = -1; - PrimExpr new_shape_prod = IntImm(DataType::Int(64), 1); + std::vector zero_dims; for (int i = 0; i < static_cast(array->size()); ++i) { const auto* _len = array->at(i).as(); CHECK(_len != nullptr) << "Reshape only expects the input new shape to be either an Expr or an " @@ -534,7 +535,10 @@ Expr ConvertNewShapeToExpr(const Expr& data, const ObjectRef& shape) { "integers. However, the give new shape is " << shape; const auto* int_len = len.as(); - if (int_len != nullptr && int_len->value == -1) { + if (int_len != nullptr && int_len->value == 0) { + // Note that this dimension should be copied from the original shape. + zero_dims.push_back(i); + } else if (int_len != nullptr && int_len->value == -1) { CHECK_EQ(dim_to_infer, -1) << "Reshape accepts at most one \"-1\" in the new shape. However, " "there are multiple \"-1\" in the given new shape " << shape; @@ -544,15 +548,12 @@ Expr ConvertNewShapeToExpr(const Expr& data, const ObjectRef& shape) { << "Reshape requires all values in the new shape to be positive except a single \"-1\". " "However, the given new shape is " << shape; - // We expect any symbolic not to signal the intent of -1, and therefore do no check for - // symbolic value here. - new_shape_prod = new_shape_prod * len; } } Array array_ref = GetRef>(array); // When there is no dimension to infer, just return the input array as ShapeExpr. - if (dim_to_infer == -1) { + if (dim_to_infer == -1 && zero_dims.empty()) { return ShapeExpr(array_ref); } @@ -570,9 +571,32 @@ Expr ConvertNewShapeToExpr(const Expr& data, const ObjectRef& shape) { "to infer. However, the given input shape is " << data_sinfo->shape << " whose shape value is unknown."; - arith::Analyzer analyzer; - PrimExpr old_shape_prod = ComputeShapeProduct(shape_sinfo->values.value()); - array_ref.Set(dim_to_infer, analyzer.Simplify(floordiv(old_shape_prod, new_shape_prod))); + // Set any 0 valued dimensions to match the corresponding input shape. + if (!zero_dims.empty()) { + for (int i : zero_dims) { + array_ref.Set(i, shape_sinfo->values.value()[i]); + } + } + + // Set any -1 dimensions to complete the number of appropriate elements. + // Start by computing the shape product of all positive indices. + PrimExpr new_shape_prod = IntImm(DataType::Int(64), 1); + for (int i = 0; i < static_cast(array_ref.size()); ++i) { + PrimExpr new_dim = array_ref[i]; + const auto* int_dim = new_dim.as(); + // We expect any symbolic not to signal the intent of -1, and therefore do no check for + // symbolic value here. + if (int_dim == nullptr || int_dim->value > 0) { + new_shape_prod = new_shape_prod * new_dim; + } + } + + // Assign appropriate value to -1 dimension. + if (dim_to_infer != -1) { + arith::Analyzer analyzer; + PrimExpr old_shape_prod = ComputeShapeProduct(shape_sinfo->values.value()); + array_ref.Set(dim_to_infer, analyzer.Simplify(floordiv(old_shape_prod, new_shape_prod))); + } return ShapeExpr(array_ref); } diff --git a/tests/python/relax/frontend/test_onnx_frontend.py b/tests/python/relax/frontend/test_onnx_frontend.py index 7ac30f7c794f..d21d70394df7 100644 --- a/tests/python/relax/frontend/test_onnx_frontend.py +++ b/tests/python/relax/frontend/test_onnx_frontend.py @@ -286,24 +286,33 @@ def test_cast(from_type, to_type): def test_gather(): - gather_node = helper.make_node("Gather", ["data", "indices"], ["y"], axis=0) + def _verify_gather(data_shape, indices, out_shape): + gather_node = helper.make_node("Gather", ["data", "indices"], ["y"], axis=0) - graph = helper.make_graph( - [gather_node], - "gather_test", - inputs=[ - helper.make_tensor_value_info("data", TensorProto.FLOAT, [5, 4, 3, 2]), - helper.make_tensor_value_info("indices", TensorProto.INT32, [3]), - ], - outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, [3, 4, 3, 2])], - ) + if isinstance(indices, (list, tuple)): + indices_shape = [len(indices)] + else: + indices_shape = [] - model = helper.make_model(graph, producer_name="gather_test") - input_values = { - "data": np.random.randn(5, 4, 3, 2).astype("float32"), - "indices": np.array([0, 1, 3]).astype("int32"), - } - check_correctness(model, inputs=input_values) + graph = helper.make_graph( + [gather_node], + "gather_test", + inputs=[ + helper.make_tensor_value_info("data", TensorProto.FLOAT, data_shape), + helper.make_tensor_value_info("indices", TensorProto.INT32, indices_shape), + ], + outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, out_shape)], + ) + + model = helper.make_model(graph, producer_name="gather_test") + input_values = { + "data": np.random.randn(*data_shape).astype("float32"), + "indices": np.array(indices).astype("int32"), + } + check_correctness(model, inputs=input_values) + + _verify_gather([5, 4, 3, 2], [0, 1, 3], [3, 4, 3, 2]) + _verify_gather([3], 0, []) @pytest.mark.parametrize("alpha", [None, 0.25]) @@ -339,7 +348,11 @@ def test_gemm(alpha, beta, useC): @pytest.mark.parametrize( "in_shape, shape, out_shape", - [([7, 32, 32, 8], [224, 256], [224, 256]), ([7, 32, 32, 8], [-1, 8192], [7, 8192])], + [ + ([7, 32, 32, 8], [224, 256], [224, 256]), + ([7, 32, 32, 8], [-1, 8192], [7, 8192]), + ([7, 32, 32, 8], [0, 32, 32, 8], [7, 32, 32, 8]), + ], ) def test_reshape(in_shape, shape, out_shape): reshape_node = helper.make_node("Reshape", ["data", "shape"], ["reshaped"]) @@ -515,21 +528,24 @@ def test_relu(): def test_conv(): - conv_node = helper.make_node("Conv", ["x", "w", "b"], ["y"]) - nchw_shape = [3, 12, 32, 32] - graph = helper.make_graph( - [conv_node], - "conv_test", - inputs=[ - helper.make_tensor_value_info("x", TensorProto.FLOAT, nchw_shape), - helper.make_tensor_value_info("w", TensorProto.FLOAT, [4, 12, 3, 3]), - helper.make_tensor_value_info("b", TensorProto.FLOAT, [4]), - ], - outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, [3, 4, 30, 30])], - ) + def _verify_conv(input_shape, weight_shape, output_shape): + bias_shape = [output_shape[1]] + conv_node = helper.make_node("Conv", ["x", "w", "b"], ["y"]) + graph = helper.make_graph( + [conv_node], + "conv_test", + inputs=[ + helper.make_tensor_value_info("x", TensorProto.FLOAT, input_shape), + helper.make_tensor_value_info("w", TensorProto.FLOAT, weight_shape), + helper.make_tensor_value_info("b", TensorProto.FLOAT, bias_shape), + ], + outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, output_shape)], + ) - model = helper.make_model(graph, producer_name="conv_test") - check_correctness(model) + model = helper.make_model(graph, producer_name="conv_test") + check_correctness(model) + + _verify_conv([3, 12, 32, 32], [4, 12, 3, 3], [3, 4, 30, 30]) def test_pow(): @@ -645,6 +661,9 @@ def test_log(): def test_instance_norm(): + verify_ternary( + "InstanceNormalization", [1, 3, 32, 32], [3], [3], [1, 3, 32, 32], attrs={"epsilon": 1e-12} + ) verify_ternary( "InstanceNormalization", [1, 32, 32], [32], [32], [1, 32, 32], attrs={"epsilon": 1e-12} ) diff --git a/tests/python/relax/test_op_manipulate.py b/tests/python/relax/test_op_manipulate.py index abb414b4724c..b3cabe6430d2 100644 --- a/tests/python/relax/test_op_manipulate.py +++ b/tests/python/relax/test_op_manipulate.py @@ -178,6 +178,12 @@ def test_reshape_infer_struct_info_shape_var(): _check_inference( bb, relax.op.reshape(x0, (3, -1, 5)), relax.TensorStructInfo((3, 8, 5), "float32") ) + _check_inference( + bb, relax.op.reshape(x0, (2, 3, 0, 5)), relax.TensorStructInfo((2, 3, 4, 5), "float32") + ) + _check_inference( + bb, relax.op.reshape(x0, (1, 3, 0, -1)), relax.TensorStructInfo((1, 3, 4, 10), "float32") + ) _check_inference(bb, relax.op.reshape(x0, ns0), relax.TensorStructInfo(ns0, "float32")) _check_inference(bb, relax.op.reshape(x0, ns1), relax.TensorStructInfo(ns1, "float32")) _check_inference( @@ -277,8 +283,6 @@ def test_reshape_infer_struct_info_non_positive_new_shape(): bb = relax.BlockBuilder() x = relax.Var("x", R.Tensor((2, 3, 4, 5), "float32")) - with pytest.raises(TVMError): - bb.normalize(relax.op.reshape(x, (2, 0, 4, 5))) with pytest.raises(TVMError): bb.normalize(relax.op.reshape(x, (-2, -3, -4, -5)))