Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Tidy-up and improve ergonomics with new interface and dataset #45

Closed
wants to merge 37 commits into from

Conversation

bytesnake
Copy link
Member

@bytesnake bytesnake commented Sep 16, 2020

I think the linfa crate has collected enough algorithms to move towards a unified interface. The intention of this PR is to tidy up the project and implement unified traits for transformers, fittable models and incremental models. Further a Dataset struct is introduced, which wraps records, targets, labels and weights in an unified way.

The description is also WIP and will be updated over time

Traits for the learning process

This PR introduces traits for transformers, learnable models and incremental models. The trait implementation (see here) follows the convention:

  • all three traits should be implemented for a hyparameter set, which governs the learning algorithm
  • Transformer: a transformer is an algorithm, which does not expose its inner state. A common example are kernel methods, which are unique for every dataset given. Support Vector Machines, on the other hand, may implement this as well and call fit and predict internally.
  • Fit: a fittable algorithm learns model from the training dataset. It can predict new targets from the same domain as the input data.
  • IncrementalFit: an incremental algorithm can make updates of its inner state, depending on a former model and new data

Dataset

A dataset contains a records field and further may contains targets, weights and labels. The Targets trait corresponds to any kind of target (f32 as well as String), but the Labels trait can be used to narrow down the implementation to comparable targets (atm implemented for bool, usize and String)

Tidying up the crate

  • remove top level dependencies to sub-crates
  • remove dependency on serde

Unresolved questions

  • the implementation of a dataset follows the model of ndarray, where everything is implemented with concrete types. This creates on the one hand some boilercode (for example Vec<T>, &Vec<T> and &[T] have to be implemented separately), but I don't think that Rust is expressive enough atm to go another way
  • the name Records was chosen to distinguish it form the Data trait of ndarray, but its not the most common name to describe the actual data in a dataset

Example

Here is a simple example for a kernel transformation, followed by the training of a SVC model:

    // everything above 6.5 is considered a good wine
    let dataset = Dataset::new(data, targets)
        .map_targets(|x| **x > 6.5);

    // split into training and validation dataset
    let (train, valid) = dataset.split_with_ratio(0.1);
    
    // transform with RBF kernel
    let train_kernel = Kernel::params()
        .method(KernelMethod::Gaussian(80.0))
        .transform(&train);

    println!(
        "Fit SVM classifier with #{} training points",
        train.observations()
    );

    // fit a SVM with C value 7 and 0.6 for positive and negative classes
    let model = Svm::params()
        .pos_neg_weights(7., 0.6)
        .fit(&train_kernel);

    println!("{}", model);
    // A positive prediction indicates a good wine, a negative, a bad one
    fn tag_classes(x: &bool) -> String {
        if *x {
            "good".into()
        } else {
            "bad".into()
        }
    };

    // map targets for validation dataset
    let valid = valid.map_targets(tag_classes);

    // predict and map targets
    let pred = model.predict(&valid)
        .map_targets(|x| *x > 0.0)
        .map_targets(tag_classes);

    // create a confusion matrix
    let cm = pred.confusion_matrix(&valid);

    // Print the confusion matrix, this will print a table with four entries. On the diagonal are
    // the number of true-positive and true-negative predictions, off the diagonal are
    // false-positive and false-negative
    println!("{:?}", cm);

    // Calculate the accuracy and Matthew Correlation Coefficient (cross-correlation between
    // predicted and targets)
    println!("accuracy {}, MCC {}", cm.accuracy(), cm.mcc());

@bytesnake
Copy link
Member Author

I don't have super much time atm, but any feedback is welcome

NDarray has already a trait called `Data`, to avoid name collisions our
`Data` is now called `Records`
 * implement `Transformer` for `KernelParams`
 * move creation functions to `KernelParams`
 * use `Records` and `Float`
 * make `Fit` more generic over return object
 * implement one-class classification for Labels = ()
The phantom type distinguishes between different kind of predictions,
like boolean, probability or floating predictions
 * add builder pattern to kernel methods
 * wrestle with type system
 * split Targets into Targets and Labels
 * implement Fit and Predict traits for SVM regression
 * create a dataset from records and targets
 * use kernel method as transformer
 * fit a model with hyperparameters and given dataset
 * create a second validation dataset and populate targets with predict
 * evaluate with confusion matrix
@bytesnake bytesnake changed the title [WIP] Experiment with public interface and datasets [WIP] Tidy-up and improve ergonomics with new interface and dataset Oct 17, 2020
@paulkoerbitz
Copy link
Collaborator

I've taken a short look tonight! This looks really awesome, thanks for this extensive amount of work! I'd like to spend a bit more time thinking about the names and visibility of some items. I'll write more comments tomorrow.

Thanks for the work!

@paulkoerbitz paulkoerbitz self-requested a review October 27, 2020 21:30
@bytesnake
Copy link
Member Author

Moved to #55

@bytesnake bytesnake closed this Nov 2, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants