# Document Classification with pre-trained ml models

This notebook explores document classification using models for our Smart Document Classifier API.
- facebook/bart-large-mnli
- MoritzLaurer/mDeBERTa-v3-base-mnli-xnli

In [1]:
# Import additional libraries for model comparison
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import numpy as np
from sklearn.metrics import accuracy_score, classification_report
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import warnings
import time  # Fix the time import
warnings.filterwarnings('ignore')

In [2]:
# Import necessary libraries
from transformers import pipeline, AutoTokenizer
import os
import json
from typing import List, Dict, Any
import time

In [92]:
# Define document categories for classification
DOCUMENT_CATEGORIES = [
    "Technical Documentation",
    "Business Proposal", 
    "Legal Document",
    "Academic Paper",
    "General Article",
    # "Other"
]

print(f"Document categories: {DOCUMENT_CATEGORIES}")
print(f"Total categories: {len(DOCUMENT_CATEGORIES)}")

Document categories: ['Technical Documentation', 'Business Proposal', 'Legal Document', 'Academic Paper', 'General Article']
Total categories: 5


In [93]:
def classify_document(model, text: str, categories: List[str] = DOCUMENT_CATEGORIES) -> Dict[str, Any]:
    """
    Classify a document using a zero-shot classification model by chunking the text.

    Args:
        model: The model to use for classification.
        text: Document text to classify.
        categories: List of possible categories.
        
    Returns:
        Dictionary with classification results.
    """
    try:
        # Step 1: Initialize tokenizer
        from transformers import AutoTokenizer
        tokenizer = AutoTokenizer.from_pretrained(model)
        
        max_chunk_tokens = 800 # We will use this as our new chunk size
        
        # IMPROVED: Tokenize the entire text once
        tokens = tokenizer.encode(text, add_special_tokens=False)
        full_token_count = len(tokens) # Keep track of the original token count

        # NEW: Initialize lists to store results from each chunk
        all_scores = {category: [] for category in categories}
        total_inference_time = 0
        
        # NEW: Chunk the tokens and process each chunk
        # This loop iterates through the list of tokens in steps of max_chunk_tokens
        for i in range(0, full_token_count, max_chunk_tokens):
            # NEW: Get the tokens for the current chunk
            chunk_tokens = tokens[i:i + max_chunk_tokens]
            
            # NEW: Decode the chunk of tokens back into text
            chunk_text = tokenizer.decode(chunk_tokens, skip_special_tokens=True)
            
            # Perform classification for the current chunk
            start_time = time.time()
            result = classifier(chunk_text, categories, multi_label=False)
            total_inference_time += (time.time() - start_time)
            
            # NEW: Aggregate scores from this chunk
            for label, score in zip(result["labels"], result["scores"]):
                all_scores[label].append(score)

        # Step 5: Average the scores across all chunks
        # This calculates the average confidence score for each category
        averaged_scores = {
            category: sum(scores) / len(scores) if scores else 0
            for category, scores in all_scores.items()
        }
        
        # Find the category with the highest average score
        predicted_category = max(averaged_scores, key=averaged_scores.get)
        confidence_score = averaged_scores[predicted_category]

        # Format results
        classification_result = {
            "predicted_category": predicted_category,
            "confidence_score": round(confidence_score, 4),
            "all_scores": {label: round(score, 4) for label, score in averaged_scores.items()},
            "inference_time": round(total_inference_time, 3), # Total inference time is the sum of all chunks' inference times
            "model_used": model,
            "token_count": full_token_count,
            "was_truncated": False # The entire document is processed now, so truncation is not applicable
        }
        
        return classification_result
        
    except Exception as e:
        return {
            "error": str(e),
            "predicted_category": None,
            "confidence_score": 0.0
        }

In [94]:
# Test the classification with a sample document
sample_text = """
This document outlines the technical specifications for implementing a RESTful API 
using FastAPI framework. The API includes endpoints for document upload, processing, 
and classification. Key components include SQLAlchemy for database operations, 
Pydantic for data validation, and uvicorn as the ASGI server.
"""

# print("Testing document classification...")
# for model in models:
#     print(f"Loading {model} model...")
#     classifier = pipeline("zero-shot-classification", model=model)
#     print(f"{model} loaded successfully!")
#     print("-" * 30)
#     result = classify_document(model, sample_text)
#     print("\nClassification Result:")
#     print(json.dumps(result, indent=2))
#     print("-" * 50)

Testing document classification...
Loading MoritzLaurer/mDeBERTa-v3-base-mnli-xnli model...
MoritzLaurer/mDeBERTa-v3-base-mnli-xnli loaded successfully!
------------------------------

Classification Result:
{
  "predicted_category": "Technical Documentation",
  "confidence_score": 0.9506,
  "all_scores": {
    "Technical Documentation": 0.9506,
    "Business Proposal": 0.0105,
    "Legal Document": 0.0056,
    "Academic Paper": 0.0186,
    "General Article": 0.0147
  },
  "inference_time": 0.424,
  "model_used": "MoritzLaurer/mDeBERTa-v3-base-mnli-xnli",
  "token_count": 72,
  "was_truncated": false
}
--------------------------------------------------
Loading facebook/bart-large-mnli model...
facebook/bart-large-mnli loaded successfully!
------------------------------

Classification Result:
{
  "predicted_category": "Technical Documentation",
  "confidence_score": 0.8648,
  "all_scores": {
    "Technical Documentation": 0.8648,
    "Business Proposal": 0.0195,
    "Legal Document": 0

In [95]:
from transformers import pipeline
classifier = pipeline("zero-shot-classification", model="MoritzLaurer/mDeBERTa-v3-base-mnli-xnli")

sample_text = """
This document outlines the technical specifications for implementing a RESTful API 
using FastAPI framework. The API includes endpoints for document upload, processing, 
and classification. Key components include SQLAlchemy for database operations, 
Pydantic for data validation, and uvicorn as the ASGI server.
"""
# result = classify_document(model, sample_text)
# result

{'predicted_category': 'Technical Documentation',
 'confidence_score': 0.9506,
 'all_scores': {'Technical Documentation': 0.9506,
  'Business Proposal': 0.0105,
  'Legal Document': 0.0056,
  'Academic Paper': 0.0186,
  'General Article': 0.0147},
 'inference_time': 0.172,
 'model_used': 'facebook/bart-large-mnli',
 'token_count': 66,
 'was_truncated': False}

In [96]:
import os
from PyPDF2 import PdfReader
from docx import Document  # For .docx files

def read_file_content(filepath):
    """
    Reads the content of a file based on its extension,
    supporting .txt, .docx, and .pdf formats.
    """
    filename, file_extension = os.path.splitext(filepath)
    file_extension = file_extension.lower()

    if file_extension == '.txt':
        # Standard Python file reader for plain text files
        with open(filepath, 'r', encoding='utf-8', errors='ignore') as f:
            return f.read()
    elif file_extension == '.docx':
        # Use python-docx to read .docx files
        doc = Document(filepath)
        return '\n'.join([paragraph.text for paragraph in doc.paragraphs])
    elif file_extension == '.pdf':
        # Use PyPDF2 to read .pdf files
        reader = PdfReader(filepath)
        content = ""
        for page in reader.pages:
            content += page.extract_text()
        return content
    elif file_extension == '.doc':
        return "Error: Reading .doc files without external dependencies is not supported in this function."
    else:
        return f"Unsupported file type: {file_extension}"

In [97]:
# Create the ML classifier module for FastAPI integration
classifier_module_code = '''
"""
Document Classifier Module using BART-Large-MNLI
For Smart Document Classifier FastAPI Application
"""

from transformers import pipeline
from typing import Dict, Any, List, Optional
import logging
import time

# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

class DocumentClassifier:
    """Document classifier using Facebook's BART-Large-MNLI model"""
    
    CATEGORIES = [
        "Technical Documentation",
        "Business Proposal", 
        "Legal Document",
        "Academic Paper",
        "General Article",
        "Other"
    ]
    
    def __init__(self):
        """Initialize the classifier"""
        self.classifier = None
        self.model_name = "facebook/bart-large-mnli"
        self.is_loaded = False
        
    def load_model(self):
        """Load the BART-Large-MNLI model"""
        try:
            logger.info(f"Loading {self.model_name} model...")
            self.classifier = pipeline(
                "zero-shot-classification", 
                model=self.model_name,
                device=-1  # Use CPU, change to 0 for GPU
            )
            self.is_loaded = True
            logger.info("Model loaded successfully!")
        except Exception as e:
            logger.error(f"Failed to load model: {str(e)}")
            raise e
    
    def classify(self, text: str, categories: Optional[List[str]] = None) -> Dict[str, Any]:
        """
        Classify a document
        
        Args:
            text: Document text to classify
            categories: Optional custom categories (defaults to self.CATEGORIES)
            
        Returns:
            Classification results with confidence scores
        """
        if not self.is_loaded:
            self.load_model()
            
        if not text or not text.strip():
            return {
                "error": "Empty text provided",
                "predicted_category": "Other",
                "confidence_score": 0.0
            }
            
        categories = categories or self.CATEGORIES
        
        try:
            # Truncate text if too long (BART token limit ~1024)
            max_length = 800  # Conservative limit for better performance
            if len(text) > max_length:
                text = text[:max_length] + "..."
                logger.info(f"Text truncated to {max_length} characters")
            
            # Perform classification
            start_time = time.time()
            result = self.classifier(text, categories)
            inference_time = time.time() - start_time
            
            # Format results
            classification_result = {
                "predicted_category": result["labels"][0],
                "confidence_score": round(result["scores"][0], 4),
                "all_scores": {
                    label: round(score, 4) 
                    for label, score in zip(result["labels"], result["scores"])
                },
                "inference_time": round(inference_time, 3),
                "model_used": self.model_name,
                "text_length": len(text)
            }
            
            logger.info(f"Classification completed: {result['labels'][0]} ({result['scores'][0]:.4f})")
            return classification_result
            
        except Exception as e:
            logger.error(f"Classification failed: {str(e)}")
            return {
                "error": str(e),
                "predicted_category": "Other",
                "confidence_score": 0.0
            }

# Global classifier instance (singleton pattern)
_classifier_instance = None

def get_classifier() -> DocumentClassifier:
    """Get or create the global classifier instance"""
    global _classifier_instance
    if _classifier_instance is None:
        _classifier_instance = DocumentClassifier()
    return _classifier_instance

def classify_document_text(text: str) -> Dict[str, Any]:
    """
    Convenience function to classify document text
    
    Args:
        text: Document text to classify
        
    Returns:
        Classification results
    """
    classifier = get_classifier()
    return classifier.classify(text)
'''

# Write the module to a file
with open('../backend/ml_classifier.py', 'w') as f:
    f.write(classifier_module_code)

print("✅ ML classifier module created at: backend/ml_classifier.py")
print("📦 Module includes:")
print("   - DocumentClassifier class")
print("   - Singleton pattern for model loading")
print("   - Error handling and logging")
print("   - Performance optimizations")

✅ ML classifier module created at: backend/ml_classifier.py
📦 Module includes:
   - DocumentClassifier class
   - Singleton pattern for model loading
   - Error handling and logging
   - Performance optimizations


In [None]:
from transformers import AutoTokenizer
import re

# Load BART tokenizer to properly count tokens
tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large-mnli")

def smart_truncate_text(text: str, max_tokens: int = 800) -> str:
    """
    Intelligently truncate text for BART classification
    
    Strategies:
    1. Use actual tokenizer to count tokens, not characters
    2. Take beginning + end of document (sandwich approach)  
    3. Preserve sentence boundaries
    4. Include document structure clues (titles, headers)
    """
    
    # Strategy 1: Simple tokenizer-based truncation
    def tokenizer_truncate(text: str, max_tokens: int) -> str:
        tokens = tokenizer.encode(text, add_special_tokens=False)
        if len(tokens) <= max_tokens:
            return text
        
        truncated_tokens = tokens[:max_tokens]
        return tokenizer.decode(truncated_tokens, skip_special_tokens=True)
    
    # Strategy 2: Sandwich approach (beginning + end)
    def sandwich_truncate(text: str, max_tokens: int) -> str:
        tokens = tokenizer.encode(text, add_special_tokens=False)
        if len(tokens) <= max_tokens:
            return text
            
        # Take 60% from beginning, 40% from end
        start_tokens = int(max_tokens * 0.6)
        end_tokens = max_tokens - start_tokens
        
        beginning = tokenizer.decode(tokens[:start_tokens], skip_special_tokens=True)
        ending = tokenizer.decode(tokens[-end_tokens:], skip_special_tokens=True)
        
        return f"{beginning}\n\n[...document continues...]\n\n{ending}"
    
    # Strategy 3: Smart chunking (preserve sentences)
    def smart_chunk_truncate(text: str, max_tokens: int) -> str:
        # Split into sentences
        sentences = re.split(r'[.!?]+', text)
        
        result_tokens = []
        current_length = 0
        
        for sentence in sentences:
            sentence = sentence.strip()
            if not sentence:
                continue
                
            sentence_tokens = tokenizer.encode(sentence, add_special_tokens=False)
            
            if current_length + len(sentence_tokens) <= max_tokens:
                result_tokens.extend(sentence_tokens)
                current_length += len(sentence_tokens)
            else:
                break
                
        if result_tokens:
            return tokenizer.decode(result_tokens, skip_special_tokens=True)
        else:
            # Fallback to simple truncation if first sentence is too long
            return tokenizer_truncate(text, max_tokens)
    
    # Choose strategy based on document characteristics
    token_count = len(tokenizer.encode(text, add_special_tokens=False))
    
    if token_count <= max_tokens:
        return text
    elif token_count > max_tokens * 3:  # Very long document
        return sandwich_truncate(text, max_tokens)
    else:  # Moderately long document
        return smart_chunk_truncate(text, max_tokens)

# Test the different approaches
test_text = """
# Technical Documentation: FastAPI Implementation Guide

This comprehensive guide covers the implementation of a FastAPI web application for document classification.

## Architecture Overview
The system uses a modular architecture with the following components:
- FastAPI framework for REST API endpoints
- SQLAlchemy ORM for database operations  
- Pydantic models for data validation
- BART-Large-MNLI for document classification
- Uvicorn ASGI server for deployment

## Implementation Details
The core application consists of several modules that work together to provide document classification capabilities.

### Database Layer
The database layer uses SQLAlchemy to manage document metadata and classification results.

### ML Classification
The machine learning component uses Facebook's BART-Large-MNLI model for zero-shot classification.

## Performance Considerations
For production deployment, consider the following optimizations:
- Model caching and singleton patterns
- Async processing for better throughput  
- Resource cleanup to prevent memory leaks
- Proper error handling and logging

## Conclusion
This FastAPI implementation provides a robust foundation for document classification tasks.
"""

print("Original text length:", len(test_text), "characters")
print("Original token count:", len(tokenizer.encode(test_text, add_special_tokens=False)), "tokens")
print()

# Test different truncation methods
truncated = smart_truncate_text(test_text, max_tokens=100)
print("Smart truncated length:", len(tokenizer.encode(truncated, add_special_tokens=False)), "tokens")
print("Smart truncated text:")
print(truncated)

In [None]:
# Demonstrate the problem with simple character truncation
sample_document = """
LEGAL AGREEMENT - TERMS OF SERVICE

IMPORTANT: Please read these terms carefully before using our service.

1. ACCEPTANCE OF TERMS
By accessing and using this service, you agree to be bound by the terms and conditions outlined in this agreement.

2. SERVICE DESCRIPTION  
Our document classification service uses artificial intelligence to automatically categorize uploaded documents into predefined categories.

3. USER OBLIGATIONS
Users must ensure that uploaded documents do not contain:
- Confidential or proprietary information
- Personal identifying information (PII) 
- Copyrighted material without permission
- Malicious code or harmful content

4. LIMITATION OF LIABILITY
In no event shall the company be liable for any indirect, incidental, special, consequential, or punitive damages.

5. TERMINATION
We reserve the right to terminate or suspend access to our service immediately, without prior notice.
"""

print("=== PROBLEM WITH SIMPLE CHARACTER TRUNCATION ===")
print()

# Current problematic approach (character-based)
simple_truncated = sample_document[:200] + "..."
print("Simple character truncation (200 chars):")
print(repr(simple_truncated))
print()

# Show what happens with tokenization
print("Tokens in simple truncated text:", len(tokenizer.encode(simple_truncated, add_special_tokens=False)))
print()

# Show what BART actually sees after tokenization
tokens = tokenizer.encode(simple_truncated, add_special_tokens=False)
decoded_back = tokenizer.decode(tokens, skip_special_tokens=True)
print("What BART actually processes:")
print(repr(decoded_back))
print()

print("=== ISSUES IDENTIFIED ===")
print("1. Cut off mid-sentence: 'By accessing and using this service, you agree to be bound by the terms...'")
print("2. Lost document type identifier: 'LEGAL AGREEMENT' is preserved, but context is lost")  
print("3. Character count ≠ token count: 203 characters ≈ ~50 tokens (varies by content)")
print("4. Classification might fail: Incomplete context about legal terms and obligations")

In [108]:
from transformers.utils import logging
import pandas as pd
import os
import time

# logging.set_verbosity_error()

# Dataset
dataset_path = "../data/Dataset"
sample_files = []

for f in os.listdir(dataset_path):
    sample_files.append(f)

# sample_files = [
#     "compujai.txt", # Expected: Technical
#     "ALERTA-Net- A Temporal Distance-Aware Recurrent Networks for Stock Movement and Volatility Prediction.pdf" # Expected: Academic
# ]

# Models
models = [
    "MoritzLaurer/mDeBERTa-v3-base-mnli-xnli",
    "facebook/bart-large-mnli"
]


print("Testing classification on existing dataset documents:")
print("=" * 60)

# Pre-load all the classifiers to avoid redundant loading in the loop
classifiers = {model: pipeline("zero-shot-classification", model=model) for model in models}

# Create a single list to hold all results from all files and models
all_results = []

for filename in sample_files:
    filepath = os.path.join(dataset_path, filename)
    if os.path.exists(filepath):
        content = read_file_content(filepath)
        print(f"Current file: {filename}")
        
        # Now, iterate through each model for the current file
        for model_name in models:
            classifier = classifiers[model_name]
            
            # Start timer
            start_time = time.time()
            
            # Perform classification
            result = classify_document(model_name, content)
            
            # Calculate inference time
            inference_time = time.time() - start_time
            
            # Create a dictionary for this single result
            result_entry = {
                'filename': filename,
                'model': model_name,
                'predicted_category': result['predicted_category'],
                'confidence_score': result['confidence_score'],
                'all_scores': result['all_scores'],
                'inference_time': inference_time,
                'token_count': result['token_count'],
            }
            
            # Append this single result to the master list
            all_results.append(result_entry)
            print(f"Finishing current model {model_name}.")
        print("-" * 30)
            
    else:
        print(f"File not found: {filepath}")

# --- DataFrame Creation and Printing (Moved outside of all loops) ---
# Flatten the list of results and prepare data for the DataFrame
data_list = []
for result in all_results:
    row = {
        'filename': result['filename'],
        'model': result['model'],
        'Predicted Category': result['predicted_category'],
        'Confidence Score': result['confidence_score'],
        'Inference Time (s)': result['inference_time'],
        'Token Count': result['token_count'],
    }
    row.update({f"Score_{key}": value for key, value in result['all_scores'].items()})
    data_list.append(row)

# Create the DataFrame and set the MultiIndex
df = pd.DataFrame(data_list)
df = df.set_index(['filename', 'model'])

print("\nProcessing complete.")
print("=" * 100)
print("\nFinal Results Table:")
df

Testing classification on existing dataset documents:
Current file: Python Patterns .txt
Finishing current model MoritzLaurer/mDeBERTa-v3-base-mnli-xnli.
Finishing current model facebook/bart-large-mnli.
Finishing current file Python Patterns .txt.
Current file: Proposal for the Implementation of DAO for Enhanced Data Governance and Collaboritive Research in Genomic Sequencing.txt
Finishing current model MoritzLaurer/mDeBERTa-v3-base-mnli-xnli.
Finishing current model facebook/bart-large-mnli.
Finishing current file Proposal for the Implementation of DAO for Enhanced Data Governance and Collaboritive Research in Genomic Sequencing.txt.
Current file: How I use LLMs as a staff engineer _ sean goedecke.pdf
Finishing current model MoritzLaurer/mDeBERTa-v3-base-mnli-xnli.
Finishing current model facebook/bart-large-mnli.
Finishing current file How I use LLMs as a staff engineer _ sean goedecke.pdf.
Current file: Chat UI Pattern.txt
Finishing current model MoritzLaurer/mDeBERTa-v3-base-mnli-

Unnamed: 0_level_0,Unnamed: 1_level_0,Predicted Category,Confidence Score,Inference Time (s),Token Count,Truncated,Score_Technical Documentation,Score_Business Proposal,Score_Legal Document,Score_Academic Paper,Score_General Article
filename,model,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1
Python Patterns .txt,MoritzLaurer/mDeBERTa-v3-base-mnli-xnli,Technical Documentation,0.5641,2.691173,995,False,0.5641,0.1,0.0708,0.1376,0.1275
Python Patterns .txt,facebook/bart-large-mnli,Technical Documentation,0.3802,2.947661,1110,False,0.3802,0.0908,0.1205,0.0619,0.3465
Proposal for the Implementation of DAO for Enhanced Data Governance and Collaboritive Research in Genomic Sequencing.txt,MoritzLaurer/mDeBERTa-v3-base-mnli-xnli,Legal Document,0.3165,1.292495,1472,False,0.127,0.1188,0.3165,0.3007,0.137
Proposal for the Implementation of DAO for Enhanced Data Governance and Collaboritive Research in Genomic Sequencing.txt,facebook/bart-large-mnli,Legal Document,0.2588,2.913074,1112,False,0.2166,0.1943,0.2588,0.1929,0.1374
How I use LLMs as a staff engineer _ sean goedecke.pdf,MoritzLaurer/mDeBERTa-v3-base-mnli-xnli,Legal Document,0.2561,1.530928,1800,False,0.188,0.1687,0.2561,0.1738,0.2133
How I use LLMs as a staff engineer _ sean goedecke.pdf,facebook/bart-large-mnli,General Article,0.2776,3.971126,1676,False,0.1927,0.2175,0.1704,0.1418,0.2776
Chat UI Pattern.txt,MoritzLaurer/mDeBERTa-v3-base-mnli-xnli,General Article,0.2469,0.880229,609,False,0.2183,0.1603,0.22,0.1545,0.2469
Chat UI Pattern.txt,facebook/bart-large-mnli,Technical Documentation,0.276,1.517312,540,False,0.276,0.1636,0.2092,0.1555,0.1956
Why is this CEO bragging.docx,MoritzLaurer/mDeBERTa-v3-base-mnli-xnli,Legal Document,0.217,3.356532,5637,False,0.1832,0.2136,0.217,0.1707,0.2155
Why is this CEO bragging.docx,facebook/bart-large-mnli,General Article,0.3488,11.385205,5003,False,0.1154,0.2574,0.1504,0.128,0.3488
