Skip to content

Commit

Permalink
fixed error message to match function names in DataCutter and DataBal…
Browse files Browse the repository at this point in the history
…ancer (#256)
  • Loading branch information
leahmcguire committed Mar 29, 2019
1 parent 8e6e050 commit 7cd8aa2
Show file tree
Hide file tree
Showing 6 changed files with 24 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,6 @@ case object DataBalancer {
class DataBalancer(uid: String = UID[DataBalancer]) extends Splitter(uid = uid) with DataBalancerParams {

@transient private lazy val log = LoggerFactory.getLogger(this.getClass)
@transient private[op] var summary: Option[DataBalancerSummary] = None

/**
* Computes the upSample and downSample proportions.
Expand Down Expand Up @@ -142,12 +141,12 @@ class DataBalancer(uid: String = UID[DataBalancer]) extends Splitter(uid = uid)
* @param data to prepare for model training. first column must be the label as a double
* @return balanced training set and a test set
*/
def validationPrepare(data: Dataset[Row]): Dataset[Row] = {
override def validationPrepare(data: Dataset[Row]): Dataset[Row] = {

if (summary.isEmpty) throw new RuntimeException("Cannot call prepare until examine has been called")
val dataPrep = super.validationPrepare(data)

val negativeData = data.filter(_.getDouble(0) == 0.0).persist()
val positiveData = data.filter(_.getDouble(0) == 1.0).persist()
val negativeData = dataPrep.filter(_.getDouble(0) == 0.0).persist()
val positiveData = dataPrep.filter(_.getDouble(0) == 1.0).persist()
val seed = getSeed

// If these conditions are met, that means that we have enough information to balance the data : upSample,
Expand All @@ -172,7 +171,7 @@ class DataBalancer(uid: String = UID[DataBalancer]) extends Splitter(uid = uid)
sampleBalancedData(
fraction = fraction,
seed = seed,
data = data,
data = dataPrep,
positiveData = positiveData,
negativeData = negativeData
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,6 @@ class DataCutter(uid: String = UID[DataCutter]) extends Splitter(uid = uid) with

@transient private lazy val log = LoggerFactory.getLogger(this.getClass)

@transient private[op] var summary: Option[DataCutterSummary] = None

/**
* Function to set parameters before passing into the validation step
* eg - do data balancing or dropping based on the labels
Expand Down Expand Up @@ -104,11 +102,11 @@ class DataCutter(uid: String = UID[DataCutter]) extends Splitter(uid = uid) with
* @param data first column must be the label as a double
* @return Training set test set
*/
def validationPrepare(data: Dataset[Row]): Dataset[Row] = {
if (summary.isEmpty) throw new RuntimeException("Cannot call prepare until examine has been called")
override def validationPrepare(data: Dataset[Row]): Dataset[Row] = {
val dataPrep = super.validationPrepare(data)

val keep: Set[Double] = getLabelsToKeep.toSet
val dataUse = data.filter(r => keep.contains(r.getDouble(0)))
val dataUse = dataPrep.filter(r => keep.contains(r.getDouble(0)))

dataUse
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,24 +61,17 @@ case object DataSplitter {
*/
class DataSplitter(uid: String = UID[DataSplitter]) extends Splitter(uid = uid) {


/**
* Function to set parameters before passing into the validation step
* eg - do data balancing or dropping based on the labels
*
* @param data
* @return Parameters set in examining data
*/
override def preValidationPrepare(data: Dataset[Row]): Option[SplitterSummary] = Option(DataSplitterSummary())

/**
* Function to use to prepare the dataset for modeling
* eg - do data balancing or dropping based on the labels
*
* @param data
* @return Training set test set
*/
def validationPrepare(data: Dataset[Row]): Dataset[Row] = data
override def preValidationPrepare(data: Dataset[Row]): Option[SplitterSummary] = {
summary = Option(DataSplitterSummary())
summary
}

override def copy(extra: ParamMap): DataSplitter = {
val copy = new DataSplitter(uid)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ import scala.util.Try
*/
abstract class Splitter(val uid: String) extends SplitterParams {

@transient private[op] var summary: Option[SplitterSummary] = None

/**
* Function to use to create the training set and test set.
*
Expand All @@ -67,7 +69,10 @@ abstract class Splitter(val uid: String) extends SplitterParams {
* @param data
* @return Training set test set
*/
def validationPrepare(data: Dataset[Row]): Dataset[Row]
def validationPrepare(data: Dataset[Row]): Dataset[Row] = {
checkPreconditions()
data
}


/**
Expand All @@ -79,6 +84,10 @@ abstract class Splitter(val uid: String) extends SplitterParams {
*/
def preValidationPrepare(data: Dataset[Row]): Option[SplitterSummary]


protected def checkPreconditions(): Unit =
require(summary.nonEmpty, "Cannot call validationPrepare until preValidationPrepare has been called")

}

trait SplitterParams extends Params {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ class DataBalancerTest extends FlatSpec with TestSparkContext with SplitterSumma
it should "throw an error if you try to prepare before examining" in {
val balancer = DataBalancer(sampleFraction = 0.1, maxTrainingSample = 2000, seed = 11L)
intercept[RuntimeException](balancer.validationPrepare(data)).getMessage shouldBe
"Cannot call prepare until examine has been called"
"requirement failed: Cannot call validationPrepare until preValidationPrepare has been called"
}

it should "remember that data is already balanced" in {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ class DataCutterTest extends FlatSpec with TestSparkContext with SplitterSummary
it should "throw an error when prepare is called before examine" in {
val dataCutter = DataCutter(seed = seed, minLabelFraction = 0.4)
intercept[RuntimeException](dataCutter.validationPrepare(randDF)).getMessage shouldBe
"Cannot call prepare until examine has been called"
"requirement failed: Cannot call validationPrepare until preValidationPrepare has been called"
}

it should "filter out all but the top N label categories" in {
Expand Down

0 comments on commit 7cd8aa2

Please sign in to comment.