

# Zero-shot text classification with SSTuing

First of all, install the dependencies

In [None]:
!pip install transformers

Collecting transformers
  Downloading transformers-4.34.1-py3-none-any.whl (7.7 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.7/7.7 MB[0m [31m50.2 MB/s[0m eta [36m0:00:00[0m
Collecting huggingface-hub<1.0,>=0.16.4 (from transformers)
  Downloading huggingface_hub-0.18.0-py3-none-any.whl (301 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m302.0/302.0 kB[0m [31m30.0 MB/s[0m eta [36m0:00:00[0m
Collecting tokenizers<0.15,>=0.14 (from transformers)
  Downloading tokenizers-0.14.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (3.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.8/3.8 MB[0m [31m57.7 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting safetensors>=0.3.1 (from transformers)
  Downloading safetensors-0.4.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m72.7 MB/s[0m eta [36m0:00:00[0m
Col

In [None]:
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch, string, random

# Load the model and tokenizer
For English tasks, please use the following models:
* DAMO-NLP-SG/zero-shot-classify-SSTuning-base
* DAMO-NLP-SG/zero-shot-classify-SSTuning-base
* DAMO-NLP-SG/zero-shot-classify-SSTuning-base

For non-English tasks, please use:
* DAMO-NLP-SG/zero-shot-classify-SSTuning-XLM-R

In [None]:
model_name = "DAMO-NLP-SG/zero-shot-classify-SSTuning-XLM-R" # @param ["DAMO-NLP-SG/zero-shot-classify-SSTuning-base", "DAMO-NLP-SG/zero-shot-classify-SSTuning-large", "DAMO-NLP-SG/zero-shot-classify-SSTuning-ALBERT", "DAMO-NLP-SG/zero-shot-classify-SSTuning-XLM-R"]

tokenizer = AutoTokenizer.from_pretrained(model_name)

model = AutoModelForSequenceClassification.from_pretrained(model_name)

Downloading (…)lve/main/config.json:   0%|          | 0.00/1.63k [00:00<?, ?B/s]

Downloading (…)tencepiece.bpe.model:   0%|          | 0.00/5.07M [00:00<?, ?B/s]

Downloading tokenizer.json:   0%|          | 0.00/17.1M [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/239 [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/1.11G [00:00<?, ?B/s]

## Create some helper functions to process the data

In [None]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
list_ABC = [x for x in string.ascii_uppercase]

def add_prefix(text, list_label, shuffle=False):
    # Append a period '.' to each label. This will improve the accuracy
    list_label = [x+'.' if x[-1] not in ['.','!'] else x for x in list_label]

    # Extend the list_label with padding tokens to have a length of 20
    list_label_new = list_label + [tokenizer.pad_token]* (20 - len(list_label))

    if shuffle:
        # Shuffle the order of elements in list_label_new if shuffle flag is True
        random.shuffle(list_label_new)

    # Create a string representation of label options by combining each label with its corresponding index
    s_option = ' '.join(['('+list_ABC[i]+') '+list_label_new[i] for i in range(len(list_label_new))])

    # Return the modified text with label options and the list_label_new
    return f'{s_option} {tokenizer.sep_token} {text}', list_label_new

In [None]:
def check_text(model, text, list_label, shuffle=False):
    # Add prefixes to the text using the add_prefix function
    text, list_label_new = add_prefix(text, list_label, shuffle=shuffle)
    print('input text:   ', text)

    # Set the model to evaluation mode and move it to the appropriate device
    model.to(device).eval()

    # Perform tokenization and encoding of the text
    encoding = tokenizer([text], truncation=True, max_length=512)

    # Create a dictionary of tensors for the encoded text
    item = {key: torch.tensor(val).to(device) for key, val in encoding.items()}

    # Generate logits from the model
    logits = model(**item).logits

    # Select a subset of logits based on shuffle flag
    logits = logits if shuffle else logits[:, 0:len(list_label)]

    # Convert logits to probabilities using softmax
    probs = torch.nn.functional.softmax(logits, dim=-1).tolist()

    # Get the predicted label index
    predictions = torch.argmax(logits, dim=-1).item()

    # Round the probabilities to five decimal places
    probabilities = [round(x, 5) for x in probs[0]]

    print('probabilities:',probabilities)
    print(f'prediction:    {predictions} => ({list_ABC[predictions]}) {list_label_new[predictions]}')
    print(f'probability:   {round(probabilities[predictions]*100,2)}%')

# Inference

## Sentiment Analysis

Provide the input and the list of labels.
You can use original labels or convert the labels to sentences.

In [None]:
text = "I love this place! The food is always so fresh and delicious. The staff is always friendly, as well."

list_label = ["negative","positve"]
# list_label = ["It's terrible.","It's great."]

Process the input and do inference

In [None]:
check_text(model,text,list_label, shuffle=False)

input text:    (A) negative. (B) positve. (C) <pad> (D) <pad> (E) <pad> (F) <pad> (G) <pad> (H) <pad> (I) <pad> (J) <pad> (K) <pad> (L) <pad> (M) <pad> (N) <pad> (O) <pad> (P) <pad> (Q) <pad> (R) <pad> (S) <pad> (T) <pad> </s> I love this place! The food is always so fresh and delicious. The staff is always friendly, as well.
probabilities: [0.00466, 0.99534]
prediction:    1 => (B) positve.
probability:   99.53%


## Topic Classification

Provide the input and the list of labels.

In [None]:
text = "Wall St. Bears Claw Back Into the Black (Reuters) Reuters - Short-sellers, Wall Street's dwindling\band of ultra-cynics, are seeing green again."

# list_label = ["politics","sports","business","technology"]
list_label = ["This text is about politics.", "This text is about sports.", "This text is about business.", "This text is about technology."]

In [None]:
check_text(model,text,list_label, shuffle=False)

input text:    (A) This text is about politics. (B) This text is about sports. (C) This text is about business. (D) This text is about technology. (E) <pad> (F) <pad> (G) <pad> (H) <pad> (I) <pad> (J) <pad> (K) <pad> (L) <pad> (M) <pad> (N) <pad> (O) <pad> (P) <pad> (Q) <pad> (R) <pad> (S) <pad> (T) <pad> </s> Wall St. Bears Claw Back Into the Black (Reuters) Reuters - Short-sellers, Wall Street's dwindlingand of ultra-cynics, are seeing green again.
probabilities: [0.4364, 0.09742, 0.22859, 0.23759]
prediction:    0 => (A) This text is about politics.
probability:   43.64%


## Multilingual for zero-shot-classify-SSTuning-XLM-R

In [None]:
# Chinese
text = "我喜欢这个地方！ 食物总是那么新鲜可口。 工作人员也总是很友好。"
list_label = ["太糟糕了!","太棒了!"]

check_text(model,text,list_label, shuffle=False)

input text:    (A) 太糟糕了! (B) 太棒了! (C) <pad> (D) <pad> (E) <pad> (F) <pad> (G) <pad> (H) <pad> (I) <pad> (J) <pad> (K) <pad> (L) <pad> (M) <pad> (N) <pad> (O) <pad> (P) <pad> (Q) <pad> (R) <pad> (S) <pad> (T) <pad> </s> 我喜欢这个地方！ 食物总是那么新鲜可口。 工作人员也总是很友好。
probabilities: [0.00144, 0.99856]
prediction:    1 => (B) 太棒了!
probability:   99.86%


In [None]:
# English text and Chinese label
text = "I love this place! The food is always so fresh and delicious. The staff is always friendly, as well."
list_label = ["太糟糕了!","太棒了!"]

check_text(model,text,list_label, shuffle=False)

input text:    (A) 太糟糕了! (B) 太棒了! (C) <pad> (D) <pad> (E) <pad> (F) <pad> (G) <pad> (H) <pad> (I) <pad> (J) <pad> (K) <pad> (L) <pad> (M) <pad> (N) <pad> (O) <pad> (P) <pad> (Q) <pad> (R) <pad> (S) <pad> (T) <pad> </s> I love this place! The food is always so fresh and delicious. The staff is always friendly, as well.
probabilities: [0.00223, 0.99777]
prediction:    1 => (B) 太棒了!
probability:   99.78%


In [None]:
# Spanish
text = "¡Amo este lugar! La comida es siempre tan fresca y deliciosa. El personal siempre es amable, también."
list_label = ["Es terrible.","Es genial."]

check_text(model,text,list_label, shuffle=False)

input text:    (A) Es terrible. (B) Es genial. (C) <pad> (D) <pad> (E) <pad> (F) <pad> (G) <pad> (H) <pad> (I) <pad> (J) <pad> (K) <pad> (L) <pad> (M) <pad> (N) <pad> (O) <pad> (P) <pad> (Q) <pad> (R) <pad> (S) <pad> (T) <pad> </s> ¡Amo este lugar! La comida es siempre tan fresca y deliciosa. El personal siempre es amable, también.
probabilities: [0.00053, 0.99947]
prediction:    1 => (B) Es genial.
probability:   99.95%


In [None]:
# Thai
text = "ฉันรักที่นี่! อาหารสดและอร่อยอยู่เสมอ พนักงานก็เป็นมิตรเช่นกัน"
list_label = ["มันน่ากลัว.","มันยอดเยี่ยมมาก"]

check_text(model,text,list_label, shuffle=False)

input text:    (A) มันน่ากลัว. (B) มันยอดเยี่ยมมาก. (C) <pad> (D) <pad> (E) <pad> (F) <pad> (G) <pad> (H) <pad> (I) <pad> (J) <pad> (K) <pad> (L) <pad> (M) <pad> (N) <pad> (O) <pad> (P) <pad> (Q) <pad> (R) <pad> (S) <pad> (T) <pad> </s> ฉันรักที่นี่! อาหารสดและอร่อยอยู่เสมอ พนักงานก็เป็นมิตรเช่นกัน
probabilities: [0.09556, 0.90444]
prediction:    1 => (B) มันยอดเยี่ยมมาก.
probability:   90.44%


In [None]:
# Thai
text = "ฉันรักที่นี่! อาหารสดและอร่อยอยู่เสมอ พนักงานก็เป็นมิตรเช่นกัน"
# list_label = ["มันน่ากลัว.","มันยอดเยี่ยมมาก"]
list_label = ["Es terrible.","Es genial."]

check_text(model,text,list_label, shuffle=False)

input text:    (A) Es terrible. (B) Es genial. (C) <pad> (D) <pad> (E) <pad> (F) <pad> (G) <pad> (H) <pad> (I) <pad> (J) <pad> (K) <pad> (L) <pad> (M) <pad> (N) <pad> (O) <pad> (P) <pad> (Q) <pad> (R) <pad> (S) <pad> (T) <pad> </s> ฉันรักที่นี่! อาหารสดและอร่อยอยู่เสมอ พนักงานก็เป็นมิตรเช่นกัน
probabilities: [0.00248, 0.99752]
prediction:    1 => (B) Es genial.
probability:   99.75%
