Skip to content

Latest commit



92 lines (65 loc) · 2.74 KB


File metadata and controls

92 lines (65 loc) · 2.74 KB


In order to use different models, query strategies, and stopping criteria from the active learner, we provide classification abstractions to allow for a unified interface.


The classifier interface is very simple and scikit-learn-like, with the difference that it operates on :py:class:`Datasets<>` objects. Call the fit() method with a training set as argument to train your classifier, and use predict() to obtain predictions.

.. literalinclude:: ../small_text/classifiers/
   :pyobject: Classifier


This is a simple example which shows the training of a tiny toy dataset.

.. testcode::

   import numpy as np
   from small_text.classifiers import ConfidenceEnhancedLinearSVC, SklearnClassifier
   from import SklearnDataset

   # this is a linear which has been extended to return confidence estimates
   model = ConfidenceEnhancedLinearSVC()
   num_classes = 2
   clf = SklearnClassifier(model, num_classes)

   x = np.array([
       [0, 0],
       [0, 0.5],
       [0.5, 1],
       [1, 1]
   y = np.array([0, 0, 1, 1])
   train_set = SklearnDataset(x, y)

   Generate predictions on the train set
   (Only for the purpose of demonstration;
    usually you would be more interested in obtaining predictions on new, unseen data.)
   y_train_pred = clf.predict(train_set)


.. testoutput::

   [0 0 1 1]


To configure the active learner to use classifiers a factory object is required because new classifier objects are created at each iteration (unless explicitly configured not to). A factory creates new instances of an object, for which the knowledge of what to pass to the constructor is required, which is why we need a factory. Assuming all constructor took zero arguments we would not need factories here.

.. testcode::

   from small_text.classifiers import ConfidenceEnhancedLinearSVC
   from small_text.classifiers.factories import SklearnClassifierFactory

   clf_template = ConfidenceEnhancedLinearSVC()
   num_classes = 2

   clf_factory = SklearnClassifierFactory(clf_template, num_classes)
   clf =

This also means that any classifier parameters, e.g. for multi-label classification, are managed by the factory:

.. testcode::

   from small_text.classifiers import ConfidenceEnhancedLinearSVC
   from small_text.classifiers.factories import SklearnClassifierFactory

   clf_template = ConfidenceEnhancedLinearSVC()
   num_classes = 2
   classifier_kwargs = {'multi_label': True}

   clf_factory = SklearnClassifierFactory(clf_template, num_classes, kwargs=classifier_kwargs)
   clf =