From 3f1524193702b73dc6f6b847dbdff5dbfcc0d5a7 Mon Sep 17 00:00:00 2001 From: yqzhishen Date: Sat, 16 Mar 2024 14:56:50 +0800 Subject: [PATCH] Add tensor caching system for DiffSinger --- .../DiffSinger/DiffSingerBasePhonemizer.cs | 30 +- OpenUtau.Core/DiffSinger/DiffSingerCache.cs | 311 ++++++++++++++++++ OpenUtau.Core/DiffSinger/DiffSingerPitch.cs | 15 +- .../DiffSinger/DiffSingerRenderer.cs | 39 ++- OpenUtau.Core/DiffSinger/DiffSingerSinger.cs | 7 +- .../DiffSinger/DiffSingerVariance.cs | 30 +- OpenUtau.Core/DiffSinger/DiffSingerVocoder.cs | 3 + OpenUtau.Core/Util/Preferences.cs | 1 + OpenUtau/Strings/Strings.axaml | 1 + OpenUtau/Strings/Strings.zh-CN.axaml | 1 + OpenUtau/ViewModels/PreferencesViewModel.cs | 7 + OpenUtau/Views/PreferencesDialog.axaml | 14 +- 12 files changed, 430 insertions(+), 29 deletions(-) create mode 100644 OpenUtau.Core/DiffSinger/DiffSingerCache.cs diff --git a/OpenUtau.Core/DiffSinger/DiffSingerBasePhonemizer.cs b/OpenUtau.Core/DiffSinger/DiffSingerBasePhonemizer.cs index 61af0c478..d955756fb 100644 --- a/OpenUtau.Core/DiffSinger/DiffSingerBasePhonemizer.cs +++ b/OpenUtau.Core/DiffSinger/DiffSingerBasePhonemizer.cs @@ -2,12 +2,14 @@ using System.Collections.Generic; using System.IO; using System.Linq; +using K4os.Hash.xxHash; using Serilog; using Microsoft.ML.OnnxRuntime; using Microsoft.ML.OnnxRuntime.Tensors; using OpenUtau.Api; using OpenUtau.Core.Ustx; +using OpenUtau.Core.Util; namespace OpenUtau.Core.DiffSinger { @@ -17,6 +19,8 @@ public abstract class DiffSingerBasePhonemizer : MachineLearningPhonemizer DsConfig dsConfig; string rootPath; float frameMs; + ulong linguisticHash; + ulong durationHash; InferenceSession linguisticModel; InferenceSession durationModel; IG2p g2p; @@ -51,14 +55,18 @@ public abstract class DiffSingerBasePhonemizer : MachineLearningPhonemizer //Load models var linguisticModelPath = Path.Join(rootPath, dsConfig.linguistic); try { - linguisticModel = new InferenceSession(linguisticModelPath); + var linguisticModelBytes = File.ReadAllBytes(linguisticModelPath); + linguisticHash = XXH64.DigestOf(linguisticModelBytes); + linguisticModel = new InferenceSession(linguisticModelBytes); } catch (Exception e) { Log.Error(e, $"failed to load linguistic model from {linguisticModelPath}"); return; } var durationModelPath = Path.Join(rootPath, dsConfig.dur); try { - durationModel = new InferenceSession(durationModelPath); + var durationModelBytes = File.ReadAllBytes(durationModelPath); + durationHash = XXH64.DigestOf(durationModelBytes); + durationModel = new InferenceSession(durationModelBytes); } catch (Exception e) { Log.Error(e, $"failed to load duration model from {durationModelPath}"); return; @@ -260,7 +268,14 @@ public abstract class DiffSingerBasePhonemizer : MachineLearningPhonemizer new DenseTensor(word_dur, new int[] { word_dur.Length }, false) .Reshape(new int[] { 1, word_dur.Length }))); Onnx.VerifyInputNames(linguisticModel, linguisticInputs); - var linguisticOutputs = linguisticModel.Run(linguisticInputs); + var linguisticCache = Preferences.Default.DiffSingerTensorCache + ? new DiffSingerCache(linguisticHash, linguisticInputs) + : null; + var linguisticOutputs = linguisticCache?.Load(); + if (linguisticOutputs is null) { + linguisticOutputs = linguisticModel.Run(linguisticInputs).Cast().ToList(); + linguisticCache?.Save(linguisticOutputs); + } Tensor encoder_out = linguisticOutputs .Where(o => o.Name == "encoder_out") .First() @@ -291,7 +306,14 @@ public abstract class DiffSingerBasePhonemizer : MachineLearningPhonemizer durationInputs.Add(NamedOnnxValue.CreateFromTensor("spk_embed", spkEmbedTensor)); } Onnx.VerifyInputNames(durationModel, durationInputs); - var durationOutputs = durationModel.Run(durationInputs); + var durationCache = Preferences.Default.DiffSingerTensorCache + ? new DiffSingerCache(durationHash, durationInputs) + : null; + var durationOutputs = durationCache?.Load(); + if (durationOutputs is null) { + durationOutputs = durationModel.Run(durationInputs).Cast().ToList(); + durationCache?.Save(durationOutputs); + } List durationFrames = durationOutputs.First().AsTensor().Select(x=>(double)x).ToList(); //Alignment diff --git a/OpenUtau.Core/DiffSinger/DiffSingerCache.cs b/OpenUtau.Core/DiffSinger/DiffSingerCache.cs new file mode 100644 index 000000000..83c30d8e0 --- /dev/null +++ b/OpenUtau.Core/DiffSinger/DiffSingerCache.cs @@ -0,0 +1,311 @@ +using System; +using System.Collections.Generic; +using System.IO; +using System.Linq; +using K4os.Hash.xxHash; +using Microsoft.ML.OnnxRuntime; +using Microsoft.ML.OnnxRuntime.Tensors; +using Serilog; + +namespace OpenUtau.Core.DiffSinger { + + public class DiffSingerCache { + private const string FormatHeader = "TENSORCACHE"; + + private readonly ulong hash; + private readonly string filename; + + public ulong Hash => hash; + + public DiffSingerCache(ulong identifier, ICollection inputs) { + using var stream = new MemoryStream(); + using (var writer = new BinaryWriter(stream)) { + writer.Write(identifier); + foreach (var onnxValue in inputs.OrderBy(v => v.Name, StringComparer.InvariantCulture)) { + SerializeNamedOnnxValue(writer, onnxValue); + } + } + + hash = XXH64.DigestOf(stream.ToArray()); + filename = $"ds-{hash:x16}.tensorcache"; + } + + public ICollection? Load() { + var cachePath = Path.Join(PathManager.Inst.CachePath, filename); + if (!File.Exists(cachePath)) return null; + + var result = new List(); + using var stream = new FileStream(cachePath, FileMode.Open, FileAccess.Read); + using var reader = new BinaryReader(stream); + // header + if (reader.ReadString() != FormatHeader) { + throw new InvalidDataException($"[TensorCache] Unexpected file header in {filename}."); + } + try { + // count + var count = reader.ReadInt32(); + for (var i = 0; i < count; ++i) { + // data + result.Add(DeserializeNamedOnnxValue(reader)); + } + } catch (Exception e) { + Log.Error(e, + "[TensorCache] Exception encountered when deserializing cache file. Root exception message: {msg}", e.Message); + Delete(); + return null; + } + + return result; + } + + public void Delete() { + var cachePath = Path.Join(PathManager.Inst.CachePath, filename); + if (File.Exists(cachePath)) { + File.Delete(cachePath); + } + } + + public void Save(ICollection outputs) { + var cachePath = Path.Join(PathManager.Inst.CachePath, filename); + using var stream = new FileStream(cachePath, FileMode.Create, FileAccess.Write); + using var writer = new BinaryWriter(stream); + // header + writer.Write(FormatHeader); + // count + writer.Write(outputs.Count); + foreach (var onnxValue in outputs) { + // data + SerializeNamedOnnxValue(writer, onnxValue); + } + } + + private static void SerializeNamedOnnxValue(BinaryWriter writer, NamedOnnxValue namedOnnxValue) { + if (namedOnnxValue.ValueType != OnnxValueType.ONNX_TYPE_TENSOR) { + throw new NotSupportedException( + $"[TensorCache] The only supported ONNX value type is {OnnxValueType.ONNX_TYPE_TENSOR}. Got {namedOnnxValue.ValueType} instead." + ); + } + // name + writer.Write(namedOnnxValue.Name); + var tensorBase = (TensorBase) namedOnnxValue.Value; + var elementType = tensorBase.GetTypeInfo().ElementType; + // dtype + writer.Write((int)elementType); + switch (elementType) { + case TensorElementType.Float: { + var tensor = namedOnnxValue.AsTensor(); + SerializeTensor(writer, tensor); + break; + } + case TensorElementType.UInt8: { + var tensor = namedOnnxValue.AsTensor(); + SerializeTensor(writer, tensor); + break; + } + case TensorElementType.Int8: { + var tensor = namedOnnxValue.AsTensor(); + SerializeTensor(writer, tensor); + break; + } + case TensorElementType.UInt16: { + var tensor = namedOnnxValue.AsTensor(); + SerializeTensor(writer, tensor); + break; + } + case TensorElementType.Int16: { + var tensor = namedOnnxValue.AsTensor(); + SerializeTensor(writer, tensor); + break; + } + case TensorElementType.Int32: { + var tensor = namedOnnxValue.AsTensor(); + SerializeTensor(writer, tensor); + break; + } + case TensorElementType.Int64: { + var tensor = namedOnnxValue.AsTensor(); + SerializeTensor(writer, tensor); + break; + } + case TensorElementType.String: { + var tensor = namedOnnxValue.AsTensor(); + SerializeTensor(writer, tensor); + break; + } + case TensorElementType.Bool: { + var tensor = namedOnnxValue.AsTensor(); + SerializeTensor(writer, tensor); + break; + } + case TensorElementType.Float16: { + var tensor = namedOnnxValue.AsTensor(); + SerializeTensor(writer, tensor); + break; + } + case TensorElementType.Double: { + var tensor = namedOnnxValue.AsTensor(); + SerializeTensor(writer, tensor); + break; + } + case TensorElementType.UInt32: { + var tensor = namedOnnxValue.AsTensor(); + SerializeTensor(writer, tensor); + break; + } + case TensorElementType.UInt64: { + var tensor = namedOnnxValue.AsTensor(); + SerializeTensor(writer, tensor); + break; + } + case TensorElementType.BFloat16: { + var tensor = namedOnnxValue.AsTensor(); + SerializeTensor(writer, tensor); + break; + } + case TensorElementType.Complex64: + case TensorElementType.Complex128: + case TensorElementType.DataTypeMax: + default: + throw new NotSupportedException($"[TensorCache] Unsupported tensor element type: {elementType}."); + } + } + + private static void SerializeTensor(BinaryWriter writer, Tensor tensor) { + if (tensor.IsReversedStride) { + throw new NotSupportedException("[TensorCache] Tensors in reversed strides are not supported."); + } + // rank + writer.Write(tensor.Rank); + // shape + foreach (var dim in tensor.Dimensions) { + writer.Write(dim); + } + // size + var size = (int)tensor.Length; + writer.Write(size); + if (typeof(T) == typeof(string)) { + // string tensor + // data + foreach (var element in tensor.ToArray()) { + writer.Write(element!.ToString()); + } + } else { + // numeric tensor + // data + var data = new byte[size * tensor.GetTypeInfo().TypeSize]; + Buffer.BlockCopy(tensor.ToArray(), 0, data, 0, data.Length); + writer.Write(data); + } + } + + private static NamedOnnxValue DeserializeNamedOnnxValue(BinaryReader reader) { + // name + var name = reader.ReadString(); + // dtype + var dtype = (TensorElementType)reader.ReadInt32(); + // rank + var rank = reader.ReadInt32(); + // shape + int[] shape = new int[rank]; + for (var i = 0; i < rank; ++i) { + shape[i] = reader.ReadInt32(); + } + // size + var size = reader.ReadInt32(); + NamedOnnxValue namedOnnxValue; + switch (dtype) { + case TensorElementType.Float: { + var tensor = DeserializeTensor(reader, size, sizeof(float), shape); + namedOnnxValue = NamedOnnxValue.CreateFromTensor(name, tensor); + break; + } + case TensorElementType.UInt8: { + var tensor = DeserializeTensor(reader, size, sizeof(byte), shape); + namedOnnxValue = NamedOnnxValue.CreateFromTensor(name, tensor); + break; + } + case TensorElementType.Int8: { + var tensor = DeserializeTensor(reader, size, sizeof(sbyte), shape); + namedOnnxValue = NamedOnnxValue.CreateFromTensor(name, tensor); + break; + } + case TensorElementType.UInt16: { + var tensor = DeserializeTensor(reader, size, sizeof(ushort), shape); + namedOnnxValue = NamedOnnxValue.CreateFromTensor(name, tensor); + break; + } + case TensorElementType.Int16: { + var tensor = DeserializeTensor(reader, size, sizeof(short), shape); + namedOnnxValue = NamedOnnxValue.CreateFromTensor(name, tensor); + break; + } + case TensorElementType.Int32: { + var tensor = DeserializeTensor(reader, size, sizeof(int), shape); + namedOnnxValue = NamedOnnxValue.CreateFromTensor(name, tensor); + break; + } + case TensorElementType.Int64: { + var tensor = DeserializeTensor(reader, size, sizeof(long), shape); + namedOnnxValue = NamedOnnxValue.CreateFromTensor(name, tensor); + break; + } + case TensorElementType.String: { + // string tensor + Tensor tensor = new DenseTensor(size); + for (var i = 0; i < size; ++i) { + tensor[i] = reader.ReadString(); + } + tensor = tensor.Reshape(shape); + namedOnnxValue = NamedOnnxValue.CreateFromTensor(name, tensor); + break; + } + case TensorElementType.Bool: { + var tensor = DeserializeTensor(reader, size, sizeof(bool), shape); + namedOnnxValue = NamedOnnxValue.CreateFromTensor(name, tensor); + break; + } + case TensorElementType.Float16: { + var tensor = DeserializeTensor(reader, size, sizeof(ushort), shape); + namedOnnxValue = NamedOnnxValue.CreateFromTensor(name, tensor); + break; + } + case TensorElementType.Double: { + var tensor = DeserializeTensor(reader, size, sizeof(double), shape); + namedOnnxValue = NamedOnnxValue.CreateFromTensor(name, tensor); + break; + } + case TensorElementType.UInt32: { + var tensor = DeserializeTensor(reader, size, sizeof(uint), shape); + namedOnnxValue = NamedOnnxValue.CreateFromTensor(name, tensor); + break; + } + case TensorElementType.UInt64: { + var tensor = DeserializeTensor(reader, size, sizeof(ulong), shape); + namedOnnxValue = NamedOnnxValue.CreateFromTensor(name, tensor); + break; + } + case TensorElementType.BFloat16: { + var tensor = DeserializeTensor(reader, size, sizeof(ushort), shape); + namedOnnxValue = NamedOnnxValue.CreateFromTensor(name, tensor); + break; + } + case TensorElementType.Complex64: + case TensorElementType.Complex128: + case TensorElementType.DataTypeMax: + default: + throw new NotSupportedException($"[TensorCache] Unsupported tensor element type: {dtype}."); + } + + return namedOnnxValue; + } + + private static Tensor DeserializeTensor(BinaryReader reader, int size, int typeSize, ReadOnlySpan shape) + { + var bytes = reader.ReadBytes(size * typeSize); + var data = new T[size]; + Buffer.BlockCopy(bytes, 0, data, 0, bytes.Length); + Tensor tensor = new DenseTensor(data, shape); + return tensor; + } + } +} diff --git a/OpenUtau.Core/DiffSinger/DiffSingerPitch.cs b/OpenUtau.Core/DiffSinger/DiffSingerPitch.cs index 503c6cecb..7327717a9 100644 --- a/OpenUtau.Core/DiffSinger/DiffSingerPitch.cs +++ b/OpenUtau.Core/DiffSinger/DiffSingerPitch.cs @@ -4,6 +4,7 @@ using System.Linq; using System.Runtime.CompilerServices; using System.Text; +using K4os.Hash.xxHash; using Microsoft.ML.OnnxRuntime; using Microsoft.ML.OnnxRuntime.Tensors; @@ -19,6 +20,7 @@ public class DsPitch : IDisposable string rootPath; DsConfig dsConfig; List phonemes; + ulong linguisticHash; InferenceSession linguisticModel; InferenceSession pitchModel; IG2p g2p; @@ -39,7 +41,9 @@ public DsPitch(string rootPath) phonemes = File.ReadLines(phonemesPath, Encoding.UTF8).ToList(); //Load models var linguisticModelPath = Path.Join(rootPath, dsConfig.linguistic); - linguisticModel = Onnx.getInferenceSession(linguisticModelPath); + var linguisticModelBytes = File.ReadAllBytes(linguisticModelPath); + linguisticHash = XXH64.DigestOf(linguisticModelBytes); + linguisticModel = Onnx.getInferenceSession(linguisticModelBytes); var pitchModelPath = Path.Join(rootPath, dsConfig.pitch); pitchModel = Onnx.getInferenceSession(pitchModelPath); frameMs = 1000f * dsConfig.hop_size / dsConfig.sample_rate; @@ -123,7 +127,14 @@ public DsPitch(string rootPath) } Onnx.VerifyInputNames(linguisticModel, linguisticInputs); - var linguisticOutputs = linguisticModel.Run(linguisticInputs); + var linguisticCache = Preferences.Default.DiffSingerTensorCache + ? new DiffSingerCache(linguisticHash, linguisticInputs) + : null; + var linguisticOutputs = linguisticCache?.Load(); + if (linguisticOutputs is null) { + linguisticOutputs = linguisticModel.Run(linguisticInputs).Cast().ToList(); + linguisticCache?.Save(linguisticOutputs); + } Tensor encoder_out = linguisticOutputs .Where(o => o.Name == "encoder_out") .First() diff --git a/OpenUtau.Core/DiffSinger/DiffSingerRenderer.cs b/OpenUtau.Core/DiffSinger/DiffSingerRenderer.cs index dd71f66ee..dd2106940 100644 --- a/OpenUtau.Core/DiffSinger/DiffSingerRenderer.cs +++ b/OpenUtau.Core/DiffSinger/DiffSingerRenderer.cs @@ -12,6 +12,7 @@ using OpenUtau.Core.Render; using OpenUtau.Core.SignalChain; using OpenUtau.Core.Ustx; +using OpenUtau.Core.Util; using Serilog; namespace OpenUtau.Core.DiffSinger { @@ -292,28 +293,40 @@ public class DiffSingerRenderer : IRenderer { .Reshape(new int[] { 1, tension.Length }))); } } - Tensor mel; - lock(acousticModel){ - if(cancellation.IsCancellationRequested) { - return null; + Onnx.VerifyInputNames(acousticModel, acousticInputs); + var acousticCache = Preferences.Default.DiffSingerTensorCache + ? new DiffSingerCache(singer.acousticHash, acousticInputs) + : null; + var acousticOutputs = acousticCache?.Load(); + if (acousticOutputs is null) { + lock(acousticModel){ + if(cancellation.IsCancellationRequested) { + return null; + } + acousticOutputs = acousticModel.Run(acousticInputs).Cast().ToList(); } - Onnx.VerifyInputNames(acousticModel, acousticInputs); - var acousticOutputs = acousticModel.Run(acousticInputs); - mel = acousticOutputs.First().AsTensor().Clone(); + acousticCache?.Save(acousticOutputs); } + Tensor mel = acousticOutputs.First().AsTensor().Clone(); //vocoder //waveform = session.run(['waveform'], {'mel': mel, 'f0': f0})[0] var vocoderInputs = new List(); vocoderInputs.Add(NamedOnnxValue.CreateFromTensor("mel", mel)); vocoderInputs.Add(NamedOnnxValue.CreateFromTensor("f0",f0tensor)); - Tensor samplesTensor; - lock(vocoder){ - if(cancellation.IsCancellationRequested) { - return null; + var vocoderCache = Preferences.Default.DiffSingerTensorCache + ? new DiffSingerCache(vocoder.hash, vocoderInputs) + : null; + var vocoderOutputs = vocoderCache?.Load(); + if (vocoderOutputs is null) { + lock(vocoder){ + if(cancellation.IsCancellationRequested) { + return null; + } + vocoderOutputs = vocoder.session.Run(vocoderInputs).Cast().ToList(); } - var vocoderOutputs = vocoder.session.Run(vocoderInputs); - samplesTensor = vocoderOutputs.First().AsTensor(); + vocoderCache?.Save(vocoderOutputs); } + Tensor samplesTensor = vocoderOutputs.First().AsTensor(); //Check the size of samplesTensor int[] expectedShape = new int[] { 1, -1 }; if(!DiffSingerUtils.ValidateShape(samplesTensor, expectedShape)){ diff --git a/OpenUtau.Core/DiffSinger/DiffSingerSinger.cs b/OpenUtau.Core/DiffSinger/DiffSingerSinger.cs index db02db87d..12742d963 100644 --- a/OpenUtau.Core/DiffSinger/DiffSingerSinger.cs +++ b/OpenUtau.Core/DiffSinger/DiffSingerSinger.cs @@ -3,6 +3,7 @@ using System.IO; using System.Linq; using System.Text; +using K4os.Hash.xxHash; using OpenUtau.Classic; using OpenUtau.Core.Ustx; using Serilog; @@ -44,6 +45,7 @@ class DiffSingerSinger : USinger { public List phonemes = new List(); public DsConfig dsConfig; + public ulong acousticHash; public InferenceSession acousticSession = null; public DsVocoder vocoder = null; public DsPitch pitchPredictor = null; @@ -126,7 +128,10 @@ class DiffSingerSinger : USinger { public InferenceSession getAcousticSession() { if (acousticSession is null) { - acousticSession = Onnx.getInferenceSession(Path.Combine(Location, dsConfig.acoustic)); + var acousticPath = Path.Combine(Location, dsConfig.acoustic); + var acousticBytes = File.ReadAllBytes(acousticPath); + acousticHash = XXH64.DigestOf(acousticBytes); + acousticSession = Onnx.getInferenceSession(acousticBytes); } return acousticSession; } diff --git a/OpenUtau.Core/DiffSinger/DiffSingerVariance.cs b/OpenUtau.Core/DiffSinger/DiffSingerVariance.cs index 017d04ea2..82b89131b 100644 --- a/OpenUtau.Core/DiffSinger/DiffSingerVariance.cs +++ b/OpenUtau.Core/DiffSinger/DiffSingerVariance.cs @@ -3,7 +3,7 @@ using System.IO; using System.Linq; using System.Text; - +using K4os.Hash.xxHash; using Serilog; using Microsoft.ML.OnnxRuntime; using Microsoft.ML.OnnxRuntime.Tensors; @@ -23,6 +23,8 @@ public class DsVariance : IDisposable{ string rootPath; DsConfig dsConfig; List phonemes; + ulong linguisticHash; + ulong varianceHash; InferenceSession linguisticModel; InferenceSession varianceModel; IG2p g2p; @@ -43,9 +45,13 @@ public DsVariance(string rootPath) phonemes = File.ReadLines(phonemesPath, Encoding.UTF8).ToList(); //Load models var linguisticModelPath = Path.Join(rootPath, dsConfig.linguistic); - linguisticModel = Onnx.getInferenceSession(linguisticModelPath); + var linguisticModelBytes = File.ReadAllBytes(linguisticModelPath); + linguisticHash = XXH64.DigestOf(linguisticModelBytes); + linguisticModel = Onnx.getInferenceSession(linguisticModelBytes); var varianceModelPath = Path.Join(rootPath, dsConfig.variance); - varianceModel = Onnx.getInferenceSession(varianceModelPath); + var varianceModelBytes = File.ReadAllBytes(varianceModelPath); + varianceHash = XXH64.DigestOf(varianceModelBytes); + varianceModel = Onnx.getInferenceSession(varianceModelBytes); frameMs = 1000f * dsConfig.hop_size / dsConfig.sample_rate; //Load g2p g2p = LoadG2p(rootPath); @@ -119,7 +125,14 @@ public DsVariance(string rootPath) } Onnx.VerifyInputNames(linguisticModel, linguisticInputs); - var linguisticOutputs = linguisticModel.Run(linguisticInputs); + var linguisticCache = Preferences.Default.DiffSingerTensorCache + ? new DiffSingerCache(linguisticHash, linguisticInputs) + : null; + var linguisticOutputs = linguisticCache?.Load(); + if (linguisticOutputs is null) { + linguisticOutputs = linguisticModel.Run(linguisticInputs).Cast().ToList(); + linguisticCache?.Save(linguisticOutputs); + } Tensor encoder_out = linguisticOutputs .Where(o => o.Name == "encoder_out") .First() @@ -183,7 +196,14 @@ public DsVariance(string rootPath) varianceInputs.Add(NamedOnnxValue.CreateFromTensor("spk_embed", spkEmbedTensor)); } Onnx.VerifyInputNames(varianceModel, varianceInputs); - var varianceOutputs = varianceModel.Run(varianceInputs); + var varianceCache = Preferences.Default.DiffSingerTensorCache + ? new DiffSingerCache(varianceHash, varianceInputs) + : null; + var varianceOutputs = varianceCache?.Load(); + if (varianceOutputs is null) { + varianceOutputs = varianceModel.Run(varianceInputs).Cast().ToList(); + varianceCache?.Save(varianceOutputs); + } Tensor? energy_pred = dsConfig.predict_energy ? varianceOutputs .Where(o => o.Name == "energy_pred") diff --git a/OpenUtau.Core/DiffSinger/DiffSingerVocoder.cs b/OpenUtau.Core/DiffSinger/DiffSingerVocoder.cs index 1a5fc35e7..1b1188fd6 100644 --- a/OpenUtau.Core/DiffSinger/DiffSingerVocoder.cs +++ b/OpenUtau.Core/DiffSinger/DiffSingerVocoder.cs @@ -1,11 +1,13 @@ using System; using System.IO; +using K4os.Hash.xxHash; using Microsoft.ML.OnnxRuntime; namespace OpenUtau.Core.DiffSinger { public class DsVocoder : IDisposable { public string Location; public DsVocoderConfig config; + public ulong hash; public InferenceSession session; public int num_mel_bins => config.num_mel_bins; @@ -25,6 +27,7 @@ public class DsVocoder : IDisposable { catch (Exception ex) { throw new Exception($"Error loading vocoder {name}. Please download vocoder from https://github.com/xunmengshe/OpenUtau/wiki/Vocoders"); } + hash = XXH64.DigestOf(model); session = Onnx.getInferenceSession(model); } diff --git a/OpenUtau.Core/Util/Preferences.cs b/OpenUtau.Core/Util/Preferences.cs index 0fabcc1fb..991c3d03d 100644 --- a/OpenUtau.Core/Util/Preferences.cs +++ b/OpenUtau.Core/Util/Preferences.cs @@ -144,6 +144,7 @@ public class SerializablePreferences { public int OnnxGpu = 0; public int DiffsingerSpeedup = 50; public int DiffSingerDepth = 1000; + public bool DiffSingerTensorCache = true; public string Language = string.Empty; public string SortingOrder = string.Empty; public List RecentFiles = new List(); diff --git a/OpenUtau/Strings/Strings.axaml b/OpenUtau/Strings/Strings.axaml index 2b51b8f0f..a05068c69 100644 --- a/OpenUtau/Strings/Strings.axaml +++ b/OpenUtau/Strings/Strings.axaml @@ -333,6 +333,7 @@ Warning: this option removes custom presets. Use track color in UI Cache Clear cache on quit + DiffSinger Tensor Cache Preferences Note: please restart OpenUtau after changing this item. Off diff --git a/OpenUtau/Strings/Strings.zh-CN.axaml b/OpenUtau/Strings/Strings.zh-CN.axaml index 94ac8bd55..250751713 100644 --- a/OpenUtau/Strings/Strings.zh-CN.axaml +++ b/OpenUtau/Strings/Strings.zh-CN.axaml @@ -318,6 +318,7 @@ 在界面上使用音轨颜色 缓存 退出时清空缓存 + DiffSinger 张量缓存 使用偏好 注意: 修改本项后请重启OpenUtau diff --git a/OpenUtau/ViewModels/PreferencesViewModel.cs b/OpenUtau/ViewModels/PreferencesViewModel.cs index 67f1a41c2..ea155e403 100644 --- a/OpenUtau/ViewModels/PreferencesViewModel.cs +++ b/OpenUtau/ViewModels/PreferencesViewModel.cs @@ -42,6 +42,7 @@ public class PreferencesViewModel : ViewModelBase { public List DiffsingerSpeedupOptions { get; } = new List { 1, 5, 10, 20, 50, 100 }; [Reactive] public int DiffSingerDepth { get; set; } [Reactive] public int DiffsingerSpeedup { get; set; } + [Reactive] public bool DiffSingerTensorCache { get; set; } [Reactive] public bool HighThreads { get; set; } [Reactive] public int Theme { get; set; } [Reactive] public bool PenPlusDefault { get; set; } @@ -141,6 +142,7 @@ public class LyricsHelperOption { OnnxGpu = OnnxGpuOptions.FirstOrDefault(x => x.deviceId == Preferences.Default.OnnxGpu, OnnxGpuOptions[0]); DiffSingerDepth = Preferences.Default.DiffSingerDepth; DiffsingerSpeedup = Preferences.Default.DiffsingerSpeedup; + DiffSingerTensorCache = Preferences.Default.DiffSingerTensorCache; Theme = Preferences.Default.Theme; PenPlusDefault = Preferences.Default.PenPlusDefault; DegreeStyle = Preferences.Default.DegreeStyle; @@ -334,6 +336,11 @@ public class LyricsHelperOption { Preferences.Default.DiffSingerDepth = index; Preferences.Save(); }); + this.WhenAnyValue(vm => vm.DiffSingerTensorCache) + .Subscribe(useCache => { + Preferences.Default.DiffSingerTensorCache = useCache; + Preferences.Save(); + }); } public void TestAudioOutputDevice() { diff --git a/OpenUtau/Views/PreferencesDialog.axaml b/OpenUtau/Views/PreferencesDialog.axaml index b5b4f81eb..483ed30ee 100644 --- a/OpenUtau/Views/PreferencesDialog.axaml +++ b/OpenUtau/Views/PreferencesDialog.axaml @@ -97,10 +97,16 @@ - - - - + + + + + + + + + +