# Using LLMs as text classifiers with an sklearn interface

TODO:
- filter warnings
- google colab installs

<table align="left"><td>
<a target="_blank" href="https://colab.research.google.com/github/skorch-dev/skorch/blob/master/notebooks/llm-classifier-demo.ipynb">
    <img src="https://www.tensorflow.org/images/colab_logo_32px.png" />Run in Google Colab</a>  
</td><td>
<a target="_blank" href="https://github.com/skorch-dev/skorch/blob/master/notebooks/llm-classifier-demo.ipynb"><img width=32px src="https://www.tensorflow.org/images/GitHub-Mark-32px.png" />View source on GitHub</a></td></table>

The first part of the notebook requires Hugging Face `transformers` and `datasets` as additional dependencies. If you have not already installed it, you can do so like this:

`python -m pip install transformers datasets`

In [1]:
import subprocess

# Installation on Google Colab
try:
    import google.colab
    subprocess.run(['python', '-m', 'pip', 'install', 'skorch', 'transformers', 'datasets'])
except ImportError:
    pass

## imports

In [2]:
import datasets
import numpy as np
import pandas as pd
import transformers
import torch
from sklearn.metrics import accuracy_score, log_loss
from sklearn.model_selection import GridSearchCV

In [3]:
transformers.logging.set_verbosity_warning()
datasets.logging.set_verbosity_warning()

In [4]:
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'

## load data

Load sentiment dataset

In [5]:
imdb = datasets.load_dataset('imdb').shuffle(seed=0)

Found cached dataset imdb (/home/vinh/.cache/huggingface/datasets/imdb/plain_text/1.0.0/d613c88cf8fa3bab83b4ded3713f1f74830d1100e171db75bbddb80b3345c9c0)


  0%|          | 0/3 [00:00<?, ?it/s]

Loading cached shuffled indices for dataset at /home/vinh/.cache/huggingface/datasets/imdb/plain_text/1.0.0/d613c88cf8fa3bab83b4ded3713f1f74830d1100e171db75bbddb80b3345c9c0/cache-79aee49c9f40dc82.arrow
Loading cached shuffled indices for dataset at /home/vinh/.cache/huggingface/datasets/imdb/plain_text/1.0.0/d613c88cf8fa3bab83b4ded3713f1f74830d1100e171db75bbddb80b3345c9c0/cache-5a09ddfc1bd0fbc8.arrow
Loading cached shuffled indices for dataset at /home/vinh/.cache/huggingface/datasets/imdb/plain_text/1.0.0/d613c88cf8fa3bab83b4ded3713f1f74830d1100e171db75bbddb80b3345c9c0/cache-f131e6602007628b.arrow


Limit to 100 samples. Using zero/few shot learning mostly makes sense when there are very few labeled samples.

In [6]:
X = imdb['train'][:100]['text']
y = imdb['train'][:100]['label']

In [7]:
labels = np.array(['negative', 'positive'])[y]

## zero shot classification

In [8]:
from skorch.llm import ZeroShotClassifier

### "train" zero shot classifier

For this notebook, we use a small LLM, `flan-t5-small`.

In [9]:
clf = ZeroShotClassifier(
    'google/flan-t5-small', device=device, use_caching=False
)

In [10]:
%time clf.fit(X=None, y=['positive', 'negative']);

CPU times: user 2.92 s, sys: 615 ms, total: 3.53 s
Wall time: 3.38 s


In general, fitting is fast because, basically, nothing happens. If the LLM is not cached locally, it will, however, be downloaded from Hugging Face, which may take some time.

### evaluate

In [11]:
%time y_proba = clf.predict_proba(X)

Token indices sequence length is longer than the specified maximum sequence length for this model (844 > 512). Running this sequence through the model will result in indexing errors


CPU times: user 26.9 s, sys: 1.21 s, total: 28.1 s
Wall time: 7.4 s


In [12]:
log_loss(y, y_proba)

0.2870120002577984

In [13]:
y_pred = y_proba.argmax(1)

In [14]:
accuracy_score(y, y_pred)

0.86

In [15]:
clf.predict(["A masterpiece, instant classic, 5 stars out of 5"])

array(['positive'], dtype='<U8')

### grid search the prompt

In [16]:
prompt0 = """You are a text classification assistant.

The text to classify:

```
{text}
```

Choose the label among the following possibilities with the highest probability.
Only return the label, nothing more:

{labels}

Your response:
"""

In [17]:
prompt1 = """Your task is to classify text.

Choose the label among the following possibilities with the highest probability.
Only return the label, nothing more:

{labels}

The text to classify:

```
{text}
```

Your response:
"""

In [18]:
params = {'prompt': [prompt0, prompt1]}

In [19]:
search = GridSearchCV(clf, param_grid=params, cv=2, scoring=['accuracy', 'neg_log_loss'], refit=False)

In [20]:
%time search.fit(X, labels)

Token indices sequence length is longer than the specified maximum sequence length for this model (844 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (885 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (844 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (885 > 512). Running this sequence through the model will result in indexing errors


CPU times: user 1min 55s, sys: 5.6 s, total: 2min 1s
Wall time: 34.8 s


grid search results:

In [21]:
pd.DataFrame(search.cv_results_)[['mean_test_accuracy', 'mean_test_neg_log_loss', 'param_prompt', 'mean_score_time']]

Unnamed: 0,mean_test_accuracy,mean_test_neg_log_loss,param_prompt,mean_score_time
0,0.86,-0.287012,You are a text classification assistant.\n\nTh...,7.035766
1,0.93,-0.246949,Your task is to classify text.\n\nChoose the l...,7.055968


**Conclusion**: `prompt1` is performing better. Mean test accuracy of 93% and log loss of 0.25 are pretty good, given that we use zero shot and don't perform any fine-tuning.

## few shot classification

In [22]:
from skorch.llm import FewShotClassifier
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer

### train few shot classifier

Instead of passing the model name to initialize the classifier, as in `clf = FewShotClassifier('google/flan-t5-small')`, it is also possible to pass the model and tokenizer explicitly. This is a good option if you need more control over them. In our case, it amounts to the same result.

In [23]:
model = AutoModelForSeq2SeqLM.from_pretrained('google/flan-t5-small').to('cuda:0')
tokenizer = AutoTokenizer.from_pretrained('google/flan-t5-small')

Use `max_samples` samples from the training data for few shot prompting.

In [24]:
clf = FewShotClassifier(
    model=model, tokenizer=tokenizer, max_samples=5, use_caching=False
)

In [25]:
%time clf.fit(X[:5], labels[:5]);

CPU times: user 824 µs, sys: 43 µs, total: 867 µs
Wall time: 451 µs


Show how the prompt looks like:

In [26]:
print(clf.get_prompt(X[5]))

You are a text classification assistant.

Choose the label among the following possibilities with the highest probability.
Only return the label, nothing more:

['negative', 'positive']

Here are a few examples:

```
I liked how this started out, featuring some decent special-effects especially for a film 50 years old. There was some pretty impressive scenery. However, the film bogs down fairly early on with some very dumb dialog as the males all try to flirt with Anne Francis "Altaira Morbius.")<br /><br />Viewing this in the '90s after a long absence, it was fun to see Francis again, an actress who has done mostly television shows since this film was released....and is still acting. It also was interesting to see a young-looking Leslie Nielsen ("Dr. John J. Adams"), who I wouldn't have recognized had it not been for this voice <br /><br />I watched half of this movie before the boredom came almost overwhelming and I had a strong desire to go to sleep. I appreciated them re-doing this

### evaluate

In [27]:
%time y_proba = clf.predict_proba(X)

Token indices sequence length is longer than the specified maximum sequence length for this model (1512 > 512). Running this sequence through the model will result in indexing errors


CPU times: user 58.3 s, sys: 2.66 s, total: 1min
Wall time: 15.3 s


In [28]:
log_loss(y, y_proba)

0.23766533962376823

In [29]:
y_pred = y_proba.argmax(1)

In [30]:
accuracy_score(y, y_pred)

0.9

In [31]:
clf.predict(["Even if paid $1000, I would not watch this movie again"])

array(['negative'], dtype='<U8')

### grid search best number of few shot samples

Note that grid search will split `X` and `y` for each run. Since the few shot samples are taken from X and y, those will thus be different for each split, which could have a big influence on the performance of the model. If you always want to have the same few shot samples in each split, you should craft your own prompt with those examples and then use it with `ZeroShotClassifier`. Just ensure that those prompts are not part of the validation/test data!

In [32]:
params = {'max_samples': [3, 5, 7]}

In [33]:
search = GridSearchCV(clf, param_grid=params, cv=2, scoring=['accuracy', 'neg_log_loss'], refit=False)

In [34]:
%time search.fit(X, labels)

CPU times: user 7min 5s, sys: 18.7 s, total: 7min 24s
Wall time: 1min 51s


In [35]:
pd.DataFrame(search.cv_results_)[['mean_test_accuracy', 'mean_test_neg_log_loss', 'param_max_samples', 'mean_score_time']]

Unnamed: 0,mean_test_accuracy,mean_test_neg_log_loss,param_max_samples,mean_score_time
0,0.92,-0.210638,3,14.588259
1,0.92,-0.228785,5,18.326905
2,0.91,-0.229507,7,22.358228


**Conclusion**: No significant change in accuracy but medium improvement in log loss compared to zero shot. Having more samples doesn't help.

## Debugging

Working with LLMs can be difficult because it is hard to know for certain if the prompt works well and if the LLM is capable of classifying the input. For this reason, skorch provides a few options to help identify these issues.

### Returning unnormalized probabilities

By default, the model will normalize the probabilities to sum to 1. This is what is expected when calling `predict_proba`. However, this can hide underlying issues. The LLM can in theory predict any token from its vocabulary, there is no guarantee that it will choose one of the provided labels. skorch will force the LLM to use one of the labels, but we also track the probabilities assigned, or not assigned, to these labels.

To give an example, for a given input, it's possible that the LLM predicts a probability of 10% that the label is 'negative' and 70% that it is 'positive'. By default, we normalize the probability to be 1, i.e. we return 0.125 and 0.875. The problem is that we would return the same normalized probabilities even if the model predicts 1% and 7%. But if the model predicts such low probabilities, there is probably something wrong and we would like to know about it.

For this reason, we added the option to disable the normalization of probabilities. Let's check how well our zero-shot flan-t5 model is doing without normalization:

In [36]:
clf = ZeroShotClassifier('google/flan-t5-small', use_caching=False, probas_sum_to_1=False)

In [37]:
clf.fit(X=None, y=['positive', 'negative']);

In [38]:
y_proba = clf.predict_proba(X[:3])

In [39]:
y_proba

array([[0.55589342, 0.43614349],
       [0.56059146, 0.43085057],
       [0.94313842, 0.04362511]])

In [40]:
y_proba.sum(1)

array([0.99203691, 0.99144202, 0.98676353])

As you can see, the probabilities returned by flan-t5 are quite high. Without normalization, they still sum up to ~99%, which is very good.

Now let's take a look at an LLM that doesn't work well for this task, GPT2:

In [41]:
# note that since GPT2 is a decoder-only language model, we don't need to set use_caching=False
clf = ZeroShotClassifier('gpt2', probas_sum_to_1=False)

In [42]:
clf.fit(X=None, y=['positive', 'negative']);

In [43]:
y_proba = clf.predict_proba(X[:3])

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


In [44]:
y_proba

array([[3.86091414e-13, 1.38600503e-12],
       [2.50074673e-13, 8.08236941e-13],
       [3.82718732e-13, 1.23716673e-12]])

As we can see, the probabilities are really low, but if we had normalized them, we might not have noticed:

In [45]:
# normalize probabilities to sum up to 1
y_proba / y_proba.sum(1, keepdims=True)

array([[0.21787269, 0.78212731],
       [0.23629588, 0.76370412],
       [0.23626284, 0.76373716]])

This means we should probably use a different LLM or tinker with the prompt until we get better results.

### Specific actions when probabilities are low

We provide more options to identify low probabilities in a more way that does not require manually inspecting the probabilities. For this, we provide two arguments for `ZeroShotClassifier` and `FewShotClassifier`.

The first argument is called `error_low_prob`. It should be one of the following strings: `'ignore'`, `'warn'`, `'raise'`, or `'return_none'`.

By default, it is `'ignore'`, which means that nothing happens, no matter how low the predicted proabilities. By setting it to `'warn'`, there will be a warning when the total probabilities of at least one predicted sample is too low. Use this option if you want to get the result but be alerted about possible problems.

By passing `error_low_prob='raise'`, an error will be raised as soon as a sample with low total probabilities is encountered. This is useful if you want inference to stop immediately, instead of waiting for all predictions to be made.

Finally, you can set `error_low_prob='return_none'`. In this case, nothing changes when calling `predict_proba`. When calling `predict`, however, the probabilities for the samples will be checked and if they're too low, the prediction will be replaced by `None`. This is useful if the predictions are generally good, but some examples are, for one reason or another, hard to predict.

The second parameter, which should be used in conjunction with `error_low_prob`, is called `threshold_low_prob`. This is simply a float between 0 and 1 that indicates what the probability is that should be considered "low". Note that this value is compared to the _sum of the probability for all labels_ of a given sample. So when setting `threshold_low_prob=0.1`, and the probability for 'negative' is 0.05, but the probability for 'positive' is 0.2, this would be fine because in total, their probabilities exceed 0.1.

Let's see how this works in practice by using the option to raise an error and setting the threshold to 0.5:

In [46]:
# note that since GPT2 is a decoder-only language model, we don't need to set use_caching=False
clf = ZeroShotClassifier('gpt2', error_low_prob='raise', threshold_low_prob=0.5)

In [47]:
clf.fit(X=None, y=['positive', 'negative']);

In [48]:
try:
    clf.predict_proba(X[:3])
except Exception as exc:
    print("There was an error:", exc)

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


There was an error: The sum of all probabilities is 0.000, which is below the minimum threshold of 0.500


As you can see, we indeed got an error, alerting us immediately to potential issues.

## Testing MNLI

An existing method is to use natural language inference (NLI). Compare the results to https://huggingface.co/facebook/bart-large-mnli, which is the most used zero shot classifier on Hugging Face.

In [49]:
from transformers import pipeline

In [50]:
classifier = pipeline('zero-shot-classification', model='facebook/bart-large-mnli', device=device)

In [51]:
%time preds = classifier(imdb['train'][:100]['text'], ['negative', 'positive'])

CPU times: user 12.8 s, sys: 3.92 ms, total: 12.8 s
Wall time: 12.8 s


In [52]:
y_proba = np.vstack([p['scores'] if p['labels'] == ['negative', 'positive'] else p['scores'][::-1] for p in preds])

In [53]:
accuracy_score(y, y_proba.argmax(1))

0.84

In [54]:
log_loss(y, y_proba)

0.3443705626436628

**Conclusion**: This model is slower than the tested zero shot classifier, it is less flexible (we cannot adjust prompt or other parameters), and it performs worse.

## Testing vanilla ML

Use a standard TFIDF + logistic regression benchmark.

In [55]:
from sklearn.pipeline import Pipeline
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import cross_validate

In [56]:
tfidf = Pipeline([
    ('tfidf', TfidfVectorizer()),
    ('clf', LogisticRegression()),
])

In [57]:
params = {'tfidf__max_features': [500, 1000, 2000], 'tfidf__ngram_range': [(1, 1), (1, 2), (2, 2), (1, 3)]}

In [58]:
search = GridSearchCV(
    tfidf, param_grid=params, cv=2, scoring=['accuracy', 'neg_log_loss'], refit=False
)

In [59]:
%time search.fit(X, y)

CPU times: user 1.18 s, sys: 3.95 ms, total: 1.18 s
Wall time: 1.18 s


The table is quite big, let's look at the top 5 best log losses:

In [60]:
cols = ['mean_test_accuracy', 'mean_test_neg_log_loss', 'param_tfidf__max_features', 'param_tfidf__ngram_range']
pd.DataFrame(search.cv_results_)[cols].sort_values('mean_test_neg_log_loss', ascending=False).head()

Unnamed: 0,mean_test_accuracy,mean_test_neg_log_loss,param_tfidf__max_features,param_tfidf__ngram_range
3,0.69,-0.662397,500,"(1, 3)"
7,0.71,-0.663959,1000,"(1, 3)"
1,0.68,-0.664004,500,"(1, 2)"
5,0.7,-0.664215,1000,"(1, 2)"
0,0.65,-0.664609,500,"(1, 1)"


**Conclusion**: This classical model is much faster, even if we include the training time, because it is much smaller than an LLM. However, it's scores are also much worse, given the small dataset. If speed is no concern, using an LLM classifier would thus be a good option for this task.