Skip to content

Commit

Permalink
add testcases
Browse files Browse the repository at this point in the history
  • Loading branch information
yanqingmen committed Jul 7, 2015
2 parents 4d382a8 + e99ab0d commit 0fc47f5
Show file tree
Hide file tree
Showing 4 changed files with 217 additions and 2 deletions.
6 changes: 4 additions & 2 deletions java/xgboost4j/src/main/java/org/dmlc/xgboost4j/Booster.java
Original file line number Diff line number Diff line change
Expand Up @@ -76,14 +76,16 @@ public Booster(Iterable<Entry<String, Object>> params, DMatrix[] dMatrixs) throw
* @throws org.dmlc.xgboost4j.util.XGBoostError
*/
public Booster(Iterable<Entry<String, Object>> params, String modelPath) throws XGBoostError {
long[] out = new long[1];
init(null);
if(modelPath == null) {
throw new NullPointerException("modelPath : null");
}
loadModel(modelPath);
setParam("seed","0");
setParams(params);
}




private void init(DMatrix[] dMatrixs) throws XGBoostError {
Expand Down
3 changes: 3 additions & 0 deletions java/xgboost4j/src/main/java/org/dmlc/xgboost4j/DMatrix.java
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,9 @@ public static enum SparseType {
* @throws org.dmlc.xgboost4j.util.XGBoostError
*/
public DMatrix(String dataPath) throws XGBoostError {
if(dataPath == null) {
throw new NullPointerException("dataPath: null");
}
long[] out = new long[1];
ErrorHandle.checkCall(XgboostJNI.XGDMatrixCreateFromFile(dataPath, 1, out));
handle = out[0];
Expand Down
108 changes: 108 additions & 0 deletions java/xgboost4j/src/test/java/org/dmlc/xgboost4j/BoosterTest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package org.dmlc.xgboost4j;

import java.util.AbstractMap;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import junit.framework.TestCase;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.dmlc.xgboost4j.util.Trainer;
import org.dmlc.xgboost4j.util.XGBoostError;
import org.junit.Test;

/**
* test cases for Booster
* @author hzx
*/
public class BoosterTest {
public static class EvalError implements IEvaluation {
private static final Log logger = LogFactory.getLog(EvalError.class);

String evalMetric = "custom_error";

public EvalError() {
}

@Override
public String getMetric() {
return evalMetric;
}

@Override
public float eval(float[][] predicts, DMatrix dmat) {
float error = 0f;
float[] labels;
try {
labels = dmat.getLabel();
} catch (XGBoostError ex) {
logger.error(ex);
return -1f;
}
int nrow = predicts.length;
for(int i=0; i<nrow; i++) {
if(labels[i]==0f && predicts[i][0]>0) {
error++;
}
else if(labels[i]==1f && predicts[i][0]<=0) {
error++;
}
}

return error/labels.length;
}
}

@Test
public void testBoosterBasic() throws XGBoostError {
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");

//set params
Map<String, Object> paramMap = new HashMap<String, Object>() {
{
put("eta", 1.0);
put("max_depth", 2);
put("silent", 1);
put("objective", "binary:logistic");
}
};
Iterable<Entry<String, Object>> param = paramMap.entrySet();

//set watchList
List<Entry<String, DMatrix>> watchs = new ArrayList<>();
watchs.add(new AbstractMap.SimpleEntry<>("train", trainMat));
watchs.add(new AbstractMap.SimpleEntry<>("test", testMat));

//set round
int round = 2;

//train a boost model
Booster booster = Trainer.train(param, trainMat, round, watchs, null, null);

//predict raw output
float[][] predicts = booster.predict(testMat, true);

//eval
IEvaluation eval = new EvalError();
//error must be less than 0.1
TestCase.assertTrue(eval.eval(predicts, testMat)<0.1f);
}
}
102 changes: 102 additions & 0 deletions java/xgboost4j/src/test/java/org/dmlc/xgboost4j/DMatrixTest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package org.dmlc.xgboost4j;

import java.util.Arrays;
import java.util.Random;
import junit.framework.TestCase;
import org.dmlc.xgboost4j.util.XGBoostError;
import org.junit.Test;

/**
* test cases for DMatrix
* @author hzx
*/
public class DMatrixTest {

@Test
public void testCreateFromFile() throws XGBoostError {
//create DMatrix from file
DMatrix dmat = new DMatrix("../../demo/data/agaricus.txt.test");
//get label
float[] labels = dmat.getLabel();
//check length
TestCase.assertTrue(dmat.rowNum()==labels.length);
//set weights
float[] weights = Arrays.copyOf(labels, labels.length);
dmat.setWeight(weights);
float[] dweights = dmat.getWeight();
TestCase.assertTrue(Arrays.equals(weights, dweights));
}

@Test
public void testCreateFromCSR() throws XGBoostError {
//create Matrix from csr format sparse Matrix and labels
/**
* sparse matrix
* 1 0 2 3 0
* 4 0 2 3 5
* 3 1 2 5 0
*/
float[] data = new float[] {1, 2, 3, 4, 2, 3, 5, 3, 1, 2, 5};
int[] colIndex = new int[] {0, 2, 3, 0, 2, 3, 4, 0, 1, 2, 3};
long[] rowHeaders = new long[] {0, 3, 7, 11};
DMatrix dmat1 = new DMatrix(rowHeaders, colIndex, data, DMatrix.SparseType.CSR);
//check row num
System.out.println(dmat1.rowNum());
TestCase.assertTrue(dmat1.rowNum()==3);
//test set label
float[] label1 = new float[] {1, 0, 1};
dmat1.setLabel(label1);
float[] label2 = dmat1.getLabel();
TestCase.assertTrue(Arrays.equals(label1, label2));
}

@Test
public void testCreateFromDenseMatrix() throws XGBoostError {
//create DMatrix from 10*5 dense matrix
int nrow = 10;
int ncol = 5;
float[] data0 = new float[nrow*ncol];
//put random nums
Random random = new Random();
for(int i=0; i<nrow*ncol; i++) {
data0[i] = random.nextFloat();
}

//create label
float[] label0 = new float[nrow];
for(int i=0; i<nrow; i++) {
label0[i] = random.nextFloat();
}

DMatrix dmat0 = new DMatrix(data0, nrow, ncol);
dmat0.setLabel(label0);

//check
TestCase.assertTrue(dmat0.rowNum()==10);
TestCase.assertTrue(dmat0.getLabel().length==10);

//set weights for each instance
float[] weights = new float[nrow];
for(int i=0; i<nrow; i++) {
weights[i] = random.nextFloat();
}
dmat0.setWeight(weights);

TestCase.assertTrue(Arrays.equals(weights, dmat0.getWeight()));
}
}

0 comments on commit 0fc47f5

Please sign in to comment.