[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/vuhung16au/hf-transformer-trove/blob/main/examples/basic1.2/02-zero-shot-classification.ipynb)
[![View on GitHub](https://img.shields.io/badge/View_on-GitHub-blue?logo=github)](https://github.com/vuhung16au/hf-transformer-trove/blob/main/examples/basic1.2/02-zero-shot-classification.ipynb)

# 02 - Zero-Shot Classification: Classify Without Training Data

## 🎯 Learning Objectives
By the end of this notebook, you will understand:
- What zero-shot classification is and when to use it
- How to use Hugging Face pipelines for zero-shot classification
- The underlying models and techniques (BART, RoBERTa, CLIP)
- How to work with different types of candidate labels
- Performance considerations and best practices
- Advanced techniques for improving zero-shot classification

## 📋 Prerequisites
- Basic understanding of machine learning concepts
- Familiarity with Python and text classification
- Knowledge of transformers (refer to [Notebook 01](../01_intro_hf_transformers.ipynb))
- Understanding of NLP fundamentals (refer to [NLP Learning Journey](https://github.com/vuhung16au/nlp-learning-journey))

## 📚 What We'll Cover
1. **Introduction**: Zero-shot classification concepts
2. **Basic Pipeline Usage**: Using the zero-shot classification pipeline
3. **Manual Implementation**: Understanding the underlying process
4. **Model Comparison**: Different zero-shot models
5. **Real-world Applications**: Practical use cases
6. **Performance Analysis**: Speed and accuracy considerations
7. **Advanced Techniques**: Improving classification performance
8. **Summary and Best Practices**: Key takeaways

## What is Zero-Shot Classification?

**Zero-shot classification** is the ability to classify text into categories that the model has never seen during training. Instead of being trained on specific labels, the model uses its general understanding of language to determine which category best fits the input text.

### Key Advantages:
- 🚀 **No training required**: Start classifying immediately
- 🔄 **Flexible labels**: Change categories without retraining
- 💰 **Cost-effective**: No need for labeled training data
- ⚡ **Rapid prototyping**: Test ideas quickly

### How it Works:
Zero-shot classification typically works by:
1. **Natural Language Inference (NLI)**: Treating classification as a "does this text belong to this category?" question
2. **Entailment**: Using models trained on textual entailment tasks
3. **Similarity**: Computing semantic similarity between text and label descriptions

The mathematical foundation often relies on:
$$P(\text{label} | \text{text}) = \frac{\exp(\text{similarity}(\text{text}, \text{label}))}{\sum_{i} \exp(\text{similarity}(\text{text}, \text{label}_i))}$$

## Setup and Installation

In [None]:
# Install required packages (uncomment and run if needed)
# !pip install transformers torch datasets tokenizers matplotlib seaborn plotly

# Import essential libraries
import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import time
import warnings
from typing import List, Dict, Optional, Union
from collections import Counter

# Hugging Face imports
from transformers import (
    pipeline, 
    AutoTokenizer, 
    AutoModelForSequenceClassification,
    AutoConfig
)

warnings.filterwarnings('ignore')

# For Google Colab compatibility
try:
    from google.colab import userdata
    COLAB_AVAILABLE = True
except ImportError:
    COLAB_AVAILABLE = False

# Load environment variables from .env.local for local development
try:
    from dotenv import load_dotenv
    load_dotenv('.env.local', override=True)
    print("✅ Environment variables loaded from .env.local")
except ImportError:
    print("💡 python-dotenv not available. Using system environment variables.")

## Device Setup and Configuration

In [None]:
def get_device() -> torch.device:
    """
    Automatically detect and return the best available device.
    
    Priority: CUDA > MPS (Apple Silicon) > CPU
    
    Returns:
        torch.device: The optimal device for current hardware
    """
    if torch.cuda.is_available():
        device = torch.device("cuda")
        print(f"🚀 Using CUDA GPU: {torch.cuda.get_device_name()}")
        print(f"   GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f}GB")
    elif torch.backends.mps.is_available():
        device = torch.device("mps") 
        print("🍎 Using Apple MPS (Apple Silicon)")
    else:
        device = torch.device("cpu")
        print("💻 Using CPU (consider GPU for better performance)")
    
    return device

# Set up device and display system info
device = get_device()
print(f"\n=== System Information ===")
print(f"PyTorch version: {torch.__version__}")
print(f"Device: {device}")

# Set visualization style
plt.style.use('default')
sns.set_palette("husl")
plt.rcParams['figure.figsize'] = [10, 6]

## Part 1: Basic Zero-Shot Classification

Let's start with the basic example from the issue and then expand on it with educational explanations.

In [None]:
# Create a zero-shot classification pipeline
# This uses the default model: facebook/bart-large-mnli
print("📥 Loading zero-shot classification pipeline...")
start_time = time.time()

classifier = pipeline("zero-shot-classification")

load_time = time.time() - start_time
print(f"✅ Pipeline loaded in {load_time:.2f} seconds")
print(f"📊 Default model: {classifier.model.config.name_or_path}")
print(f"🏷️  Model type: {classifier.model.config.model_type}")

In [None]:
# Basic zero-shot classification example (from the issue)
text = "This is a course about the Transformers library"
candidate_labels = ["education", "politics", "business"]

print("🔍 Performing Zero-Shot Classification")
print("="*50)
print(f"Text: '{text}'")
print(f"Candidate Labels: {candidate_labels}")
print("\n📈 Results:")

# Perform classification with timing
start_time = time.time()
result = classifier(
    text,
    candidate_labels=candidate_labels,
)
inference_time = time.time() - start_time

# Display results in an educational format
for label, score in zip(result['labels'], result['scores']):
    confidence = score * 100
    bar = '█' * int(confidence / 5)  # Visual bar representation
    print(f"  {label:12}: {confidence:5.1f}% {bar}")

print(f"\n⏱️  Inference time: {inference_time:.3f} seconds")
print(f"🏆 Predicted class: {result['labels'][0]} (confidence: {result['scores'][0]:.3f})")

## Part 2: Testing with Multiple Examples

Let's test the zero-shot classifier with various examples to understand its behavior.

In [None]:
# Test with diverse examples
test_examples = [
    "This is a course about the Transformers library",
    "The stock market crashed today due to inflation concerns",
    "The president announced new economic policies",
    "Students are learning machine learning fundamentals",
    "The company's quarterly earnings exceeded expectations",
    "Educational institutions are adopting AI technologies"
]

labels = ["education", "politics", "business"]

print("🔍 Testing Zero-Shot Classification with Multiple Examples")
print("="*70)

results = []
for i, text in enumerate(test_examples, 1):
    print(f"\n📄 Example {i}: '{text}'")
    
    start_time = time.time()
    result = classifier(text, candidate_labels=labels)
    inference_time = time.time() - start_time
    
    predicted_label = result['labels'][0]
    confidence = result['scores'][0]
    
    print(f"   🏆 Prediction: {predicted_label} ({confidence:.3f})")
    print(f"   ⏱️  Time: {inference_time:.3f}s")
    
    # Show all scores
    print("   📊 All scores:")
    for label, score in zip(result['labels'], result['scores']):
        bar = '█' * int(score * 20)  # Visual representation
        print(f"      {label:12}: {score:.3f} {bar}")
    
    results.append({
        'text': text,
        'predicted_label': predicted_label,
        'confidence': confidence,
        'inference_time': inference_time
    })

# Summary statistics
avg_confidence = np.mean([r['confidence'] for r in results])
avg_time = np.mean([r['inference_time'] for r in results])
label_distribution = Counter([r['predicted_label'] for r in results])

print("\n" + "="*70)
print(f"📊 SUMMARY STATISTICS")
print(f"   Average confidence: {avg_confidence:.3f}")
print(f"   Average inference time: {avg_time:.3f}s")
print(f"   Label distribution: {dict(label_distribution)}")

## Part 3: Real-World Applications

Let's explore practical applications of zero-shot classification in different domains.

In [None]:
# Real-world application examples
applications = {
    "Customer Support": {
        "texts": [
            "I can't log into my account, password doesn't work",
            "My order arrived damaged, need replacement",
            "How do I cancel my subscription?",
            "Want to upgrade to premium plan",
            "App keeps crashing when I upload photos"
        ],
        "labels": ["login_issue", "shipping_problem", "account_management", "sales_inquiry", "technical_bug"]
    },
    
    "Content Moderation": {
        "texts": [
            "Amazing product! Highly recommend to everyone",
            "This is the worst service I've ever used",
            "When will the new update be released?",
            "Check out this discount link: special-deal.com",
            "The interface could be more user-friendly"
        ],
        "labels": ["positive_feedback", "negative_feedback", "question", "potential_spam", "suggestion"]
    },
    
    "News Classification": {
        "texts": [
            "Scientists discover new treatment for diabetes",
            "Stock market reaches all-time high today",
            "Local team wins championship after 20 years",
            "New smartphone model features AI camera",
            "Government announces climate change initiative"
        ],
        "labels": ["health", "finance", "sports", "technology", "politics"]
    }
}

def demonstrate_application(app_name, app_data):
    """Demonstrate zero-shot classification for a specific application."""
    print(f"\n📋 Application: {app_name}")
    print("-" * 50)
    
    texts = app_data["texts"]
    labels = app_data["labels"]
    
    app_results = []
    total_time = 0
    
    for i, text in enumerate(texts, 1):
        start_time = time.time()
        result = classifier(text, candidate_labels=labels)
        inference_time = time.time() - start_time
        total_time += inference_time
        
        predicted_label = result['labels'][0]
        confidence = result['scores'][0]
        
        print(f"\n  📝 {i}. '{text}'")
        print(f"     → {predicted_label} ({confidence:.3f})")
        
        app_results.append({
            'text': text,
            'predicted': predicted_label,
            'confidence': confidence
        })
    
    # Application summary
    avg_confidence = np.mean([r['confidence'] for r in app_results])
    print(f"\n  📊 Summary:")
    print(f"     Average confidence: {avg_confidence:.3f}")
    print(f"     Total processing time: {total_time:.2f}s")
    print(f"     Throughput: {len(texts)/total_time:.1f} examples/second")
    
    return app_results

print("🌍 REAL-WORLD ZERO-SHOT CLASSIFICATION APPLICATIONS")
print("="*60)

# Demonstrate each application
all_results = {}
for app_name, app_data in applications.items():
    all_results[app_name] = demonstrate_application(app_name, app_data)

## Part 4: Visualization and Analysis

In [None]:
# Create visualizations of the classification results
fig, axes = plt.subplots(2, 2, figsize=(15, 10))
fig.suptitle('Zero-Shot Classification Analysis', fontsize=16, fontweight='bold')

# 1. Confidence distribution across all applications
all_confidences = []
for app_results in all_results.values():
    all_confidences.extend([r['confidence'] for r in app_results])

axes[0, 0].hist(all_confidences, bins=15, alpha=0.7, color='skyblue', edgecolor='black')
axes[0, 0].set_xlabel('Confidence Score')
axes[0, 0].set_ylabel('Frequency')
axes[0, 0].set_title('Overall Confidence Distribution')
axes[0, 0].axvline(np.mean(all_confidences), color='red', linestyle='--', 
                   label=f'Mean: {np.mean(all_confidences):.3f}')
axes[0, 0].legend()

# 2. Average confidence by application
app_names = list(all_results.keys())
app_confidences = [np.mean([r['confidence'] for r in results]) for results in all_results.values()]

bars = axes[0, 1].bar(app_names, app_confidences, color=['lightcoral', 'lightgreen', 'lightsalmon'])
axes[0, 1].set_ylabel('Average Confidence')
axes[0, 1].set_title('Average Confidence by Application')
axes[0, 1].set_xticklabels(app_names, rotation=15)

# Add value labels on bars
for bar, conf in zip(bars, app_confidences):
    axes[0, 1].text(bar.get_x() + bar.get_width()/2., bar.get_height() + 0.01,
                    f'{conf:.3f}', ha='center', va='bottom')

# 3. Label distribution for first application (Customer Support)
customer_labels = [r['predicted'] for r in all_results['Customer Support']]
label_counts = Counter(customer_labels)

axes[1, 0].pie(label_counts.values(), labels=label_counts.keys(), autopct='%1.1f%%')
axes[1, 0].set_title('Customer Support Label Distribution')

# 4. Processing time simulation
batch_sizes = [1, 5, 10, 20, 50]
simulated_times = [size * 0.2 for size in batch_sizes]  # Simulated processing times
throughput = [size / time for size, time in zip(batch_sizes, simulated_times)]

axes[1, 1].plot(batch_sizes, throughput, marker='o', linewidth=2, markersize=6)
axes[1, 1].set_xlabel('Batch Size')
axes[1, 1].set_ylabel('Throughput (examples/sec)')
axes[1, 1].set_title('Theoretical Throughput vs Batch Size')
axes[1, 1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

# Print detailed statistics
print(f"\n📊 Overall Statistics:")
print(f"   Total examples processed: {len(all_confidences)}")
print(f"   Overall average confidence: {np.mean(all_confidences):.3f}")
print(f"   Confidence std deviation: {np.std(all_confidences):.3f}")
print(f"   Min confidence: {np.min(all_confidences):.3f}")
print(f"   Max confidence: {np.max(all_confidences):.3f}")

## Part 5: Advanced Techniques and Best Practices

In [None]:
def demonstrate_advanced_techniques():
    """Demonstrate advanced zero-shot classification techniques and best practices."""
    
    print("🔬 ADVANCED ZERO-SHOT CLASSIFICATION TECHNIQUES")
    print("="*60)
    
    # 1. Hypothesis Template Approach
    print("\n1️⃣ Using Hypothesis Templates for Better Performance")
    print("-" * 50)
    
    text = "The new AI model achieved 95% accuracy on the benchmark dataset"
    
    # Standard labels
    standard_labels = ["technology", "science", "business"]
    
    # Enhanced descriptive labels (hypothesis templates)
    enhanced_labels = [
        "This text is about technology, software, or technical innovations",
        "This text is about scientific research, discoveries, or academic studies", 
        "This text is about business, finance, or commercial activities"
    ]
    
    print(f"Text: '{text}'\n")
    
    # Test standard approach
    result_standard = classifier(text, candidate_labels=standard_labels)
    print("Standard Labels:")
    for label, score in zip(result_standard['labels'], result_standard['scores']):
        print(f"  {label:12}: {score:.3f}")
    
    # Test enhanced approach
    result_enhanced = classifier(text, candidate_labels=enhanced_labels)
    print("\nEnhanced Labels (Hypothesis Templates):")
    enhanced_mapping = {enhanced_labels[i]: standard_labels[i] for i in range(len(standard_labels))}
    for label, score in zip(result_enhanced['labels'], result_enhanced['scores']):
        original_label = enhanced_mapping.get(label, label)
        print(f"  {original_label:12}: {score:.3f}")
    
    # 2. Multi-label Classification Simulation
    print("\n\n2️⃣ Multi-label Classification Approach")
    print("-" * 50)
    
    multi_text = "The university announced a new AI research program with industry partnerships"
    potential_labels = ["education", "technology", "business", "research"]
    
    print(f"Text: '{multi_text}'\n")
    print("Checking each label independently:")
    
    relevant_labels = []
    for label in potential_labels:
        # Check if text belongs to this label vs "other topics"
        binary_result = classifier(
            multi_text, 
            candidate_labels=[f"This is about {label}", "This is about other topics"]
        )
        
        if binary_result['labels'][0].startswith("This is about " + label):
            score = binary_result['scores'][0]
            print(f"  {label:12}: {score:.3f} ✓")
            if score > 0.6:  # Threshold for relevance
                relevant_labels.append((label, score))
        else:
            score = binary_result['scores'][1]  # Score for "other topics"
            print(f"  {label:12}: {1-score:.3f}")
    
    print(f"\nRelevant labels (>0.6 threshold): {[label for label, _ in relevant_labels]}")
    
    # 3. Confidence Thresholding
    print("\n\n3️⃣ Confidence Thresholding for Quality Control")
    print("-" * 50)
    
    test_cases = [
        "Machine learning algorithms are transforming healthcare",  # Clear case
        "The weather today is quite nice for a walk",  # Ambiguous case
        "Quantum computing might revolutionize cryptography"  # Moderate case
    ]
    
    confidence_threshold = 0.7
    labels = ["technology", "health", "lifestyle", "science"]
    
    for i, text in enumerate(test_cases, 1):
        result = classifier(text, candidate_labels=labels)
        top_label = result['labels'][0]
        confidence = result['scores'][0]
        
        status = "✅ AUTO" if confidence >= confidence_threshold else "⚠️  REVIEW"
        print(f"\n  Case {i}: '{text[:40]}...'")
        print(f"    Prediction: {top_label} ({confidence:.3f}) {status}")
        
        if confidence < confidence_threshold:
            print(f"    → Requires manual review (confidence < {confidence_threshold})")

# Run advanced techniques demonstration
demonstrate_advanced_techniques()

## Part 6: Performance Analysis and Best Practices

In [None]:
def performance_analysis_and_best_practices():
    """Provide comprehensive performance analysis and best practices."""
    
    print("⚡ PERFORMANCE ANALYSIS & BEST PRACTICES")
    print("="*50)
    
    print("\n🚀 MODEL COMPARISON (Typical Performance)")
    print("-" * 40)
    
    model_comparison = pd.DataFrame({
        'Model': ['BART-Large-MNLI', 'RoBERTa-Large-MNLI', 'DistilBERT-MNLI', 'MobileBERT-MNLI'],
        'Parameters': ['400M', '355M', '67M', '25M'],
        'CPU Speed': ['~2.5s', '~2.0s', '~0.8s', '~0.3s'],
        'GPU Speed': ['~0.3s', '~0.25s', '~0.1s', '~0.05s'],
        'Accuracy': ['High', 'High', 'Med-High', 'Medium']
    })
    
    print(model_comparison.to_string(index=False))
    
    print("\n🎯 BEST PRACTICES CHECKLIST")
    print("-" * 40)
    
    best_practices = {
        "Label Design": [
            "✓ Use clear, descriptive labels",
            "✓ Make labels mutually exclusive",
            "✓ Consider hypothesis templates for complex cases",
            "✓ Test different label formulations"
        ],
        "Quality Control": [
            "✓ Set confidence thresholds (e.g., >0.7 for auto-processing)",
            "✓ Manually review low-confidence predictions",
            "✓ Monitor performance metrics over time",
            "✓ Validate on representative test sets"
        ],
        "Production Deployment": [
            "✓ Use batch processing for better throughput",
            "✓ Implement caching for repeated classifications",
            "✓ Monitor inference times and resource usage",
            "✓ Have fallback strategies for edge cases"
        ],
        "Model Selection": [
            "✓ Balance accuracy vs speed requirements",
            "✓ Consider domain-specific models when available",
            "✓ Test multiple models on your specific data",
            "✓ Use smaller models for real-time applications"
        ]
    }
    
    for category, practices in best_practices.items():
        print(f"\n📋 {category}:")
        for practice in practices:
            print(f"  {practice}")
    
    print("\n⚠️  COMMON PITFALLS TO AVOID")
    print("-" * 40)
    
    pitfalls = [
        "❌ Using too many labels (>15-20) - reduces accuracy",
        "❌ Labels that are too similar or overlapping",
        "❌ Ignoring confidence scores",
        "❌ Not validating on real data from your domain",
        "❌ Assuming zero-shot works perfectly without testing",
        "❌ Not considering model bias and limitations"
    ]
    
    for pitfall in pitfalls:
        print(f"  {pitfall}")
    
    print("\n🔧 OPTIMIZATION TIPS")
    print("-" * 20)
    
    optimization_tips = [
        "⚡ Batch similar texts together",
        "⚡ Cache results for repeated text-label pairs", 
        "⚡ Use GPU when available for faster inference",
        "⚡ Consider model quantization for deployment",
        "⚡ Implement async processing for better UX"
    ]
    
    for tip in optimization_tips:
        print(f"  {tip}")

# Run performance analysis
performance_analysis_and_best_practices()

## 📋 Summary

### 🔑 Key Concepts Mastered
- **Zero-Shot Classification**: Understanding how to classify text without training data
- **Natural Language Inference**: Using NLI models as the foundation for zero-shot tasks  
- **Pipeline Usage**: Leveraging HuggingFace pipelines for rapid prototyping
- **Confidence Analysis**: Interpreting and using confidence scores for quality control
- **Real-World Applications**: Implementing solutions for practical business problems
- **Performance Optimization**: Balancing speed and accuracy for production use

### 📈 Best Practices Learned
- Design clear, descriptive labels for better classification accuracy
- Use hypothesis templates to improve model understanding in complex cases
- Monitor confidence scores for quality control and human review triggers
- Implement batch processing for better throughput in production systems
- Regular validation on representative test data is crucial for maintaining performance
- Choose models based on your specific speed vs accuracy requirements

### 🚀 Next Steps
- **Notebook 03**: Working with the Datasets library for more complex data handling
- **Notebook 05**: Fine-tuning models for improved performance on specific domains
- **Advanced Topics**: Explore few-shot learning and custom NLI model development
- **Production Deployment**: Implement zero-shot classification in real applications

Zero-shot classification is a powerful technique that opens up many possibilities for rapid text classification without the need for labeled training data. The concepts and techniques learned in this notebook provide a solid foundation for more advanced NLP applications and production deployments!

---

## About the Author

**Vu Hung Nguyen** - AI Engineer & Researcher

Connect with me:
- 🌐 **Website**: [vuhung16au.github.io](https://vuhung16au.github.io/)
- 💼 **LinkedIn**: [linkedin.com/in/nguyenvuhung](https://www.linkedin.com/in/nguyenvuhung/)
- 💻 **GitHub**: [github.com/vuhung16au](https://github.com/vuhung16au/)

*This notebook is part of the [HF Transformer Trove](https://github.com/vuhung16au/hf-transformer-trove) educational series.*