-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Add basic common classes for generic ML library.
- Loading branch information
Juergen Reuter
committed
Mar 10, 2020
1 parent
33a4361
commit aeeff0a
Showing
22 changed files
with
1,510 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
71 changes: 71 additions & 0 deletions
71
four-in-a-row/java/org/soundpaint/ml/common/AbstractCostFunction.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
/* | ||
* @(#)AbstractCostFunction.java 1.00 20/03/06 | ||
* | ||
* Copyright (C) 2020 Jürgen Reuter | ||
* | ||
* This program is free software: you can redistribute it and/or modify | ||
* it under the terms of the GNU General Public License as published by | ||
* the Free Software Foundation, either version 3 of the License, or | ||
* (at your option) any later version. | ||
* | ||
* This program is distributed in the hope that it will be useful, | ||
* but WITHOUT ANY WARRANTY; without even the implied warranty of | ||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the | ||
* GNU General Public License for more details. | ||
* | ||
* You should have received a copy of the GNU General Public License | ||
* along with this program. If not, see <http://www.gnu.org/licenses/>. | ||
*/ | ||
package org.soundpaint.ml.common; | ||
|
||
import java.util.List; | ||
|
||
public abstract class AbstractCostFunction implements CostFunction | ||
{ | ||
private final String id; | ||
protected final List<Double> targetValues; | ||
protected final int size; | ||
protected final double reverseSize; | ||
|
||
private AbstractCostFunction() | ||
{ | ||
throw new UnsupportedOperationException("unsupported default constructor"); | ||
} | ||
|
||
public AbstractCostFunction(final String id, final List<Double> targetValues) | ||
{ | ||
if (id == null) { throw new NullPointerException("id"); } | ||
if (targetValues == null) { | ||
throw new NullPointerException("targetValues"); | ||
} | ||
if (targetValues.size() == 0) { | ||
throw new IllegalArgumentException("empty targetValues"); | ||
} | ||
for (final Double targetValue : targetValues) { | ||
if (targetValue == null) { | ||
throw new IllegalArgumentException("list of target values contains null"); | ||
} | ||
} | ||
this.id = id; | ||
this.targetValues = targetValues; | ||
size = targetValues.size(); | ||
reverseSize = 1.0 / size; | ||
} | ||
|
||
public String getId() | ||
{ | ||
return id; | ||
} | ||
|
||
public String toString() | ||
{ | ||
return "cost function " + id; | ||
} | ||
} | ||
|
||
/* | ||
* Local Variables: | ||
* coding:utf-8 | ||
* mode:Java | ||
* End: | ||
*/ |
75 changes: 75 additions & 0 deletions
75
four-in-a-row/java/org/soundpaint/ml/common/ActivationFunction.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
/* | ||
* @(#)ActivationFunction.java 1.00 20/03/04 | ||
* | ||
* Copyright (C) 2020 Jürgen Reuter | ||
* | ||
* This program is free software: you can redistribute it and/or modify | ||
* it under the terms of the GNU General Public License as published by | ||
* the Free Software Foundation, either version 3 of the License, or | ||
* (at your option) any later version. | ||
* | ||
* This program is distributed in the hope that it will be useful, | ||
* but WITHOUT ANY WARRANTY; without even the implied warranty of | ||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the | ||
* GNU General Public License for more details. | ||
* | ||
* You should have received a copy of the GNU General Public License | ||
* along with this program. If not, see <http://www.gnu.org/licenses/>. | ||
*/ | ||
package org.soundpaint.ml.common; | ||
|
||
import java.util.function.Function; | ||
|
||
public enum ActivationFunction implements Function<Double, Double> | ||
{ | ||
THRESHOLD("threshold") { | ||
public Double apply(final Double input) { | ||
return 1.0 / (1.0 + Math.exp(-input)); | ||
} | ||
}, | ||
|
||
SIGMOID("sigmoid") { | ||
public Double apply(final Double input) { | ||
return 1.0 / (1.0 + Math.exp(-input)); | ||
} | ||
}, | ||
|
||
HYPERBOLIC_TANGENT("hyperbolic tangent") { | ||
public Double apply(final Double input) { | ||
return Math.tanh(input); | ||
} | ||
}, | ||
|
||
RECTIFIED_LINEAR_UNIT("rectified linear unit") { | ||
public Double apply(final Double input) { | ||
return input < 0.0 ? 0.0 : input; | ||
} | ||
}; | ||
|
||
private final String id; | ||
|
||
private ActivationFunction() { | ||
throw new UnsupportedOperationException("unsupported default constructor"); | ||
} | ||
|
||
private ActivationFunction(final String id) { | ||
this.id = id; | ||
} | ||
|
||
public String getId() | ||
{ | ||
return id; | ||
} | ||
|
||
public String toString() | ||
{ | ||
return "activation function " + id; | ||
} | ||
} | ||
|
||
/* | ||
* Local Variables: | ||
* coding:utf-8 | ||
* mode:Java | ||
* End: | ||
*/ |
66 changes: 66 additions & 0 deletions
66
four-in-a-row/java/org/soundpaint/ml/common/ActivationOperation.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
/* | ||
* @(#)ActivationOperation.java 1.00 20/03/09 | ||
* | ||
* Copyright (C) 2020 Jürgen Reuter | ||
* | ||
* This program is free software: you can redistribute it and/or modify | ||
* it under the terms of the GNU General Public License as published by | ||
* the Free Software Foundation, either version 3 of the License, or | ||
* (at your option) any later version. | ||
* | ||
* This program is distributed in the hope that it will be useful, | ||
* but WITHOUT ANY WARRANTY; without even the implied warranty of | ||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the | ||
* GNU General Public License for more details. | ||
* | ||
* You should have received a copy of the GNU General Public License | ||
* along with this program. If not, see <http://www.gnu.org/licenses/>. | ||
*/ | ||
package org.soundpaint.ml.common; | ||
|
||
import java.util.ArrayList; | ||
import java.util.List; | ||
|
||
public class ActivationOperation extends Operation<Double, Double> | ||
{ | ||
private final ActivationFunction function; | ||
private final List<Double> inputValues; | ||
|
||
private static String getFunctionId(final ActivationFunction function) | ||
{ | ||
if (function == null) { | ||
throw new NullPointerException("function"); | ||
} | ||
return function.getId() + "op"; | ||
} | ||
|
||
public ActivationOperation(final ActivationFunction function, | ||
final Node<Double, Double> x) | ||
{ | ||
super(getFunctionId(function), | ||
new ArrayList<Node<Double, Double>>(List.of(x))); | ||
this.function = function; | ||
inputValues = new ArrayList<Double>(); | ||
} | ||
|
||
public Double compute(final List<Double> operands) | ||
{ | ||
if (operands == null) { | ||
throw new NullPointerException("operands"); | ||
} | ||
if (operands.size() != 1) { | ||
throw new IllegalArgumentException("require 1 operand, got: " + | ||
operands.size()); | ||
} | ||
inputValues.clear(); | ||
inputValues.add(operands.get(0)); | ||
return function.apply(operands.get(0)); | ||
} | ||
} | ||
|
||
/* | ||
* Local Variables: | ||
* coding:utf-8 | ||
* mode:Java | ||
* End: | ||
*/ |
56 changes: 56 additions & 0 deletions
56
four-in-a-row/java/org/soundpaint/ml/common/AddOperation.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
/* | ||
* @(#)AddOperation.java 1.00 20/03/08 | ||
* | ||
* Copyright (C) 2020 Jürgen Reuter | ||
* | ||
* This program is free software: you can redistribute it and/or modify | ||
* it under the terms of the GNU General Public License as published by | ||
* the Free Software Foundation, either version 3 of the License, or | ||
* (at your option) any later version. | ||
* | ||
* This program is distributed in the hope that it will be useful, | ||
* but WITHOUT ANY WARRANTY; without even the implied warranty of | ||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the | ||
* GNU General Public License for more details. | ||
* | ||
* You should have received a copy of the GNU General Public License | ||
* along with this program. If not, see <http://www.gnu.org/licenses/>. | ||
*/ | ||
package org.soundpaint.ml.common; | ||
|
||
import java.util.ArrayList; | ||
import java.util.List; | ||
|
||
public class AddOperation extends Operation<Double, Double> | ||
{ | ||
private final List<Double> inputValues; | ||
|
||
public AddOperation(final Node<Double, Double> x, | ||
final Node<Double, Double> y) | ||
{ | ||
super("addop", new ArrayList<Node<Double, Double>>(List.of(x, y))); | ||
inputValues = new ArrayList<Double>(); | ||
} | ||
|
||
public Double compute(final List<Double> operands) | ||
{ | ||
if (operands == null) { | ||
throw new NullPointerException("operands"); | ||
} | ||
if (operands.size() != 2) { | ||
throw new IllegalArgumentException("require 2 operands, got: " + | ||
operands.size()); | ||
} | ||
inputValues.clear(); | ||
inputValues.add(operands.get(0)); | ||
inputValues.add(operands.get(1)); | ||
return operands.get(0) + operands.get(1); | ||
} | ||
} | ||
|
||
/* | ||
* Local Variables: | ||
* coding:utf-8 | ||
* mode:Java | ||
* End: | ||
*/ |
34 changes: 34 additions & 0 deletions
34
four-in-a-row/java/org/soundpaint/ml/common/CostFunction.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
/* | ||
* @(#)CostFunction.java 1.00 20/03/06 | ||
* | ||
* Copyright (C) 2020 Jürgen Reuter | ||
* | ||
* This program is free software: you can redistribute it and/or modify | ||
* it under the terms of the GNU General Public License as published by | ||
* the Free Software Foundation, either version 3 of the License, or | ||
* (at your option) any later version. | ||
* | ||
* This program is distributed in the hope that it will be useful, | ||
* but WITHOUT ANY WARRANTY; without even the implied warranty of | ||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the | ||
* GNU General Public License for more details. | ||
* | ||
* You should have received a copy of the GNU General Public License | ||
* along with this program. If not, see <http://www.gnu.org/licenses/>. | ||
*/ | ||
package org.soundpaint.ml.common; | ||
|
||
import java.util.List; | ||
import java.util.function.Function; | ||
|
||
public interface CostFunction extends Function<List<Perceptron>, Double> | ||
{ | ||
String getId(); | ||
} | ||
|
||
/* | ||
* Local Variables: | ||
* coding:utf-8 | ||
* mode:Java | ||
* End: | ||
*/ |
53 changes: 53 additions & 0 deletions
53
four-in-a-row/java/org/soundpaint/ml/common/CrossEntropyCostFunction.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
/* | ||
* @(#)CrossEntropyFunction.java 1.00 20/03/06 | ||
* | ||
* Copyright (C) 2020 Jürgen Reuter | ||
* | ||
* This program is free software: you can redistribute it and/or modify | ||
* it under the terms of the GNU General Public License as published by | ||
* the Free Software Foundation, either version 3 of the License, or | ||
* (at your option) any later version. | ||
* | ||
* This program is distributed in the hope that it will be useful, | ||
* but WITHOUT ANY WARRANTY; without even the implied warranty of | ||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the | ||
* GNU General Public License for more details. | ||
* | ||
* You should have received a copy of the GNU General Public License | ||
* along with this program. If not, see <http://www.gnu.org/licenses/>. | ||
*/ | ||
package org.soundpaint.ml.common; | ||
|
||
import java.util.List; | ||
|
||
public class CrossEntropyCostFunction extends AbstractCostFunction | ||
{ | ||
public CrossEntropyCostFunction(final String id, | ||
final List<Double> targetValues) | ||
{ | ||
super(id, targetValues); | ||
} | ||
|
||
public Double apply(final List<Perceptron> outputLayerPerceptrons) { | ||
if (outputLayerPerceptrons.size() != size) { | ||
throw new IllegalArgumentException("vector size mismatch:" + | ||
size + " != " + | ||
outputLayerPerceptrons.size()); | ||
} | ||
double cost = 0.0; | ||
for (int index = 0; index < size; index++) { | ||
final Perceptron perceptron = outputLayerPerceptrons.get(index); | ||
final double y = targetValues.get(index); | ||
final double a = perceptron.getOutput(); | ||
cost += y * Math.log(a) + (1.0 - y) * Math.log(1.0 - a); | ||
} | ||
return - cost * reverseSize; | ||
} | ||
} | ||
|
||
/* | ||
* Local Variables: | ||
* coding:utf-8 | ||
* mode:Java | ||
* End: | ||
*/ |
Oops, something went wrong.