<a href="https://colab.research.google.com/github/selwyn-mccracken/blogs/blob/master/07_Apply_labels_with_zero_shot_classification.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Apply labels with zero-shot classification

This notebook shows how zero-shot classification can be used to perform text classification, labeling and topic modeling. txtai provides a light-weight wrapper around the zero-shot-classification pipeline in Hugging Face Transformers. This method works impressively well out of the box. Kudos to the Hugging Face team for the phenomenal work on zero-shot classification!

The examples in this notebook pick the best matching label using a list of labels for a snippet of text.

[tldrstory](https://github.com/neuml/tldrstory) has full-stack implementation of a zero-shot classification system using Streamlit, FastAPI and Hugging Face Transformers. There is also a [Medium article describing tldrstory](https://towardsdatascience.com/tldrstory-ai-powered-understanding-of-headlines-and-story-text-fc86abd702fc) and zero-shot classification. 


# Install dependencies

Install `txtai` and all dependencies.

In [1]:
%%capture
!pip install git+https://github.com/neuml/txtai

# Create a Labels instance

The Labels instance is the main entrypoint for zero-shot classification. This is a light-weight wrapper around the zero-shot-classification pipeline in Hugging Face Transformers.

In addition to the default model, additional models can be found on the [Hugging Face model hub](https://huggingface.co/models?search=mnli).


In [2]:
%%capture

from txtai.pipeline import Labels

# Create labels model
labels = Labels()

# Alternate models can be used via passing the model path as shown below
# labels = Labels("roberta-large-mnli")

# Applying labels to text

The example below shows how a zero-shot classifier can be applied to arbitary text. The default model for the zero-shot classification pipeline is *bart-large-mnli*. 

Look at the results below. It's nothing short of amazing✨ how well it performs. These aren't all simple even for a human. For example, intercepted was purposely picked as that is more common in football than basketball. The amount of knowledge stored in larger Transformer models continues to impress me. 

In [4]:
data = ["Dodgers lose again, give up 3 HRs in a loss to the Giants",
        "Giants 5 Cardinals 4 final in extra innings",
        "Dodgers drop Game 2 against the Giants, 5-4",
        "Flyers 4 Lightning 1 final. 45 saves for the Lightning.",
        "Slashing, penalty, 2 minute power play coming up",
        "What a stick save!",
        "Leads the NFL in sacks with 9.5",
        "UCF 38 Temple 13",
        "With the 30 yard completion, down to the 10 yard line",
        "Drains the 3pt shot!!, 0:15 remaining in the game",
        "Intercepted! Drives down the court and shoots for the win",
        "Massive dunk!!! they are now up by 15 with 2 minutes to go"]

# List of labels
tags = ["Baseball", "Football", "Hockey", "Basketball"]

print("%-75s %s" % ("Text", "Label"))
print("-" * 100)

for text in data:
    print("%-75s %s" % (text, tags[labels(text, tags)[0][0]]))

Text                                                                        Label
----------------------------------------------------------------------------------------------------
Dodgers lose again, give up 3 HRs in a loss to the Giants                   Baseball
Giants 5 Cardinals 4 final in extra innings                                 Baseball
Dodgers drop Game 2 against the Giants, 5-4                                 Baseball
Flyers 4 Lightning 1 final. 45 saves for the Lightning.                     Hockey
Slashing, penalty, 2 minute power play coming up                            Hockey
What a stick save!                                                          Hockey
Leads the NFL in sacks with 9.5                                             Football
UCF 38 Temple 13                                                            Football
With the 30 yard completion, down to the 10 yard line                       Football
Drains the 3pt shot!!, 0:15 remaining in the game         

In [6]:
tags = ['Software engineering and Development','Agile delivery and Governance','Cyber security','Support and Operations','Data science','Strategy and Policy','User research and Design','Change and Transformation','ICT systems integration','Content and Publishing','Digital sourcing and ICT procurement','Marketing, Communications and Engagement','Training, Learning and Development']
text = """ 
Cyber Security Assessor SFIA SCTY Level 4, SFIA BURM Level 4_ASD’s CESAR package ensures we can identify more cyber threats, disrupt more cybercriminals offshore, build more partnerships with industry and government and protect more Australians. These additional measures include protections for critical infrastructure facilities, strengthening our partnerships with industry and boosting the provision of cyber security advice to families, older Australians and small businesses. ACSC is working with critical infrastructure owners and operators to understand and uplift their cyber security. The work will be informed and supported by the ACSC’s ongoing technical cyber security advice and guidance. There is an expectation that successful candidates will work 5 days per week (estimated 40 hour week). On boarding is in Canberra, noting there is be a requirement for short term occasional travel within Australia. The Cyber Security Assessor conducts independent comprehensive assessments of the management, operational, and technical security controls and control enhancements employed within or inherited by an information technology (IT) system to determine the overall effectiveness of the controls. The person will possess broad knowledge in: • Current industry methods for evaluating, implementing, and disseminating information technology (IT) security assessment, monitoring, detection, and remediation tools and procedures utilising standards-based concepts and capabilities, • Cyber security and privacy principles used to manage risks related to the use, processing, storage and transmission of information or data, • Cyber threats and vulnerabilities, and • Critical Information systems with information communication technology that were designed without security considerations. The person will possess skills in: • Performing risk assessments and review of systems, • Technical writing, including developing and editing assessment products, • Interpreting vulnerability scanner results to identify vulnerabilities, • Interfacing with customers, and • Preparing and presenting briefings. The Cyber Security Assessor’s major responsibilities include: • Develop security compliance processes and/or audits for external services, • Assess the effectiveness of security controls, • Perform security reviews and identify security gaps in security architecture resulting in recommendations for inclusion in the risk management strategy, • Verify that application software/network/ system security postures are implemented as stated, document deviations, and recommend required actions to correct those deviations, and • Participate in Risk Governance processes to provide security risk, mitigations and input on other technical risk. The role is full time onsite in our Canberra offices only.\nMust have current negative vetting level 1 clearance
"""

In [21]:
print("%-75s | %s" % (text, tags[labels(text, tags)[0][0]]))

 
Cyber Security Assessor SFIA SCTY Level 4, SFIA BURM Level 4_ASD’s CESAR package ensures we can identify more cyber threats, disrupt more cybercriminals offshore, build more partnerships with industry and government and protect more Australians. These additional measures include protections for critical infrastructure facilities, strengthening our partnerships with industry and boosting the provision of cyber security advice to families, older Australians and small businesses. ACSC is working with critical infrastructure owners and operators to understand and uplift their cyber security. The work will be informed and supported by the ACSC’s ongoing technical cyber security advice and guidance. There is an expectation that successful candidates will work 5 days per week (estimated 40 hour week). On boarding is in Canberra, noting there is be a requirement for short term occasional travel within Australia. The Cyber Security Assessor conducts independent comprehensive assessments of th

In [10]:
labels2  = Labels('valhalla/distilbart-mnli-12-3')


In [22]:
print("%-75s | %s" % (text, tags[labels2(text, tags)[0][0]]))

 
Cyber Security Assessor SFIA SCTY Level 4, SFIA BURM Level 4_ASD’s CESAR package ensures we can identify more cyber threats, disrupt more cybercriminals offshore, build more partnerships with industry and government and protect more Australians. These additional measures include protections for critical infrastructure facilities, strengthening our partnerships with industry and boosting the provision of cyber security advice to families, older Australians and small businesses. ACSC is working with critical infrastructure owners and operators to understand and uplift their cyber security. The work will be informed and supported by the ACSC’s ongoing technical cyber security advice and guidance. There is an expectation that successful candidates will work 5 days per week (estimated 40 hour week). On boarding is in Canberra, noting there is be a requirement for short term occasional travel within Australia. The Cyber Security Assessor conducts independent comprehensive assessments of th

In [14]:
labels3 = Labels('valhalla/distilbart-mnli-12-6') 

Downloading:   0%|          | 0.00/1.39k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.23G [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/899k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/456k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/772 [00:00<?, ?B/s]

In [23]:
print("%-75s %s" % (text, tags[labels3(text, tags)[0][0]]))

 
Cyber Security Assessor SFIA SCTY Level 4, SFIA BURM Level 4_ASD’s CESAR package ensures we can identify more cyber threats, disrupt more cybercriminals offshore, build more partnerships with industry and government and protect more Australians. These additional measures include protections for critical infrastructure facilities, strengthening our partnerships with industry and boosting the provision of cyber security advice to families, older Australians and small businesses. ACSC is working with critical infrastructure owners and operators to understand and uplift their cyber security. The work will be informed and supported by the ACSC’s ongoing technical cyber security advice and guidance. There is an expectation that successful candidates will work 5 days per week (estimated 40 hour week). On boarding is in Canberra, noting there is be a requirement for short term occasional travel within Australia. The Cyber Security Assessor conducts independent comprehensive assessments of th

In [18]:
labels4 = Labels('valhalla/distilbart-mnli-12-1') 

Downloading:   0%|          | 0.00/1.39k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/890M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/899k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/456k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/772 [00:00<?, ?B/s]

In [24]:
print("%-75s | %s" % (text, tags[labels4(text, tags)[0][0]]))

 
Cyber Security Assessor SFIA SCTY Level 4, SFIA BURM Level 4_ASD’s CESAR package ensures we can identify more cyber threats, disrupt more cybercriminals offshore, build more partnerships with industry and government and protect more Australians. These additional measures include protections for critical infrastructure facilities, strengthening our partnerships with industry and boosting the provision of cyber security advice to families, older Australians and small businesses. ACSC is working with critical infrastructure owners and operators to understand and uplift their cyber security. The work will be informed and supported by the ACSC’s ongoing technical cyber security advice and guidance. There is an expectation that successful candidates will work 5 days per week (estimated 40 hour week). On boarding is in Canberra, noting there is be a requirement for short term occasional travel within Australia. The Cyber Security Assessor conducts independent comprehensive assessments of th

# Let's try emoji 😀

Does the model have knowledge of emoji? Check out the run below, sure looks like it does! Notice the labels are applied based on the perspective from which the information is presented. 

In [25]:
tags = ["😀", "😡"]

print("%-75s %s" % ("Text", "Label"))
print("-" * 100)

for text in data:
    print("%-75s %s" % (text, tags[labels(text, tags)[0][0]]))

Text                                                                        Label
----------------------------------------------------------------------------------------------------
Dodgers lose again, give up 3 HRs in a loss to the Giants                   😡
Giants 5 Cardinals 4 final in extra innings                                 😀
Dodgers drop Game 2 against the Giants, 5-4                                 😡
Flyers 4 Lightning 1 final. 45 saves for the Lightning.                     😀
Slashing, penalty, 2 minute power play coming up                            😡
What a stick save!                                                          😀
Leads the NFL in sacks with 9.5                                             😀
UCF 38 Temple 13                                                            😀
With the 30 yard completion, down to the 10 yard line                       😀
Drains the 3pt shot!!, 0:15 remaining in the game                           😀
Intercepted! Drives down the court an