Skip to content
This repository has been archived by the owner on Mar 1, 2024. It is now read-only.

Wordlift #556

Open
wants to merge 19 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
36 changes: 20 additions & 16 deletions llama_hub/wordlift/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,27 +35,31 @@ config_options = {
'metadata_fields': ['<YOUR_METADATA_FIELDS>']
}
# Create an instance of the WordLiftLoader
reader = WordLiftLoader(endpoint, headers, query, fields, config_options)
async def main():
reader = WordLiftLoader(endpoint, headers, query, fields, config_options)

# Load the data
documents = reader.load_data()
# Load the data
documents = await reader.load_data()

# Convert the documents
converted_doc = []
for doc in documents:
converted_doc_id = json.dumps(doc.doc_id)
converted_doc.append(Document(text=doc.text, doc_id=converted_doc_id,
embedding=doc.embedding, doc_hash=doc.doc_hash, extra_info=doc.extra_info))
# Convert the documents
converted_doc = []
for doc in documents:
converted_doc_id = json.dumps(doc.doc_id)
converted_doc.append(Document(text=doc.text, doc_id=converted_doc_id,
embedding=doc.embedding, doc_hash=doc.doc_hash, extra_info=doc.extra_info))

# Create the index and query engine
index = VectorStoreIndex.from_documents(converted_doc)
query_engine = index.as_query_engine()
# Create the index and query engine
index = VectorStoreIndex.from_documents(converted_doc)
query_engine = index.as_query_engine()

# Perform a query
result = query_engine.query("<YOUR_QUERY>")
# Perform a query
result = query_engine.query("<YOUR_QUERY>")

# Process the result as needed
logging.info("Result: %s", result)
# Process the result as needed
logging.info("Result: %s", result)

if __name__ == "__main__":
asyncio.run(main()) # Run the asyncio event loop

```
This loader is designed to be used as a way to load data from WordLift KGs into [LlamaIndex](https://github.com/emptycrown/llama-hub/tree/main/llama_hub/apify/actor#:~:text=load%20data%20into-,LlamaIndex,-and/or%20subsequently) and/or subsequently used as a Tool in a [LangChain](https://github.com/hwchase17/langchain) Agent.
130 changes: 80 additions & 50 deletions llama_hub/wordlift/base.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,25 @@
import requests
from bs4 import BeautifulSoup
from typing import List
from llama_index.readers.base import BaseReader
from llama_index.readers.schema.base import Document
import logging
import os
import re
import warnings
from typing import List
from urllib.parse import urlparse

DATA_KEY = 'data'
ERRORS_KEY = 'errors'
DEFAULT_PAGE = 0
DEFAULT_ROWS = 500
import requests
import asyncio
import urllib.parse
from bs4 import BeautifulSoup
from llama_index.readers.base import BaseReader
from llama_index.readers.schema.base import Document

DATA_KEY = "data"
ERRORS_KEY = "errors"
TIME_OUT = 10


class WordLiftLoaderError(Exception):
"""Base class for WordLiftLoader exceptions."""

pass


Expand Down Expand Up @@ -65,7 +69,7 @@ def __init__(self, endpoint, headers, query, fields, configure_options):
self.fields = fields
self.configure_options = configure_options

def fetch_data(self) -> dict:
async def fetch_data(self) -> dict:
"""
Fetches data from the WordLift GraphQL API.

Expand All @@ -77,16 +81,20 @@ def fetch_data(self) -> dict:
"""
try:
query = self.alter_query()
response = requests.post(self.endpoint, json={
"query": query}, headers=self.headers)
response = await asyncio.to_thread(
requests.post,
self.endpoint,
json={"query": query},
headers=self.headers,
)
Comment on lines +84 to +89
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the reason to run the request on another thread? or add async operations?

response.raise_for_status()
data = response.json()
if ERRORS_KEY in data:
raise APICallError(data[ERRORS_KEY])
return data
except requests.exceptions.RequestException as e:
logging.error('Error connecting to the API:', exc_info=True)
raise APICallError('Error connecting to the API') from e
logging.error("Error connecting to the API:", exc_info=True)
raise APICallError("Error connecting to the API") from e

def transform_data(self, data: dict) -> List[Document]:
"""
Expand All @@ -104,62 +112,71 @@ def transform_data(self, data: dict) -> List[Document]:
try:
data = data[DATA_KEY][self.fields]
documents = []
text_fields = self.configure_options.get('text_fields', [])
metadata_fields = self.configure_options.get('metadata_fields', [])
text_fields = self.configure_options.get("text_fields", [])
metadata_fields = self.configure_options.get("metadata_fields", [])
for i in range(len(text_fields)):
if text_fields[i] == 'url' or text_fields[i] == 'address':
text_fields[i] = 'body'

for item in data:
if not all(key in item for key in text_fields):
logging.warning(
f"Skipping document due to missing text fields: {item}")
f"Skipping document due to missing text fields: {item}"
)
continue
row = {}
for key, value in item.items():
if key in text_fields or key in metadata_fields:
row[key] = value
else:
row[key] = clean_value(value)

text_parts = [
get_separated_value(row, field.split('.'))
get_separated_value(row, field.split("."))
for field in text_fields
if get_separated_value(row, field.split('.')) is not None
if get_separated_value(row, field.split(".")) is not None
]

text_parts = flatten_list(text_parts)
text = ' '.join(text_parts)
text = " ".join(text_parts)

extra_info = {}
for field in metadata_fields:
field_keys = field.split('.')
field_keys = field.split(".")
value = get_separated_value(row, field_keys)
if value is None:
logging.warning(f"Using default value for {field}")
value = "n.a"
if isinstance(value, list) and len(value) != 0:
value = value[0]
if is_url(value) and is_valid_html(value):
extra_info[field] = value
value = value.replace("\n", "")
else:
extra_info[field] = clean_value(value)

document = Document(text=text, extra_info=extra_info)
cleaned_value = clean_value(value)
cleaned_value = cleaned_value.replace("\n", "")
extra_info[field] = value
text = text.replace("\n", "")
plain_text = re.sub("<.*?>", "", text)
document = Document(text=plain_text, extra_info=extra_info)
documents.append(document)

return documents
except Exception as e:
logging.error('Error transforming data:', exc_info=True)
raise DataTransformError('Error transforming data') from e
logging.error("Error transforming data:", exc_info=True)
raise DataTransformError("Error transforming data") from e

def load_data(self) -> List[Document]:
async def load_data(self) -> List[Document]:
"""
Loads the data by fetching and transforming it.

Returns:
List[Document]: The list of loaded documents.
"""
try:
data = self.fetch_data()
data = await self.fetch_data()
documents = self.transform_data(data)
return documents
except (APICallError, DataTransformError) as e:
logging.error('Error loading data:', exc_info=True)
except (APICallError, DataTransformError):
logging.error("Error loading data:", exc_info=True)
raise

def alter_query(self):
Expand All @@ -170,7 +187,11 @@ def alter_query(self):
str: The altered GraphQL query with pagination arguments.
"""
from graphql import parse, print_ast
from graphql.language.ast import ArgumentNode, NameNode, IntValueNode
from graphql.language.ast import ArgumentNode, IntValueNode, NameNode

DEFAULT_PAGE = 0
DEFAULT_ROWS = 500

query = self.query
page = DEFAULT_PAGE
rows = DEFAULT_ROWS
Expand All @@ -179,14 +200,12 @@ def alter_query(self):

field_node = ast.definitions[0].selection_set.selections[0]

if not any(arg.name.value == 'page' for arg in field_node.arguments):
if not any(arg.name.value == "page" for arg in field_node.arguments):
page_argument = ArgumentNode(
name=NameNode(value='page'),
value=IntValueNode(value=page)
name=NameNode(value="page"), value=IntValueNode(value=page)
)
rows_argument = ArgumentNode(
name=NameNode(value='rows'),
value=IntValueNode(value=rows)
name=NameNode(value="rows"), value=IntValueNode(value=rows)
)
field_node.arguments = field_node.arguments + \
(page_argument, rows_argument)
Expand All @@ -205,8 +224,8 @@ def is_url(text: str) -> bool:
bool: True if the text is a URL, False otherwise.
"""
try:
result = urlparse(text)
return all([result.scheme, result.netloc])
parsed_url = urllib.parse.urlparse(text)
return all([parsed_url.scheme, parsed_url.netloc])
except ValueError:
return False

Expand All @@ -220,17 +239,24 @@ def is_valid_html(content: str) -> bool:

if is_url(content):
try:
response = requests.get(content)
response = requests.get(content, timeout=TIME_OUT)
if response.status_code == 200:
html_content = response.text
return BeautifulSoup(html_content, 'html.parser').find('html') is not None
return (
BeautifulSoup(html_content, "html.parser").find(
"html") is not None
)
else:
return False
except (requests.exceptions.RequestException, requests.exceptions.ConnectionError):
except (
requests.exceptions.RequestException,
requests.exceptions.ConnectionError,
requests.exceptions.Timeout
):
# If there is a connection error or the URL doesn't resolve, skip it
return False

return BeautifulSoup(content, 'html.parser').find('html') is not None
return BeautifulSoup(content, "html.parser").find("html") is not None


@staticmethod
Expand All @@ -256,24 +282,28 @@ def clean_html(text: str) -> str:
if isinstance(text, str):
try:
if is_url(text):
response = requests.get(text)
response = requests.get(text, timeout=TIME_OUT)
if response.status_code == 200:
html_content = response.text
soup = BeautifulSoup(html_content, 'html.parser')
soup = BeautifulSoup(html_content, "lxml")
cleaned_text = soup.get_text()
else:
cleaned_text = ""
elif os.path.isfile(text):
with open(text, 'r') as file:
soup = BeautifulSoup(file, 'html.parser')
with open(text, "r") as file:
soup = BeautifulSoup(file, "lxml")
cleaned_text = soup.get_text()
else:
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=UserWarning)
soup = BeautifulSoup(text, 'html.parser')
soup = BeautifulSoup(text, "lxml")
cleaned_text = soup.get_text()
return cleaned_text
except (requests.exceptions.RequestException, requests.exceptions.ConnectionError):
except (
requests.exceptions.RequestException,
requests.exceptions.ConnectionError,
requests.exceptions.Timeout
):
# If there is a connection error or the URL doesn't resolve, skip it
return ""

Expand Down