[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/education-toolkit/blob/main/02_ml-demos-with-gradio.ipynb)



💡 **Welcome!**

This notebook provides a short walk through of text classification using few shot learning with [SetFit](https://github.com/huggingface/setfit). 
This notebook can be found at [https://bit.ly/raj_setfit](https://bit.ly/raj_setfit).

In [1]:
!python -m pip install setfit

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting setfit
  Downloading setfit-0.3.0-py3-none-any.whl (21 kB)
Collecting evaluate==0.2.2
  Downloading evaluate-0.2.2-py3-none-any.whl (69 kB)
[K     |████████████████████████████████| 69 kB 3.9 MB/s 
[?25hCollecting datasets==2.3.2
  Downloading datasets-2.3.2-py3-none-any.whl (362 kB)
[K     |████████████████████████████████| 362 kB 10.8 MB/s 
[?25hCollecting sentence-transformers==2.2.2
  Downloading sentence-transformers-2.2.2.tar.gz (85 kB)
[K     |████████████████████████████████| 85 kB 4.8 MB/s 
Collecting responses<0.19
  Downloading responses-0.18.0-py3-none-any.whl (38 kB)
Collecting multiprocess
  Downloading multiprocess-0.70.14-py37-none-any.whl (115 kB)
[K     |████████████████████████████████| 115 kB 45.2 MB/s 
Collecting xxhash
  Downloading xxhash-3.1.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (212 kB)
[K     |██████████████████████████████

## Load Dataset

In [2]:
from datasets import load_dataset
from sentence_transformers.losses import CosineSimilarityLoss

from setfit import SetFitModel, SetFitTrainer

In [3]:
# Load a dataset from the Hugging Face Hub
dataset = load_dataset("sst2")

Downloading builder script:   0%|          | 0.00/1.53k [00:00<?, ?B/s]

Downloading metadata:   0%|          | 0.00/929 [00:00<?, ?B/s]



Downloading and preparing dataset sst2/default (download: 7.09 MiB, generated: 4.78 MiB, post-processed: Unknown size, total: 11.88 MiB) to /root/.cache/huggingface/datasets/sst2/default/2.0.0/9896208a8d85db057ac50c72282bcb8fe755accc671a57dd8059d4e130961ed5...


Downloading data:   0%|          | 0.00/7.44M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/67349 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/872 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/1821 [00:00<?, ? examples/s]

Dataset sst2 downloaded and prepared to /root/.cache/huggingface/datasets/sst2/default/2.0.0/9896208a8d85db057ac50c72282bcb8fe755accc671a57dd8059d4e130961ed5. Subsequent calls will reuse this data.


  0%|          | 0/3 [00:00<?, ?it/s]

## Simulate the few-shot regime by sampling 8 examples per class

In [4]:
num_classes = 2
train_dataset = dataset["train"].shuffle(seed=42).select(range(8 * num_classes))
eval_dataset = dataset["validation"]

In [5]:
train_dataset['sentence']

['klein , charming in comedies like american pie and dead-on in election , ',
 'be fruitful ',
 'soulful and ',
 'the proud warrior that still lingers in the souls of these characters ',
 'covered earlier and much better ',
 'wise and powerful ',
 'a powerful and reasonably fulfilling gestalt ',
 'smart and newfangled ',
 'it too is a bomb . ',
 'guilty about it ',
 'while the importance of being earnest offers opportunities for occasional smiles and chuckles ',
 "stevens ' vibrant creative instincts ",
 'great artistic significance ',
 "what does n't this film have that an impressionable kid could n't stand to hear ? ",
 'working from a surprisingly sensitive script co-written by gianni romoli ... ',
 'eight crazy nights is a total misfire . ']

## Load a SetFit model from Hub

In [6]:
model = SetFitModel.from_pretrained("sentence-transformers/paraphrase-mpnet-base-v2")

Downloading:   0%|          | 0.00/594 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/690 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/190 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/3.70k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/594 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/122 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/438M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/53.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/239 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/466k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.19k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/232k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/229 [00:00<?, ?B/s]

model_head.pkl not found on HuggingFace Hub, initialising classification head with random weights. You should TRAIN this model on a downstream task to use it for predictions and inference.


## Create Trainer

In [7]:
trainer = SetFitTrainer(
    model=model,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    loss_class=CosineSimilarityLoss,
    batch_size=16,
    num_iterations=20, # The number of text pairs to generate for contrastive learning
    column_mapping={"sentence": "text", "label": "label"} # Map dataset columns to text/label expected by trainer
)

## Train and evaluate

In [8]:
trainer.train()
metrics = trainer.evaluate()

Applying column mapping to training dataset
***** Running training *****
  Num examples = 640
  Num epochs = 1
  Total optimization steps = 40
  Total train batch size = 16


Epoch:   0%|          | 0/1 [00:00<?, ?it/s]

Iteration:   0%|          | 0/40 [00:00<?, ?it/s]

Applying column mapping to evaluation dataset


Downloading builder script:   0%|          | 0.00/4.20k [00:00<?, ?B/s]

***** Running evaluation *****


In [None]:
## Log into Hugging Face Hub

In [9]:
from huggingface_hub import notebook_login
notebook_login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [None]:
!huggingface-cli whoami
!git config --global credential.helper store

rajistics
[1morgs: [0m huggingface,spaces-explorers,demo-org,HF-test-lab,qualitydatalab,FinanceInc,inferenceendpoints,vendorabc


In [None]:
trainer.push_to_hub(repo_path_or_name="rajistics/my-setfit-model",use_auth_token=True)

Cloning https://huggingface.co/rajistics/my-setfit-model12 into local empty directory.


Upload file pytorch_model.bin:   0%|          | 3.34k/418M [00:00<?, ?B/s]

Upload file model_head.pkl:  48%|####8     | 3.34k/6.95k [00:00<?, ?B/s]

remote: Scanning LFS files for validity, may be slow...        
remote: LFS file scan complete.        
To https://huggingface.co/rajistics/my-setfit-model12
   d6e8535..d2ec56b  main -> main

remote: LFS file scan complete.        
To https://huggingface.co/rajistics/my-setfit-model12
   d6e8535..d2ec56b  main -> main



'https://huggingface.co/rajistics/my-setfit-model12/commit/d2ec56b70ea8dfa346857146234db328d92e3818'

## Download model for local Inference

In [None]:
modelt = SetFitModel.from_pretrained("rajistics/my-setfit-model")
# Run inference
preds = modelt(["i loved the spiderman movie!", "pineapple on pizza is the worst 🤮"]) 

Downloading:   0%|          | 0.00/662 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.38k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/190 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/3.69k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/662 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/122 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/7.12k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/438M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/53.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/280 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/712k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.50k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/232k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/229 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/7.12k [00:00<?, ?B/s]

## Hugging Face Inference Endpoints

Example Endpoint: https://huggingface.co/philschmid/setfit-ag-news-endpoint

Sample request once endpoint is created

In [None]:
import json
import requests as r

ENDPOINT_URL=""# url of your endpoint
HF_TOKEN=""

# payload samples
regular_payload = { "inputs": "Coming to The Rescue Got a unique problem? Not to worry: you can find a financial planner for every specialized need"}

# HTTP headers for authorization
headers= {
    "Authorization": f"Bearer {HF_TOKEN}",
    "Content-Type": "application/json"
}

# send request
response = r.post(ENDPOINT_URL, headers=headers, json=paramter_payload)
classified = response.json()

print(classified)