diff --git a/lib/weka-stable-3.8.5.jar b/lib/weka-stable-3.8.5.jar new file mode 100644 index 0000000..ccae073 Binary files /dev/null and b/lib/weka-stable-3.8.5.jar differ diff --git a/pom.xml b/pom.xml index 37f186e..de344f9 100644 --- a/pom.xml +++ b/pom.xml @@ -12,7 +12,8 @@ 1.8 1.8 8.2.0.0-SNAPSHOT - 3.8.3.1 + + 3.8.5 1.0.25 1.0.5 1.1.3.4.O @@ -74,15 +75,21 @@ ${weka.version} --> + + - pdm-ce - pdm-ce + nz.ac.waikato.cms.weka + weka-stable ${weka.version} system - ${basedir}/lib/pdm-ce-${weka.version}.jar + ${basedir}/lib/weka-stable-${weka.version}.jar diff --git a/src/main/java/org/pentaho/di/trans/steps/pmi/BaseSupervisedPMIStepData.java b/src/main/java/org/pentaho/di/trans/steps/pmi/BaseSupervisedPMIStepData.java index 265e011..6476537 100644 --- a/src/main/java/org/pentaho/di/trans/steps/pmi/BaseSupervisedPMIStepData.java +++ b/src/main/java/org/pentaho/di/trans/steps/pmi/BaseSupervisedPMIStepData.java @@ -1237,9 +1237,10 @@ protected static void establishOutputRowMeta( RowMetaInterface outRowMeta, Varia if ( stepMeta.getOutputIRMetrics() ) { String classLabels = classArffMeta.getNominalVals(); if ( !Const.isEmpty( classLabels ) ) { - TreeSet ts = new TreeSet<>( ArffMeta.stringToVals( classLabels ) ); + // TreeSet ts = new TreeSet<>( ArffMeta.stringToVals( classLabels ) ); + ArrayList preOrdered = new ArrayList<>( ArffMeta.stringToVals( classLabels ) ); //String[] labels = classLabels.split( "," ); - for ( String label : ts ) { + for ( String label : preOrdered ) { label = label.trim(); vm = ValueMetaFactory @@ -1283,9 +1284,10 @@ protected static void establishOutputRowMeta( RowMetaInterface outRowMeta, Varia if ( stepMeta.getOutputAUCMetrics() ) { String classLabels = classArffMeta.getNominalVals(); if ( !Const.isEmpty( classLabels ) ) { - TreeSet ts = new TreeSet<>( ArffMeta.stringToVals( classLabels ) ); + //TreeSet ts = new TreeSet<>( ArffMeta.stringToVals( classLabels ) ); // String[] labels = classLabels.split( "," ); - for ( String label : ts ) { + ArrayList preOrdered = new ArrayList<>( ArffMeta.stringToVals( classLabels ) ); + for ( String label : preOrdered ) { label = label.trim(); vm = diff --git a/src/main/java/org/pentaho/di/trans/steps/pmi/PMILifecycleListener.java b/src/main/java/org/pentaho/di/trans/steps/pmi/PMILifecycleListener.java index c3743a7..a728e18 100644 --- a/src/main/java/org/pentaho/di/trans/steps/pmi/PMILifecycleListener.java +++ b/src/main/java/org/pentaho/di/trans/steps/pmi/PMILifecycleListener.java @@ -42,7 +42,8 @@ public class PMILifecycleListener implements KettleLifecycleListener { // TODO replace this by some code that somehow locates the pdm jar file in plugins/steps/pmi/lib // This allows the Spark engine to locate the main weka.jar file for use in the Spark execution environment - System.setProperty( "weka.jar.filename", "pdm-ce-3.8.3.1.jar" ); + //System.setProperty( "weka.jar.filename", "pdm-ce-3.8.3.1.jar" ); + System.setProperty( "weka.jar.filename", "weka-stable-3.8.5.jar" ); // check that the required packages are installed (and possibly install if not) try { diff --git a/src/main/java/org/pentaho/di/trans/steps/pmi/PMIScoring.java b/src/main/java/org/pentaho/di/trans/steps/pmi/PMIScoring.java index df3f5fb..daba087 100644 --- a/src/main/java/org/pentaho/di/trans/steps/pmi/PMIScoring.java +++ b/src/main/java/org/pentaho/di/trans/steps/pmi/PMIScoring.java @@ -167,84 +167,15 @@ private PMIScoringModel setModel( String modelFileName ) throws KettleException Object[] r = getRow(); - if ( r == null ) { - if ( !m_meta.getEvaluateRatherThanScore() && m_data.getModel().isBatchPredictor() && !m_meta - .getFileNameFromField() && m_batch.size() > 0 ) { - try { - outputBatchRows( true ); - } catch ( Exception ex ) { - throw new KettleException( - BaseMessages.getString( PMIScoringMeta.PKG, "PMIScoring.Error.ProblemWhileGettingPredictionsForBatch" ), - ex ); //$NON-NLS-1$ - } - } - - if ( m_meta.getEvaluateRatherThanScore() && m_data.getModel().isSupervisedLearningModel() ) { - // generate the output row - try { - if ( m_data.getModel().isBatchPredictor() ) { - outputBatchRows( true ); - } else { - Object[] outputRow = m_data.evaluateForRow( getInputRowMeta(), m_data.getOutputRowMeta(), null, m_meta ); - putRow( m_data.getOutputRowMeta(), outputRow ); - } - } catch ( Exception ex ) { - throw new KettleException( - BaseMessages.getString( PMIScoringMeta.PKG, "PMIScoring.Error.ProblemWhileGettingPredictionsForBatch" ), - ex ); //$NON-NLS-1$ - } - } - - // see if we have an incremental model that is to be saved somewhere. - if ( !m_meta.getFileNameFromField() && m_meta.getUpdateIncrementalModel() ) { - if ( !Const.isEmpty( m_meta.getSavedModelFileName() ) ) { - // try and save that sucker... - try { - String modName = environmentSubstitute( m_meta.getSavedModelFileName() ); - File updatedModelFile = null; - if ( modName.startsWith( "file:" ) ) { - try { - modName = modName.replace( " ", "%20" ); - updatedModelFile = new File( new java.net.URI( modName ) ); - } catch ( Exception ex ) { - throw new KettleException( - BaseMessages.getString( PMIScoringMeta.PKG, "PMIScoring.Error.MalformedURIForUpdatedModelFile" ), - ex ); - } - } else { - updatedModelFile = new File( modName ); - } - PMIScoringData.saveSerializedModel( m_data.getModel(), updatedModelFile ); - } catch ( Exception ex ) { - throw new KettleException( - BaseMessages.getString( PMIScoringMeta.PKG, "PMIScoring.Error.ProblemSavingUpdatedModelToFile" ), - ex ); //$NON-NLS-1$ - } - } - } - - if ( m_meta.getFileNameFromField() ) { - // clear the main model - m_data.getModel().done(); - m_data.setModel( null ); - m_data.setDefaultModel( null ); - if ( m_modelCache != null ) { - m_modelCache.clear(); - } - } else { - m_data.getModel().done(); - m_data.setModel( null ); - m_data.setDefaultModel( null ); - } - - setOutputDone(); - return false; - } - // Handle the first row if ( first ) { first = false; + if (r == null) { + setOutputDone(); + return false; + } + m_data.setOutputRowMeta( getInputRowMeta().clone() ); if ( m_meta.getFileNameFromField() ) { RowMetaInterface inputRowMeta = getInputRowMeta(); @@ -378,6 +309,80 @@ private PMIScoringModel setModel( String modelFileName ) throws KettleException } } // end (if first) + if ( r == null ) { + if ( !m_meta.getEvaluateRatherThanScore() && m_data.getModel().isBatchPredictor() && !m_meta + .getFileNameFromField() && m_batch.size() > 0 ) { + try { + outputBatchRows( true ); + } catch ( Exception ex ) { + throw new KettleException( + BaseMessages.getString( PMIScoringMeta.PKG, "PMIScoring.Error.ProblemWhileGettingPredictionsForBatch" ), + ex ); //$NON-NLS-1$ + } + } + + if ( m_meta.getEvaluateRatherThanScore() && m_data.getModel().isSupervisedLearningModel() ) { + // generate the output row + try { + if ( m_data.getModel().isBatchPredictor() ) { + outputBatchRows( true ); + } else { + Object[] outputRow = m_data.evaluateForRow( getInputRowMeta(), m_data.getOutputRowMeta(), null, m_meta ); + putRow( m_data.getOutputRowMeta(), outputRow ); + } + } catch ( Exception ex ) { + throw new KettleException( + BaseMessages.getString( PMIScoringMeta.PKG, "PMIScoring.Error.ProblemWhileGettingPredictionsForBatch" ), + ex ); //$NON-NLS-1$ + } + } + + // see if we have an incremental model that is to be saved somewhere. + if ( !m_meta.getFileNameFromField() && m_meta.getUpdateIncrementalModel() ) { + if ( !Const.isEmpty( m_meta.getSavedModelFileName() ) ) { + // try and save that sucker... + try { + String modName = environmentSubstitute( m_meta.getSavedModelFileName() ); + File updatedModelFile = null; + if ( modName.startsWith( "file:" ) ) { + try { + modName = modName.replace( " ", "%20" ); + updatedModelFile = new File( new java.net.URI( modName ) ); + } catch ( Exception ex ) { + throw new KettleException( + BaseMessages.getString( PMIScoringMeta.PKG, "PMIScoring.Error.MalformedURIForUpdatedModelFile" ), + ex ); + } + } else { + updatedModelFile = new File( modName ); + } + PMIScoringData.saveSerializedModel( m_data.getModel(), updatedModelFile ); + } catch ( Exception ex ) { + throw new KettleException( + BaseMessages.getString( PMIScoringMeta.PKG, "PMIScoring.Error.ProblemSavingUpdatedModelToFile" ), + ex ); //$NON-NLS-1$ + } + } + } + + if ( m_meta.getFileNameFromField() ) { + // clear the main model + m_data.getModel().done(); + m_data.setModel( null ); + m_data.setDefaultModel( null ); + if ( m_modelCache != null ) { + m_modelCache.clear(); + } + } else { + m_data.getModel().done(); + m_data.setModel( null ); + m_data.setDefaultModel( null ); + } + + setOutputDone(); + return false; + } + // Make prediction for row using model try { if ( m_meta.getFileNameFromField() ) { diff --git a/src/main/java/org/pentaho/di/ui/trans/steps/pmi/BaseSupervisedPMIStepDialog.java b/src/main/java/org/pentaho/di/ui/trans/steps/pmi/BaseSupervisedPMIStepDialog.java index 5a30061..d641a82 100644 --- a/src/main/java/org/pentaho/di/ui/trans/steps/pmi/BaseSupervisedPMIStepDialog.java +++ b/src/main/java/org/pentaho/di/ui/trans/steps/pmi/BaseSupervisedPMIStepDialog.java @@ -540,6 +540,8 @@ protected void setData( BaseSupervisedPMIStepMeta meta ) { String[] schemeOpts = m_scheme.getSchemeOptions(); if ( schemeOpts != null && schemeOpts.length > 0 ) { meta.setSchemeCommandLineOptions( Utils.joinOptions( schemeOpts ) ); + } else { + meta.setSchemeCommandLineOptions( "" ); } if ( m_incrementalRowCacheField != null ) { diff --git a/src/main/java/org/pentaho/di/ui/trans/steps/pmi/GOEDialog.java b/src/main/java/org/pentaho/di/ui/trans/steps/pmi/GOEDialog.java index 48bc30a..d5a254e 100644 --- a/src/main/java/org/pentaho/di/ui/trans/steps/pmi/GOEDialog.java +++ b/src/main/java/org/pentaho/di/ui/trans/steps/pmi/GOEDialog.java @@ -316,7 +316,7 @@ protected static void setValuesOnObject( Object objectToEdit, Map 0 ) { if ( category == null || category.length() == 0 || !category.equalsIgnoreCase( m_propertyGroupingCategory ) ) { continue; } diff --git a/src/main/java/org/pentaho/pmi/engines/KerasScheme.java b/src/main/java/org/pentaho/pmi/engines/KerasScheme.java index ad20bbc..34a8e36 100644 --- a/src/main/java/org/pentaho/pmi/engines/KerasScheme.java +++ b/src/main/java/org/pentaho/pmi/engines/KerasScheme.java @@ -45,7 +45,8 @@ public abstract class KerasScheme { s_excludedSchemes = Arrays.asList( "Naive Bayes", "Naive Bayes incremental", "Naive Bayes multinomial", "Decision tree classifier", "Decision tree regressor", "Random forest classifier", "Random forest regressor", "Gradient boosted trees", - "Support vector regressor", "Multi-layer perceptron classifier", "Multi-layer perceptron regressor", + "Support vector regressor", "Support vector classifier", "Logistic regression", "Linear regression", + "Multi-layer perceptron classifier", "Multi-layer perceptron regressor", "Extreme gradient boosting classifier", "Extreme gradient boosting regressor", "Multi-layer perceptron classifier", "Multi-layer perceptron regressor" );