In [43]:
"""
Table QA with just table documents.

Uses EmbeddingRetriever and TableReader.

See first half of https://haystack.deepset.ai/tutorials/15_tableqa
"""
import os
from pathlib import Path

from haystack import Document
from haystack.nodes import AzureConverter, EmbeddingRetriever, TableReader
from haystack.document_stores import InMemoryDocumentStore
from haystack.pipelines import Pipeline
from haystack.utils import print_answers

AZURE_CONVERTER_KEY = os.environ.get("AZURE_CONVERTER_KEY")

In [3]:
converter = AzureConverter(
    endpoint="https://azureconverter.cognitiveservices.azure.com/",
    credential_key=AZURE_CONVERTER_KEY,
    save_json=True
)

PDF_PATH = Path("/home/tomw/unifi-pdf-llm/data/test/Sasol Sustainability Report_2021_22Sep21_10h30_0_0 - short.pdf")

docs = converter.convert(file_path=PDF_PATH, meta=None)

In [10]:
# Filter to just table documents
for idx, doc in enumerate(docs):
    if doc.content_type == 'text':
        docs.pop(idx)

In [11]:
document_store = InMemoryDocumentStore()

document_store.write_documents(docs)

In [14]:
# Using an embedding model that is finetuned so it can also generate embeddings for tables (not just text)
retriever = EmbeddingRetriever(document_store=document_store, embedding_model="deepset/all-mpnet-base-v2-table")

# Add table embeddings to the tables in the document store
document_store.update_embeddings(retriever=retriever)

config_sentence_transformers.json: 100%|██████████| 116/116 [00:00<00:00, 843kB/s]
modules.json: 100%|██████████| 349/349 [00:00<00:00, 2.66MB/s]
README.md: 100%|██████████| 2.42k/2.42k [00:00<00:00, 14.4MB/s]
sentence_bert_config.json: 100%|██████████| 53.0/53.0 [00:00<00:00, 419kB/s]
config.json: 100%|██████████| 576/576 [00:00<00:00, 4.12MB/s]
pytorch_model.bin: 100%|██████████| 438M/438M [01:10<00:00, 6.23MB/s] 
tokenizer_config.json: 100%|██████████| 345/345 [00:00<00:00, 2.13MB/s]
vocab.txt: 100%|██████████| 232k/232k [00:00<00:00, 1.36MB/s]
tokenizer.json: 100%|██████████| 711k/711k [00:00<00:00, 6.07MB/s]
special_tokens_map.json: 100%|██████████| 239/239 [00:00<00:00, 1.60MB/s]
1_Pooling/config.json: 100%|██████████| 190/190 [00:00<00:00, 1.48MB/s]
Batches: 100%|██████████| 1/1 [00:00<00:00,  1.17it/s]ocs/s]
Documents Processed: 10000 docs [00:00, 11565.42 docs/s]     


In [15]:
# Try the Retriever
retrieved_tables = retriever.retrieve("How many employees were there in 2020?", top_k=1)

# Get highest scored table
print(retrieved_tables[0].content)

Batches: 100%|██████████| 1/1 [00:00<00:00, 16.39it/s]

                           Human Capital - Our people  \
0                                    Employee numbers   
1                                   Employee turnover   
2                                              Safety   
3                                Recordable Case Rate   
4                                          - Employee   
5                                  - Service provider   
6                     Lost Work Day Case Rate (LWDCR)   
7                                          - Employee   
8                                  - Service provider   
9           Employee and service provider fatalities*   
10                                         - Employee   
11                                 - Service provider   
12  Employee and service provider fatal injury fre...   
13  Total major and significant fires, explosions ...   
14             - Major fires, explosions and releases   
15       - Significant fires, explosions and releases   
16                             




In [53]:
# TODO: Try different TableReader models
reader = TableReader(model_name_or_path="google/tapas-large-finetuned-wtq", max_seq_len=512)

OSError: We couldn't connect to 'https://huggingface.co' to load this file, couldn't find it in the cached files and it looks like google/tapas-large-finetuned-wtq is not the path to a directory containing a file named config.json.
Checkout your internet connection or see how to run the library in offline mode at 'https://huggingface.co/docs/transformers/installation#offline-mode'.

In [35]:
# Try the TableReader on one table

# Works very well if the query is very similar to the information in the table
table_doc = docs[1]
prediction = reader.predict(query="Employee numbers 2021?", documents=[table_doc])
print(prediction['answers'])

print("\nLess precise query:\n")

# Not so well if the query is more natural language
table_doc = docs[1]
prediction = reader.predict(query="How many employees were there in 2021?", documents=[table_doc])
print(prediction['answers'])

  text = normalize_for_match(row[col_index].text)
  cell = row[col_index]


[<Answer {'answer': '28 949', 'type': 'extractive', 'score': 1.0, 'context': [['Human Capital - Our people', 'Footnote', '2021', '2020', '2019', '2018', 'Level of assurance 2021'], ['Employee numbers', '1', '28 949', '31 001', '31 429', '31 270', ''], ['Employee turnover', '1', '3 869', '1 936', '1780', '1 560', ''], ['Safety', '', '', '', '', '', ''], ['Recordable Case Rate', '2', '0,26', '0,27', '0,26', '0,27', 'Limited'], ['- Employee', '', '0,29', '0,31', '0,31', '0,30', ''], ['- Service provider', '', '0,22', '0,23', '0,22', '0,25', ''], ['Lost Work Day Case Rate (LWDCR)', '', '0,14', '0,11', '0,10', '0,11', ''], ['- Employee', '', '0,16', '0,14', '0,13', '0,15', ''], ['- Service provider', '', '0,11', '0,08', '0,08', '0,09', ''], ['Employee and service provider fatalities*', '', '2', '6', '3', '4', 'Limited'], ['- Employee', '', '1', '3', '2', '3', ''], ['- Service provider', '', '1', '3', '1', '1', ''], ['Employee and service provider fatal injury frequency rate', '', '0,001', '

  text = normalize_for_match(row[col_index].text)
  cell = row[col_index]


In [38]:
table_qa_pipeline = Pipeline()
table_qa_pipeline.add_node(component=retriever, name="EmbeddingRetriever", inputs=["Query"])
table_qa_pipeline.add_node(component=reader, name="TableReader", inputs=["EmbeddingRetriever"])

In [42]:
prediction = table_qa_pipeline.run("How many employees were there in 2020?", params={"top_k": 1})
print_answers(prediction, details="minimum")

Batches: 100%|██████████| 1/1 [00:00<00:00, 22.07it/s]
  text = normalize_for_match(row[col_index].text)
  cell = row[col_index]


'Query: How many employees were there in 2020?'
'Answers:'
[   {   'answer': '3',
        'context':                            Human Capital - Our people  \
0                                    Employee numbers   
1                                   Employee turnover   
2                                              Safety   
3                                Recordable Case Rate   
4                                          - Employee   
5                                  - Service provider   
6                     Lost Work Day Case Rate (LWDCR)   
7                                          - Employee   
8                                  - Service provider   
9           Employee and service provider fatalities*   
10                                         - Employee   
11                                 - Service provider   
12  Employee and service provider fatal injury fre...   
13  Total major and significant fires, explosions ...   
14             - Major fires, explosions and