# Document Classification with BART-Large-MNLI

This notebook explores document classification using Facebook's BART-Large-MNLI model for our Smart Document Classifier API.

We'll use the **pipeline approach** (Option 1) as it's more suitable for our FastAPI integration:
- Simpler implementation
- Better error handling
- Built-in optimizations
- Easier to maintain

In [None]:
# Install required packages
!pip install transformers torch

In [2]:
# Import necessary libraries
from transformers import pipeline
import os
import json
from typing import List, Dict, Any
import time

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
# Initialize the zero-shot classification pipeline with BART-Large-MNLI
print("Loading BART-Large-MNLI model...")
classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
print("Model loaded successfully!")

Loading BART-Large-MNLI model...


Device set to use mps:0


Model loaded successfully!


In [5]:
# Define document categories for classification
DOCUMENT_CATEGORIES = [
    "Technical Documentation",
    "Business Proposal", 
    "Legal Document",
    "Academic Paper",
    "General Article",
    "Other"
]

print(f"Document categories: {DOCUMENT_CATEGORIES}")
print(f"Total categories: {len(DOCUMENT_CATEGORIES)}")

Document categories: ['Technical Documentation', 'Business Proposal', 'Legal Document', 'Academic Paper', 'General Article', 'Other']
Total categories: 6


In [10]:
def classify_document(text: str, categories: List[str] = DOCUMENT_CATEGORIES) -> Dict[str, Any]:
    """
    Classify a document using BART-Large-MNLI zero-shot classification
    
    Args:
        text: Document text to classify
        categories: List of possible categories
        
    Returns:
        Dictionary with classification results
    """
    try:
        # Truncate text if too long (BART has token limits)
        max_length = 1000  # Adjust based on performance needs
        if len(text) > max_length:
            text = text[:max_length] + "..."
        
        # Perform classification
        start_time = time.time()
        result = classifier(text, categories)
        inference_time = time.time() - start_time
        print(result)
        
        # Format results
        classification_result = {
            "predicted_category": result["labels"][0],
            "confidence_score": round(result["scores"][0], 4),
            "all_scores": {
                label: round(score, 4) 
                for label, score in zip(result["labels"], result["scores"])
            },
            "inference_time": round(inference_time, 3),
            "model_used": "facebook/bart-large-mnli"
        }
        
        return classification_result
        
    except Exception as e:
        return {
            "error": str(e),
            "predicted_category": None,
            "confidence_score": 0.0
        }

print("Document classification function created!")

Document classification function created!


In [11]:
# Test the classification with a sample document
sample_text = """
This document outlines the technical specifications for implementing a RESTful API 
using FastAPI framework. The API includes endpoints for document upload, processing, 
and classification. Key components include SQLAlchemy for database operations, 
Pydantic for data validation, and uvicorn as the ASGI server.
"""

print("Testing document classification...")
result = classify_document(sample_text)
# print("\nClassification Result:")
# print(json.dumps(result, indent=2))

Testing document classification...
{'sequence': '\nThis document outlines the technical specifications for implementing a RESTful API \nusing FastAPI framework. The API includes endpoints for document upload, processing, \nand classification. Key components include SQLAlchemy for database operations, \nPydantic for data validation, and uvicorn as the ASGI server.\n', 'labels': ['Technical Documentation', 'General Article', 'Other', 'Academic Paper', 'Business Proposal', 'Legal Document'], 'scores': [0.8104618191719055, 0.07211217284202576, 0.06284535676240921, 0.022135350853204727, 0.018267560750246048, 0.014177760109305382]}

Classification Result:


In [None]:
# Test with some documents from our dataset
dataset_path = "../ml/Dataset/"

# Read a few sample documents
sample_files = [
    "python_doc.txt",
    "How I use LLMs as a staff engineer.txt", 
    "compujai.txt"
]

print("Testing classification on existing dataset documents:")
print("=" * 60)

for filename in sample_files:
    filepath = os.path.join(dataset_path, filename)
    if os.path.exists(filepath):
        with open(filepath, 'r', encoding='utf-8') as f:
            content = f.read()
        
        print(f"\nFile: {filename}")
        print(f"Content preview: {content[:100]}...")
        
        result = classify_document(content)
        print(f"Predicted Category: {result['predicted_category']}")
        print(f"Confidence: {result['confidence_score']}")
        print(f"Inference Time: {result['inference_time']}s")
    else:
        print(f"File not found: {filepath}")

## Creating ML Module for FastAPI Integration

Now let's create the classifier module that will be integrated into our FastAPI backend:

In [None]:
# Create the ML classifier module for FastAPI integration
classifier_module_code = '''
"""
Document Classifier Module using BART-Large-MNLI
For Smart Document Classifier FastAPI Application
"""

from transformers import pipeline
from typing import Dict, Any, List, Optional
import logging
import time

# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

class DocumentClassifier:
    """Document classifier using Facebook's BART-Large-MNLI model"""
    
    CATEGORIES = [
        "Technical Documentation",
        "Business Proposal", 
        "Legal Document",
        "Academic Paper",
        "General Article",
        "Other"
    ]
    
    def __init__(self):
        """Initialize the classifier"""
        self.classifier = None
        self.model_name = "facebook/bart-large-mnli"
        self.is_loaded = False
        
    def load_model(self):
        """Load the BART-Large-MNLI model"""
        try:
            logger.info(f"Loading {self.model_name} model...")
            self.classifier = pipeline(
                "zero-shot-classification", 
                model=self.model_name,
                device=-1  # Use CPU, change to 0 for GPU
            )
            self.is_loaded = True
            logger.info("Model loaded successfully!")
        except Exception as e:
            logger.error(f"Failed to load model: {str(e)}")
            raise e
    
    def classify(self, text: str, categories: Optional[List[str]] = None) -> Dict[str, Any]:
        """
        Classify a document
        
        Args:
            text: Document text to classify
            categories: Optional custom categories (defaults to self.CATEGORIES)
            
        Returns:
            Classification results with confidence scores
        """
        if not self.is_loaded:
            self.load_model()
            
        if not text or not text.strip():
            return {
                "error": "Empty text provided",
                "predicted_category": "Other",
                "confidence_score": 0.0
            }
            
        categories = categories or self.CATEGORIES
        
        try:
            # Truncate text if too long (BART token limit ~1024)
            max_length = 800  # Conservative limit for better performance
            if len(text) > max_length:
                text = text[:max_length] + "..."
                logger.info(f"Text truncated to {max_length} characters")
            
            # Perform classification
            start_time = time.time()
            result = self.classifier(text, categories)
            inference_time = time.time() - start_time
            
            # Format results
            classification_result = {
                "predicted_category": result["labels"][0],
                "confidence_score": round(result["scores"][0], 4),
                "all_scores": {
                    label: round(score, 4) 
                    for label, score in zip(result["labels"], result["scores"])
                },
                "inference_time": round(inference_time, 3),
                "model_used": self.model_name,
                "text_length": len(text)
            }
            
            logger.info(f"Classification completed: {result['labels'][0]} ({result['scores'][0]:.4f})")
            return classification_result
            
        except Exception as e:
            logger.error(f"Classification failed: {str(e)}")
            return {
                "error": str(e),
                "predicted_category": "Other",
                "confidence_score": 0.0
            }

# Global classifier instance (singleton pattern)
_classifier_instance = None

def get_classifier() -> DocumentClassifier:
    """Get or create the global classifier instance"""
    global _classifier_instance
    if _classifier_instance is None:
        _classifier_instance = DocumentClassifier()
    return _classifier_instance

def classify_document_text(text: str) -> Dict[str, Any]:
    """
    Convenience function to classify document text
    
    Args:
        text: Document text to classify
        
    Returns:
        Classification results
    """
    classifier = get_classifier()
    return classifier.classify(text)
'''

# Write the module to a file
with open('../backend/ml_classifier.py', 'w') as f:
    f.write(classifier_module_code)

print("✅ ML classifier module created at: backend/ml_classifier.py")
print("📦 Module includes:")
print("   - DocumentClassifier class")
print("   - Singleton pattern for model loading")
print("   - Error handling and logging")
print("   - Performance optimizations")

: 

## Fixing Multiprocessing Resource Warning

The warning about leaked semaphore objects occurs because the transformers library uses multiprocessing resources that aren't properly cleaned up on shutdown. We've implemented proper resource management:

1. **Added cleanup methods** to the DocumentClassifier class
2. **Registered atexit handlers** to cleanup on process termination  
3. **Added FastAPI shutdown event** to cleanup ML resources
4. **Explicit resource management** with garbage collection and PyTorch cache clearing

This prevents the resource tracker warning: `resource_tracker: There appear to be 1 leaked semaphore objects to clean up at shutdown`