In [1]:
import gradio as gr
import pytesseract
from PIL import Image
import re
import tempfile
import matplotlib.pyplot as plt
import pandas as pd

# Predefined field names to extract for each document type
FIELD_NAMES = {
    "Balance Sheet": ["Assets", "Liabilities", "Equity", "Current Assets"],
    "Profit & Loss": ["Revenue", "Expenses", "Net Profit", "Operating Income"],
    "Payslip": ["Employee Name", "Gross Salary", "Deductions", "Net Pay", "Bonus"]
}

# Function to extract text from image using OCR (Tesseract)
def extract_text_from_image(image):
    """
    Extract text from the provided image using OCR (Tesseract).
    """
    text = pytesseract.image_to_string(image, config='--psm 6')  # Adjusted PSM for better text block recognition
    return text

# Function to extract data based on predefined terms using regex
def extract_data_based_on_terms(text, document_type):
    """
    Extract data from the text based on predefined terms using regex.
    """
    extracted_data = {}
    for term in FIELD_NAMES.get(document_type, []):
        # Use regex to find the term in the text (case-insensitive search)
        pattern = r"\b" + re.escape(term) + r"\b.*?(\d[\d,\.]*)"
        matches = re.search(pattern, text, flags=re.IGNORECASE)
        extracted_data[term] = matches.group(1) if matches else "null"  # Capture the first number found or return null
    return extracted_data

# Function to clean extracted values
def clean_values(values):
    """
    Cleans and converts values to floats where possible.
    """
    cleaned_values = []
    for value in values:
        try:
            # Normalize value (handle commas and mixed formats)
            normalized_value = value.replace(",", "")
            cleaned_values.append(float(normalized_value))
        except (ValueError, AttributeError):
            cleaned_values.append("null")  # Retain invalid values as "null"
    return cleaned_values

# Function to extract, store in table format, and visualize
def extract_and_visualize(images, chart_type, document_type):
    extracted_data = {field: [] for field in FIELD_NAMES.get(document_type, [])}
    for img_file in images:
        try:
            image = Image.open(img_file.name)
            text = extract_text_from_image(image)
            extracted_data_for_image = extract_data_based_on_terms(text, document_type)
            for field, value in extracted_data_for_image.items():
                extracted_data[field].append(value)
        except Exception as e:
            return f"Error processing image: {str(e)}", None, None

    # Clean data for plotting and table representation
    cleaned_data = {}
    summed_data = {}
    for field, values in extracted_data.items():
        cleaned_values = clean_values(values)
        cleaned_data[field] = cleaned_values
        summed_data[field] = sum([v for v in cleaned_values if isinstance(v, float)])

    # Create a DataFrame for table representation
    max_length = max(len(v) for v in cleaned_data.values())
    padded_data = {field: values + [""] * (max_length - len(values)) for field, values in cleaned_data.items()}
    df = pd.DataFrame(padded_data)

    # Convert DataFrame to HTML table for Gradio display (with a title)
    table_html = f"""
    <div style="text-align: center; font-size: 20px; font-weight: bold; margin-bottom: 10px;">
        {document_type} - Extracted Data Table
    </div>
    """ + df.to_html(index=False, justify="center", classes="dataframe")

    # Convert cleaned data to text format
    data_as_text = "\n".join([f"{field}: {', '.join(map(str, cleaned_data[field]))}" for field in cleaned_data])

    # Visualization
    fig, ax = plt.subplots(figsize=(10, 6))
    colors = plt.cm.tab20.colors

    if chart_type == "Bar Chart":
        ax.bar(summed_data.keys(), summed_data.values(), color=colors[:len(summed_data)])
        ax.set_title(f"{document_type} Data Visualization (Bar Chart)")
        ax.set_ylabel('Total Value')
        ax.set_xlabel('Fields')
        plt.xticks(rotation=45, ha="right")
    elif chart_type == "Pie Chart":
        ax.pie(
            summed_data.values(),
            labels=summed_data.keys(),
            autopct='%1.1f%%',
            startangle=140,
            colors=colors[:len(summed_data)]
        )
        ax.set_title(f"{document_type} Data Visualization (Pie Chart)")

    plt.tight_layout()

    # Save the chart as a temporary image file
    with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmpfile:
        chart_path = tmpfile.name
        plt.savefig(chart_path)

    plt.close()

    return table_html, chart_path, data_as_text

# Gradio interface components
def create_gradio_interface():
    """
    Create a Gradio interface to allow users to interact with the image processing function.
    """
    with gr.Blocks() as interface:
        gr.Markdown("""
            <div style="text-align: center; font-size: 24px; font-weight: bold;">
                ðŸ’¬ Extract and Analyze Data from Documents using OCR ðŸ’¬
            </div>
        """)

        with gr.Row():
            images = gr.File(label="Upload Images", file_types=["image"], file_count="multiple")
            document_type = gr.Dropdown(["Balance Sheet", "Profit & Loss", "Payslip"], label="Select Document Type", value="Balance Sheet")
            chart_type = gr.Radio(["Bar Chart", "Pie Chart"], label="Chart Type", value="Bar Chart")

        # A single button for both text extraction and visualization
        submit_btn = gr.Button("Extract and Visualize")

        with gr.Row():
            output_table = gr.HTML(label="Extracted Data Table")
            output_chart = gr.Image(label="Visualization")
            output_text = gr.Textbox(label="Extracted Data as Text", interactive=False, lines=10)

        # Define interaction for the single button
        submit_btn.click(
            extract_and_visualize,
            inputs=[images, chart_type, document_type],
            outputs=[output_table, output_chart, output_text]  # Return table, chart, and text
        )

    # Run the Gradio interface
    interface.launch(share=True)

# Start the Gradio interface
if __name__ == "__main__":
    create_gradio_interface()


* Running on local URL:  http://127.0.0.1:7861
* Running on public URL: https://8fb1e42456e6897fdf.gradio.live

This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)
