Skip to content

Commit

Permalink
Support models with Float timestep tensor input
Browse files Browse the repository at this point in the history
  • Loading branch information
saddam213 committed Nov 9, 2023
1 parent e56a7b6 commit 38f60b5
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -193,9 +193,18 @@ protected virtual async Task<DenseTensor<float>> DecodeLatents(IModelOptions mod
protected virtual IReadOnlyList<NamedOnnxValue> CreateUnetInputParams(IModelOptions model, DenseTensor<float> inputTensor, DenseTensor<float> promptEmbeddings, DenseTensor<float> guidanceEmbeddings, int timestep)
{
var inputNames = _onnxModelService.GetInputNames(model, OnnxModelType.Unet);
var inputMetaData = _onnxModelService.GetInputMetadata(model, OnnxModelType.Unet);

// Some models support Long or Float, could be more but fornow just support these 2
var timesepMetaKey = inputNames[1];
var timestepMetaData = inputMetaData[timesepMetaKey];
var timestepNamedOnnxValue = timestepMetaData.ElementDataType == TensorElementType.Int64
? NamedOnnxValue.CreateFromTensor(timesepMetaKey, new DenseTensor<long>(new long[] { timestep }, new int[] { 1 }))
: NamedOnnxValue.CreateFromTensor(timesepMetaKey, new DenseTensor<float>(new float[] { timestep }, new int[] { 1 }));

return CreateInputParameters(
NamedOnnxValue.CreateFromTensor(inputNames[0], inputTensor),
NamedOnnxValue.CreateFromTensor(inputNames[1], new DenseTensor<long>(new long[] { timestep }, new int[] { 1 })),
timestepNamedOnnxValue,
NamedOnnxValue.CreateFromTensor(inputNames[2], promptEmbeddings),
NamedOnnxValue.CreateFromTensor(inputNames[3], guidanceEmbeddings));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -213,9 +213,18 @@ protected virtual DenseTensor<float> PerformGuidance(DenseTensor<float> noisePre
protected virtual IReadOnlyList<NamedOnnxValue> CreateUnetInputParams(IModelOptions model, DenseTensor<float> inputTensor, DenseTensor<float> promptEmbeddings, int timestep)
{
var inputNames = _onnxModelService.GetInputNames(model, OnnxModelType.Unet);
var inputMetaData = _onnxModelService.GetInputMetadata(model, OnnxModelType.Unet);

// Some models support Long or Float, could be more but fornow just support these 2
var timesepMetaKey = inputNames[1];
var timestepMetaData = inputMetaData[timesepMetaKey];
var timestepNamedOnnxValue = timestepMetaData.ElementDataType == TensorElementType.Int64
? NamedOnnxValue.CreateFromTensor(timesepMetaKey, new DenseTensor<long>(new long[] { timestep }, new int[] { 1 }))
: NamedOnnxValue.CreateFromTensor(timesepMetaKey, new DenseTensor<float>(new float[] { timestep }, new int[] { 1 }));

return CreateInputParameters(
NamedOnnxValue.CreateFromTensor(inputNames[0], inputTensor),
NamedOnnxValue.CreateFromTensor(inputNames[1], new DenseTensor<long>(new long[] { timestep }, new int[] { 1 })),
timestepNamedOnnxValue,
NamedOnnxValue.CreateFromTensor(inputNames[2], promptEmbeddings));
}

Expand Down

0 comments on commit 38f60b5

Please sign in to comment.