diff --git a/llama_hub/wordlift/README.md b/llama_hub/wordlift/README.md index 23e4c89237..c9f2a11135 100644 --- a/llama_hub/wordlift/README.md +++ b/llama_hub/wordlift/README.md @@ -35,27 +35,31 @@ config_options = { 'metadata_fields': [''] } # Create an instance of the WordLiftLoader -reader = WordLiftLoader(endpoint, headers, query, fields, config_options) +async def main(): + reader = WordLiftLoader(endpoint, headers, query, fields, config_options) -# Load the data -documents = reader.load_data() + # Load the data + documents = await reader.load_data() -# Convert the documents -converted_doc = [] -for doc in documents: - converted_doc_id = json.dumps(doc.doc_id) - converted_doc.append(Document(text=doc.text, doc_id=converted_doc_id, - embedding=doc.embedding, doc_hash=doc.doc_hash, extra_info=doc.extra_info)) + # Convert the documents + converted_doc = [] + for doc in documents: + converted_doc_id = json.dumps(doc.doc_id) + converted_doc.append(Document(text=doc.text, doc_id=converted_doc_id, + embedding=doc.embedding, doc_hash=doc.doc_hash, extra_info=doc.extra_info)) -# Create the index and query engine -index = VectorStoreIndex.from_documents(converted_doc) -query_engine = index.as_query_engine() + # Create the index and query engine + index = VectorStoreIndex.from_documents(converted_doc) + query_engine = index.as_query_engine() -# Perform a query -result = query_engine.query("") + # Perform a query + result = query_engine.query("") -# Process the result as needed -logging.info("Result: %s", result) + # Process the result as needed + logging.info("Result: %s", result) + +if __name__ == "__main__": + asyncio.run(main()) # Run the asyncio event loop ``` This loader is designed to be used as a way to load data from WordLift KGs into [LlamaIndex](https://github.com/emptycrown/llama-hub/tree/main/llama_hub/apify/actor#:~:text=load%20data%20into-,LlamaIndex,-and/or%20subsequently) and/or subsequently used as a Tool in a [LangChain](https://github.com/hwchase17/langchain) Agent. diff --git a/llama_hub/wordlift/base.py b/llama_hub/wordlift/base.py index 40b2cdd2be..b28ec2d9c9 100644 --- a/llama_hub/wordlift/base.py +++ b/llama_hub/wordlift/base.py @@ -1,21 +1,25 @@ -import requests -from bs4 import BeautifulSoup -from typing import List -from llama_index.readers.base import BaseReader -from llama_index.readers.schema.base import Document import logging import os +import re import warnings +from typing import List from urllib.parse import urlparse -DATA_KEY = 'data' -ERRORS_KEY = 'errors' -DEFAULT_PAGE = 0 -DEFAULT_ROWS = 500 +import requests +import asyncio +import urllib.parse +from bs4 import BeautifulSoup +from llama_index.readers.base import BaseReader +from llama_index.readers.schema.base import Document + +DATA_KEY = "data" +ERRORS_KEY = "errors" +TIME_OUT = 10 class WordLiftLoaderError(Exception): """Base class for WordLiftLoader exceptions.""" + pass @@ -65,7 +69,7 @@ def __init__(self, endpoint, headers, query, fields, configure_options): self.fields = fields self.configure_options = configure_options - def fetch_data(self) -> dict: + async def fetch_data(self) -> dict: """ Fetches data from the WordLift GraphQL API. @@ -77,16 +81,20 @@ def fetch_data(self) -> dict: """ try: query = self.alter_query() - response = requests.post(self.endpoint, json={ - "query": query}, headers=self.headers) + response = await asyncio.to_thread( + requests.post, + self.endpoint, + json={"query": query}, + headers=self.headers, + ) response.raise_for_status() data = response.json() if ERRORS_KEY in data: raise APICallError(data[ERRORS_KEY]) return data except requests.exceptions.RequestException as e: - logging.error('Error connecting to the API:', exc_info=True) - raise APICallError('Error connecting to the API') from e + logging.error("Error connecting to the API:", exc_info=True) + raise APICallError("Error connecting to the API") from e def transform_data(self, data: dict) -> List[Document]: """ @@ -104,13 +112,17 @@ def transform_data(self, data: dict) -> List[Document]: try: data = data[DATA_KEY][self.fields] documents = [] - text_fields = self.configure_options.get('text_fields', []) - metadata_fields = self.configure_options.get('metadata_fields', []) + text_fields = self.configure_options.get("text_fields", []) + metadata_fields = self.configure_options.get("metadata_fields", []) + for i in range(len(text_fields)): + if text_fields[i] == 'url' or text_fields[i] == 'address': + text_fields[i] = 'body' for item in data: if not all(key in item for key in text_fields): logging.warning( - f"Skipping document due to missing text fields: {item}") + f"Skipping document due to missing text fields: {item}" + ) continue row = {} for key, value in item.items(): @@ -118,36 +130,41 @@ def transform_data(self, data: dict) -> List[Document]: row[key] = value else: row[key] = clean_value(value) - text_parts = [ - get_separated_value(row, field.split('.')) + get_separated_value(row, field.split(".")) for field in text_fields - if get_separated_value(row, field.split('.')) is not None + if get_separated_value(row, field.split(".")) is not None ] text_parts = flatten_list(text_parts) - text = ' '.join(text_parts) + text = " ".join(text_parts) extra_info = {} for field in metadata_fields: - field_keys = field.split('.') + field_keys = field.split(".") value = get_separated_value(row, field_keys) + if value is None: + logging.warning(f"Using default value for {field}") + value = "n.a" if isinstance(value, list) and len(value) != 0: value = value[0] if is_url(value) and is_valid_html(value): - extra_info[field] = value + value = value.replace("\n", "") else: - extra_info[field] = clean_value(value) - - document = Document(text=text, extra_info=extra_info) + cleaned_value = clean_value(value) + cleaned_value = cleaned_value.replace("\n", "") + extra_info[field] = value + text = text.replace("\n", "") + plain_text = re.sub("<.*?>", "", text) + document = Document(text=plain_text, extra_info=extra_info) documents.append(document) return documents except Exception as e: - logging.error('Error transforming data:', exc_info=True) - raise DataTransformError('Error transforming data') from e + logging.error("Error transforming data:", exc_info=True) + raise DataTransformError("Error transforming data") from e - def load_data(self) -> List[Document]: + async def load_data(self) -> List[Document]: """ Loads the data by fetching and transforming it. @@ -155,11 +172,11 @@ def load_data(self) -> List[Document]: List[Document]: The list of loaded documents. """ try: - data = self.fetch_data() + data = await self.fetch_data() documents = self.transform_data(data) return documents - except (APICallError, DataTransformError) as e: - logging.error('Error loading data:', exc_info=True) + except (APICallError, DataTransformError): + logging.error("Error loading data:", exc_info=True) raise def alter_query(self): @@ -170,7 +187,11 @@ def alter_query(self): str: The altered GraphQL query with pagination arguments. """ from graphql import parse, print_ast - from graphql.language.ast import ArgumentNode, NameNode, IntValueNode + from graphql.language.ast import ArgumentNode, IntValueNode, NameNode + + DEFAULT_PAGE = 0 + DEFAULT_ROWS = 500 + query = self.query page = DEFAULT_PAGE rows = DEFAULT_ROWS @@ -179,14 +200,12 @@ def alter_query(self): field_node = ast.definitions[0].selection_set.selections[0] - if not any(arg.name.value == 'page' for arg in field_node.arguments): + if not any(arg.name.value == "page" for arg in field_node.arguments): page_argument = ArgumentNode( - name=NameNode(value='page'), - value=IntValueNode(value=page) + name=NameNode(value="page"), value=IntValueNode(value=page) ) rows_argument = ArgumentNode( - name=NameNode(value='rows'), - value=IntValueNode(value=rows) + name=NameNode(value="rows"), value=IntValueNode(value=rows) ) field_node.arguments = field_node.arguments + \ (page_argument, rows_argument) @@ -205,8 +224,8 @@ def is_url(text: str) -> bool: bool: True if the text is a URL, False otherwise. """ try: - result = urlparse(text) - return all([result.scheme, result.netloc]) + parsed_url = urllib.parse.urlparse(text) + return all([parsed_url.scheme, parsed_url.netloc]) except ValueError: return False @@ -220,17 +239,24 @@ def is_valid_html(content: str) -> bool: if is_url(content): try: - response = requests.get(content) + response = requests.get(content, timeout=TIME_OUT) if response.status_code == 200: html_content = response.text - return BeautifulSoup(html_content, 'html.parser').find('html') is not None + return ( + BeautifulSoup(html_content, "html.parser").find( + "html") is not None + ) else: return False - except (requests.exceptions.RequestException, requests.exceptions.ConnectionError): + except ( + requests.exceptions.RequestException, + requests.exceptions.ConnectionError, + requests.exceptions.Timeout + ): # If there is a connection error or the URL doesn't resolve, skip it return False - return BeautifulSoup(content, 'html.parser').find('html') is not None + return BeautifulSoup(content, "html.parser").find("html") is not None @staticmethod @@ -256,24 +282,28 @@ def clean_html(text: str) -> str: if isinstance(text, str): try: if is_url(text): - response = requests.get(text) + response = requests.get(text, timeout=TIME_OUT) if response.status_code == 200: html_content = response.text - soup = BeautifulSoup(html_content, 'html.parser') + soup = BeautifulSoup(html_content, "lxml") cleaned_text = soup.get_text() else: cleaned_text = "" elif os.path.isfile(text): - with open(text, 'r') as file: - soup = BeautifulSoup(file, 'html.parser') + with open(text, "r") as file: + soup = BeautifulSoup(file, "lxml") cleaned_text = soup.get_text() else: with warnings.catch_warnings(): warnings.filterwarnings("ignore", category=UserWarning) - soup = BeautifulSoup(text, 'html.parser') + soup = BeautifulSoup(text, "lxml") cleaned_text = soup.get_text() return cleaned_text - except (requests.exceptions.RequestException, requests.exceptions.ConnectionError): + except ( + requests.exceptions.RequestException, + requests.exceptions.ConnectionError, + requests.exceptions.Timeout + ): # If there is a connection error or the URL doesn't resolve, skip it return ""