Skip to content

Commit

Permalink
Various fixes for end to end stable diffusion (apache#29)
Browse files Browse the repository at this point in the history
Various fixes for end to end stable diffusion.
  • Loading branch information
Josh Fromm committed Feb 28, 2023
1 parent 70d30bf commit da1ad37
Show file tree
Hide file tree
Showing 4 changed files with 182 additions and 82 deletions.
133 changes: 93 additions & 40 deletions python/tvm/relax/frontend/onnx_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,9 +256,21 @@ class Gather(OnnxOpConverter):

@classmethod
def _impl_v13(cls, bb, inputs, attr):
# TODO This assumes positive only indices.
# Unpack inputs
data = inputs[0]
indices = inputs[1]
# Indices must be rank 1, if we're given a scalar, expand it.
scalar_indices = False
if len(indices.struct_info.shape) == 0:
scalar_indices = True
indices = bb.normalize(relax.op.expand_dims(indices, axis=0))

axis = attr.get("axis", 0)
return bb.emit_te(topi.take, inputs[0], inputs[1], axis)
out = relax.op.take(data, indices, axis)
# If indices were scalar, output dimension needs to be reduced.
if scalar_indices:
out = relax.op.squeeze(out, axis)
return out


class Gemm(OnnxOpConverter):
Expand Down Expand Up @@ -296,30 +308,10 @@ class Reshape(OnnxOpConverter):
@classmethod
def _impl_v13(cls, bb, inputs, attr):
data = inputs[0]
# TODO We assume new_shape is a constant, need to enable tensor input to reshape
# for full support.
if not isinstance(inputs[1], relax.Constant):
return inputs[0]
new_shape = inputs[1].data.numpy().tolist()

# Convert -1 dims in new_shape into positive equivalent.
if -1 in new_shape:
if new_shape.count(-1) != 1:
raise ValueError("Reshape with multiple -1 is not supported.")

data_shape = [dim.value for dim in data.struct_info.shape.values]
total_elements = _np.prod(data_shape)
new_product = 1
for dim in new_shape:
if dim > 0:
new_product *= dim

# Replace -1 with positive equivalent
for i, dim in enumerate(new_shape):
if dim == -1:
new_shape[i] = int(total_elements / new_product)

return bb.emit_te(topi.reshape, data, new_shape)
new_shape = inputs[1]
if isinstance(inputs[1], relax.Constant):
new_shape = inputs[1].data.numpy().tolist()
return relax.op.reshape(data, new_shape)


class Gelu(OnnxOpConverter):
Expand Down Expand Up @@ -379,6 +371,12 @@ class Shape(OnnxOpConverter):

@classmethod
def _impl_v13(cls, bb, inputs, attr):
# See if we can extract a constant shape.
if all(["int" in i.dtype for i in inputs[0].struct_info.shape]):
# If so, return the shape as a constant.
data_shape = [i.value for i in inputs[0].struct_info.shape]
return relax.const(data_shape, "int64")
# Otherwise compute it dynamically.
return relax.op.shape_of(inputs[0])


Expand Down Expand Up @@ -427,20 +425,44 @@ class Conv(OnnxOpConverter):

@classmethod
def _impl_v11(cls, bb, inputs, attr):
conv_out = bb.normalize(
relax.op.nn.conv2d(
data=inputs[0],
weight=inputs[1],
strides=attr.get("strides", 1),
padding=attr.get("pads", 0),
dilation=attr.get("dilation", 1),
groups=attr.get("group", 1),
data_layout="NCHW",
kernel_layout="OIHW",
ndim = len(inputs[0].struct_info.shape)
if ndim == 3:
conv_out = bb.emit_te(
topi.nn.conv1d,
inputs[0],
inputs[1],
attr.get("strides", 1),
attr.get("pads", 0),
attr.get("dilation", 1),
"NCHW",
"OIHW",
)
)
elif ndim == 4:
conv_out = bb.normalize(
relax.op.nn.conv2d(
data=inputs[0],
weight=inputs[1],
strides=attr.get("strides", 1),
padding=attr.get("pads", 0),
dilation=attr.get("dilation", 1),
groups=attr.get("group", 1),
data_layout="NCHW",
kernel_layout="OIHW",
)
)
else:
raise NotImplementedError("Only 1d and 2d conv currently supported.")

if inputs[2] is not None:
conv_out = relax.op.add(conv_out, inputs[2])
bias = relax.op.reshape(
inputs[2],
[1, -1]
+ [
1,
]
* (ndim - 2),
)
conv_out = relax.op.add(conv_out, bias)

return conv_out

Expand Down Expand Up @@ -1021,6 +1043,36 @@ def _impl_v12(cls, bb, inputs, attr):
return bb.emit_te(topi.arange, start, limit, step)


class InstanceNormalization(OnnxOpConverter):
"""Converts an onnx InstanceNormalization node into an equivalent Relax expression."""

@classmethod
def _impl_v6(cls, bb, inputs, attr):
data = inputs[0]
scale = inputs[1]
B = inputs[2]
epsilon = attr.get("epsilon", 1e-05)
epsilon = relax.const(epsilon, dtype=data.struct_info.dtype)

ndim = len(data.struct_info.shape)
redux_axes = list(range(2, ndim))

mean = relax.op.mean(data, axis=redux_axes, keepdims=True)
var = relax.op.variance(data, axis=redux_axes, keepdims=True)
sqrt = relax.op.sqrt(var + epsilon)
out = relax.op.divide(relax.op.subtract(data, mean), sqrt)
broadcast_shape = [-1] + [
1,
] * (ndim - 2)
if scale is not None:
scale = relax.op.reshape(scale, broadcast_shape)
out = relax.op.multiply(out, scale)
if B is not None:
B = relax.op.reshape(B, broadcast_shape)
out = relax.op.add(out, B)
return out


def _get_convert_map():
return {
"MatMul": relay.frontend.onnx.MatMul,
Expand All @@ -1046,7 +1098,7 @@ def _get_convert_map():
"Tanh": Tanh,
"Sqrt": Sqrt,
"Relu": Relu,
"Conv": relay.frontend.onnx.Conv,
"Conv": Conv,
"Pow": Pow,
"Erf": Erf,
"CumSum": CumSum,
Expand All @@ -1065,7 +1117,7 @@ def _get_convert_map():
"LayerNormalization": relay.frontend.onnx.LayerNormalization,
"SkipLayerNormalization": relay.frontend.onnx.SkipLayerNormalization,
"EmbedLayerNormalization": relay.frontend.onnx.EmbedLayerNormalization,
"InstanceNormalization": relay.frontend.onnx.InstanceNorm,
"InstanceNormalization": InstanceNormalization,
# defs/reduction
"ReduceMax": relay.frontend.onnx.ReduceMax,
"ReduceMin": relay.frontend.onnx.ReduceMin,
Expand Down Expand Up @@ -1273,6 +1325,7 @@ def _construct_nodes(self, graph: onnx.onnx_ml_pb2.GraphProto):
attr["tvm_custom"]["name"] = i_name
attr["tvm_custom"]["num_outputs"] = len(outputs)

print(op_name, node.name)
op = self._convert_operator(op_name, inputs, attr, self.opset)
# Create struct information for the new operator.
op = self.bb.normalize(op)
Expand Down
42 changes: 33 additions & 9 deletions src/relax/op/tensor/manipulate.cc
Original file line number Diff line number Diff line change
Expand Up @@ -522,8 +522,9 @@ Expr ConvertNewShapeToExpr(const Expr& data, const ObjectRef& shape) {
CHECK(array != nullptr) << "Reshape only expects the input new shape to be either an Expr or an "
"Array of PrimExprs. However, the given new shape is "
<< shape;
// Identify which, if any dimensions are special values that must be computed.
int dim_to_infer = -1;
PrimExpr new_shape_prod = IntImm(DataType::Int(64), 1);
std::vector<int> zero_dims;
for (int i = 0; i < static_cast<int>(array->size()); ++i) {
const auto* _len = array->at(i).as<PrimExprNode>();
CHECK(_len != nullptr) << "Reshape only expects the input new shape to be either an Expr or an "
Expand All @@ -534,7 +535,10 @@ Expr ConvertNewShapeToExpr(const Expr& data, const ObjectRef& shape) {
"integers. However, the give new shape is "
<< shape;
const auto* int_len = len.as<IntImmNode>();
if (int_len != nullptr && int_len->value == -1) {
if (int_len != nullptr && int_len->value == 0) {
// Note that this dimension should be copied from the original shape.
zero_dims.push_back(i);
} else if (int_len != nullptr && int_len->value == -1) {
CHECK_EQ(dim_to_infer, -1) << "Reshape accepts at most one \"-1\" in the new shape. However, "
"there are multiple \"-1\" in the given new shape "
<< shape;
Expand All @@ -544,15 +548,12 @@ Expr ConvertNewShapeToExpr(const Expr& data, const ObjectRef& shape) {
<< "Reshape requires all values in the new shape to be positive except a single \"-1\". "
"However, the given new shape is "
<< shape;
// We expect any symbolic not to signal the intent of -1, and therefore do no check for
// symbolic value here.
new_shape_prod = new_shape_prod * len;
}
}

Array<PrimExpr> array_ref = GetRef<Array<PrimExpr>>(array);
// When there is no dimension to infer, just return the input array as ShapeExpr.
if (dim_to_infer == -1) {
if (dim_to_infer == -1 && zero_dims.empty()) {
return ShapeExpr(array_ref);
}

Expand All @@ -570,9 +571,32 @@ Expr ConvertNewShapeToExpr(const Expr& data, const ObjectRef& shape) {
"to infer. However, the given input shape is "
<< data_sinfo->shape << " whose shape value is unknown.";

arith::Analyzer analyzer;
PrimExpr old_shape_prod = ComputeShapeProduct(shape_sinfo->values.value());
array_ref.Set(dim_to_infer, analyzer.Simplify(floordiv(old_shape_prod, new_shape_prod)));
// Set any 0 valued dimensions to match the corresponding input shape.
if (!zero_dims.empty()) {
for (int i : zero_dims) {
array_ref.Set(i, shape_sinfo->values.value()[i]);
}
}

// Set any -1 dimensions to complete the number of appropriate elements.
// Start by computing the shape product of all positive indices.
PrimExpr new_shape_prod = IntImm(DataType::Int(64), 1);
for (int i = 0; i < static_cast<int>(array_ref.size()); ++i) {
PrimExpr new_dim = array_ref[i];
const auto* int_dim = new_dim.as<IntImmNode>();
// We expect any symbolic not to signal the intent of -1, and therefore do no check for
// symbolic value here.
if (int_dim == nullptr || int_dim->value > 0) {
new_shape_prod = new_shape_prod * new_dim;
}
}

// Assign appropriate value to -1 dimension.
if (dim_to_infer != -1) {
arith::Analyzer analyzer;
PrimExpr old_shape_prod = ComputeShapeProduct(shape_sinfo->values.value());
array_ref.Set(dim_to_infer, analyzer.Simplify(floordiv(old_shape_prod, new_shape_prod)));
}
return ShapeExpr(array_ref);
}

Expand Down
81 changes: 50 additions & 31 deletions tests/python/relax/frontend/test_onnx_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,24 +286,33 @@ def test_cast(from_type, to_type):


def test_gather():
gather_node = helper.make_node("Gather", ["data", "indices"], ["y"], axis=0)
def _verify_gather(data_shape, indices, out_shape):
gather_node = helper.make_node("Gather", ["data", "indices"], ["y"], axis=0)

graph = helper.make_graph(
[gather_node],
"gather_test",
inputs=[
helper.make_tensor_value_info("data", TensorProto.FLOAT, [5, 4, 3, 2]),
helper.make_tensor_value_info("indices", TensorProto.INT32, [3]),
],
outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, [3, 4, 3, 2])],
)
if isinstance(indices, (list, tuple)):
indices_shape = [len(indices)]
else:
indices_shape = []

model = helper.make_model(graph, producer_name="gather_test")
input_values = {
"data": np.random.randn(5, 4, 3, 2).astype("float32"),
"indices": np.array([0, 1, 3]).astype("int32"),
}
check_correctness(model, inputs=input_values)
graph = helper.make_graph(
[gather_node],
"gather_test",
inputs=[
helper.make_tensor_value_info("data", TensorProto.FLOAT, data_shape),
helper.make_tensor_value_info("indices", TensorProto.INT32, indices_shape),
],
outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, out_shape)],
)

model = helper.make_model(graph, producer_name="gather_test")
input_values = {
"data": np.random.randn(*data_shape).astype("float32"),
"indices": np.array(indices).astype("int32"),
}
check_correctness(model, inputs=input_values)

_verify_gather([5, 4, 3, 2], [0, 1, 3], [3, 4, 3, 2])
_verify_gather([3], 0, [])


@pytest.mark.parametrize("alpha", [None, 0.25])
Expand Down Expand Up @@ -339,7 +348,11 @@ def test_gemm(alpha, beta, useC):

@pytest.mark.parametrize(
"in_shape, shape, out_shape",
[([7, 32, 32, 8], [224, 256], [224, 256]), ([7, 32, 32, 8], [-1, 8192], [7, 8192])],
[
([7, 32, 32, 8], [224, 256], [224, 256]),
([7, 32, 32, 8], [-1, 8192], [7, 8192]),
([7, 32, 32, 8], [0, 32, 32, 8], [7, 32, 32, 8]),
],
)
def test_reshape(in_shape, shape, out_shape):
reshape_node = helper.make_node("Reshape", ["data", "shape"], ["reshaped"])
Expand Down Expand Up @@ -515,21 +528,24 @@ def test_relu():


def test_conv():
conv_node = helper.make_node("Conv", ["x", "w", "b"], ["y"])
nchw_shape = [3, 12, 32, 32]
graph = helper.make_graph(
[conv_node],
"conv_test",
inputs=[
helper.make_tensor_value_info("x", TensorProto.FLOAT, nchw_shape),
helper.make_tensor_value_info("w", TensorProto.FLOAT, [4, 12, 3, 3]),
helper.make_tensor_value_info("b", TensorProto.FLOAT, [4]),
],
outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, [3, 4, 30, 30])],
)
def _verify_conv(input_shape, weight_shape, output_shape):
bias_shape = [output_shape[1]]
conv_node = helper.make_node("Conv", ["x", "w", "b"], ["y"])
graph = helper.make_graph(
[conv_node],
"conv_test",
inputs=[
helper.make_tensor_value_info("x", TensorProto.FLOAT, input_shape),
helper.make_tensor_value_info("w", TensorProto.FLOAT, weight_shape),
helper.make_tensor_value_info("b", TensorProto.FLOAT, bias_shape),
],
outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, output_shape)],
)

model = helper.make_model(graph, producer_name="conv_test")
check_correctness(model)
model = helper.make_model(graph, producer_name="conv_test")
check_correctness(model)

_verify_conv([3, 12, 32, 32], [4, 12, 3, 3], [3, 4, 30, 30])


def test_pow():
Expand Down Expand Up @@ -645,6 +661,9 @@ def test_log():


def test_instance_norm():
verify_ternary(
"InstanceNormalization", [1, 3, 32, 32], [3], [3], [1, 3, 32, 32], attrs={"epsilon": 1e-12}
)
verify_ternary(
"InstanceNormalization", [1, 32, 32], [32], [32], [1, 32, 32], attrs={"epsilon": 1e-12}
)
Expand Down
Loading

0 comments on commit da1ad37

Please sign in to comment.