Skip to content

Commit

Permalink
All test from Peng's page has been implemented and passed w/o errors.
Browse files Browse the repository at this point in the history
  • Loading branch information
sramirez committed Jun 1, 2017
1 parent ed272db commit 89c3b00
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 46 deletions.
66 changes: 62 additions & 4 deletions src/test/scala/org/apache/spark/ml/feature/ITSelectorSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import TestHelper._


/**
* Test infomartion theoretic feature selection
* Test information theoretic feature selection on datasets from Peng's webpage
*
* @author Sergio Ramirez
*/
Expand All @@ -21,20 +21,78 @@ class ITSelectorSuite extends FunSuite with BeforeAndAfterAll {
sqlContext = new SQLContext(SPARK_CTX)
}

/** Do entropy based binning of cars data from UC Irvine repository. */
/** Do mRMR feature selection on COLON data. */
test("Run ITFS on colon data (nPart = 10, nfeat = 10)") {

val df = readColonData(sqlContext)
val df = readCSVData(sqlContext, "test_colon_s3.csv")
val cols = df.columns
val pad = 2
val allVectorsDense = true
val model = getSelectorModel(sqlContext, df, cols.drop(1), cols.head, 10, 10, allVectorsDense, pad)
val model = getSelectorModel(sqlContext, df, cols.drop(1), cols.head,
10, 10, allVectorsDense, pad)

assertResult("512, 764, 1324, 1380, 1411, 1422, 1581, 1670, 1671, 1971") {
model.selectedFeatures.mkString(", ")
}
}

/** Do mRMR feature selection on LEUKEMIA data. */
test("Run ITFS on leukemia data (nPart = 10, nfeat = 10)") {

val df = readCSVData(sqlContext, "test_leukemia_s3.csv")
val cols = df.columns
val pad = 2
val allVectorsDense = true
val model = getSelectorModel(sqlContext, df, cols.drop(1), cols.head,
10, 10, allVectorsDense, pad)

assertResult("1084, 1719, 1774, 1822, 2061, 2294, 3192, 4387, 4787, 6795") {
model.selectedFeatures.mkString(", ")
}
}

/** Do mRMR feature selection on LUNG data. */
test("Run ITFS on lung data (nPart = 10, nfeat = 10)") {

val df = readCSVData(sqlContext, "test_lung_s3.csv")
val cols = df.columns
val pad = 2
val allVectorsDense = true
val model = getSelectorModel(sqlContext, df, cols.drop(1), cols.head,
10, 10, allVectorsDense, pad)

assertResult("18, 22, 29, 125, 132, 150, 166, 242, 243, 269") {
model.selectedFeatures.mkString(", ")
}
}

/** Do mRMR feature selection on LYMPHOMA data. */
test("Run ITFS on lymphoma data (nPart = 10, nfeat = 10)") {

val df = readCSVData(sqlContext, "test_lymphoma_s3.csv")
val cols = df.columns
val pad = 2
val allVectorsDense = true
val model = getSelectorModel(sqlContext, df, cols.drop(1), cols.head,
10, 10, allVectorsDense, pad)

assertResult("236, 393, 759, 2747, 2818, 2841, 2862, 3014, 3702, 3792") {
model.selectedFeatures.mkString(", ")
}
}

/** Do mRMR feature selection on NCI data. */
test("Run ITFS on nci data (nPart = 10, nfeat = 10)") {

val df = readCSVData(sqlContext, "test_nci9_s3.csv")
val cols = df.columns
val pad = 2
val allVectorsDense = true
val model = getSelectorModel(sqlContext, df, cols.drop(1), cols.head,
10, 10, allVectorsDense, pad)

assertResult("443, 755, 1369, 1699, 3483, 5641, 6290, 7674, 9399, 9576") {
model.selectedFeatures.mkString(", ")
}
}
}
48 changes: 6 additions & 42 deletions src/test/scala/org/apache/spark/ml/feature/TestHelper.scala
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,15 @@ object TestHelper {
final val INDEX_SUFFIX: String = "_IDX"

/**
* @return the discretizer fit to the data given the specified features to bin and label use as target.
* @return the feature select fit to the data given the specified features to bin and label use as target.
*/

def createSelectorModel(sqlContext: SQLContext, dataframe: Dataset[_], inputCols: Array[String],
labelColumn: String,
nPartitions: Int = 100,
numTopFeatures: Int = 20,
allVectorsDense: Boolean = true,
padded: Int = 0): InfoThSelectorModel = {
padded: Int = 0 /* if minimum value is negative */): InfoThSelectorModel = {
val featureAssembler = new VectorAssembler()
.setInputCols(inputCols)
.setOutputCol("features")
Expand Down Expand Up @@ -73,7 +73,7 @@ object TestHelper {

/**
* The label column will have null values replaced with MISSING values in this case.
* @return the discretizer fit to the data given the specified features to bin and label use as target.
* @return the feature selector fit to the data given the specified features to bin and label use as target.
*/
def getSelectorModel(sqlContext: SQLContext, dataframe: DataFrame, inputCols: Array[String],
labelColumn: String,
Expand Down Expand Up @@ -121,53 +121,17 @@ object TestHelper {
sc
}

/** @return standard iris dataset from UCI repo.
*/
/*def readColonData(sqlContext: SQLContext): DataFrame = {
val data = SPARK_CTX.textFile(FILE_PREFIX + "iris.data")
val nullable = true
val schema = (0 until 9712).map(i => StructField("var" + i, DoubleType, nullable)).toList :+
StructField("colontype", StringType, nullable)
// ints and dates must be read as doubles
val rows = data.map(line => line.split(",").map(elem => elem.trim))
.map(x => {Row.fromSeq(Seq(asDouble(x(0)), asDouble(x(1)), asDouble(x(2)), asDouble(x(3)), asString(x(4))))})
sqlContext.createDataFrame(rows, schema)
}
/** @return standard iris dataset from UCI repo.
/** @return standard csv data from the repo.
*/
def readColonData2(sqlContext: SQLContext): DataFrame = {
val data = SPARK_CTX.textFile(FILE_PREFIX + "iris.data")
val nullable = true
val schema = StructType(List(
StructField("features", new VectorUDT, nullable),
StructField("class", DoubleType, nullable)
))
val rows = data.map{line =>
val split = line.split(",").map(elem => elem.trim)
val features = Vectors.dense(split.drop(1).map(_.toDouble))
val label = split.head.toDouble
(features, label)
}
val asd = sqlContext.createDataFrame(rows, schema)
}*/


def readColonData(sqlContext: SQLContext): DataFrame = {
def readCSVData(sqlContext: SQLContext, file: String): DataFrame = {
val df = sqlContext.read
.format("com.databricks.spark.csv")
.option("header", "true") // Use first line of all files as header
.option("inferSchema", "true") // Automatically infer data types
.load(FILE_PREFIX + "test_colon_s3.csv")
.load(FILE_PREFIX + file)
df
}




/** @return dataset with 3 double columns. The first is the label column and contain null.
*/
def readNullLabelTestData(sqlContext: SQLContext): DataFrame = {
Expand Down

0 comments on commit 89c3b00

Please sign in to comment.