-
Notifications
You must be signed in to change notification settings - Fork 23
/
package.scala
40 lines (32 loc) · 1.57 KB
/
package.scala
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
package io.picnicml.doddlemodel
import breeze.linalg.{DenseMatrix, DenseVector, unique}
import io.picnicml.doddlemodel.CrossScalaCompat.floatOrdering
import io.picnicml.doddlemodel.data.Feature.FeatureIndex
package object data {
type RealVector = DenseVector[Float]
type IntVector = DenseVector[Int]
type Simplex = DenseMatrix[Float]
type Features = DenseMatrix[Float]
type Target = DenseVector[Float]
type FeaturesWithIndex = (Features, FeatureIndex)
type Dataset = (Features, Target)
type DatasetWithIndex = (Features, Target, FeatureIndex)
def loadBostonDataset: DatasetWithIndex = ResourceDatasetLoaders.loadBostonDataset
def loadBreastCancerDataset: DatasetWithIndex = ResourceDatasetLoaders.loadBreastCancerDataset
def loadIrisDataset: DatasetWithIndex = ResourceDatasetLoaders.loadIrisDataset
def loadHighSchoolTestDataset: DatasetWithIndex = ResourceDatasetLoaders.loadHighSchoolTestDataset
def numberOfUniqueGroups(groups: IntVector): Int = {
val uniqueGroups = unique(groups)
require(uniqueGroups.toArray.sorted sameElements Array.range(0, uniqueGroups.length),
"Invalid encoding of groups, all group indices in [0, numGroups) have to exist")
uniqueGroups.length
}
def numberOfTargetClasses(y: Target): Int = {
val targetClasses = unique(y)
require(targetClasses.length >= 2,
"Target variable must be comprised of at least two categories")
require(targetClasses.toArray.sorted sameElements Array.range(0, targetClasses.length),
"Invalid encoding of categories in the target variable")
targetClasses.length
}
}