Skip to content

Commit

Permalink
Zero-shot and few-shot classification (#979)
Browse files Browse the repository at this point in the history
Adding zero-shot and few-shot classification to skorch by using Large Language Models (LLMs)

To see this in action, check out the accompanying [notebook](https://github.com/skorch-dev/skorch/blob/llm-classifier/notebooks/LLM_Classifier.ipynb).

## Description

This PR adds two new classes, `ZeroShotClassifier` and `FewShotClassifier`. The classes allow users to perform zero-shot and few-shot classification. At the same time, the classes are sklearn estimators/classifiers, so they can be used as normal sklearn models.

This opens up a few opportunities, such as grid searching for the best prompt or the optimal number of examples to use for few shot learning. Even if a user doesn't plan on using these classes in the final application, they may still use it to figure out the best prompt and LLM to use.

By integrating with Hugging Face transformers, this feature allows to run the LLMs completely locally, no API is being called with user data. This also allows us to do some nice tricks that are not easily achieved -- or may even be impossible -- with APIs, such as forcing the LLM to predict only the labels we provide, or to return the probabilities associated with each label. We also added some caching on top, which is useful if the labels share a long common prefix (say, `intent.support.email` and `intent.support.phone`).

## Similar libraries

This feature was inspired by https://github.com/iryna-kondr/scikit-llm, which has a similar approach, but is not based on HF transformers. At the time of writing, it uses OpenAI API or GPT4All. Apart from those differences, AFAICT, this package does not provide a `predict_proba`, doesn't allow to perform grid search on prompt etc., doesn't prevent the LLM from making predictions that differ from the allowed labels, and doesn't do any caching. It is a good choice though if using OpenAI is desired or if users are interested in summarization, text embeddings, translation, or multi-label prediction.

## Implementation

@ottonemo and I paired on this PR.

Why add this feature to skorch, as it doesn't really use anything from skorch? This is true and it could also be a standalone package, for instance. The arguments for adding it here is that this feature is spiritually similar to what skorch already does, namely marrying (PyTorch) neural networks with scikit-learn. Since this feature has dependencies on PyTorch and sklearn anyway, it's not like users would have to install a lot of useless skorch dependencies only to use this feature. And in the end, adding this feature to skorch should give it some visibility, which would otherwise not be easy to achieve.

Why have a new submodule `llm`? We decided on this because we may very well have more LLM-related features to add in the future. Otherwise, the code could probably also move `hf.py`.

Why base  this on top of HF transformers? This library provides the widest range of language models and thus immediately gives the user many options. Thanks to the ability to intercept logits after each forward call, we can implement some of the features described above, which wouldn't be possible with other options. Users will probably already have some familiarity the HF transformers.

We think if this feature finds a lot of adoption and there is a demand for it, we can still later add the option to use OpenAI API or other such services. This may prevent us from offering certain features like `predict_proba`, but perhaps not all users are interested in that.

## Open tasks

This PR provides a first batch of features, but there are certainly ideas to add more features if this idea gets traction. The most obvious ones are:

1. Multi-label prediction: It should be rather straightforward to add multi-label predictions, i.e. for a single sample, allowing multiple labels to be predicted. I would have to check how sklearn expects the data to be for multi-label prediction.
2. More flexibility when it comes to formatting the prompt. Right now, the prompt as rather rigid. For instance, users cannot easily change how exactly the few shot samples are added to the prompt. Perhaps we can do better there, e.g. by using jinja. Another idea is to allow the `prompt` argument to be a function that returns a string.
3. Providing the same functionality as an sklearn transformer. This would allow users to put this in an sklearn `Pipeline` as part of a feature engineering step. As an example, let's say we deal with customer reviews in an e-commerce setting. This transformer could generate a feature that indicates whether a customer review mentions the size of an item of clothing, which may not be of interest in itself, but may be a useful feature for an overarching supervised learning task.

It this PR is merged as is, these open tasks should be added as issues to be worked on in the future.

## Coincidental changes

The base class of skorch errors was changed to `Exception` from `BaseException`. This way, a user can do `try ... except Exception`. Previously, they would have had to do `try ... except BaseException`, which is pretty uncommon. Probably not many (or any) users should be negatively affected by this change, since `except BaseException` still works.


# Changes in this squashed commit

* [WIP] Initial working version

* Add a bunch of more tests

Most notably, for FewShotClassifier.

Add functionality to check prompt placeholders.

Use check_random_state in FewShotClassifier.

* Remove need for padding label ids

This is preferred because not all pretrained tokenizers define a pad
token, so using padding would put an extra burden on users.

* Loading model&tokenizer just from model name

Just pass the model name. Passing model and tokenizer explicitly is
still possible.

* Move prompts to their own module

* Use torch.tensor to initialize tensor

Not torch.LongTensor.

* Some fixes to _load_model_and_tokenizer

* Add check for all probs being 0

* Remove option to cache seq2seq model

There is a fundamental problem with caching seq2seq/encoder-decoder,
because the outcomes are not necessarily the same. Therefore, caching is
disabled for seq2seq and an error is raised.

A few more checks on the model/tokenizer are introduced to ensure that
valid arguments are passed.

* Some changes to the notebook

- Use fewer samples: 100 instead of 500, 500 was just a bit slow. The
  results are not affected a lot by this
- Add installation instructions for extra dependencies
- Disable caching for flan-t5

* Add checking for low probabilities

* Add test for wrong architecture raising

* Add docstrings, more checking

* More checking, testing of prompts, extend notebook

* Add more code and tests for verifying X and y

Also, fixed a bug that prevented usage of y if ys are not strings. It's
not something we recommend to do (we warn about it), but it should be
possible.

* Fix the repr of the estimators

* Update notebook, adding a lot more explanations

* Rename notebook to conform with existing naming

* Add documentation

* Add entry to CHANGES.md

* Clean up notebook

* Fix a docstring

* Support BFloat

The `prod_cpu` op is not implemented for bfloat16
which is why we need to convert to a support float type.
Since the sample will probably quite small we don't need to
specify a more concrete / memory-friendly type.

* Add method to clear cache

* Fix a few tests after prompt was changed

Also: Add comment that changing the prompt may require adjusting the
tests.

* Add docs section about debugging, fix some typos

* Add section prompts for few-shot classification

* Fix a bug in get_examples for few-shot learning

It was possible that the same sample is returned twice. This should now
be fixed. Tests were added.

Also added a docstring for random_state, which I forget to add earlier.

* Add one more debugging tip to the docs

* Document when not to use this new feature

* Fix issue with rendering some docstrings

* Apply some linting

- long lines
- useless f-string

* Small amendment to docs about dealing with issues

* Allow prompts not to have all placeholders

Will only warn, not raise an error. This is because it may sometimes be
desired not to fill all the keys (especially the labels may not be
needed, depending on the prompt).

* Add Bonus section testing bloomz-1b1 on PIQA

In addition to the already comprehensive list of examples it might
be a good idea to show-case how to handle tasks that are not simply
classification tasks and how to deal with the problems that arise
when doing so.

---------

Co-authored-by: Marian Tietz <marian.tietz@ottogroup.com>
  • Loading branch information
BenjaminBossan and ottonemo committed Jun 23, 2023
1 parent 5f1ec03 commit 7991e5c
Show file tree
Hide file tree
Showing 14 changed files with 4,773 additions and 2 deletions.
2 changes: 2 additions & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Added

- Add support for [zero-shot and few-shot classification](https://skorch.readthedocs.io/en/latest/user/llm.html#using-language-models-as-zero-and-few-shot-classifiers) with the help of Large Language Models and the Hugging Face transformers library

### Changed

### Fixed
Expand Down
1 change: 1 addition & 0 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ User's Guide
user/customization
user/performance
user/huggingface
user/LLM
user/FAQ


Expand Down
5 changes: 5 additions & 0 deletions docs/llm.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
skorch.llm
==========

.. automodule:: skorch.llm
:members:
1 change: 1 addition & 0 deletions docs/skorch.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ skorch
helper
hf
history
llm
net
probabilistic
regressor
Expand Down
428 changes: 428 additions & 0 deletions docs/user/LLM.rst

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion docs/user/REST.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
REST Service
============

In this section we'll take the RNN sentiment classifer from the
In this section we'll take the RNN sentiment classifier from the
example `Predicting sentiment on the IMDB dataset
<https://github.com/skorch-dev/skorch/blob/master/examples/rnn_classifer/RNN_sentiment_classification.ipynb>`_
and use it to demonstrate how to easily expose your PyTorch module on
Expand Down
2 changes: 2 additions & 0 deletions docs/user/tutorials.rst
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,5 @@ The following are examples and notebooks on how to use skorch.
* `Hugging Face Vision Transformer <https://nbviewer.org/github/skorch-dev/skorch/blob/master/notebooks/Hugging_Face_VisionTransformer.ipynb>`_ - Show how to fine-tune a vision transformer model for a classification task using the Hugging Face transformers library and skorch. `Run in Google Colab 💻 <https://colab.research.google.com/github/skorch-dev/skorch/blob/master/notebooks/Hugging_Face_VisionTransformer.ipynb>`_

* `SkorchDoctor <https://nbviewer.org/github/skorch-dev/skorch/blob/master/notebooks/Skorch_Doctor.ipynb>`_ - Diagnosing problems in training your neural net `Run in Google Colab 💻 <https://colab.research.google.com/github/skorch-dev/skorch/blob/master/notebooks/Skorch_Doctor.ipynb>`_

* `Classifying with LLMs <https://nbviewer.org/github/skorch-dev/skorch/blob/master/notebooks/LLM_Classifier.ipynb>`_ - Using (Large) Language Models as zero-shot and few-shot classifiers `Run in Google Colab 💻 <https://colab.research.google.com/github/skorch-dev/skorch/blob/master/notebooks/LLM_Classifier.ipynb>`_

0 comments on commit 7991e5c

Please sign in to comment.