diff --git a/src/OpenCvSharp/Modules/ml/ANN_MLP.cs b/src/OpenCvSharp/Modules/ml/ANN_MLP.cs index b14a79a7c..f86fa709d 100644 --- a/src/OpenCvSharp/Modules/ml/ANN_MLP.cs +++ b/src/OpenCvSharp/Modules/ml/ANN_MLP.cs @@ -1,4 +1,5 @@ using System; +using System.ComponentModel; namespace OpenCvSharp.ML { @@ -265,6 +266,53 @@ public double RpropDWMax #endregion #region Methods + + /// + /// Sets training method and common parameters. + /// + /// Default value is ANN_MLP::RPROP. See ANN_MLP::TrainingMethods. + /// passed to setRpropDW0 for ANN_MLP::RPROP and to setBackpropWeightScale for ANN_MLP::BACKPROP and to initialT for ANN_MLP::ANNEAL. + /// passed to setRpropDWMin for ANN_MLP::RPROP and to setBackpropMomentumScale for ANN_MLP::BACKPROP and to finalT for ANN_MLP::ANNEAL. + public virtual void SetTrainMethod(TrainingMethods method, double param1 = 0, double param2 = 0) + { + if (!Enum.IsDefined(typeof(TrainingMethods), method)) + throw new InvalidEnumArgumentException(nameof(method), (int)method, typeof(TrainingMethods)); + + NativeMethods.HandleException( + NativeMethods.ml_ANN_MLP_setTrainMethod(ptr, (int)method, param1, param2)); + + GC.KeepAlive(this); + } + + /// + /// Returns current training method + /// + /// + public virtual TrainingMethods GetTrainMethod() + { + NativeMethods.HandleException( + NativeMethods.ml_ANN_MLP_getTrainMethod(ptr, out var ret)); + GC.KeepAlive(this); + return (TrainingMethods) ret; + } + + /// + /// Initialize the activation function for each neuron. + /// Currently the default and the only fully supported activation function is ANN_MLP::SIGMOID_SYM. + /// + /// The type of activation function. See ANN_MLP::ActivationFunctions. + /// The first parameter of the activation function, \f$\alpha\f$. Default value is 0. + /// The second parameter of the activation function, \f$\beta\f$. Default value is 0. + public virtual void SetActivationFunction(ActivationFunctions type, double param1 = 0, double param2 = 0) + { + if (!Enum.IsDefined(typeof(ActivationFunctions), type)) + throw new InvalidEnumArgumentException(nameof(type), (int)type, typeof(ActivationFunctions)); + + NativeMethods.HandleException( + NativeMethods.ml_ANN_MLP_setActivationFunction(ptr, (int)type, param1, param2)); + + GC.KeepAlive(this); + } /// /// Integer vector specifying the number of neurons in each layer including the input and output layers. diff --git a/test/OpenCvSharp.Tests/ml/ANN_MLPTest.cs b/test/OpenCvSharp.Tests/ml/ANN_MLPTest.cs new file mode 100644 index 000000000..780f66817 --- /dev/null +++ b/test/OpenCvSharp.Tests/ml/ANN_MLPTest.cs @@ -0,0 +1,58 @@ +using System; +using OpenCvSharp.ML; +using Xunit; +using Xunit.Abstractions; + +namespace OpenCvSharp.Tests.ML +{ + public class ANN_MLPTest : TestBase + { + private readonly ITestOutputHelper testOutputHelper; + + public ANN_MLPTest(ITestOutputHelper testOutputHelper) + { + this.testOutputHelper = testOutputHelper; + } + + [Fact] + public void RunTest() + { + float[,] trainFeaturesData = + { + {0, 0}, + {0, 100}, + {100, 0}, + {100, 100}, + }; + using var trainFeatures = new Mat(4, 2, MatType.CV_32F, trainFeaturesData); + + float[] trainLabelsData = { 1, 0, 1, 0 }; + using var trainLabels = new Mat(4, 1, MatType.CV_32F, trainLabelsData); + + using var model = ANN_MLP.Create(); + model.SetActivationFunction(ANN_MLP.ActivationFunctions.SigmoidSym, 0.1, 0.1); + model.SetTrainMethod(ANN_MLP.TrainingMethods.BackProp, 0.1, 0.1); + //model.TermCriteria = new TermCriteria(CriteriaType.MaxIter | CriteriaType.Eps, 10000, 0.0001); + + using var layerSize = new Mat(3, 1, MatType.CV_32SC1); + layerSize.Set(0, 2); + layerSize.Set(1, 10); + layerSize.Set(2, 1); + model.SetLayerSizes(layerSize); + + bool trainSuccess = model.Train(trainFeatures, SampleTypes.RowSample, trainLabels); + Assert.True(trainSuccess); + Assert.True(model.IsTrained()); + + float[] testFeatureData = { 0, 0 }; + using var testFeature = new Mat(1, 2, MatType.CV_32F, testFeatureData); + + using var result = new Mat(); + var detectedClass = model.Predict(testFeature, result); + + // TODO + //Assert.Equal(-1, detectedClass); + } + } +} +