diff --git a/src/OpenCvSharp/Modules/ml/ANN_MLP.cs b/src/OpenCvSharp/Modules/ml/ANN_MLP.cs
index b14a79a7c..f86fa709d 100644
--- a/src/OpenCvSharp/Modules/ml/ANN_MLP.cs
+++ b/src/OpenCvSharp/Modules/ml/ANN_MLP.cs
@@ -1,4 +1,5 @@
using System;
+using System.ComponentModel;
namespace OpenCvSharp.ML
{
@@ -265,6 +266,53 @@ public double RpropDWMax
#endregion
#region Methods
+
+ ///
+ /// Sets training method and common parameters.
+ ///
+ /// Default value is ANN_MLP::RPROP. See ANN_MLP::TrainingMethods.
+ /// passed to setRpropDW0 for ANN_MLP::RPROP and to setBackpropWeightScale for ANN_MLP::BACKPROP and to initialT for ANN_MLP::ANNEAL.
+ /// passed to setRpropDWMin for ANN_MLP::RPROP and to setBackpropMomentumScale for ANN_MLP::BACKPROP and to finalT for ANN_MLP::ANNEAL.
+ public virtual void SetTrainMethod(TrainingMethods method, double param1 = 0, double param2 = 0)
+ {
+ if (!Enum.IsDefined(typeof(TrainingMethods), method))
+ throw new InvalidEnumArgumentException(nameof(method), (int)method, typeof(TrainingMethods));
+
+ NativeMethods.HandleException(
+ NativeMethods.ml_ANN_MLP_setTrainMethod(ptr, (int)method, param1, param2));
+
+ GC.KeepAlive(this);
+ }
+
+ ///
+ /// Returns current training method
+ ///
+ ///
+ public virtual TrainingMethods GetTrainMethod()
+ {
+ NativeMethods.HandleException(
+ NativeMethods.ml_ANN_MLP_getTrainMethod(ptr, out var ret));
+ GC.KeepAlive(this);
+ return (TrainingMethods) ret;
+ }
+
+ ///
+ /// Initialize the activation function for each neuron.
+ /// Currently the default and the only fully supported activation function is ANN_MLP::SIGMOID_SYM.
+ ///
+ /// The type of activation function. See ANN_MLP::ActivationFunctions.
+ /// The first parameter of the activation function, \f$\alpha\f$. Default value is 0.
+ /// The second parameter of the activation function, \f$\beta\f$. Default value is 0.
+ public virtual void SetActivationFunction(ActivationFunctions type, double param1 = 0, double param2 = 0)
+ {
+ if (!Enum.IsDefined(typeof(ActivationFunctions), type))
+ throw new InvalidEnumArgumentException(nameof(type), (int)type, typeof(ActivationFunctions));
+
+ NativeMethods.HandleException(
+ NativeMethods.ml_ANN_MLP_setActivationFunction(ptr, (int)type, param1, param2));
+
+ GC.KeepAlive(this);
+ }
///
/// Integer vector specifying the number of neurons in each layer including the input and output layers.
diff --git a/test/OpenCvSharp.Tests/ml/ANN_MLPTest.cs b/test/OpenCvSharp.Tests/ml/ANN_MLPTest.cs
new file mode 100644
index 000000000..780f66817
--- /dev/null
+++ b/test/OpenCvSharp.Tests/ml/ANN_MLPTest.cs
@@ -0,0 +1,58 @@
+using System;
+using OpenCvSharp.ML;
+using Xunit;
+using Xunit.Abstractions;
+
+namespace OpenCvSharp.Tests.ML
+{
+ public class ANN_MLPTest : TestBase
+ {
+ private readonly ITestOutputHelper testOutputHelper;
+
+ public ANN_MLPTest(ITestOutputHelper testOutputHelper)
+ {
+ this.testOutputHelper = testOutputHelper;
+ }
+
+ [Fact]
+ public void RunTest()
+ {
+ float[,] trainFeaturesData =
+ {
+ {0, 0},
+ {0, 100},
+ {100, 0},
+ {100, 100},
+ };
+ using var trainFeatures = new Mat(4, 2, MatType.CV_32F, trainFeaturesData);
+
+ float[] trainLabelsData = { 1, 0, 1, 0 };
+ using var trainLabels = new Mat(4, 1, MatType.CV_32F, trainLabelsData);
+
+ using var model = ANN_MLP.Create();
+ model.SetActivationFunction(ANN_MLP.ActivationFunctions.SigmoidSym, 0.1, 0.1);
+ model.SetTrainMethod(ANN_MLP.TrainingMethods.BackProp, 0.1, 0.1);
+ //model.TermCriteria = new TermCriteria(CriteriaType.MaxIter | CriteriaType.Eps, 10000, 0.0001);
+
+ using var layerSize = new Mat(3, 1, MatType.CV_32SC1);
+ layerSize.Set(0, 2);
+ layerSize.Set(1, 10);
+ layerSize.Set(2, 1);
+ model.SetLayerSizes(layerSize);
+
+ bool trainSuccess = model.Train(trainFeatures, SampleTypes.RowSample, trainLabels);
+ Assert.True(trainSuccess);
+ Assert.True(model.IsTrained());
+
+ float[] testFeatureData = { 0, 0 };
+ using var testFeature = new Mat(1, 2, MatType.CV_32F, testFeatureData);
+
+ using var result = new Mat();
+ var detectedClass = model.Predict(testFeature, result);
+
+ // TODO
+ //Assert.Equal(-1, detectedClass);
+ }
+ }
+}
+