# Document Classification with BART-Large-MNLI

This notebook explores document classification using Facebook's BART-Large-MNLI model for our Smart Document Classifier API.

We'll use the **pipeline approach** (Option 1) as it's more suitable for our FastAPI integration:
- Simpler implementation
- Better error handling
- Built-in optimizations
- Easier to maintain

In [None]:
# Install required packages
!pip install transformers torch

: 

In [None]:
# Import necessary libraries
from transformers import pipeline
import os
import json
from typing import List, Dict, Any
import time

: 

In [3]:
# Initialize the zero-shot classification pipeline with BART-Large-MNLI
print("Loading BART-Large-MNLI model...")
classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
print("Model loaded successfully!")

Loading BART-Large-MNLI model...


Device set to use mps:0


Model loaded successfully!


In [5]:
# Define document categories for classification
DOCUMENT_CATEGORIES = [
    "Technical Documentation",
    "Business Proposal", 
    "Legal Document",
    "Academic Paper",
    "General Article",
    "Other"
]

print(f"Document categories: {DOCUMENT_CATEGORIES}")
print(f"Total categories: {len(DOCUMENT_CATEGORIES)}")

Document categories: ['Technical Documentation', 'Business Proposal', 'Legal Document', 'Academic Paper', 'General Article', 'Other']
Total categories: 6


In [None]:
def classify_document(text: str, categories: List[str] = DOCUMENT_CATEGORIES) -> Dict[str, Any]:
    """
    Classify a document using BART-Large-MNLI zero-shot classification
    
    Args:
        text: Document text to classify
        categories: List of possible categories
        
    Returns:
        Dictionary with classification results
    """
    try:
        # IMPROVED: Use tokenizer-based truncation instead of character truncation
        from transformers import AutoTokenizer
        tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large-mnli")
        
        # Count actual tokens, not characters
        tokens = tokenizer.encode(text, add_special_tokens=False)
        max_tokens = 800  # Conservative limit for BART
        
        if len(tokens) > max_tokens:
            # Proper token-based truncation
            truncated_tokens = tokens[:max_tokens]
            text = tokenizer.decode(truncated_tokens, skip_special_tokens=True)
            print(f"⚠️  Text truncated from {len(tokens)} to {max_tokens} tokens")
        
        # Perform classification
        start_time = time.time()
        result = classifier(text, categories)
        inference_time = time.time() - start_time
        print(f"📊 Classification result: {result['labels'][0]} ({result['scores'][0]:.4f})")
        
        # Format results
        classification_result = {
            "predicted_category": result["labels"][0],
            "confidence_score": round(result["scores"][0], 4),
            "all_scores": {
                label: round(score, 4) 
                for label, score in zip(result["labels"], result["scores"])
            },
            "inference_time": round(inference_time, 3),
            "model_used": "facebook/bart-large-mnli",
            "token_count": len(tokens),
            "was_truncated": len(tokens) > max_tokens
        }
        
        return classification_result
        
    except Exception as e:
        return {
            "error": str(e),
            "predicted_category": None,
            "confidence_score": 0.0
        }

print("✅ Updated document classification function with proper tokenizer-based truncation!")

Document classification function created!


In [None]:
# Compare old vs new truncation methods
from transformers import AutoTokenizer

# Load tokenizer for proper token counting
tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large-mnli")

test_doc = """
LEGAL AGREEMENT - SOFTWARE LICENSE TERMS

IMPORTANT: Please read these terms and conditions carefully before using our software.

1. GRANT OF LICENSE
Subject to the terms and conditions of this Agreement, Company hereby grants to you a limited, non-exclusive, non-transferable license to use the software solely for your internal business purposes.

2. RESTRICTIONS
You may not: (a) modify, adapt, or create derivative works; (b) reverse engineer, decompile, or disassemble; (c) remove or alter any proprietary notices; (d) distribute, sublicense, or transfer the software.

3. INTELLECTUAL PROPERTY
All intellectual property rights in and to the software remain the exclusive property of Company. No rights are granted except as expressly set forth herein.

4. WARRANTY DISCLAIMER
THE SOFTWARE IS PROVIDED "AS IS" WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, AND NON-INFRINGEMENT.

5. LIMITATION OF LIABILITY
IN NO EVENT SHALL COMPANY BE LIABLE FOR ANY INDIRECT, INCIDENTAL, SPECIAL, CONSEQUENTIAL, OR PUNITIVE DAMAGES ARISING OUT OF OR RELATING TO THIS AGREEMENT.
"""

print("🔍 TRUNCATION COMPARISON")
print("=" * 50)
print(f"Original document: {len(test_doc)} characters, {len(tokenizer.encode(test_doc))} tokens")
print()

# OLD METHOD (Character-based - WRONG)
old_truncated = test_doc[:300] + "..."
print("❌ OLD METHOD (Character truncation at 300 chars):")
print(f"Result: {len(old_truncated)} chars, {len(tokenizer.encode(old_truncated))} tokens")
print(f"Preview: {repr(old_truncated[:100])}...")
print()

# NEW METHOD (Token-based - CORRECT)
tokens = tokenizer.encode(test_doc, add_special_tokens=False)
max_tokens = 80  # Small limit for demo
if len(tokens) > max_tokens:
    truncated_tokens = tokens[:max_tokens]
    new_truncated = tokenizer.decode(truncated_tokens, skip_special_tokens=True)
else:
    new_truncated = test_doc

print("✅ NEW METHOD (Token truncation at 80 tokens):")
print(f"Result: {len(new_truncated)} chars, {len(tokenizer.encode(new_truncated))} tokens")
print(f"Preview: {repr(new_truncated[:100])}...")
print()

print("🎯 KEY DIFFERENCES:")
print("1. Old method cuts mid-sentence, new method respects token boundaries")
print("2. Old method doesn't account for tokenization differences")
print("3. New method ensures exact token count for BART model")
print("4. New method preserves more meaningful content")

In [11]:
# Test the classification with a sample document
sample_text = """
This document outlines the technical specifications for implementing a RESTful API 
using FastAPI framework. The API includes endpoints for document upload, processing, 
and classification. Key components include SQLAlchemy for database operations, 
Pydantic for data validation, and uvicorn as the ASGI server.
"""

print("Testing document classification...")
result = classify_document(sample_text)
# print("\nClassification Result:")
# print(json.dumps(result, indent=2))

Testing document classification...
{'sequence': '\nThis document outlines the technical specifications for implementing a RESTful API \nusing FastAPI framework. The API includes endpoints for document upload, processing, \nand classification. Key components include SQLAlchemy for database operations, \nPydantic for data validation, and uvicorn as the ASGI server.\n', 'labels': ['Technical Documentation', 'General Article', 'Other', 'Academic Paper', 'Business Proposal', 'Legal Document'], 'scores': [0.8104618191719055, 0.07211217284202576, 0.06284535676240921, 0.022135350853204727, 0.018267560750246048, 0.014177760109305382]}

Classification Result:


In [None]:
# Test with some documents from our dataset
dataset_path = "../ml/Dataset/"

# Read a few sample documents
sample_files = [
    "python_doc.txt",
    "How I use LLMs as a staff engineer.txt", 
    "compujai.txt"
]

print("Testing classification on existing dataset documents:")
print("=" * 60)

for filename in sample_files:
    filepath = os.path.join(dataset_path, filename)
    if os.path.exists(filepath):
        with open(filepath, 'r', encoding='utf-8') as f:
            content = f.read()
        
        print(f"\nFile: {filename}")
        print(f"Content preview: {content[:100]}...")
        
        result = classify_document(content)
        print(f"Predicted Category: {result['predicted_category']}")
        print(f"Confidence: {result['confidence_score']}")
        print(f"Inference Time: {result['inference_time']}s")
    else:
        print(f"File not found: {filepath}")

## Creating ML Module for FastAPI Integration

Now let's create the classifier module that will be integrated into our FastAPI backend:

In [None]:
# Create the ML classifier module for FastAPI integration
classifier_module_code = '''
"""
Document Classifier Module using BART-Large-MNLI
For Smart Document Classifier FastAPI Application
"""

from transformers import pipeline
from typing import Dict, Any, List, Optional
import logging
import time

# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

class DocumentClassifier:
    """Document classifier using Facebook's BART-Large-MNLI model"""
    
    CATEGORIES = [
        "Technical Documentation",
        "Business Proposal", 
        "Legal Document",
        "Academic Paper",
        "General Article",
        "Other"
    ]
    
    def __init__(self):
        """Initialize the classifier"""
        self.classifier = None
        self.model_name = "facebook/bart-large-mnli"
        self.is_loaded = False
        
    def load_model(self):
        """Load the BART-Large-MNLI model"""
        try:
            logger.info(f"Loading {self.model_name} model...")
            self.classifier = pipeline(
                "zero-shot-classification", 
                model=self.model_name,
                device=-1  # Use CPU, change to 0 for GPU
            )
            self.is_loaded = True
            logger.info("Model loaded successfully!")
        except Exception as e:
            logger.error(f"Failed to load model: {str(e)}")
            raise e
    
    def classify(self, text: str, categories: Optional[List[str]] = None) -> Dict[str, Any]:
        """
        Classify a document
        
        Args:
            text: Document text to classify
            categories: Optional custom categories (defaults to self.CATEGORIES)
            
        Returns:
            Classification results with confidence scores
        """
        if not self.is_loaded:
            self.load_model()
            
        if not text or not text.strip():
            return {
                "error": "Empty text provided",
                "predicted_category": "Other",
                "confidence_score": 0.0
            }
            
        categories = categories or self.CATEGORIES
        
        try:
            # Truncate text if too long (BART token limit ~1024)
            max_length = 800  # Conservative limit for better performance
            if len(text) > max_length:
                text = text[:max_length] + "..."
                logger.info(f"Text truncated to {max_length} characters")
            
            # Perform classification
            start_time = time.time()
            result = self.classifier(text, categories)
            inference_time = time.time() - start_time
            
            # Format results
            classification_result = {
                "predicted_category": result["labels"][0],
                "confidence_score": round(result["scores"][0], 4),
                "all_scores": {
                    label: round(score, 4) 
                    for label, score in zip(result["labels"], result["scores"])
                },
                "inference_time": round(inference_time, 3),
                "model_used": self.model_name,
                "text_length": len(text)
            }
            
            logger.info(f"Classification completed: {result['labels'][0]} ({result['scores'][0]:.4f})")
            return classification_result
            
        except Exception as e:
            logger.error(f"Classification failed: {str(e)}")
            return {
                "error": str(e),
                "predicted_category": "Other",
                "confidence_score": 0.0
            }

# Global classifier instance (singleton pattern)
_classifier_instance = None

def get_classifier() -> DocumentClassifier:
    """Get or create the global classifier instance"""
    global _classifier_instance
    if _classifier_instance is None:
        _classifier_instance = DocumentClassifier()
    return _classifier_instance

def classify_document_text(text: str) -> Dict[str, Any]:
    """
    Convenience function to classify document text
    
    Args:
        text: Document text to classify
        
    Returns:
        Classification results
    """
    classifier = get_classifier()
    return classifier.classify(text)
'''

# Write the module to a file
with open('../backend/ml_classifier.py', 'w') as f:
    f.write(classifier_module_code)

print("✅ ML classifier module created at: backend/ml_classifier.py")
print("📦 Module includes:")
print("   - DocumentClassifier class")
print("   - Singleton pattern for model loading")
print("   - Error handling and logging")
print("   - Performance optimizations")

: 

## Fixing Multiprocessing Resource Warning

The warning about leaked semaphore objects occurs because the transformers library uses multiprocessing resources that aren't properly cleaned up on shutdown. We've implemented proper resource management:

1. **Added cleanup methods** to the DocumentClassifier class
2. **Registered atexit handlers** to cleanup on process termination  
3. **Added FastAPI shutdown event** to cleanup ML resources
4. **Explicit resource management** with garbage collection and PyTorch cache clearing

This prevents the resource tracker warning: `resource_tracker: There appear to be 1 leaked semaphore objects to clean up at shutdown`

## Better Text Truncation Strategies

The current simple truncation `text[:1000]` is problematic because:

1. **Character vs Token mismatch** - BART uses BPE tokenization, not character counting
2. **Loses important context** - May cut mid-sentence or lose document structure  
3. **Suboptimal for classification** - Beginning might not contain key classification signals
4. **No intelligent splitting** - Doesn't respect word/sentence boundaries

Let's implement better strategies:

In [None]:
from transformers import AutoTokenizer
import re

# Load BART tokenizer to properly count tokens
tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large-mnli")

def smart_truncate_text(text: str, max_tokens: int = 800) -> str:
    """
    Intelligently truncate text for BART classification
    
    Strategies:
    1. Use actual tokenizer to count tokens, not characters
    2. Take beginning + end of document (sandwich approach)  
    3. Preserve sentence boundaries
    4. Include document structure clues (titles, headers)
    """
    
    # Strategy 1: Simple tokenizer-based truncation
    def tokenizer_truncate(text: str, max_tokens: int) -> str:
        tokens = tokenizer.encode(text, add_special_tokens=False)
        if len(tokens) <= max_tokens:
            return text
        
        truncated_tokens = tokens[:max_tokens]
        return tokenizer.decode(truncated_tokens, skip_special_tokens=True)
    
    # Strategy 2: Sandwich approach (beginning + end)
    def sandwich_truncate(text: str, max_tokens: int) -> str:
        tokens = tokenizer.encode(text, add_special_tokens=False)
        if len(tokens) <= max_tokens:
            return text
            
        # Take 60% from beginning, 40% from end
        start_tokens = int(max_tokens * 0.6)
        end_tokens = max_tokens - start_tokens
        
        beginning = tokenizer.decode(tokens[:start_tokens], skip_special_tokens=True)
        ending = tokenizer.decode(tokens[-end_tokens:], skip_special_tokens=True)
        
        return f"{beginning}\n\n[...document continues...]\n\n{ending}"
    
    # Strategy 3: Smart chunking (preserve sentences)
    def smart_chunk_truncate(text: str, max_tokens: int) -> str:
        # Split into sentences
        sentences = re.split(r'[.!?]+', text)
        
        result_tokens = []
        current_length = 0
        
        for sentence in sentences:
            sentence = sentence.strip()
            if not sentence:
                continue
                
            sentence_tokens = tokenizer.encode(sentence, add_special_tokens=False)
            
            if current_length + len(sentence_tokens) <= max_tokens:
                result_tokens.extend(sentence_tokens)
                current_length += len(sentence_tokens)
            else:
                break
                
        if result_tokens:
            return tokenizer.decode(result_tokens, skip_special_tokens=True)
        else:
            # Fallback to simple truncation if first sentence is too long
            return tokenizer_truncate(text, max_tokens)
    
    # Choose strategy based on document characteristics
    token_count = len(tokenizer.encode(text, add_special_tokens=False))
    
    if token_count <= max_tokens:
        return text
    elif token_count > max_tokens * 3:  # Very long document
        return sandwich_truncate(text, max_tokens)
    else:  # Moderately long document
        return smart_chunk_truncate(text, max_tokens)

# Test the different approaches
test_text = """
# Technical Documentation: FastAPI Implementation Guide

This comprehensive guide covers the implementation of a FastAPI web application for document classification.

## Architecture Overview
The system uses a modular architecture with the following components:
- FastAPI framework for REST API endpoints
- SQLAlchemy ORM for database operations  
- Pydantic models for data validation
- BART-Large-MNLI for document classification
- Uvicorn ASGI server for deployment

## Implementation Details
The core application consists of several modules that work together to provide document classification capabilities.

### Database Layer
The database layer uses SQLAlchemy to manage document metadata and classification results.

### ML Classification
The machine learning component uses Facebook's BART-Large-MNLI model for zero-shot classification.

## Performance Considerations
For production deployment, consider the following optimizations:
- Model caching and singleton patterns
- Async processing for better throughput  
- Resource cleanup to prevent memory leaks
- Proper error handling and logging

## Conclusion
This FastAPI implementation provides a robust foundation for document classification tasks.
"""

print("Original text length:", len(test_text), "characters")
print("Original token count:", len(tokenizer.encode(test_text, add_special_tokens=False)), "tokens")
print()

# Test different truncation methods
truncated = smart_truncate_text(test_text, max_tokens=100)
print("Smart truncated length:", len(tokenizer.encode(truncated, add_special_tokens=False)), "tokens")
print("Smart truncated text:")
print(truncated)

In [None]:
# Demonstrate the problem with simple character truncation
sample_document = """
LEGAL AGREEMENT - TERMS OF SERVICE

IMPORTANT: Please read these terms carefully before using our service.

1. ACCEPTANCE OF TERMS
By accessing and using this service, you agree to be bound by the terms and conditions outlined in this agreement.

2. SERVICE DESCRIPTION  
Our document classification service uses artificial intelligence to automatically categorize uploaded documents into predefined categories.

3. USER OBLIGATIONS
Users must ensure that uploaded documents do not contain:
- Confidential or proprietary information
- Personal identifying information (PII) 
- Copyrighted material without permission
- Malicious code or harmful content

4. LIMITATION OF LIABILITY
In no event shall the company be liable for any indirect, incidental, special, consequential, or punitive damages.

5. TERMINATION
We reserve the right to terminate or suspend access to our service immediately, without prior notice.
"""

print("=== PROBLEM WITH SIMPLE CHARACTER TRUNCATION ===")
print()

# Current problematic approach (character-based)
simple_truncated = sample_document[:200] + "..."
print("Simple character truncation (200 chars):")
print(repr(simple_truncated))
print()

# Show what happens with tokenization
print("Tokens in simple truncated text:", len(tokenizer.encode(simple_truncated, add_special_tokens=False)))
print()

# Show what BART actually sees after tokenization
tokens = tokenizer.encode(simple_truncated, add_special_tokens=False)
decoded_back = tokenizer.decode(tokens, skip_special_tokens=True)
print("What BART actually processes:")
print(repr(decoded_back))
print()

print("=== ISSUES IDENTIFIED ===")
print("1. Cut off mid-sentence: 'By accessing and using this service, you agree to be bound by the terms...'")
print("2. Lost document type identifier: 'LEGAL AGREEMENT' is preserved, but context is lost")  
print("3. Character count ≠ token count: 203 characters ≈ ~50 tokens (varies by content)")
print("4. Classification might fail: Incomplete context about legal terms and obligations")