<a href="https://colab.research.google.com/github/younes2808/Sci2XML/blob/main/app/fullApp.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Frontend <--> Classifier <--> API

>[Frontend <--> Classifier <--> API](#scrollTo=T2R0JzBQ1izB)

>[Frontend - Streamlit](#scrollTo=33kAj9VFcYj-)

>>[Requirements](#scrollTo=wxhHhxVAvu7m)

>>[Load VLM](#scrollTo=gZDphOMig6w0)

>>[Load ML](#scrollTo=3ugtlL1RZAwq)

>>[Grobid](#scrollTo=TnmL9TPQIXhO)

>>[The streamlit app (Includes classifier code)](#scrollTo=crvJjF3_vx4X)

>>[Starting streamlit app in another thread and hosting it using localtunnel.](#scrollTo=Tf2SDnNVv2zn)

>[API](#scrollTo=TMxQtt40nL5M)

>>[Imports](#scrollTo=mCCISnjCal2b)

>>[Unichart](#scrollTo=SbQFCfD0WaoY)

>>[Sumen](#scrollTo=oJA3-KFbWeoy)

>>[Run this cell to start API](#scrollTo=SR7BiGlQ5JS4)

>>[Test call to API](#scrollTo=Xna7W9ARifAF)

>>[To kill/stop API](#scrollTo=ouOLpu9l5CRG)



To run program:
1. Connect to T4 (GPU)
2. Upload ML model file (".pkl" file. About 50MB, takes some time to upload...) and modules/css.html and images/image.png
3. Run "Requirements" (Frontend - Streamlit) cell and "Load ML" cell
4. Run Grobid cells
5. Run "The Streamlit app" cell
  *   This will write the Streamlit code and classifier code to file app.py
6. Run "Starting Streamlit app..." cell
  *   In the output of this cell you will see a URL/IP. This is the password for the localtunnel site where streamlit/frontend is hosted. To get the URL for the localtunnel/streamlit site you must open (doubleclick) the file "streamliturl.txt" on the left and use the URL there.
7. Run the API cells
  * The Flask API is now running.
8. Open the localtunnel site and upload XML and PDF and click send to classifier button.

Flow:
1. Streamlit ->[XML+PDF]-> Classifier
2. Classifier:
  *   Classifies each non-textual element
  *   For each element: Classifier -> [Image of element] -> API
  *   API redirects to correct function and returns NL
  *   Classifier adds response to altered XML file
  *   Classifier adds response to array in frontend
  *   Frontend updates view of array when a new element is processed
3. Classifier -> [Altered XML] -> Streamlit


# Frontend - Streamlit
https://discuss.streamlit.io/t/how-to-launch-streamlit-app-from-google-colab-notebook/42399

## Requirements

In [1]:
!pip install -q streamlit
!npm install localtunnel
#!pip install lmdeploy
!pip install pdf2image
!apt-get install poppler-utils
!pip install -U skorch
!pip install streamlit-pdf-viewer
!pip install st-annotated-text
!pip install stqdm

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.3/44.3 kB[0m [31m1.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m9.6/9.6 MB[0m [31m44.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m6.9/6.9 MB[0m [31m56.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m79.1/79.1 kB[0m [31m4.4 MB/s[0m eta [36m0:00:00[0m
[?25h[1G[0K⠙[1G[0K⠹[1G[0K⠸[1G[0K⠼[1G[0K⠴[1G[0K⠦[1G[0K⠧[1G[0K⠇[1G[0K⠏[1G[0K⠋[1G[0K⠙[1G[0K⠹[1G[0K⠸[1G[0K⠼[1G[0K⠴[1G[0K⠦[1G[0K⠧[1G[0K⠇[1G[0K⠏[1G[0K⠋[1G[0K⠙[1G[0K⠹[1G[0K⠸[1G[0K
added 22 packages in 4s
[1G[0K⠸[1G[0K
[1G[0K⠸[1G[0K3 packages are looking for funding
[1G[0K⠸[1G[0K  run `npm fund` for details
[1G[0K⠸[1G[0KCollecting pdf2image
  Downloading pdf2image-1.17.0-py3-none-any.whl.metadata (6.2 kB)
Downloading pdf2image-1.17.0-py3-none-any.whl (11 kB)
Installi

## Load ML

In [32]:
from skorch import NeuralNetClassifier
import torch.nn as nn
import torch
import multiprocessing as mp
from skorch.dataset import ValidSplit
from skorch.callbacks import LRScheduler, Checkpoint
from skorch.callbacks import Freezer, EarlyStopping
import torchvision

def loadML():
  """
  Load the ML model.

  Paramaters:
  None

  Returns:
  densenet: The ML model.
  """
  print("\n--- Loading ML ---")


  n_classes = 9
  batch_size = 128
  num_workers = mp.cpu_count()

  # callback functions for models

  # DenseNet169
  # callback for Reduce on Plateau scheduler
  lr_scheduler = LRScheduler(policy='ReduceLROnPlateau',
                                      factor=0.5, patience=1)
  # callback for saving the best on validation accuracy model
  checkpoint = Checkpoint(f_params='best_model_densenet169.pkl',
                                  monitor='valid_acc_best')
  # callback for freezing all layer of the model except the last layer
  freezer = Freezer(lambda x: not x.startswith('model.classifier'))
  # callback for early stopping
  early_stopping = EarlyStopping(patience=5)

  # ... (import other necessary libraries) ...
  class DenseNet169(nn.Module):
      def __init__(self, output_features, num_units=512, drop=0.5,
                  num_units1=512, drop1=0.5):
          super().__init__()
          model = torchvision.models.densenet169(pretrained=True)
          n_inputs = model.classifier.in_features
          model.classifier = nn.Sequential(
                                  nn.Linear(n_inputs, num_units),
                                  nn.ReLU(),
                                  nn.Dropout(p=drop),
                                  nn.Linear(num_units, num_units1),
                                  nn.ReLU(),
                                  nn.Dropout(p=drop1),
                                  nn.Linear(num_units1, output_features))
          self.model = model

      def forward(self, x):
          return self.model(x)
  # NeuralNetClassifier for based on DenseNet169 with custom parameters
  densenet = NeuralNetClassifier(
      # pretrained DenseNet169 + custom classifier
      module=DenseNet169,
      module__output_features=n_classes,
      # criterion
      criterion=nn.CrossEntropyLoss,
      # batch_size = 128
      batch_size=batch_size,
      # number of epochs to train
      max_epochs=5,
      # optimizer Adam used
      optimizer=torch.optim.Adam,
      optimizer__lr = 0.001,
      optimizer__weight_decay=1e-6,
      # shuffle dataset while loading
      iterator_train__shuffle=True,
      # load in parallel
      iterator_train__num_workers=num_workers,
      # stratified kfold split of loaded dataset
      train_split=ValidSplit(cv=5, stratified=True),
      # callbacks declared earlier
      callbacks=[lr_scheduler, checkpoint, freezer, early_stopping],
      # use GPU or CPU
      device="cuda:0" if torch.cuda.is_available() else "cpu"
  )

  densenet.initialize()  # Initialize the model before loading parameters
  densenet.load_params(f_params='best_model_densenet169_sentence.pkl')
  # Load the saved model
  return densenet

ML = loadML()

## Grobid

In [2]:
!wget https://github.com/kermitt2/grobid/archive/0.8.1.zip
!unzip 0.8.1.zip

[1;30;43mStrømmer utdata som er avkortet til de siste 5000 linjene.[0m
  inflating: grobid-0.8.1/grobid-trainer/resources/dataset/entities/chemistry/corpus/chebi-onto.xml.generated  
  inflating: grobid-0.8.1/grobid-trainer/resources/dataset/entities/chemistry/corpus/chebi-onto.xml.generated2  
  inflating: grobid-0.8.1/grobid-trainer/resources/dataset/entities/chemistry/corpus/chemistry-types.xml  
  inflating: grobid-0.8.1/grobid-trainer/resources/dataset/entities/chemistry/corpus/chemistry-types.xml.new  
  inflating: grobid-0.8.1/grobid-trainer/resources/dataset/entities/chemistry/corpus/chemistry-types.xml.orig  
  inflating: grobid-0.8.1/grobid-trainer/resources/dataset/entities/chemistry/corpus/epochem-chem-onto.xml  
  inflating: grobid-0.8.1/grobid-trainer/resources/dataset/entities/chemistry/corpus/epochem-compounds.xml  
  inflating: grobid-0.8.1/grobid-trainer/resources/dataset/entities/chemistry/corpus/epochem-new-classes.xml  
  inflating: grobid-0.8.1/grobid-trainer/re

In [3]:
!pwd
!ls
%cd grobid-0.8.1
!pwd
!ls

/content
0.8.1.zip  grobid-0.8.1  node_modules  package.json  package-lock.json	sample_data
/content/grobid-0.8.1
/content/grobid-0.8.1
build.gradle  Dockerfile.crf	gradle.properties  grobid-core	   grobid-trainer  Readme.md
CHANGELOG.md  Dockerfile.delft	gradlew		   grobid-home	   LICENSE	   readthedocs.yml
doc	      gradle		gradlew.bat	   grobid-service  mkdocs.yml	   settings.gradle


In [4]:
!ls
!./gradlew clean install

build.gradle  Dockerfile.crf	gradle.properties  grobid-core	   grobid-trainer  Readme.md
CHANGELOG.md  Dockerfile.delft	gradlew		   grobid-home	   LICENSE	   readthedocs.yml
doc	      gradle		gradlew.bat	   grobid-service  mkdocs.yml	   settings.gradle
Downloading https://services.gradle.org/distributions/gradle-7.2-bin.zip
..........10%...........20%...........30%...........40%...........50%...........60%...........70%...........80%...........90%...........100%

Welcome to Gradle 7.2!

Here are the highlights of this release:
 - Toolchain support for Scala
 - More cache hits when Java source files have platform-specific line endings
 - More resilient remote HTTP build cache behavior

For more details see https://docs.gradle.org/7.2/release-notes.html

Starting a Gradle Daemon (subsequent builds will be faster)


[1m> Task :grobid-core:compileJava[m[0K
Note: Some input files use or override a deprecated API.
Note: Recompile with -Xlint:deprecation for details.
Note: Some input files

In [5]:
!ls
!./gradlew run &>/content/gradlelog.txt &
# Check gradlelog.txt to see when it is ready. Should be > 46 lines when ready.

build	      Dockerfile.crf	 gradlew      grobid-service  Readme.md
build.gradle  Dockerfile.delft	 gradlew.bat  grobid-trainer  readthedocs.yml
CHANGELOG.md  gradle		 grobid-core  LICENSE	      settings.gradle
doc	      gradle.properties  grobid-home  mkdocs.yml


In [6]:
!ls
%cd ..
!ls

build	      Dockerfile.crf	 gradlew      grobid-service  Readme.md
build.gradle  Dockerfile.delft	 gradlew.bat  grobid-trainer  readthedocs.yml
CHANGELOG.md  gradle		 grobid-core  LICENSE	      settings.gradle
doc	      gradle.properties  grobid-home  mkdocs.yml
/content
0.8.1.zip  gradlelog.txt  grobid-0.8.1	node_modules  package.json  package-lock.json  sample_data


In [8]:
import time
time.sleep(50) # To ensure that Grobid server is up and running...

!curl http://172.28.0.12:8070/api/isalive
import socket
print("\n Server adress: ", socket.gethostbyname(socket.gethostname()), "/8070")

true
 Server adress:  172.28.0.12 /8070


## The streamlit app (Includes classifier code)

In [45]:
%%writefile app.py

import streamlit as st
import requests, json
from PIL import Image
import io
from io import StringIO
import time

#import sys
#sys.stdout = open("streamlitlog", "w")

##### CLASSIFIER ######
# Load modules:

import pandas as pd
from bs4 import BeautifulSoup

from pdf2image import convert_from_path, convert_from_bytes
from pdf2image.exceptions import (
    PDFInfoNotInstalledError,
    PDFPageCountError,
    PDFSyntaxError
)

from PIL import Image, ImageDraw
import os
import json
import time
import requests
import io
import re

#import sys
#sys.stdout = open("classifierlog", "w")

apiURL = "http://172.28.0.12:8000/"


def openXMLfile(XMLfile, PDFfile):
    """
    Opens the XML file and converts it to a python dict.

    Paramaters:
    XMLfile: The XML file as stringio object.
    PDFfile: The PDF file as bytes object.

    Returns:
    images: The pages as images from the PDF file.
    figures: The figures from the XML file.
    formulas: The formulas from the XML file.
    """

    print("\n----- Opening XML and PDF file... -------")

    #stringio = StringIO(XMLfile.getvalue().decode("utf-8"), newline=None)
    #XMLfile = stringio.read()

    PDFfile = PDFfile.getvalue()

    global Bs_data
    st.session_state.Bs_data = BeautifulSoup(XMLfile, "xml")
    #Bs_data = BeautifulSoup(data, "xml")
    Bs_data = st.session_state.Bs_data

    figures = Bs_data.find_all('figure')

    print("Figures:")
    print(figures)
    st.session_state.metrics["figuresGrobid"] = len(figures)

    formulas = Bs_data.find_all('formula')

    print("Formulas:")
    print(formulas)
    st.session_state.metrics["formulasGrobid"] = len(formulas)

    #images = convert_from_path(pathToPDF, poppler_path='C:\\Program Files\\Release-24.08.0-0\\poppler-24.08.0\\Library\\bin')
    #images = convert_from_path(pathToPDF)
    images = convert_from_bytes(PDFfile)

    for i in range(0, len(images)):
        print("--- Image nr ", i+1)

    return images, figures, formulas


def addToXMLfile(type, name, newContent):
    """
    Adds a new element to the XML file. When a non-textual element has been processed it should be placed back into the XML file at the correct location.

    Paramaters:
    type: The type of the element. (figure or formula)
    name: The name of the element.
    newContent: The new content to be added to the XML file.

    Returns:
    None
    """
    print("\n-- Adding to XML file... --")
    parentTag = st.session_state.Bs_data.find(type, {"xml:id": name})
    print("parentTag: ", parentTag)
    if (parentTag == None):
      print("Could not find tag to place element back into...")
      return
    textWithoutTag = parentTag.find_all(string=True, recursive=False)
    print("findall", textWithoutTag)

    if (len(textWithoutTag) == 0):
        print("Probably a figure...")
        parentTag.append(newContent["preferred"])
    else:
        print("Probably a formula...")
        for text in textWithoutTag:
            if (text in parentTag.contents):
                # print(parentTag.contents.index(text))
                parentTag.contents[parentTag.contents.index(text)].replace_with(newContent["preferred"])

    print(parentTag)


def saveXMLfile(pathToXML):
    """
    FOR TESTING! Saves the XML file.

    Paramaters:
    pathToXML: The path to the XML file.

    Returns:
    Bs_data: The XML file in python dict format.
    """
    print("\n----- Saving XML file... -----")
    with open(pathToXML, "w", encoding="utf-8") as file:
        file.write(str(Bs_data))
    return Bs_data


def classify(XMLtype, image, elementNr, pagenr, regex):
    """
    Classifies a given element as either a formula, table, chart or figure.

    Paramaters:
    XMLtype: the type of element. (figure or formula)
    image: the image to be sent to the VLM model.
    elementNr: the number of the element.
    pagenr: the page number of the element.
    regex: the formula string to be matched against regex.

    Returns:
    None
    """
    print("\n -- Classifier... --")


    ## Redirecting to correct endpoint in API...

    subtype = "unknown"

    ## API request header:
    headers = {'Content-type': 'application/json', 'Accept': 'text/plain'}
    APIresponse = ""


    ## For formulas:
    if (XMLtype == "formula"):
      pattern = r"^(?!\(+$)(?!\)+$).{3,}$"
      ## ^ and $ ensures that the whole string matches.
      ## (?!\(+$) is a negative lookahead that checks that the string doesnt only contain trailing "(".
      ## .{3,} matches any character at least three times, and ensures the string is longer than 2 characters.
      if (re.match(pattern, regex)):
          print("YES: ", "Formula: ", elementNr, " ->", regex)
          st.session_state.metrics["formulas"] += 1
          subtype = "formula"
          print("Redirecting to formulaParser")
          ##### APIresponse = API.call("127.0.0.1/formulaParser") #####

          img_byte_arr = io.BytesIO()
          image.save(img_byte_arr, format='PNG')
          img_byte_arr = img_byte_arr.getvalue()

          APIresponse = requests.post(apiURL+"parseFormula", files={'image': img_byte_arr})
          APIresponse = APIresponse.json()
          APIresponse["element_number"] = elementNr
          APIresponse["page_number"] = pagenr

          print("Response from formulaParser: --> ", APIresponse["preferred"])
      else:
          print("NO: ", "Formula: ", elementNr, " ->", regex)
          print("The formula is NOT identified as an actual formula. Aborting...")
          return


    ## For figures:
    else:

      ## When VLM is local:
      #figureClass = callVLM(VLM, image, query)
      ## When VLM is via API:
      img_byte_arr = io.BytesIO()
      image.save(img_byte_arr, format='PNG')
      img_byte_arr = img_byte_arr.getvalue()
      #files = {"image": ("image1.png", img_byte_arr), "query": ("query.txt", query)}
      files = {"image": ("image1.png", img_byte_arr)}
      response = requests.post(apiURL+"callClassifier", files=files)
      print(response.status_code)
      response = response.json()
      print(response)
      figureClass = response["ClassifierResponse"]

      print("Classifier - ML: This image is a -> ", figureClass, " <-    Sending it over to the correct API endpoint")

      ## For 'other':
      if (figureClass.lower() in ["just_image", "table", "text_sentence"]):
        print("Identified as other/unknown. Aborting...")
        return

      ## For charts:
      if (figureClass.lower() in ['bar_chart', 'diagram', 'graph', 'pie_chart']):
          print("Redirecting to chartParser. Image identified as ", figureClass.lower())
          subtype = figureClass.lower()
          st.session_state.metrics["chart"] += 1
          ##### APIresponse = API.call("127.0.0.1/chartParser") #####
          img_byte_arr = io.BytesIO()
          image.save(img_byte_arr, format='PNG')
          img_byte_arr = img_byte_arr.getvalue()

          APIresponse = requests.post(apiURL+"parseChart", files={'image': img_byte_arr})
          APIresponse = APIresponse.json()
          APIresponse["element_number"] = elementNr
          APIresponse["page_number"] = pagenr

          print("Response from chartParser: --> ", APIresponse["preferred"])

      ## For figures:
      if (figureClass.lower() in ['flow_chart', 'growth_chart']):
          print("Redirecting to figureParser. Image identified as ", figureClass.lower())
          subtype = figureClass.lower()
          st.session_state.metrics["figures"] += 1
          ##### APIresponse = API.call("127.0.0.1/figureParser") #####
          img_byte_arr = io.BytesIO()
          image.save(img_byte_arr, format='PNG')
          img_byte_arr = img_byte_arr.getvalue()

          APIresponse = requests.post(apiURL+"parseFigure", files={'image': img_byte_arr})
          APIresponse = APIresponse.json()
          APIresponse["element_number"] = elementNr
          APIresponse["page_number"] = pagenr

          print("Response from figureParser: --> ", APIresponse["preferred"])

      ## For formulas
      if ("formula" in figureClass.lower()):
        print("Redirecting to formulaParser")
        subtype = "formula"
        st.session_state.metrics["formulas"] += 1
        ##### APIresponse = API.call("127.0.0.1/formulaParser") #####
        img_byte_arr = io.BytesIO()
        image.save(img_byte_arr, format='PNG')
        img_byte_arr = img_byte_arr.getvalue()

        APIresponse = requests.post(apiURL+"parseFormula", files={'image': img_byte_arr})
        APIresponse = APIresponse.json()
        APIresponse["element_number"] = elementNr
        APIresponse["page_number"] = pagenr

        print("Response from formulaParser: --> ", APIresponse["preferred"])


    ## If subtype is unknown its better to abort and not add anything back into the XML.
    if (subtype == "unknown"):
      print("Identified as other/unknown. Aborting...")
      return


    print("Received response about image nr ", elementNr, ". Will now paste response back into the XML-file.")
    if (XMLtype == "figure"):
      addToXMLfile(XMLtype, "fig_" + str(elementNr), APIresponse)
    elif (XMLtype == "formula"):
      addToXMLfile(XMLtype, "formula_" + str(elementNr), APIresponse)

    ## This writes directly to screen. Is used for testing, should only be added to array instead.
    #st.write(f"Received response about {XMLtype}. It was a {subtype}. APIresponse: {APIresponse}")

    ## Adds to arrays:
    processClassifierResponse(APIresponse)

def processClassifierResponse(APIresponse):
    """
    Processes the response from the classifier and adds it to the correct array.

    Paramaters:
    APIresponse: The response from the classifier as a object/dict. Ex: "{'NL': 'some NL', 'element_type': 'figure', 'preferred': 'some NL', 'element_number': 1, 'page_number': 1}"

    Returns:
    None
    """
    print("Adding to array...")

    #elements = []
    #st.session_state.elements.append(APIresponse)
    element = APIresponse


    #for element in stqdm(elements):
    if element['element_type'] == 'formula':
        st.session_state.formulas_results_array.append(element)
        st.subheader(f"Page {element.get('page_number', 'N/A')}: Formula #{element.get('element_number', 'N/A')}")
        st.markdown(rf"$$ {element.get('formula', 'N/A')} $$")
        st.text(f"{element.get('NL', 'No description available.')}")

    elif element['element_type'] == "figure":
        st.session_state.figures_results_array.append(element)
        st.subheader(f"Page {element.get('page_number', 'N/A')}: Figure #{element.get('element_number', 'N/A')}")
        st.text(f"{element.get('NL', 'No description available.')}")

    elif element['element_type'] == "chart":
        st.session_state.charts_results_array.append(element)
        st.subheader(f"Page {element.get('page_number', 'N/A')}: Chart #{element.get('element_number', 'N/A')}")
        st.text(f"{element.get('NL', 'No description available.')}")

    elif element['element_type'] == "table":
        st.session_state.tables_results_array.append(element)
        st.subheader(f"Page {element.get('page_number', 'N/A')}: Table #{element.get('element_number', 'N/A')}")
        st.text(f"{element.get('NL', 'No description available.')}")


def processFigures(figures, images):
    """
    Crops the figures from the PDF file into images and sends them to the classifier (ML model) for classification.

    Paramaters:
    figures: The figures from the XML file.
    images: The pages as images from the PDF file.

    Returns:
    None
    """
    print("\n-------- Cropping Figures --------")
    figurnr = 0
    for figure in figures:
        # print("---")
        # print(figure.get("coords"))
        coords = ""
        try:
            coords = figure.get("coords").split(";")[-1]
            # print(coords)
        except:
            coords = figure.get("coords")
            # print(coords)

        imgside = images[int(coords.split(",")[0])-1]

        const = 2.775

        x=float(coords.split(",")[1])
        y=float(coords.split(",")[2])
        x2=float(coords.split(",")[3])
        y2=float(coords.split(",")[4])

        imgFigur = imgside.crop((x*const,y*const,(x+x2)*const,(y+y2)*const))

        print("\n ---------- Cropping image/figure nr ", figurnr, ". Sending it to ML for classification. ----------")

        ## Saving cropped image to file. Should not be done except for testing.
        # filename = "./MathFormulaImgs/MathFormulafigur" + str(figurnr) + ".png"
        # imgFigur.save(filename)

        ## SENDING TO CLASSIFICATION...

        classify("figure", imgFigur, figurnr, int(coords.split(",")[0])-1, None)

        figurnr+=1
        print("----------")


def processFormulas(formulas, images, mode):
    """
    Crops the formulas from the PDF file into images and sends them to the classifier for classification.

    Paramaters:
    formulas: The formulas from the XML file.
    images: The pages as images from the PDF file.
    mode: The mode to be used for classification. (VLM or regex)

    Returns:
    None
    """
    print("\n-------- Cropping Formulas ---------")
    formulanr = 0
    for formula in formulas:
        # print("---")

        coords = ""
        try:
            coords = formula.get("coords").split(";")[-1]
            # print(coords)
        except:
            coords = formula.get("coords")
            # print(coords)

        imgside = images[int(coords.split(",")[0])-1]

        const = 2.775

        x=float(coords.split(",")[1])
        y=float(coords.split(",")[2])
        x2=float(coords.split(",")[3])
        y2=float(coords.split(",")[4])

        imgFormula = imgside.crop((x*const,y*const,(x+x2)*const,(y+y2)*const))

        print("\n ---------- Cropping image/formula nr ", formulanr, ". Sending it to classifier for classification. ----------")

        ## Saving cropped image to file. Should not be done except for testing.
        # filename = "./MathFormulaImgs/MathFormulaformel" + str(formulanr) + ".png"
        # imgFormula.save(filename)

        ## SENDING TO CLASSIFICATION...

        if (mode == "VLM"):
          classify("formula", imgFormula, formulanr, int(coords.split(",")[0])-1, None, "Answer with only one word (Yes OR No), is this a formula?")
        elif (mode == "regex"):
          classify("formula", imgFormula, formulanr, int(coords.split(",")[0])-1, formula.text)

        formulanr+=1
        print("----------")




#----------------------- ##### FRONTEND ##### -----------------------#


#"""
#    This script is a Streamlit-based application that processes PDF files using the GROBID API
#    and provides options to view annotated results for figures and formulas or raw XML data.
#
#    Modules Used:
#    - Streamlit: For building the web interface.
#    - Requests: For making HTTP requests to the GROBID API.
#    - xml.etree.ElementTree: For parsing the XML response from GROBID.
#    - annotated_text: For highlighting elements like figures and formulas.
#    - streamlit_pdf_viewer: For displaying annotated PDFs.
#"""

import streamlit as st
import logging
import os
import math
import sys
import time
import requests
import xml.etree.ElementTree as ET
from streamlit_pdf_viewer import pdf_viewer
from annotated_text import annotated_text, annotation
from stqdm import stqdm
import xml.dom.minidom as minidom

# Configure logging to store logs in a file
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[
        logging.FileHandler("app.log"),  # Log to a file named 'app.log'
        logging.StreamHandler(sys.stdout)  # Also log to console
    ]
)

def main():
    """
    Main function to handle the Streamlit application logic.
    """

    try:
        st.set_page_config(layout="wide") # Configure the page layout to be wide
        logging.info("Streamlit page configuration set successfully.")
    except Exception as e:
        logging.error(f"Failed to set Streamlit page configuration: {e}", exc_info=True)

    try:
        css_path = os.path.join("/content/modules", "css.html")
        with open(css_path, "r") as f:
            css_content = f.read()
        logging.info(f"CSS file '{css_path}' read successfully.")

        st.markdown(f"{css_content}", unsafe_allow_html=True)
        logging.info("CSS applied to the Streamlit app successfully.")

    except FileNotFoundError:
        logging.error(f"CSS file not found: {css_path}", exc_info=True)
        st.error("CSS file not found. Please check the 'modules' directory.")

    except Exception as e:
        logging.error(f"An error occurred while applying CSS: {e}", exc_info=True)
        st.error("An unexpected error occurred while applying CSS.")

    def process_classifier(xml_input, pdf_file):
        logging.info(f"XML received in classifier:\n{xml_input}")
        logging.info(f"PDF received in classifier:\n{pdf_file}")

        print("------ Starting test run ------")

        ##  Metrics used for benchmarking:
        global metrics
        st.session_state.metrics = {}
        st.session_state.metrics["figuresGrobid"] = 0
        st.session_state.metrics["formulasGrobid"] = 0
        st.session_state.metrics["chart"] = 0
        st.session_state.metrics["formulas"] = 0
        st.session_state.metrics["figures"] = 0
        st.session_state.metrics["tables"] = 0
        st.session_state.metrics["codelistings"] = 0

        if "elements" not in st.session_state or len(st.session_state.elements) != 0:
          st.session_state.elements = []

        # Ensure arrays exist in session state
        if "formulas_results_array" not in st.session_state or len(st.session_state.formulas_results_array) != 0:
            st.session_state.formulas_results_array = []
        if "figures_results_array" not in st.session_state or len(st.session_state.figures_results_array) != 0:
            st.session_state.figures_results_array = []
        if "charts_results_array" not in st.session_state or len(st.session_state.charts_results_array) != 0:
            st.session_state.charts_results_array = []
        if "tables_results_array" not in st.session_state or len(st.session_state.tables_results_array) != 0:
            st.session_state.tables_results_array = []


        images, figures, formulas =openXMLfile(xml_input, pdf_file)
        processFigures(figures, images)
        processFormulas(formulas, images, mode="regex")



        # Convert to string with XML declaration
        xml_string = str(st.session_state.Bs_data)


        # Prettify using minidom
        #parsed_xml = minidom.parseString(xml_string)
        #pretty_xml = parsed_xml.toprettyxml(indent="    ", encoding="utf-8").decode()

        st.session_state.interpreted_xml_text = xml_string

        logging.info("Generated XML:\n" + st.session_state.interpreted_xml_text)

    def process_pdf(file, grobid_url="http://172.28.0.12:8070/api/processFulltextDocument", params=None):
        """
        Process a PDF file using the GROBID API and return the response content.

        Parameters:
        file: The PDF file to process.
        grobid_url (str): The URL of the GROBID API endpoint.
        params (dict): Additional parameters for the GROBID request.

        Returns:
        str: The XML content returned by the GROBID API, or None if an error occurred.
        """
        files = {'input': file}
        try:
            # Send request to GROBID
            logging.info(f"Sending file {file} to GROBID")

            response = requests.post(grobid_url, files=files, data=params)  # Use 'data' for form-data
            response.raise_for_status()  # Raise exception if status is not 200

            logging.info(f"Received response from GROBID (status code {response.status_code}).")

            # Check if coordinates are missing in the response
            if 'coords' not in response.text:
                logging.warning("No coordinates found in PDF file. Please check GROBID settings.")
                st.warning("No coordinates found in PDF file. Please check GROBID settings.")

            return response.text  # Return XML or JSON

        except requests.exceptions.RequestException as e:
            logging.error(f"Error while communicating with GROBID: {e}", exc_info=True)
            return None  # Return None on error

    def parse_coords_for_figures(xml_content):
        """
        Extract and parse the 'coords' attribute for <figure> and <formula> elements
        from the GROBID XML output while counting the number of occurrences.

        Parameters:
        xml_content (str): The XML content returned by the GROBID API.

        Returns:
        tuple: A tuple containing:
            - List of annotations with details like page, coordinates, and color.
            - Count of formulas.
            - Count of figures.
        """
        annotations = []

        try:
            logging.info("Parsing PDF file to XML.")

            # Parse the XML content
            namespace = {"tei": "http://www.tei-c.org/ns/1.0"}  # Define the XML namespace
            root = ET.fromstring(xml_content)

            # Find all <figure> and <formula> elements in the XML
            figures = root.findall(".//tei:figure", namespace)
            formulas = root.findall(".//tei:formula", namespace)

            st.session_state.count_figures = len(figures)  # Count figures
            st.session_state.count_formulas = len(formulas)  # Count formulas

            logging.info(f"Found {st.session_state.count_figures} figures and {st.session_state.count_formulas} formulas in PDF file.")

            for figure in figures:
                coords = figure.attrib.get("coords", None)  # Get the 'coords' attribute
                if coords:
                    for group in coords.split(';'):
                        try:
                            values = list(map(float, group.split(',')))
                            if len(values) >= 5:
                                page, x0, y0, x1, y1 = values[:5]
                                annotations.append({
                                    "page": int(page),
                                    "x": float(x0),
                                    "y": float(y0),
                                    "width": x1,
                                    "height": y1,
                                    "color": "#CC0000"
                                })
                        except ValueError as e:
                            logging.warning(f"Error parsing figure group '{group}': {e}")
                            st.warning(f"Error parsing figure group '{group}': {e}")

            for formula in formulas:
                coords = formula.attrib.get("coords", None)  # Get the 'coords' attribute
                if coords:
                    for group in coords.split(';'):
                        try:
                            values = list(map(float, group.split(',')))
                            if len(values) >= 5:
                                page, x0, y0, x1, y1 = values[:5]
                                annotations.append({
                                    "page": int(page),
                                    "x": float(x0),
                                    "y": float(y0),
                                    "width": x1,
                                    "height": y1,
                                    "color": "#0000FF"
                                })
                        except ValueError as e:
                            logging.warning(f"Error parsing formula group '{group}': {e}")
                            st.warning(f"Error parsing formula group '{group}': {e}")

        except ET.ParseError as e:
            logging.error(f"Error parsing XML: {e}", exc_info=True)
            st.error(f"Error parsing XML: {e}")

        logging.info(f"Extraction completed: {len(annotations)} annotations found.")
        return annotations, st.session_state.count_formulas, st.session_state.count_figures

    def update_xml():
        """
        Update the XML content in the session state based on the user's input in the text area.

        Updates:
        st.session_state.xml_text (str): The updated XML content from the text area (st.session_state.xml_editor).

        Parameters & Returns:
        None
        """
        try:
            st.session_state.xml_text = st.session_state.xml_editor  # Update xml_text with the current content in text area
            logging.info(f"Variable xml_text successfully set to the current content in text area.")
        except Exception as e:
            logging.error(f"An error occurred while setting variable xml_text to the current content in text area: {e}", exc_info=True)

    def update_interpreted_xml():
        """
        Update the XML content in the session state based on the user's input in the text area.

        Updates:
        st.session_state.interpreted_xml_text (str): The updated XML content from the text area (st.session_state.interpreted_xml_editor).

        Parameters & Returns:
        None
        """
        try:
            st.session_state.interpreted_xml_text = st.session_state.interpreted_xml_editor  # Update xml_text with the current content in text area
            logging.info(f"Variable interpreted_xml_text successfully set to the current content in text area.")
        except Exception as e:
            logging.error(f"An error occurred while setting variable interpreted_xml_text to the current content in text area: {e}", exc_info=True)

    # Title and logo on the page
    st.image("images/Sci2XML_logo.png")

    # Declare variable
    if 'pdf_ref' not in st.session_state:
        logging.info("Session state: 'pdf_ref' was missing and has been initialized to None.")  # Log initialization

    try:
        # Access the uploaded ref via a key
        uploaded_pdf = st.file_uploader("", type=('pdf'), key='pdf', accept_multiple_files=False)
        logging.info(f"Setting variable uploaded_pdf to be the uploaded PDF file.")
    except Exception as e:
        logging.error(f"An error occurred while setting the variable uploaded_pdf to be the uploaded PDF file: {e}", exc_info=True)

    if uploaded_pdf:
        @st.fragment
        def pdf_upload():
            logging.info("A new PDF file was uploaded.")

            # Reset interpretation results visibility when a new file is uploaded
            if "pdf_ref" in st.session_state and uploaded_pdf != st.session_state.pdf_ref:
                logging.info("Uploaded PDF differs from the stored reference. Resetting interpretation results.")
                st.session_state.show_interpretation_results = False
                st.session_state.xml_text = None
                st.session_state.interpreted_xml_text = None

            # Backup uploaded file
            st.session_state.pdf_ref = uploaded_pdf
            logging.info("Stored uploaded PDF in session state ('pdf_ref').")

            # Reset pdf_ref when no file is uploaded
            if not st.session_state.pdf:
                logging.warning("No PDF file found in session state. Resetting 'pdf_ref' to None.")
                st.session_state.pdf_ref = None

            # Process binary data if a file is present
            if st.session_state.pdf_ref:
                try:
                    logging.info(f"Extracted binary data from uploaded PDF ({len(st.session_state.pdf_ref.getvalue())} bytes).")
                except Exception as e:
                    logging.error(f"Failed to retrieve binary data from uploaded PDF: {e}", exc_info=True)
                    st.error("An error occurred while reading the uploaded PDF file.")

                # Parameters for GROBID
                params = {
                    "consolidateHeader": 1,
                    "consolidateCitations": 1,
                    "consolidateFunders": 1,
                    "includeRawAffiliations": 1,
                    "includeRawCitations": 1,
                    "segmentSentences": 1,
                    "teiCoordinates": ["ref", "s", "biblStruct", "persName", "figure", "formula", "head", "note", "title", "affiliation"]
                }

                result = None  # Ensure result is always defined

                # Process file as soon as it's uploaded
                with st.status(label=None, expanded=False, state="running") as status:
                    result = process_pdf(st.session_state.pdf_ref, params=params)

                    if result is not None and result.startswith("Error when processing file"):
                        st.error(result)
                    else:
                        if result:
                            st.session_state.rectangles, st.session_state.count_formulas, st.session_state.count_figures = parse_coords_for_figures(result)
                        else:
                            st.session_state.rectangles = []
                            st.session_state.count_formulas = 0
                            st.session_state.count_figures = 0
                            result = ""
                    status.update(label="Complete!", state="complete", expanded=False)

                # Initialize the xml_text in session_state if not already set
                if "xml_text" not in st.session_state or st.session_state.xml_text is None:
                    st.session_state.xml_text = result  # Initial XML content from GROBID
                    print(f"xml: {st.session_state.xml_text}")

                if "show_grobid_results" not in st.session_state:
                    st.session_state.show_grobid_results = True  # Set session state flag
                    print(f"grobid result: {st.session_state.show_grobid_results}")

        pdf_upload()

        if st.session_state.show_grobid_results:
        # Layout container to maintain column structure
            with st.container():
                col1, col2, col3 = st.columns([0.4, 0.2, 0.4])  # Ensures both columns have equal width

                with col1:
                    @st.fragment
                    def grobid_results_view():
                        """
                        Render the GROBID results in either PDF View with annotations or raw XML View.
                        """
                        st.header("GROBID Results", divider="gray")  # Always renders first

                        if 'grobid_results_view_option' not in st.session_state:
                            st.session_state.grobid_results_view_option = "PDF"

                        st.session_state.grobid_results_view_option = st.radio("Select View", ["PDF", "XML"], horizontal=True, key='view_toggle', label_visibility="collapsed")

                        # PDF View with annotations
                        if st.session_state.grobid_results_view_option == "PDF":
                            pdf_viewer(input=st.session_state.pdf_ref.getvalue(), height=725, annotations=st.session_state.rectangles, render_text=True, annotation_outline_size=2)
                            if st.session_state.count_formulas > 0 and st.session_state.count_figures > 0:
                                annotated_text(
                                    annotation("Formulas", "", background="#0000FF", color="#FFFFFF"), " ",
                                    annotation("Figures", "", background="#CC0000", color="#FFFFFF")
                                )
                            elif st.session_state.count_formulas > 0:
                                annotated_text(
                                    annotation("Formulas", "", background="#0000FF", color="#FFFFFF")
                                )
                            elif st.session_state.count_figures > 0:
                                annotated_text(
                                    annotation("Figures", "", background="#CC0000", color="#FFFFFF")
                                )

                        # XML View with raw content
                        elif st.session_state.grobid_results_view_option == "XML":
                            # Text area bound to session_state with on_change callback
                            st.text_area(
                                "Edit GROBID XML File",
                                value=st.session_state.xml_text,  # Initial content from session state
                                height=725,
                                key="xml_editor",  # Key for the text area
                                on_change=update_xml,  # Update xml_text when changes are made
                                label_visibility="collapsed" # Hide the label properly
                            )

                    grobid_results_view()

                    @st.fragment
                    def classify2():
                        if st.button("Process file"):
                            with col3:
                                if 'results_placeholder' not in st.session_state or st.session_state.results_placeholder == None:
                                    st.header("Interpretation Results", divider="gray")  # Always stays aligned with col1
                                    st.session_state.results_placeholder = st.empty()
                                    print("results_placeholder created")
                                else:
                                    st.session_state.results_placeholder.empty()
                                    print("results_placeholder emptied")

                                with st.session_state.results_placeholder.container():
                                    if 'interpretation_results_view_option' not in st.session_state:
                                        st.session_state.interpretation_results_view_option = "XML"

                                st.session_state.results_placeholder.empty()

                                # Create a placeholder for the container
                                container_placeholder = st.empty()

                                with container_placeholder.container(height=725, border=True):
                                    process_classifier(st.session_state.xml_text, st.session_state.pdf_ref)  # Use PDF file and updated XML file from session state

                                container_placeholder.empty()

                                with col3:
                                    with st.session_state.results_placeholder.container():
                                        @st.fragment
                                        def interpretation_results_view():
                                            """
                                            Render different interpretation results based on user selection.
                                            """
                                            st.session_state.interpretation_results_view_option = st.radio("Select Non-Textual Element", ["XML", "Formulas", "Figures", "Charts", "Table"], horizontal=True, key='interpretation_toggle', label_visibility="collapsed")

                                            if st.session_state.interpretation_results_view_option == "XML":
                                                st.text_area(
                                                    "Edit Interpreted XML File",
                                                    value=st.session_state.interpreted_xml_text,  # Initial content from session state
                                                    height=725,
                                                    key="interpreted_xml_editor",  # Key for the text area
                                                    on_change=update_interpreted_xml,  # Update xml_text when changes are made
                                                    label_visibility="collapsed" # Hide the label properly
                                                )

                                            elif st.session_state.interpretation_results_view_option == "Formulas":
                                                with st.container(height=725, border=True):
                                                    if len(st.session_state.formulas_results_array) > 0:
                                                        for formula in st.session_state.formulas_results_array:  # Use session state variable
                                                            st.subheader(f"Page {formula.get('page_number', 'N/A')}: Formula #{formula.get('element_number', 'N/A')}")
                                                            st.markdown(rf"$$ {formula.get('formula', 'N/A')} $$")
                                                            st.text(f"{formula.get('NL', 'No description available.')}")
                                                    else:
                                                        st.warning("No formulas detected in PDF file.")

                                            elif st.session_state.interpretation_results_view_option == "Figures":
                                                with st.container(height=725, border=True):
                                                    if len(st.session_state.figures_results_array) > 0:
                                                        for figure in st.session_state.figures_results_array:  # Use session state variable
                                                            st.subheader(f"Page {figure.get('page_number', 'N/A')}: Figure #{figure.get('element_number', 'N/A')}")
                                                            st.text(f"{figure.get('NL', 'No description available.')}")
                                                    else:
                                                        st.warning("No figures detected in PDF file.")

                                            elif st.session_state.interpretation_results_view_option == "Charts":
                                                with st.container(height=725, border=True):
                                                    if len(st.session_state.charts_results_array) > 0:
                                                        for chart in st.session_state.charts_results_array:  # Use session state variable
                                                            st.subheader(f"Page {chart.get('page_number', 'N/A')}: Chart #{chart.get('element_number', 'N/A')}")
                                                            st.text(f"{chart.get('NL', 'No description available.')}")
                                                    else:
                                                        st.warning("No charts detected in PDF file.")

                                            elif st.session_state.interpretation_results_view_option == "Table":
                                                with st.container(height=725, border=True):
                                                    if len(st.session_state.tables_results_array) > 0:
                                                        for table in st.session_state.tables_results_array:  # Use session state variable
                                                            st.subheader(f"Page {table.get('page_number', 'N/A')}: Table #{table.get('element_number', 'N/A')}")
                                                            st.text(f"{table.get('NL', 'No description available.')}")
                                                    else:
                                                        st.warning("No tables detected in PDF file.")

                                            st.download_button(
                                                label="Download XML",
                                                data=st.session_state.interpreted_xml_text.encode("utf-8"),  # Convert text to bytes
                                                file_name="interpreted_results.xml",
                                                mime="application/xml"
                                            )

                                        interpretation_results_view()

                    classify2()
    else:
        # Prompt user to upload a PDF file
        st.write("Upload a PDF file to analyze it in GROBID")

if __name__ == '__main__':
    try:
        logging.info("Calling main function")
        main()
    except Exception as e:
        logging.error(f"Unhandled exception in main: {e}", exc_info=True)




## Starting streamlit app in another thread and hosting it using localtunnel.
In the output of this cell you will see a URL. This is the password for the localtunnel site. To get the URL for the localtunnel/streamlit site you must open the file "streamliturl.txt" on the left and use the URL there.

Example:

URL: https://blue-swans-appear.loca.lt

Passw: 35.240.250.237


PS: Sometimes the output of this cell is redirected to either logs.txt, streamlit.txt or APIlog.

In [46]:
import threading

!curl ipv4.icanhazip.com &>/content/passwfile.txt
print("Starting frontend")
def startStreamlit():
  print("Start...")
  !streamlit run app.py &>/content/logs.txt &
  #print("Password url:")
  !npx localtunnel --port 8501 &>/content/streamliturl.txt
t1 = threading.Thread(target=startStreamlit)
t1.daemon = True
t1.start()

# API
https://colab.research.google.com/github/srivatsan88/YouTubeLI/blob/master/Running_Flask_in_Colab.ipynb#scrollTo=w0Vbn3kFz3V2

## Imports

In [16]:
import socket

from flask import Flask, jsonify, make_response, request
import threading
from io import StringIO
from PIL import Image
import io

import nest_asyncio
nest_asyncio.apply()

#from lmdeploy import pipeline, TurbomindEngineConfig
#from lmdeploy.vl import load_image

import albumentations as A
import numpy as np


import torch
from transformers import DonutProcessor, VisionEncoderDecoderModel, AutoProcessor
from io import BytesIO


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## Unichart

In [17]:
# Load UniChart model
print("Loading UniChart model...")
unichart_model = VisionEncoderDecoderModel.from_pretrained("ahmed-masry/unichart-base-960").to(device)
unichart_processor = DonutProcessor.from_pretrained("ahmed-masry/unichart-base-960")
print("UniChart model loaded successfully!")

Loading UniChart model...


Access to the secret `HF_TOKEN` has not been granted on this notebook.
You will not be requested again.
Please restart the session if you want to be prompted again.


config.json:   0%|          | 0.00/4.94k [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/809M [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/809M [00:00<?, ?B/s]

Config of the encoder: <class 'transformers.models.donut.modeling_donut_swin.DonutSwinModel'> is overwritten by shared encoder config: DonutSwinConfig {
  "attention_probs_dropout_prob": 0.0,
  "depths": [
    2,
    2,
    14,
    2
  ],
  "drop_path_rate": 0.1,
  "embed_dim": 128,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.0,
  "hidden_size": 1024,
  "image_size": [
    960,
    960
  ],
  "initializer_range": 0.02,
  "layer_norm_eps": 1e-05,
  "mlp_ratio": 4.0,
  "model_type": "donut-swin",
  "num_channels": 3,
  "num_heads": [
    4,
    8,
    16,
    32
  ],
  "num_layers": 4,
  "patch_size": 4,
  "path_norm": true,
  "qkv_bias": true,
  "transformers_version": "4.48.3",
  "use_absolute_embeddings": false,
  "window_size": 10
}

Config of the decoder: <class 'transformers.models.mbart.modeling_mbart.MBartForCausalLM'> is overwritten by shared decoder config: MBartConfig {
  "activation_dropout": 0.0,
  "activation_function": "gelu",
  "add_cross_attention": true,
  "add_f

generation_config.json:   0%|          | 0.00/186 [00:00<?, ?B/s]

preprocessor_config.json:   0%|          | 0.00/420 [00:00<?, ?B/s]

Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.48, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


tokenizer_config.json:   0%|          | 0.00/510 [00:00<?, ?B/s]

sentencepiece.bpe.model:   0%|          | 0.00/1.30M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/4.01M [00:00<?, ?B/s]

added_tokens.json:   0%|          | 0.00/235 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/355 [00:00<?, ?B/s]

UniChart model loaded successfully!


In [18]:
def generate_unichart_response(image, prompt):
    pixel_values = unichart_processor(image, return_tensors="pt").pixel_values.to(device)
    decoder_input_ids = unichart_processor.tokenizer(prompt, add_special_tokens=False, return_tensors="pt").input_ids
    outputs = unichart_model.generate(
        pixel_values,
        decoder_input_ids=decoder_input_ids.to(device),
        max_length=unichart_model.decoder.config.max_position_embeddings,
        early_stopping=True,
        pad_token_id=unichart_processor.tokenizer.pad_token_id,
        eos_token_id=unichart_processor.tokenizer.eos_token_id,
        use_cache=True,
        num_beams=4,
        bad_words_ids=[[unichart_processor.tokenizer.unk_token_id]],
        return_dict_in_generate=True,
    )
    response = unichart_processor.batch_decode(outputs.sequences)[0]
    response = response.replace(unichart_processor.tokenizer.eos_token, "").replace(unichart_processor.tokenizer.pad_token, "").strip()
    return response.split("<s_answer>")[1].strip() if "<s_answer>" in response else response

def parse_table_data(table_data):
    rows = table_data.split("&")
    headers = rows[0].split("|")
    parsed_data = []
    try:
      for row in rows[1:]:
          values = row.split("|")
          parsed_data.append({headers[i].strip(): values[i].strip() for i in range(len(headers))})
    except:
      parsed_data = []
    return parsed_data


## Sumen

In [19]:
# Load Sumen model
print("Loading Sumen OCR model...")
sumen_model = VisionEncoderDecoderModel.from_pretrained("hoang-quoc-trung/sumen-base").to(device)
sumen_processor = AutoProcessor.from_pretrained("hoang-quoc-trung/sumen-base")
print("Sumen model loaded successfully!")

Loading Sumen OCR model...


config.json:   0%|          | 0.00/4.92k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/1.40G [00:00<?, ?B/s]

Config of the encoder: <class 'transformers.models.donut.modeling_donut_swin.DonutSwinModel'> is overwritten by shared encoder config: DonutSwinConfig {
  "attention_probs_dropout_prob": 0.0,
  "depths": [
    2,
    2,
    14,
    2
  ],
  "drop_path_rate": 0.1,
  "embed_dim": 128,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.0,
  "hidden_size": 1024,
  "image_size": [
    224,
    468
  ],
  "initializer_range": 0.02,
  "layer_norm_eps": 1e-05,
  "mlp_ratio": 4.0,
  "model_type": "donut-swin",
  "num_channels": 3,
  "num_heads": [
    4,
    8,
    16,
    32
  ],
  "num_layers": 4,
  "patch_size": 4,
  "qkv_bias": true,
  "transformers_version": "4.48.3",
  "use_absolute_embeddings": false,
  "window_size": 7
}

Config of the decoder: <class 'transformers.models.mbart.modeling_mbart.MBartForCausalLM'> is overwritten by shared decoder config: MBartConfig {
  "activation_dropout": 0.0,
  "activation_function": "gelu",
  "add_cross_attention": true,
  "add_final_layer_norm": true

generation_config.json:   0%|          | 0.00/226 [00:00<?, ?B/s]

preprocessor_config.json:   0%|          | 0.00/479 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/4.49k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/2.14M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/552 [00:00<?, ?B/s]

Sumen model loaded successfully!


In [20]:
def run_sumen_ocr(image):
    pixel_values = sumen_processor.image_processor(image, return_tensors="pt").pixel_values.to(device)
    task_prompt = sumen_processor.tokenizer.bos_token
    decoder_input_ids = sumen_processor.tokenizer(task_prompt, add_special_tokens=False, return_tensors="pt").input_ids
    with torch.no_grad():
        outputs = sumen_model.generate(
            pixel_values,
            decoder_input_ids=decoder_input_ids.to(device),
            max_length=sumen_model.decoder.config.max_length,
            pad_token_id=sumen_processor.tokenizer.pad_token_id,
            eos_token_id=sumen_processor.tokenizer.eos_token_id,
            use_cache=True,
            num_beams=4,
            bad_words_ids=[[sumen_processor.tokenizer.unk_token_id]],
            return_dict_in_generate=True,
        )
    clean_latex = sumen_processor.tokenizer.batch_decode(outputs.sequences)[0]
    return clean_latex.replace("<s>", "").replace("</s>", "").strip()

## Run this cell to start API

In [21]:
import sys
sys.stdout = open("APIlog", "w")

print(socket.gethostbyname(socket.gethostname()))

app = Flask(__name__)


@app.route("/")
def hello():
    return "I am alive!"

@app.route('/parseFormula', methods=['POST'])
def handle_formula():
    print("-- You have reached endpoint for formula --")

    file = request.files['image']

    ## PROCESS IMAGE

    processedFormulaLaTex, processedFormulaNL = processFormula(file)

    return jsonify({'element_type':"formula", 'formula': processedFormulaLaTex, "NL": processedFormulaNL, "preferred": processedFormulaLaTex})

@app.route('/parseChart', methods=['POST'])
def handle_chart():
    print("-- You have reached endpoint for chart --")

    file = request.files['image']

    ## PROCESS IMAGE
    processedChartCSV, processedChartNL = processChart(file)


    return jsonify({'element_type':"chart", 'NL': processedChartNL, "csv": processedChartCSV, "preferred": processedChartNL})

@app.route('/parseFigure', methods=['POST'])
def handle_figure():
    print("-- You have reached endpoint for figure --")

    file = request.files['image']

    ## PROCESS IMAGE

    processedFigureNL = processFigure(file)

    return jsonify({'element_type':"figure", 'NL': processedFigureNL, "preferred": processedFigureNL})

@app.route('/parseTable', methods=['POST'])
def handle_table():
    print("-- You have reached endpoint for table --")

    file = request.files['image']

    ## PROCESS IMAGE
    processedTableCSV, processedTableNL = processTable(file)

    return jsonify({'element_type':"table", 'NL': processedTableNL, "csv": processedTableCSV, "preferred": processedTableCSV})



def processFormula(file):
    print("Processing formula...")
    ###
    # Send to OCR or something
    ###
    """
    if 'file' not in request.files:
        return jsonify({"error": "No file uploaded"}), 400
    file = request.files['file']
    """
    if file.filename == '':
        return jsonify({"error": "No selected file"}), 400
    image = Image.open(BytesIO(file.read())).convert('RGB')
    latex_code = run_sumen_ocr(image)
    #return jsonify({"latex": latex_code})

    NLdata = "some NL"
    return latex_code, NLdata

def processChart(file):
    print("Processing chart...")

    """
    if 'file' not in request.files:
        return jsonify({"error": "No file uploaded"}), 400
    file = request.files['file']
    """
    if file.filename == '':
        return jsonify({"error": "No selected file"}), 400
    image = Image.open(BytesIO(file.read())).convert('RGB')
    summary = generate_unichart_response(image, "<summarize_chart><s_answer>")
    table_data = generate_unichart_response(image, "<extract_data_table><s_answer>")
    structured_table_data = parse_table_data(table_data)
    #return jsonify({"summary": summary, "table_data": structured_table_data})

    return structured_table_data, summary

def processFigure(image):
    print("Processing figure...")
    ###
    # Send to VLM or something
    ###
    NLdata = "some NL"
    return NLdata

def processTable(image):
    print("Processing table...")
    ###
    # Send to OCR or VLM or tableParser or something
    ###
    CSVdata = ["some CSV data stuff", "22"]
    NLdata = "some NL"
    return CSVdata, NLdata




def callVLM(pipe, image, query):
  """
  Calls the VLM model.

  Paramaters:
  pipe: The VLM model.
  image: The image to be classified.
  query: The query to be used for classification.

  Returns:
  response.text: The response from the VLM model.
  """
  print("\n- Calling VLM -")
  #image = load_image('testimagetext.png')
  image = load_image(image)
  response = pipe((query, image))
  #print(response.text)
  return response.text



def callML(model, image):

  # Load the image
  #image_path = image  # Replace with the path to your image
  #image = Image.open(image_path)
  image = image.convert("RGB")  # Ensure the image is in RGB format

  img_size = 224


  # Define the same transformations used during training
  data_transforms = A.Compose([
      A.Resize(img_size, img_size),
      A.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
      A.pytorch.transforms.ToTensorV2()
  ])

  # Apply transformations
  transformed_image = data_transforms(image=np.array(image))["image"]

  # Add a batch dimension
  transformed_image = transformed_image.unsqueeze(0)

  # Move the image to the appropriate device (GPU or CPU)
  device = "cuda:0" if torch.cuda.is_available() else "cpu"
  transformed_image = transformed_image.to(device)

  # Make prediction
  predicted_class = model.predict(transformed_image)

  # Get the class name
  class_names = ['just_image', 'bar_chart', 'diagram', 'flow_chart', 'graph',
                'growth_chart', 'pie_chart', 'table', 'text_sentence']
  predicted_class_name = class_names[predicted_class[0]]

  print(f"Predicted class: {predicted_class_name}")
  return predicted_class_name

#VLM = loadVLM()

@app.route("/loadVLM")
def load_vlm():
    print("API endpoint: Loading VLM...")
    global VLM
    VLM = loadVLM()
    return "API endpoint: Loading VLM..."

@app.route('/callVLM', methods=['POST'])
def call_vlm():
    print("-- You have reached endpoint for classifier VLM --")

    image = request.files['image']
    image = Image.open(image)

    query = request.files['query']

    ## PROCESS IMAGE
    response = callVLM(VLM, image, query.getvalue().decode("utf-8"))
    #response = "VLMresponse"


    return jsonify({'VLMresponse':response})

@app.route('/callClassifier', methods=['POST'])
def call_ml():
    print("-- You have reached endpoint for classifier ML --")

    image = request.files['image']
    image = Image.open(image)

    ## PROCESS IMAGE
    response = callML(ML, image)
    #response = "VLMresponse"


    return jsonify({'ClassifierResponse':response})


@app.route("/test2")
def test2():
    print("API endpoint: Loading VLM...")
    g = 2
    print("..", g)
    return "API endpoint: Loading VLM..."+str(g)

@app.route('/test', methods=['POST'])
def test_function():
    text = request.get_json()['text']
    print(text)
    predictions = "predd"
    sentiment = "senttttt"
    return jsonify({'predictions ':predictions, 'sentiment ': sentiment})

port = 8000
threading.Thread(target=app.run, kwargs={'host':'0.0.0.0','port':port}).start()

## Test call to API

In [38]:
import requests
from PIL import Image
import io

query = "with one word, classify this as either a table, figure, diagram, chart, formula, image or other"
query = "Answer with only one word (Yes OR No), is this a formula?"
image = Image.open("chart3.png")

response = callML(ML, image)
print("!!!Response!!!: ", response)

print("-1--\n---", image.format, "---")
"""
img_byte_arr = io.BytesIO()
image.save(img_byte_arr, format='PNG')
img_byte_arr = img_byte_arr.getvalue()

APIresponse = requests.post("http://172.28.0.12:8002/"+"parseChart", files={'image': img_byte_arr})
APIresponse = APIresponse.json()
#APIresponse["element_number"] = elementNr
#APIresponse["page_number"] = pagenr

print("Response from chartParser: --> ", APIresponse["preferred"])

#print("YES: ", "Formula: ", elementNr, " ->", regex)
#st.session_state.metrics["formulas"] += 1
subtype = "formula"
print("Redirecting to formulaParser")
##### APIresponse = API.call("127.0.0.1/formulaParser") #####

img_byte_arr = io.BytesIO()
image.save(img_byte_arr, format='PNG')
img_byte_arr = img_byte_arr.getvalue()

APIresponse = requests.post("http://172.28.0.12:8000/"+"callML", files={'image': img_byte_arr})
APIresponse = APIresponse.json()
#APIresponse["element_number"] = elementNr
#APIresponse["page_number"] = pagenr

print("Response from formulaParser: --> ", APIresponse["preferred"])
"""

'\nimg_byte_arr = io.BytesIO()\nimage.save(img_byte_arr, format=\'PNG\')\nimg_byte_arr = img_byte_arr.getvalue()\n\nAPIresponse = requests.post("http://172.28.0.12:8002/"+"parseChart", files={\'image\': img_byte_arr})\nAPIresponse = APIresponse.json()\n#APIresponse["element_number"] = elementNr\n#APIresponse["page_number"] = pagenr\n\nprint("Response from chartParser: --> ", APIresponse["preferred"])\n\n#print("YES: ", "Formula: ", elementNr, " ->", regex)\n#st.session_state.metrics["formulas"] += 1\nsubtype = "formula"\nprint("Redirecting to formulaParser")\n##### APIresponse = API.call("127.0.0.1/formulaParser") #####\n\nimg_byte_arr = io.BytesIO()\nimage.save(img_byte_arr, format=\'PNG\')\nimg_byte_arr = img_byte_arr.getvalue()\n\nAPIresponse = requests.post("http://172.28.0.12:8000/"+"callML", files={\'image\': img_byte_arr})\nAPIresponse = APIresponse.json()\n#APIresponse["element_number"] = elementNr\n#APIresponse["page_number"] = pagenr\n\nprint("Response from formulaParser: -->

## To kill/stop API

Running this often kills the session :(

In [None]:
#!fuser -k 8000/tcp