In [2]:
# Import necessary libraries
import PyPDF2
import pytesseract
from PIL import Image
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import transforms, datasets
import os
import re
import matplotlib.pyplot as plt

In [3]:
# OCR Functions for Text Extraction
def extract_text_from_pdf(pdf_path):
    """Extract text from a PDF file."""
    with open(pdf_path, 'rb') as file:
        reader = PyPDF2.PdfReader(file)
        text = ""
        for page in reader.pages:
            text += page.extract_text()
        return text

def extract_text_from_image(image_path):
    """Extract text from an image using OCR."""
    image = Image.open(image_path)
    text = pytesseract.image_to_string(image)
    return text

In [4]:
# Information Extraction Function
def extract_invoice_info(text):
    """Extract information from an invoice using regular expressions."""
    info = {}
    info['Invoice Number'] = re.search(r'order id\s*:\s*(\d+)', text, re.IGNORECASE).group(1) if re.search(r'order id\s*:\s*(\d+)', text, re.IGNORECASE) else "Not found"
    info['Invoice Date'] = re.search(r'order date\s*:\s*(\d{4}-\d{2}-\d{2})', text, re.IGNORECASE).group(1) if re.search(r'order date\s*:\s*(\d{4}-\d{2}-\d{2})', text, re.IGNORECASE) else "Not found"
    info['Customer ID'] = re.search(r'customer id\s*:\s*(\w+)', text, re.IGNORECASE).group(1) if re.search(r'customer id\s*:\s*(\w+)', text, re.IGNORECASE) else "Not found"
    info['Total Amount'] = re.search(r'total price\s*(\d+\.\d+)', text, re.IGNORECASE).group(1) if re.search(r'total price\s*(\d+\.\d+)', text, re.IGNORECASE) else "Not found"
    info['Contact Name'] = re.search(r'contact name\s*:\s*([A-Za-z\s]+)', text, re.IGNORECASE).group(1).strip() if re.search(r'contact name\s*:\s*([A-Za-z\s]+)', text, re.IGNORECASE) else "Not found"
    info['Address'] = re.search(r'address\s*:\s*([A-Za-z0-9\s,]+)', text, re.IGNORECASE).group(1).strip() if re.search(r'address\s*:\s*([A-Za-z0-9\s,]+)', text, re.IGNORECASE) else "Not found"
    info['City'] = re.search(r'city\s*:\s*([A-Za-z\s]+)', text, re.IGNORECASE).group(1).strip() if re.search(r'city\s*:\s*([A-Za-z\s]+)', text, re.IGNORECASE) else "Not found"
    info['Postal Code'] = re.search(r'postal code\s*:\s*(\d{4,5}-\d{3})', text, re.IGNORECASE).group(1) if re.search(r'postal code\s*:\s*(\d{4,5}-\d{3})', text, re.IGNORECASE) else "Not found"
    info['Country'] = re.search(r'country\s*:\s*([A-Za-z]+)', text, re.IGNORECASE).group(1).strip() if re.search(r'country\s*:\s*([A-Za-z]+)', text, re.IGNORECASE) else "Not found"
    info['Phone'] = re.search(r'phone\s*:\s*\(?\d{2}\)?\s*\d{3}-\d{4}', text, re.IGNORECASE).group(0).replace('phone:', '').strip() if re.search(r'phone\s*:\s*\(?\d{2}\)?\s*\d{3}-\d{4}', text, re.IGNORECASE) else "Not found"
    info['Fax'] = re.search(r'fax\s*:\s*\(?\d{2}\)?\s*\d{3}-\d{4}', text, re.IGNORECASE).group(0).replace('fax:', '').strip() if re.search(r'fax\s*:\s*\(?\d{2}\)?\s*\d{3}-\d{4}', text, re.IGNORECASE) else "Not found"

    products = []
    product_lines = re.findall(r'(\d+)\s+([A-Za-z\s]+)\s+(\d+)\s+(\d+\.\d+)', text, re.IGNORECASE)
    for product in product_lines:
        products.append({
            'Product ID': product[0],
            'Product Name': product[1].strip(),
            'Quantity': product[2],
            'Unit Price': product[3]
        })
    info['Products'] = products if products else "Not found"

    return info

In [5]:
# CNN Model Definition (Functional)
def create_cnn_model(num_classes):
    """Create a CNN model using functional layers."""
    model = nn.Sequential(
        nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1),  # Grayscale input
        nn.ReLU(),
        nn.MaxPool2d(kernel_size=2, stride=2, padding=0),
        nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1),
        nn.ReLU(),
        nn.MaxPool2d(kernel_size=2, stride=2, padding=0),
        nn.Flatten(),
        nn.Linear(32 * 56 * 56, 128),  # Assuming 224x224 input images
        nn.ReLU(),
        nn.Dropout(0.5),
        nn.Linear(128, num_classes)
    )
    return model

In [7]:
# CNN Training and Evaluation
def train_cnn(image_dir, num_epochs=10, batch_size=32, learning_rate=0.001):
    """Train the CNN for document classification."""
    # Define transforms
    transform = transforms.Compose([
        transforms.Grayscale(num_output_channels=1),
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,))
    ])

    # Load dataset using ImageFolder
    dataset = datasets.ImageFolder(image_dir, transform=transform)
    train_size = int(0.8 * len(dataset))
    test_size = len(dataset) - train_size
    train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, test_size])
    
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

    # Create model
    num_classes = len(dataset.classes)
    model = create_cnn_model(num_classes)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)

    # Training loop
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    print("Training CNN for Document Classification:")
    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        print(f"Epoch {epoch+1}/{num_epochs}, Loss: {running_loss/len(train_loader)}")

    # Evaluation
    model.eval()
    correct = 0
    total = 0
    invoice_correct = 0
    invoice_total = 0
    class_names = dataset.classes
    invoice_idx = class_names.index('invoice') if 'invoice' in class_names else -1
    
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            
            # Track accuracy for invoices specifically
            if invoice_idx != -1:
                invoice_mask = (labels == invoice_idx)
                invoice_total += invoice_mask.sum().item()
                invoice_correct += (predicted[invoice_mask] == labels[invoice_mask]).sum().item()
    
    accuracy = 100 * correct / total if total > 0 else 0
    invoice_accuracy = 100 * invoice_correct / invoice_total if invoice_total > 0 else 0
    
    print("CNN Analysis:")
    print(f"Overall Test Accuracy: {accuracy}%")
    print(f"Invoice Recognition Accuracy: {invoice_accuracy}%")
    print("Class Names:", class_names)

    # Save the model
    torch.save(model.state_dict(), "../models/cnn_model.pth")
    return model, class_names, transform

def predict_document_type(image_path, model, class_names, transform):
    """Predict the document type of a single image."""
    image = Image.open(image_path)
    image = transform(image).unsqueeze(0)  # Add batch dimension
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    image = image.to(device)
    
    model.eval()
    with torch.no_grad():
        output = model(image)
        _, predicted = torch.max(output, 1)
        predicted_class = class_names[predicted.item()]
    return predicted_class

In [8]:
# CNN Pipeline: Classification + Information Extraction
def cnn_pipeline(image_path, input_type, image_dir="../data/images"):
    """Run the entire CNN pipeline: classify document type and extract information."""
    # Step 1: Train the CNN (or load a pre-trained model)
    try:
        # Check if a pre-trained model exists
        num_classes = 4  # Adjust based on your dataset (e.g., invoice, shipping_order, purchase_order, report)
        model = create_cnn_model(num_classes)
        model.load_state_dict(torch.load("../models/cnn_model.pth"))
        class_names = ['invoice', 'shipping_order', 'purchase_order', 'report']  # Adjust based on your dataset
        transform = transforms.Compose([
            transforms.Grayscale(num_output_channels=1),
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize((0.5,), (0.5,))
        ])
        print("Loaded pre-trained CNN model.")
    except FileNotFoundError:
        print("No pre-trained model found. Training a new CNN...")
        model, class_names, transform = train_cnn(image_dir, num_epochs=10)

    # Step 2: Classify the document type
    if input_type == 'image':
        category = predict_document_type(image_path, model, class_names, transform)
    elif input_type == 'pdf':
        # For PDFs, we need to convert to an image first (simplified for this example)
        raise NotImplementedError("PDF input requires conversion to image. Please provide an image file.")
    else:
        raise ValueError("Unsupported input type. Use 'image' or 'pdf'.")

    print("CNN Pipeline Analysis:")
    print(f"Predicted Category: {category}")

    # Step 3: Extract information if the document is an invoice
    if category.lower() == 'invoice':
        # Extract text using OCR
        text = extract_text_from_image(image_path)
        print("Extracted Text:\n", text)
        
        # Extract structured information
        extracted_info = extract_invoice_info(text)
        print("Extracted Information:", extracted_info)
        return category, extracted_info
    else:
        print("Document is not an invoice. Skipping information extraction.")
        return category, None

# Run the pipeline on a sample image
image_path = "../data/images/invoice/0.png"  # Adjust this path to an actual image
input_type = 'image'
category, extracted_info = cnn_pipeline(image_path, input_type)

No pre-trained model found. Training a new CNN...


FileNotFoundError: [Errno 2] No such file or directory: '../data/images'

In [None]:
# Visualize the Image
image = Image.open(image_path)
plt.imshow(image, cmap='gray')
plt.title(f"Predicted Category: {category}")
plt.axis('off')
plt.show()