# Using LLMs as text classifiers with an sklearn interface

TODO:
- filter warnings
- google colab installs

<table align="left"><td>
<a target="_blank" href="https://colab.research.google.com/github/skorch-dev/skorch/blob/master/notebooks/llm-classifier-demo.ipynb">
    <img src="https://www.tensorflow.org/images/colab_logo_32px.png" />Run in Google Colab</a>  
</td><td>
<a target="_blank" href="https://github.com/skorch-dev/skorch/blob/master/notebooks/llm-classifier-demo.ipynb"><img width=32px src="https://www.tensorflow.org/images/GitHub-Mark-32px.png" />View source on GitHub</a></td></table>

The first part of the notebook requires Hugging Face `transformers` and `datasets` as additional dependencies. If you have not already installed it, you can do so like this:

`python -m pip install transformers datasets`

In [1]:
import subprocess

# Installation on Google Colab
try:
    import google.colab
    subprocess.run(['python', '-m', 'pip', 'install', 'skorch', 'transformers', 'datasets'])
except ImportError:
    pass

## imports

In [2]:
import datasets
import numpy as np
import pandas as pd
import transformers
import torch
from sklearn.metrics import accuracy_score, log_loss
from sklearn.model_selection import GridSearchCV

In [3]:
transformers.logging.set_verbosity_error()
datasets.logging.set_verbosity_error()

In [4]:
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'

## load data

Load sentiment dataset

In [5]:
imdb = datasets.load_dataset('imdb').shuffle(seed=0)

  0%|          | 0/3 [00:00<?, ?it/s]

Limit to 100 samples. Using zero/few shot learning mostly makes sense when there are very few labeled samples.

In [6]:
X = imdb['train'][:100]['text']
y = imdb['train'][:100]['label']

In [7]:
labels = np.array(['negative', 'positive'])[y]

## zero shot classification

In [8]:
from skorch.llm import ZeroShotClassifier

### "train" zero shot classifier

For this notebook, we use a small LLM, `flan-t5-small`.

In [9]:
clf = ZeroShotClassifier(
    'google/flan-t5-small', generate_kwargs={'max_length': 512}, device=device, use_caching=False
)

In [10]:
%time clf.fit(X=None, y=['positive', 'negative']);

CPU times: user 2.93 s, sys: 590 ms, total: 3.52 s
Wall time: 3.27 s


In general, fitting is fast because, basically, nothing happens. If the LLM is not cached locally, it will, however, be downloaded from Hugging Face, which may take some time.

### evaluate

In [11]:
%time y_proba = clf.predict_proba(X)

CPU times: user 28.4 s, sys: 1.37 s, total: 29.7 s
Wall time: 7.85 s


In [12]:
log_loss(y, y_proba)

0.3767035707377107

In [13]:
y_pred = y_proba.argmax(1)

In [14]:
accuracy_score(y, y_pred)

0.83

In [15]:
clf.predict(["A masterpiece, instant classic, 5 stars out of 5"])

array(['positive'], dtype='<U8')

### grid search the prompt

In [16]:
prompt0 = """You are a text classification assistant.

The text to classify:

```
{text}
```

Choose the label among the following possibilities with the highest probability.
Only return the label, nothing more:

{labels}

Your response:
"""

In [17]:
prompt1 = """Your task is to classify text.

Choose the label among the following possibilities with the highest probability.
Only return the label, nothing more:

{labels}

The text to classify:

```
{text}
```

Your response:
"""

In [18]:
params = {'prompt': [prompt0, prompt1]}

In [19]:
search = GridSearchCV(clf, param_grid=params, cv=2, scoring=['accuracy', 'neg_log_loss'], refit=False)

In [20]:
%time search.fit(X, labels)

CPU times: user 1min 53s, sys: 5.32 s, total: 1min 58s
Wall time: 33.7 s


grid search results:

In [21]:
pd.DataFrame(search.cv_results_)[['mean_test_accuracy', 'mean_test_neg_log_loss', 'param_prompt', 'mean_score_time']]

Unnamed: 0,mean_test_accuracy,mean_test_neg_log_loss,param_prompt,mean_score_time
0,0.87,-0.296063,You are a text classification assistant.\n\nTh...,6.847503
1,0.93,-0.246425,Your task is to classify text.\n\nChoose the l...,6.929659


**Conclusion**: `prompt1` is performing better. Mean test accuracy of 93% and log loss of 0.25 are pretty good, given that we use zero shot and don't perform any fine-tuning.

## few shot classification

In [22]:
from skorch.llm import FewShotClassifier
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer

### train few shot classifier

Instead of passing the model name to initialize the classifier, as in `clf = FewShotClassifier('google/flan-t5-small')`, it is also possible to pass the model and tokenizer explicitly. This is a good option if you need more control over them. In our case, it amounts to the same result.

In [23]:
model = AutoModelForSeq2SeqLM.from_pretrained('google/flan-t5-small').to('cuda:0')
tokenizer = AutoTokenizer.from_pretrained('google/flan-t5-small')

Use `max_samples` samples from the training data for few shot prompting.

In [24]:
clf = FewShotClassifier(
    model=model, tokenizer=tokenizer, max_samples=5, generate_kwargs={'max_length': 512}, use_caching=False
)

In [25]:
%time clf.fit(X[:5], labels[:5]);

CPU times: user 815 µs, sys: 42 µs, total: 857 µs
Wall time: 386 µs


Show how the prompt looks like:

In [26]:
print(clf.get_prompt(X[5]))

You are a text classification assistant.

Choose the label among the following possibilities with the highest probability.
Only return the label, nothing more:

['negative' 'positive']

Here are a few examples:

```
I watch lots of scary movies (or at least they try to be) and this has to be the worst if not 2nd worst movie I have ever had to make myself try to sit through. I never knew the depths of Masacism until I rented this piece of moldy cheese covered in a used latex contraceptive. I am a fan of Julian Sans, but this is worse than I would hope for him.<br /><br />On the other hand the story was promising and I was intrigued...for the first minute and a half while the credits rolled and I had yet to see what pain looked like first hand. Perhaps there are some viewers out there that enjoyed this and can point me in the right direction, but then again I know of those viewers who understand if not commemorate me, especially when we had to turn the video off, and that simply is NOT d

### evaluate

In [27]:
%time y_proba = clf.predict_proba(X)

CPU times: user 1min 2s, sys: 2.77 s, total: 1min 5s
Wall time: 16.4 s


In [28]:
log_loss(y, y_proba)

0.23828560002899238

In [29]:
y_pred = y_proba.argmax(1)

In [30]:
accuracy_score(y, y_pred)

0.91

In [31]:
clf.predict(["Even if paid $1000, I would not watch this movie again"])

array(['negative'], dtype='<U8')

### grid search best number of few shot samples

Note that grid search will split `X` and `y` for each run. Since the few shot samples are taken from X and y, those will thus be different for each split, which could have a big influence on the performance of the model. If you always want to have the same few shot samples in each split, you should craft your own prompt with those examples and then use it with `ZeroShotClassifier`. Just ensure that those prompts are not part of the validation/test data!

In [32]:
params = {'max_samples': [3, 5, 7]}

In [33]:
search = GridSearchCV(clf, param_grid=params, cv=2, scoring=['accuracy', 'neg_log_loss'], refit=False)

In [34]:
%time search.fit(X, labels)

CPU times: user 10min 17s, sys: 27 s, total: 10min 44s
Wall time: 2min 53s


In [35]:
pd.DataFrame(search.cv_results_)[['mean_test_accuracy', 'mean_test_neg_log_loss', 'param_max_samples', 'mean_score_time']]

Unnamed: 0,mean_test_accuracy,mean_test_neg_log_loss,param_max_samples,mean_score_time
0,0.92,-0.227133,3,14.748633
1,0.92,-0.231064,5,37.663565
2,0.91,-0.237118,7,33.848601


**Conclusion**: No significant change in accuracy but medium improvement in log loss compared to zero shot. More than 5 samples don't seem to help.

## Testing MNLI

An existing method is to use natural language inference (NLI). Compare the results to https://huggingface.co/facebook/bart-large-mnli, which is the most used zero shot classifier on Hugging Face.

In [36]:
from transformers import pipeline

In [37]:
classifier = pipeline('zero-shot-classification', model='facebook/bart-large-mnli', device=device)

In [38]:
%%time
y_probas = []
for x in X:
    output = classifier(x, ['negative', 'positive'])
    if output['labels'] == ['negative', 'positive']:
        y_probas.append(output['scores'])
    else:
        y_probas.append(output['scores'][::-1])



CPU times: user 13.7 s, sys: 7.61 ms, total: 13.7 s
Wall time: 13.7 s


In [39]:
y_proba = np.vstack(y_probas)

In [40]:
accuracy_score(y, y_proba.argmax(1))

0.84

In [41]:
log_loss(y, y_proba)

0.3443705626436628

**Conclusion**: This model is slower than the tested zero shot classifier, it is less flexible (we cannot adjust prompt or other parameters), and it performs worse.

## Testing vanilla ML

Use a standard TFIDF + logistic regression benchmark.

In [42]:
from sklearn.pipeline import Pipeline
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import cross_validate

In [43]:
tfidf = Pipeline([
    ('tfidf', TfidfVectorizer()),
    ('clf', LogisticRegression()),
])

In [44]:
params = {'tfidf__max_features': [500, 1000, 2000], 'tfidf__ngram_range': [(1, 1), (1, 2), (2, 2), (1, 3)]}

In [45]:
search = GridSearchCV(
    tfidf, param_grid=params, cv=2, scoring=['accuracy', 'neg_log_loss'], refit=False
)

In [46]:
%time search.fit(X, y)

CPU times: user 1.38 s, sys: 0 ns, total: 1.38 s
Wall time: 1.38 s


The table is quite big, let's look at the top 5 best log losses:

In [47]:
cols = ['mean_test_accuracy', 'mean_test_neg_log_loss', 'param_tfidf__max_features', 'param_tfidf__ngram_range']
pd.DataFrame(search.cv_results_)[cols].sort_values('mean_test_neg_log_loss', ascending=False).head()

Unnamed: 0,mean_test_accuracy,mean_test_neg_log_loss,param_tfidf__max_features,param_tfidf__ngram_range
3,0.69,-0.662397,500,"(1, 3)"
7,0.71,-0.663959,1000,"(1, 3)"
1,0.68,-0.664004,500,"(1, 2)"
5,0.7,-0.664215,1000,"(1, 2)"
0,0.65,-0.664609,500,"(1, 1)"


**Conclusion**: This classical model is much faster, even if we include the training time, because it is much smaller than an LLM. However, it's scores are also much worse, given the small dataset. If speed is no concern, using an LLM classifier would thus be a good option for this task.