In [None]:
import gradio as gr
import os
import torch
from transformers import AutoProcessor, MllamaForConditionalGeneration, TextStreamer
from PIL import Image

# Check if we're running in a Hugging Face Space and if SPACES_ZERO_GPU is enabled
IS_SPACES_ZERO = os.environ.get("SPACES_ZERO_GPU", "0") == "1"
IS_SPACE = os.environ.get("SPACE_ID", None) is not None
IS_GDRVIE = True

# Determine the device (GPU if available, else CPU)
device = "cuda" if torch.cuda.is_available() else "cpu"
LOW_MEMORY = os.getenv("LOW_MEMORY", "0") == "1"
print(f"Using device: {device}")
print(f"Low memory mode: {LOW_MEMORY}")

# Get Hugging Face token from environment variables
HF_TOKEN = os.environ.get('HF_TOKEN')

# Define the model name
model_name = "Llama-3.2-11B-Vision-Instruct"
if IS_GDRVIE:
    # Define the path to the model directory in your Google Drive
    model_path = "/content/drive/MyDrive/models/" + model_name
    model = MllamaForConditionalGeneration.from_pretrained(
        model_path,
        torch_dtype=torch.bfloat16,
        device_map="auto",
    )
    processor = AutoProcessor.from_pretrained(model_path)
else:
    model_name = "ruslanmv/" + model_name
    model = MllamaForConditionalGeneration.from_pretrained(
        model_name,
        use_auth_token=HF_TOKEN,
        torch_dtype=torch.bfloat16,
        device_map="auto",
    )
    processor = AutoProcessor.from_pretrained(model_name, use_auth_token=HF_TOKEN)



# Tie the model weights to ensure the model is properly loaded
if hasattr(model, "tie_weights"):
    model.tie_weights()

# Stream LLM response generator
def stream_response(inputs):
    streamer = TextStreamer(tokenizer=processor.tokenizer)
    for token in model.generate(**inputs, max_new_tokens=2000, do_sample=True, streamer=streamer):
        yield processor.decode(token, skip_special_tokens=True)



# Predict function for Gradio app
def predict(message, image):
    # Prepare the input messages
    messages = [
        {"role": "user", "content": [
            {"type": "image"},  # Specify that an image is provided
            {"type": "text", "text": message}  # Add the user-provided text input
        ]}
    ]

    # Create the input text using the processor's chat template
    input_text = processor.apply_chat_template(messages, add_generation_prompt=True)

    # Process the inputs and move to the appropriate device
    inputs = processor(image, input_text, return_tensors="pt").to(device)

    # Return a streaming generator of responses
    full_response = ""
    for response in stream_response(inputs):
        full_response += response
    return extract_and_save_tables(full_response)

# Extract tables and save them to CSV
files_list = []

def extract_and_save_tables(full_response):
    """Extracts CSV tables from the full_response string and saves them as separate files."""
    current_table_name = None
    current_table_rows = []
    global files_list
    files_list = []  # Reset files list before extraction

    for line in full_response.splitlines():
        if line.startswith("Table "):
            if current_table_name:
                # Save the previous table
                save_table_to_csv(current_table_name, current_table_rows)
                files_list.append(current_table_name)  # Add file name to the list
            
            # Extract the table number to create the filename
            current_table_name = "table_" + line.split("Table ")[1].replace(":", "").strip() + ".csv"
            current_table_rows = []
        elif current_table_name:
            # If it's not an empty line, add it to the current table rows
            if line.strip():
                current_table_rows.append(line)

    # Save the last table
    if current_table_name:
        save_table_to_csv(current_table_name, current_table_rows)
        files_list.append(current_table_name)  # Add file name to the list

    return files_list  # Return the list of generated CSV files

def save_table_to_csv(table_name, table_rows):
    """Saves a table to a CSV file."""
    try:
        with open(table_name, 'w', newline='', encoding='utf-8') as csvfile:
            writer = csv.writer(csvfile)
            
            # Write each row to the CSV file
            for row in table_rows:
                writer.writerow(row.split(","))
        print(f"Table saved as: {table_name}")
    except Exception as e:
        print(f"Error saving table {table_name}: {e}")

# Gradio interface
def gradio_app():
    def process_image(image):
        example = '''Table 1:
        header1,header2,header3
        value1,value2,value3

        Table 2:
        header1,header2,header3
        value1,value2,value3
        '''

        message = """Please extract all tables from the image and generate CSV files.
        Each table should be separated using the format table_n.csv, where n is the table number.
        You must use CSV format with commas as the delimiter. Do not use markdown format. Ensure you use the original table headers and content from the image.
        Only answer with the CSV content. Dont explain the tables.
        An example of the desired output is:
        """ + example

        files = predict(message, image)
        return "Tables extracted and saved as CSV files.", files

    # Input components
    image_input = gr.Image(type="pil", label="Upload Image")

    #message_input = gr.Textbox(lines=2, placeholder="Enter your message", value=message)
    output_text = gr.Textbox(label="Extraction Status")
    file_output = gr.File(label="Download CSV files")

    # Gradio interface
    iface = gr.Interface(
        fn=process_image,
        inputs=[image_input],
        outputs=[output_text, file_output],
        title="Table Extractor and CSV Converter",
        description="Upload an image to extract tables and download CSV files.",
        allow_flagging="never"
    )

    iface.launch(debug=True)

# Call the Gradio app function to launch the app
gradio_app()
