Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 48 additions & 0 deletions src/OpenCvSharp/Modules/ml/ANN_MLP.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using System;
using System.ComponentModel;

namespace OpenCvSharp.ML
{
Expand Down Expand Up @@ -265,6 +266,53 @@ public double RpropDWMax
#endregion

#region Methods

/// <summary>
/// Sets training method and common parameters.
/// </summary>
/// <param name="method">Default value is ANN_MLP::RPROP. See ANN_MLP::TrainingMethods.</param>
/// <param name="param1">passed to setRpropDW0 for ANN_MLP::RPROP and to setBackpropWeightScale for ANN_MLP::BACKPROP and to initialT for ANN_MLP::ANNEAL.</param>
/// <param name="param2">passed to setRpropDWMin for ANN_MLP::RPROP and to setBackpropMomentumScale for ANN_MLP::BACKPROP and to finalT for ANN_MLP::ANNEAL.</param>
public virtual void SetTrainMethod(TrainingMethods method, double param1 = 0, double param2 = 0)
{
if (!Enum.IsDefined(typeof(TrainingMethods), method))
throw new InvalidEnumArgumentException(nameof(method), (int)method, typeof(TrainingMethods));

NativeMethods.HandleException(
NativeMethods.ml_ANN_MLP_setTrainMethod(ptr, (int)method, param1, param2));

GC.KeepAlive(this);
}

/// <summary>
/// Returns current training method
/// </summary>
/// <returns></returns>
public virtual TrainingMethods GetTrainMethod()
{
NativeMethods.HandleException(
NativeMethods.ml_ANN_MLP_getTrainMethod(ptr, out var ret));
GC.KeepAlive(this);
return (TrainingMethods) ret;
}

/// <summary>
/// Initialize the activation function for each neuron.
/// Currently the default and the only fully supported activation function is ANN_MLP::SIGMOID_SYM.
/// </summary>
/// <param name="type">The type of activation function. See ANN_MLP::ActivationFunctions.</param>
/// <param name="param1">The first parameter of the activation function, \f$\alpha\f$. Default value is 0.</param>
/// <param name="param2">The second parameter of the activation function, \f$\beta\f$. Default value is 0.</param>
public virtual void SetActivationFunction(ActivationFunctions type, double param1 = 0, double param2 = 0)
{
if (!Enum.IsDefined(typeof(ActivationFunctions), type))
throw new InvalidEnumArgumentException(nameof(type), (int)type, typeof(ActivationFunctions));

NativeMethods.HandleException(
NativeMethods.ml_ANN_MLP_setActivationFunction(ptr, (int)type, param1, param2));

GC.KeepAlive(this);
}

/// <summary>
/// Integer vector specifying the number of neurons in each layer including the input and output layers.
Expand Down
58 changes: 58 additions & 0 deletions test/OpenCvSharp.Tests/ml/ANN_MLPTest.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
using System;
using OpenCvSharp.ML;
using Xunit;
using Xunit.Abstractions;

namespace OpenCvSharp.Tests.ML
{
public class ANN_MLPTest : TestBase
{
private readonly ITestOutputHelper testOutputHelper;

public ANN_MLPTest(ITestOutputHelper testOutputHelper)
{
this.testOutputHelper = testOutputHelper;
}

[Fact]
public void RunTest()
{
float[,] trainFeaturesData =
{
{0, 0},
{0, 100},
{100, 0},
{100, 100},
};
using var trainFeatures = new Mat(4, 2, MatType.CV_32F, trainFeaturesData);

float[] trainLabelsData = { 1, 0, 1, 0 };
using var trainLabels = new Mat(4, 1, MatType.CV_32F, trainLabelsData);

using var model = ANN_MLP.Create();
model.SetActivationFunction(ANN_MLP.ActivationFunctions.SigmoidSym, 0.1, 0.1);
model.SetTrainMethod(ANN_MLP.TrainingMethods.BackProp, 0.1, 0.1);
//model.TermCriteria = new TermCriteria(CriteriaType.MaxIter | CriteriaType.Eps, 10000, 0.0001);

using var layerSize = new Mat(3, 1, MatType.CV_32SC1);
layerSize.Set<int>(0, 2);
layerSize.Set<int>(1, 10);
layerSize.Set<int>(2, 1);
model.SetLayerSizes(layerSize);

bool trainSuccess = model.Train(trainFeatures, SampleTypes.RowSample, trainLabels);
Assert.True(trainSuccess);
Assert.True(model.IsTrained());

float[] testFeatureData = { 0, 0 };
using var testFeature = new Mat(1, 2, MatType.CV_32F, testFeatureData);

using var result = new Mat();
var detectedClass = model.Predict(testFeature, result);

// TODO
//Assert.Equal(-1, detectedClass);
}
}
}