Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Zero-shot and few-shot classification (#979)
Adding zero-shot and few-shot classification to skorch by using Large Language Models (LLMs) To see this in action, check out the accompanying [notebook](https://github.com/skorch-dev/skorch/blob/llm-classifier/notebooks/LLM_Classifier.ipynb). ## Description This PR adds two new classes, `ZeroShotClassifier` and `FewShotClassifier`. The classes allow users to perform zero-shot and few-shot classification. At the same time, the classes are sklearn estimators/classifiers, so they can be used as normal sklearn models. This opens up a few opportunities, such as grid searching for the best prompt or the optimal number of examples to use for few shot learning. Even if a user doesn't plan on using these classes in the final application, they may still use it to figure out the best prompt and LLM to use. By integrating with Hugging Face transformers, this feature allows to run the LLMs completely locally, no API is being called with user data. This also allows us to do some nice tricks that are not easily achieved -- or may even be impossible -- with APIs, such as forcing the LLM to predict only the labels we provide, or to return the probabilities associated with each label. We also added some caching on top, which is useful if the labels share a long common prefix (say, `intent.support.email` and `intent.support.phone`). ## Similar libraries This feature was inspired by https://github.com/iryna-kondr/scikit-llm, which has a similar approach, but is not based on HF transformers. At the time of writing, it uses OpenAI API or GPT4All. Apart from those differences, AFAICT, this package does not provide a `predict_proba`, doesn't allow to perform grid search on prompt etc., doesn't prevent the LLM from making predictions that differ from the allowed labels, and doesn't do any caching. It is a good choice though if using OpenAI is desired or if users are interested in summarization, text embeddings, translation, or multi-label prediction. ## Implementation @ottonemo and I paired on this PR. Why add this feature to skorch, as it doesn't really use anything from skorch? This is true and it could also be a standalone package, for instance. The arguments for adding it here is that this feature is spiritually similar to what skorch already does, namely marrying (PyTorch) neural networks with scikit-learn. Since this feature has dependencies on PyTorch and sklearn anyway, it's not like users would have to install a lot of useless skorch dependencies only to use this feature. And in the end, adding this feature to skorch should give it some visibility, which would otherwise not be easy to achieve. Why have a new submodule `llm`? We decided on this because we may very well have more LLM-related features to add in the future. Otherwise, the code could probably also move `hf.py`. Why base this on top of HF transformers? This library provides the widest range of language models and thus immediately gives the user many options. Thanks to the ability to intercept logits after each forward call, we can implement some of the features described above, which wouldn't be possible with other options. Users will probably already have some familiarity the HF transformers. We think if this feature finds a lot of adoption and there is a demand for it, we can still later add the option to use OpenAI API or other such services. This may prevent us from offering certain features like `predict_proba`, but perhaps not all users are interested in that. ## Open tasks This PR provides a first batch of features, but there are certainly ideas to add more features if this idea gets traction. The most obvious ones are: 1. Multi-label prediction: It should be rather straightforward to add multi-label predictions, i.e. for a single sample, allowing multiple labels to be predicted. I would have to check how sklearn expects the data to be for multi-label prediction. 2. More flexibility when it comes to formatting the prompt. Right now, the prompt as rather rigid. For instance, users cannot easily change how exactly the few shot samples are added to the prompt. Perhaps we can do better there, e.g. by using jinja. Another idea is to allow the `prompt` argument to be a function that returns a string. 3. Providing the same functionality as an sklearn transformer. This would allow users to put this in an sklearn `Pipeline` as part of a feature engineering step. As an example, let's say we deal with customer reviews in an e-commerce setting. This transformer could generate a feature that indicates whether a customer review mentions the size of an item of clothing, which may not be of interest in itself, but may be a useful feature for an overarching supervised learning task. It this PR is merged as is, these open tasks should be added as issues to be worked on in the future. ## Coincidental changes The base class of skorch errors was changed to `Exception` from `BaseException`. This way, a user can do `try ... except Exception`. Previously, they would have had to do `try ... except BaseException`, which is pretty uncommon. Probably not many (or any) users should be negatively affected by this change, since `except BaseException` still works. # Changes in this squashed commit * [WIP] Initial working version * Add a bunch of more tests Most notably, for FewShotClassifier. Add functionality to check prompt placeholders. Use check_random_state in FewShotClassifier. * Remove need for padding label ids This is preferred because not all pretrained tokenizers define a pad token, so using padding would put an extra burden on users. * Loading model&tokenizer just from model name Just pass the model name. Passing model and tokenizer explicitly is still possible. * Move prompts to their own module * Use torch.tensor to initialize tensor Not torch.LongTensor. * Some fixes to _load_model_and_tokenizer * Add check for all probs being 0 * Remove option to cache seq2seq model There is a fundamental problem with caching seq2seq/encoder-decoder, because the outcomes are not necessarily the same. Therefore, caching is disabled for seq2seq and an error is raised. A few more checks on the model/tokenizer are introduced to ensure that valid arguments are passed. * Some changes to the notebook - Use fewer samples: 100 instead of 500, 500 was just a bit slow. The results are not affected a lot by this - Add installation instructions for extra dependencies - Disable caching for flan-t5 * Add checking for low probabilities * Add test for wrong architecture raising * Add docstrings, more checking * More checking, testing of prompts, extend notebook * Add more code and tests for verifying X and y Also, fixed a bug that prevented usage of y if ys are not strings. It's not something we recommend to do (we warn about it), but it should be possible. * Fix the repr of the estimators * Update notebook, adding a lot more explanations * Rename notebook to conform with existing naming * Add documentation * Add entry to CHANGES.md * Clean up notebook * Fix a docstring * Support BFloat The `prod_cpu` op is not implemented for bfloat16 which is why we need to convert to a support float type. Since the sample will probably quite small we don't need to specify a more concrete / memory-friendly type. * Add method to clear cache * Fix a few tests after prompt was changed Also: Add comment that changing the prompt may require adjusting the tests. * Add docs section about debugging, fix some typos * Add section prompts for few-shot classification * Fix a bug in get_examples for few-shot learning It was possible that the same sample is returned twice. This should now be fixed. Tests were added. Also added a docstring for random_state, which I forget to add earlier. * Add one more debugging tip to the docs * Document when not to use this new feature * Fix issue with rendering some docstrings * Apply some linting - long lines - useless f-string * Small amendment to docs about dealing with issues * Allow prompts not to have all placeholders Will only warn, not raise an error. This is because it may sometimes be desired not to fill all the keys (especially the labels may not be needed, depending on the prompt). * Add Bonus section testing bloomz-1b1 on PIQA In addition to the already comprehensive list of examples it might be a good idea to show-case how to handle tasks that are not simply classification tasks and how to deal with the problems that arise when doing so. --------- Co-authored-by: Marian Tietz <marian.tietz@ottogroup.com>
- Loading branch information