<a href="https://colab.research.google.com/github/sharvaridhote/NLP_Examples/blob/main/Zero_Shot_(On_the_Fly)_Learning_(and_friends).ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Zero-Shot (On-the-Fly) Learning with spaCy, sentence-transformers, and Canonical Correlation Analysis, and Particle Swarm Optimization

In this notebook we'll define a workflow that will allow us to embed both sentences and words in the same embedding space. Doing so will allow us to find similarity scores bidirectionally between a word embedding model and a transformer-based embedding model. The use case is: if you have a sentence (and embedding), which word embeddings are most similar (closest) to that sentence? These word embeddings could then be thought of as the "topics" of that sentence. 

In a zero-shot context, we can restrict the set of word embeddings we care about to the "labels" we might care about for the output. For this example, we use the [ag_news](https://huggingface.co/datasets/ag_news) dataset and predict the four class labels of `["World", "Sports", "Business", "Sci/Tech"]`.

To project both embedding types to the same space we'll use an approach called Canonical Correlation Analysis (CCA). You can think of CCA as Principal Components evil twin: instead of finding an orthogonal linear representation that maximises variance like PCA, CCA finds a representation that maximises the correlation between two "views" of the data and then embeds both views in the same space. That means word embeddings and document embeddings together in harmony. 

CCA does require some parameter tuning, so we'll try something a bit different and use Particle Swarm Optimization to locate the parameter set that gives us the least error on the test set.

**Covered in this notebook**
- Loading a model from huggingface into [spaCy](https://spacy.io/) v3
- Adding a custom document vector attribute in spaCy
- Loading word embeddings with [gensim](https://github.com/RaRe-Technologies/gensim-data)
- Canonical Correlation Analysis with [mvlearn](https://mvlearn.github.io/)
- Particle Swarm Optimization for CCA "grid search" with [optunity](https://optunity.readthedocs.io/en/latest/index.html)
- Zero-shot learning
- Error analysis for zero-shot learning

In [None]:
!pip install -U spacy[cuda110,transformers] mvlearn datasets gensim optunity scipy

In [None]:
import spacy
spacy.require_gpu()

True

In [None]:
from copy import copy
from pathlib import Path
from pprint import pprint

import gensim.downloader as api
import numpy as np
import optunity
from datasets import load_dataset
from mvlearn.embed import CCA
from sklearn.metrics import accuracy_score, classification_report
from sklearn.metrics.pairwise import cosine_similarity
from tqdm import tqdm
import cupy as cp

In [None]:
from spacy_transformers.pipeline_component import DEFAULT_CONFIG
DEFAULT_CONFIG["transformer"]

{'max_batch_items': 4096,
 'model': {'@architectures': 'spacy-transformers.TransformerModel.v1',
  'get_spans': {'@span_getters': 'spacy-transformers.strided_spans.v1',
   'stride': 96,
   'window': 128},
  'name': 'roberta-base',
  'tokenizer_config': {'use_fast': True}},
 'set_extra_annotations': {'@annotation_setters': 'spacy-transformers.null_annotation_setter.v1'}}

In [None]:
# Create a config that uses a model from the huggingface model hub
config = copy(DEFAULT_CONFIG["transformer"])
# Could use a larger or different model - we'll use this because it's faster
config["model"]["name"] = "sentence-transformers/stsb-distilbert-base"

# Create a document annotation that stores the mean pooled document vector
spacy.tokens.Doc.set_extension("trf_vector", force=True, default=None)
def transformer_document_vector(docs, trf_data):
    doc_data = list(trf_data.doc_data)
    for doc, data in zip(docs, doc_data):
        doc._.trf_vector = data.tensors[0].reshape(-1, max(doc._.trf_data.tensors[0].shape)).squeeze().mean(axis=0)

nlp = spacy.blank("en")
transformer = nlp.add_pipe("transformer", config=config)
nlp.get_pipe("transformer").set_extra_annotations = transformer_document_vector

transformer.model.initialize([nlp.make_doc("hello world")])
doc = nlp("This is a sentence.")
assert doc._.trf_vector is not None

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=436.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=231508.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=112.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=209.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=265473819.0, style=ProgressStyle(descri…




In [None]:
# could use `word2vec-google-news-300` for potentially better results, but the model is huge
w2v_model = api.load("glove-wiki-gigaword-300")



In [None]:
vocab = list(w2v_model.vocab)[:20000]
len(vocab)

20000

In [None]:
w2v_view = w2v_model[vocab]
st_view = cp.array([d._.trf_vector for d in tqdm(nlp.pipe(vocab), total=len(vocab))])
st_view = cp.asnumpy(st_view)

100%|██████████| 20000/20000 [00:07<00:00, 2668.39it/s]


In [None]:
agnews = load_dataset("ag_news")
N_TRAIN = 25000  # reduce time to embed
X_train_raw, y_train_raw = (
    agnews["train"]["text"][:N_TRAIN],
    agnews["train"]["label"][:N_TRAIN],
)
X_test_raw, y_test_raw = agnews["test"]["text"], agnews["test"]["label"]


X_train = cp.array(
    [d._.trf_vector for d in tqdm(nlp.pipe(X_train_raw), total=len(X_train_raw))]
)
X_test = cp.array(
    [d._.trf_vector for d in tqdm(nlp.pipe(X_test_raw), total=len(X_test_raw))]
)

# numpy for CCA
X_train = cp.asnumpy(X_train)
X_test = cp.asnumpy(X_test)

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1817.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1227.0, style=ProgressStyle(description…

Using custom data configuration default



Downloading and preparing dataset ag_news/default (download: 29.88 MiB, generated: 30.23 MiB, post-processed: Unknown size, total: 60.10 MiB) to /root/.cache/huggingface/datasets/ag_news/default/0.0.0/fb5c5e74a110037311ef5e904583ce9f8b9fbc1354290f97b4929f01b3f48b1a...


HBox(children=(FloatProgress(value=0.0, description='Downloading', max=11045148.0, style=ProgressStyle(descrip…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=751209.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))



HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Dataset ag_news downloaded and prepared to /root/.cache/huggingface/datasets/ag_news/default/0.0.0/fb5c5e74a110037311ef5e904583ce9f8b9fbc1354290f97b4929f01b3f48b1a. Subsequent calls will reuse this data.


100%|██████████| 25000/25000 [01:31<00:00, 274.31it/s]
100%|██████████| 7600/7600 [00:27<00:00, 277.39it/s]


## Optimizing CCA parameters with Particle Swarm Optimization

Below we'll use Particle Swarm Optimization to find optimal values of the n_components, regularization, and vocab size for CCA.

In general, particle swarm optimization is helpful in minimizing some function through "intelligent brute force" by exploring the search space. Any time you want to find some min/max value of a function given some parameters, it's possible PSO could help you out.

In [None]:
# no "sci/tech" in vocab, so we'll just use "technology"
# these align with classes 0, 1, 2, 3
topics = ["global", "sports", "business", "technology"]

def cca_optimize(n_components, regs, vocab_size):
    cca = CCA(n_components=int(n_components), regs=regs)
    vocab_size = int(vocab_size)
    cca.fit([st_view[:vocab_size,:], w2v_view[:vocab_size,:]])
    sims = cosine_similarity(
        cca.transform_view(X_test, 0),
        cca.transform_view(w2v_model[topics], 1),
    )
    max_sims = sims.argmax(axis=1).tolist()
    error = 1 - accuracy_score(y_test_raw, max_sims)
    return error

In [None]:
n_components_bounds = [5, 300]
regs_bounds = [0.01, 1.0]
vocab_size_bounds = [5000, 20000]

# it's not clear below, but optunity's default solver is particle swarm optimization
pars, details, _ = optunity.minimize(cca_optimize, num_evals=250, pmap=optunity.pmap, n_components=n_components_bounds, regs=regs_bounds, vocab_size=vocab_size_bounds)

In [None]:
print(pars, details.optimum)

{'n_components': 26.817687912734208, 'regs': 0.5694625256434449, 'vocab_size': 13391.526442307691} 0.47355263157894734


In [None]:
n_components = int(pars['n_components'])
regs = pars['regs']
vocab_size = int(pars['vocab_size'])

cca = CCA(n_components=n_components, regs=regs)
cca.fit([st_view[:vocab_size,:], w2v_view[:vocab_size,:]])

sims = cosine_similarity(
    cca.transform_view(X_test, 0),
    cca.transform_view(w2v_model[topics], 1),
)
max_sim = sims.argmax(axis=1).tolist()

print(classification_report(y_test_raw, max_sim))

              precision    recall  f1-score   support

           0       0.47      0.45      0.46      1900
           1       0.61      0.65      0.63      1900
           2       0.57      0.20      0.30      1900
           3       0.49      0.80      0.61      1900

    accuracy                           0.53      7600
   macro avg       0.54      0.53      0.50      7600
weighted avg       0.54      0.53      0.50      7600



## Exploring groups of responses

One other thing we can do is find the mean (centroid) vector for a group of responses, then see which word vectors are closest to that centroid. This gives us a way of understanding what keywords that set of sentences are about.

We'll use this to do some error analysis of responses.

In [None]:
import pandas as pd

In [None]:
vectors_df = pd.DataFrame(X_test)
outcomes = pd.DataFrame(list(zip(y_test_raw, max_sim)), columns=['actual', 'predicted'])

In [None]:
actual_sports_predicted_business = outcomes.query("actual == 1 and predicted == 2")
misclassified_index = actual_sports_predicted_business.index

In [None]:
VOCAB_LIMIT = 5000

v_missed = cosine_similarity(
    cca.transform_view(X_test[misclassified_index,:], 0),
    cca.transform_view(w2v_view[:VOCAB_LIMIT,:], 1),
)

In [None]:
for i, v in zip(misclassified_index, v_missed):
    most_sim = v.argsort().squeeze().tolist()[-25:]
    most_sim_vocab = [vocab[i] for i in reversed(most_sim)]
    pprint(most_sim_vocab)
    print(X_test_raw[i])
    print("\n")

['------',
 'orleans',
 'soul',
 'homer',
 'mets',
 'legend',
 'victor',
 'sweet',
 'garcia',
 'latin',
 'blues',
 'spiritual',
 'rodriguez',
 'caribbean',
 'vital',
 'friendship',
 'antonio',
 'jesus',
 'nelson',
 'promises',
 'sisters',
 'diego',
 'castro',
 'loved',
 'insists']
Mighty Ortiz makes sure Sox can rest easy Just imagine what David Ortiz could do on a good night's rest. Ortiz spent the night before last with his baby boy, D'Angelo, who is barely 1 month old. He had planned on attending the Red Sox' Family Day at Fenway Park yesterday morning, but he had to sleep in. After all, Ortiz had a son at home, and he ...


['futures',
 'agreements',
 'charter',
 'resort',
 'securities',
 'golf',
 'ltd.',
 'merger',
 'bid',
 'offer',
 'commerce',
 'liverpool',
 'hotels',
 'wto',
 'sec',
 'signing',
 'contracts',
 'attract',
 'acquisition',
 'clubs',
 'manchester',
 'auction',
 'trading',
 'shipping',
 'investments']
Owners Seek Best Ballpark Deal for Expos (AP) AP - Trying to get t

## Error Analysis

Many of the misclassifications where `actual=sports, predicted=business` appear to be about trade deals and contracts, which makes a lot of sense since those are, in a sense, 'business' transactions.

## Wrap-up

I think we've only scratched the surface here. One blog post of huge inspiration was [this one](https://few-shot-text-classification.fastforwardlabs.com/) on few-shot classification. We take a slightly different approach using CCA, but the end result is similar.

One thing to keep in mind with zero-shot learning is that since this approach takes a "maximum similarity" as the classification, it's important to have good topic words that are actually similar to the documents you're looking to classify. If one topic word is weak, as it appears to be with the `business` topic, then it's more likely that those that should be sports are going to have a higher similarity in another topic word.