diff --git a/apps/consumers.py b/apps/consumers.py index 6c2a07bc5a8..ae4c78c96c6 100644 --- a/apps/consumers.py +++ b/apps/consumers.py @@ -66,9 +66,9 @@ async def _respond_to_event(self, text_data): output_stream = await AppViewSet().run_app_internal_async(self.app_id, self._session_id, request_uuid, request, self.preview) async for output in output_stream: if 'errors' in output or 'session' in output: - await self.send(text_data=output) + await self.send(text_data=json.dumps(output)) else: - await self.send(text_data="{\"output\":" + output + '}') + await self.send(text_data=json.dumps({'output': output})) await self.send(text_data=json.dumps({'event': 'done'})) except Exception as e: diff --git a/apps/handlers/app_runnner.py b/apps/handlers/app_runnner.py index a132ad77e58..689f1aafef5 100644 --- a/apps/handlers/app_runnner.py +++ b/apps/handlers/app_runnner.py @@ -219,9 +219,9 @@ async def stream_output(): await asyncio.sleep(0.0001) if not metadata_sent: metadata_sent = True - yield json.dumps({'session': {'id': app_session['uuid']}, 'csp': csp, 'template': template}) + '\n' + yield {'session': {'id': app_session['uuid']}, 'csp': csp, 'template': template} output = next(output_iter) - yield json.dumps(output) + '\n' + yield output except StopIteration: pass except Exception as e: diff --git a/apps/tasks.py b/apps/tasks.py index 64aedb41771..4761aa3020f 100644 --- a/apps/tasks.py +++ b/apps/tasks.py @@ -114,6 +114,8 @@ def resync_data_entry_task(datasource: DataSource, entry_data: DataSourceEntry): def delete_data_source_task(datasource): datasource_type = datasource.type + if datasource_type.is_external_datasource: + return datasource_entry_handler_cls = DataSourceTypeFactory.get_datasource_type_handler( datasource_type, ) diff --git a/client/src/components/datasource/AddDataSourceModal.jsx b/client/src/components/datasource/AddDataSourceModal.jsx index e1988ecddb9..4b0a743e035 100644 --- a/client/src/components/datasource/AddDataSourceModal.jsx +++ b/client/src/components/datasource/AddDataSourceModal.jsx @@ -136,17 +136,21 @@ export function AddDataSourceModal({ .post("/api/datasources", { name: dataSourceName, type: dataSourceType.id, + config: dataSourceType.is_external_datasource ? formData : {}, }) .then((response) => { - const dataSource = response.data; - setDataSources([...dataSources, dataSource]); - axios() - .post(`/api/datasources/${dataSource.uuid}/add_entry`, { - entry_data: formData, - }) - .then((response) => { - dataSourceAddedCb(dataSource); - }); + // External data sources do not support adding entries + if (!dataSourceType.is_external_datasource) { + const dataSource = response.data; + setDataSources([...dataSources, dataSource]); + axios() + .post(`/api/datasources/${dataSource.uuid}/add_entry`, { + entry_data: formData, + }) + .then((response) => { + dataSourceAddedCb(dataSource); + }); + } }); handleCancelCb(); enqueueSnackbar( diff --git a/client/src/pages/data.jsx b/client/src/pages/data.jsx index 2b4c71b3543..9b55649556c 100644 --- a/client/src/pages/data.jsx +++ b/client/src/pages/data.jsx @@ -9,6 +9,7 @@ import { Chip, Grid, Stack, + Tooltip, } from "@mui/material"; import { TextareaAutosize } from "@mui/base"; @@ -16,6 +17,7 @@ import { TextareaAutosize } from "@mui/base"; import DeleteOutlineOutlinedIcon from "@mui/icons-material/DeleteOutlineOutlined"; import AddOutlinedIcon from "@mui/icons-material/AddOutlined"; import SyncOutlinedIcon from "@mui/icons-material/SyncOutlined"; +import SettingsEthernetIcon from "@mui/icons-material/SettingsEthernet"; import PeopleOutlineOutlinedIcon from "@mui/icons-material/PeopleOutlineOutlined"; import PersonOutlineOutlinedIcon from "@mui/icons-material/PersonOutlineOutlined"; @@ -192,51 +194,64 @@ export default function DataPage() { { title: "Action", key: "operation", - render: (record) => ( - - { - setModalTitle("Add New Data Entry"); - setSelectedDataSource(record); - setAddDataSourceModalOpen(true); - }} - > - - - { - setDeleteId(record); - setDeleteModalTitle("Delete Data Source"); - setDeleteModalMessage( -
- Are you sure you want to delete{" "} - {record.name} ? -
, - ); - setDeleteConfirmationModalOpen(true); - }} - > - -
- {profileFlags.IS_ORGANIZATION_MEMBER && record.isUserOwned && ( + render: (record) => { + return ( + + {!record?.type?.is_external_datasource && ( + { + setModalTitle("Add New Data Entry"); + setSelectedDataSource(record); + setAddDataSourceModalOpen(true); + }} + > + + + )} + {record?.type?.is_external_datasource && ( + + + + + + + + )} { - setModalTitle("Share Datasource"); - setSelectedDataSource(record); - setShareDataSourceModalOpen(true); + setDeleteId(record); + setDeleteModalTitle("Delete Data Source"); + setDeleteModalMessage( +
+ Are you sure you want to delete{" "} + {record.name} ? +
, + ); + setDeleteConfirmationModalOpen(true); }} > - {record.visibility === 0 ? ( - - ) : ( - - )} +
- )} -
- ), + {profileFlags.IS_ORGANIZATION_MEMBER && record.isUserOwned && ( + { + setModalTitle("Share Datasource"); + setSelectedDataSource(record); + setShareDataSourceModalOpen(true); + }} + > + {record.visibility === 0 ? ( + + ) : ( + + )} + + )} +
+ ); + }, }, ]; diff --git a/common/blocks/data/store/vectorstore/weaviate.py b/common/blocks/data/store/vectorstore/weaviate.py index 7678bf79e4c..4e7ef425664 100644 --- a/common/blocks/data/store/vectorstore/weaviate.py +++ b/common/blocks/data/store/vectorstore/weaviate.py @@ -80,6 +80,10 @@ class WeaviateConfiguration(BaseModel): weaviate_rw_api_key: Optional[str] = None embeddings_rate_limit: Optional[int] = 3000 default_batch_size: Optional[int] = 20 + username: Optional[str] = None + password: Optional[str] = None + api_key: Optional[str] = None + additional_headers: Optional[dict] = {} class Weaviate(VectorStoreInterface): @@ -130,8 +134,8 @@ def check_batch_result(results: Optional[List[Dict[str, Any]]]) -> None: json.dumps(result['result']['errors']), ), ) - - headers = {} + + headers = configuration.additional_headers if configuration.openai_key is not None: headers['X-OpenAI-Api-Key'] = configuration.openai_key if configuration.cohere_api_key is not None: @@ -144,10 +148,25 @@ def check_batch_result(results: Optional[List[Dict[str, Any]]]) -> None: headers['authorization'] = 'Bearer ' + \ configuration.weaviate_rw_api_key - self._client = weaviate.Client( - url=configuration.url, - additional_headers=headers, - ) + if configuration.username is not None and configuration.password is not None: + self._client = weaviate.Client( + url=configuration.url, + auth_client_secret=weaviate.AuthClientPassword( + username=configuration.username, password=configuration.password), + additional_headers=headers, + ) + elif configuration.api_key is not None: + self._client = weaviate.Client( + url=configuration.url, + auth_client_secret=weaviate.AuthApiKey( + api_key=configuration.api_key), + additional_headers=headers, + ) + else: + self._client = weaviate.Client( + url=configuration.url, + additional_headers=headers, + ) self.client.batch.configure( batch_size=DEFAULT_BATCH_SIZE, @@ -234,6 +253,7 @@ def similarity_search(self, index_name: str, document_query: DocumentQuery, **kw properties = [document_query.page_content_key] for key in document_query.metadata.get('additional_properties', []): properties.append(key) + additional_metadata_properties = document_query.metadata.get('metadata_properties', ['id', 'certainty', 'distance']) if kwargs.get('search_distance'): nearText['certainty'] = kwargs.get('search_distance') @@ -254,7 +274,7 @@ def similarity_search(self, index_name: str, document_query: DocumentQuery, **kw query_obj = query_obj.with_where(whereFilter) query_response = query_obj.with_near_text(nearText).with_limit( document_query.limit, - ).with_additional(['id', 'certainty', 'distance']).do() + ).with_additional(additional_metadata_properties).do() except Exception as e: logger.error('Error in similarity search: %s' % e) raise e diff --git a/common/utils/models.py b/common/utils/models.py new file mode 100644 index 00000000000..07c1d3cb14c --- /dev/null +++ b/common/utils/models.py @@ -0,0 +1,38 @@ +from typing import Optional + +import orjson as json +from pydantic import BaseModel + +class Config(BaseModel): + """ + Base class for config type models stored in the database. Supports optional encryption. + """ + config_type: str + is_encrypted: bool = False + data: str = '' + + def to_dict(self, encrypt_fn): + return { + 'config_type': self.config_type, + 'is_encrypted': self.is_encrypted, + 'data': self.get_data(encrypt_fn), + } + + def from_dict(self, dict_data, decrypt_fn): + self.config_type = dict_data.get('config_type') + self.is_encrypted = dict_data.get('is_encrypted') + self.set_data(dict_data.get('data'), decrypt_fn) + + # Use the data from the dict to populate the fields + self.__dict__.update(json.loads(self.data)) + + return self.dict(exclude={'is_encrypted', 'config_type', 'data'}) + + def get_data(self, encrypt_fn): + data = self.json(exclude={'is_encrypted', 'config_type', 'data'}) + return encrypt_fn(data).decode('utf-8') if self.is_encrypted else data + + def set_data(self, data, decrypt_fn): + self.data = data + if self.is_encrypted: + self.data = decrypt_fn(data) diff --git a/datasources/apis.py b/datasources/apis.py index 457d64fb7c7..522d355ecbb 100644 --- a/datasources/apis.py +++ b/datasources/apis.py @@ -117,11 +117,28 @@ def post(self, request): datasource_type = get_object_or_404( DataSourceType, id=request.data['type'], ) - datasource = DataSource.objects.create( + + datasource = DataSource( name=request.data['name'], owner=owner, type=datasource_type, ) + # If this is an external data source, then we need to save the config in datasource object + if datasource_type.is_external_datasource: + datasource_type_cls = DataSourceTypeFactory.get_datasource_type_handler(datasource.type) + if not datasource_type_cls: + logger.error( + 'No handler found for data source type {datasource.type}', + ) + return DRFResponse({'errors': ['No handler found for data source type']}, status=400) + + datasource_handler: DataSourceProcessor = datasource_type_cls(datasource) + if not datasource_handler: + logger.error(f'Error while creating handler for data source {datasource.name}') + return DRFResponse({'errors': ['Error while creating handler for data source type']}, status=400) + config = datasource_type_cls.process_validate_config(request.data['config'], datasource) + datasource.config = config + datasource.save() return DRFResponse(DataSourceSerializer(instance=datasource).data, status=201) @@ -147,6 +164,9 @@ def add_entry(self, request, uid): datasource = get_object_or_404( DataSource, uuid=uuid.UUID(uid), owner=request.user, ) + if datasource and datasource.type.is_external_datasource: + return DRFResponse({'errors': ['Cannot add entry to external data source']}, status=400) + entry_data = request.data['entry_data'] entry_metadata = dict(map(lambda x: (f'md_{x}', request.data['entry_metadata'][x]), request.data['entry_metadata'].keys())) if 'entry_metadata' in request.data else { } diff --git a/datasources/handlers/databases/__init__.py b/datasources/handlers/databases/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/datasources/handlers/databases/weaviate.py b/datasources/handlers/databases/weaviate.py new file mode 100644 index 00000000000..63f9b3d2d79 --- /dev/null +++ b/datasources/handlers/databases/weaviate.py @@ -0,0 +1,137 @@ +import json +import logging + +from typing import Dict, List +from typing import Optional + +from pydantic import Field + +from common.blocks.base.schema import BaseSchema as _Schema +from common.blocks.data.store.vectorstore import Document, DocumentQuery +from common.blocks.data.store.vectorstore.weaviate import Weaviate, WeaviateConfiguration, generate_where_filter +from common.utils.models import Config +from datasources.handlers.datasource_type_interface import DataSourceEntryItem +from datasources.handlers.datasource_type_interface import DataSourceSchema +from datasources.handlers.datasource_type_interface import DataSourceProcessor +from datasources.models import DataSource + + +logger = logging.getLogger(__name__) + +# This is a Python class to establish a connection with Weaviate. +# It accepts the following parameters: +# 1. weaviate_url: URL of the Weaviate instance. It is a mandatory field. +# 2. username: Your username for the Weaviate instance. This is an optional field. +# 3. password: Corresponding password for the above username. This is an optional field. +# 4. api_key: Your Weaviate API key. This is also an optional field. +# 5. additional_headers: Any additional headers that need to be passed in the request. This is optional, and should be passed as a JSON string. The default value is '{}'. +class WeaviateConnection(_Schema): + weaviate_url: str = Field(description='Weaviate URL') + username: Optional[str] = Field(description='Weaviate username') + password: Optional[str] = Field(description='Weaviate password') + api_key: Optional[str] = Field(description='Weaviate API key') + additional_headers: Optional[str] = Field(description='Weaviate headers. Please enter a JSON string.', widget='textarea', default='{}') + +# This class is a definition of the Weaviate database schema. +# `index_name`: This is a required string attribute representing the name of the weaviate index. +# `content_property_name`: This is a required string attribute representing the name of the weaviate content property to search. +# `additional_properties`: This is an optional attribute represented as a list of strings. +# It's used to specify any additional properties for the Weaviate document, +# with 'id' being the default properties. +# `connection`: This is optional and specifies the Weaviate connection string. +# It inherits structure and behaviour from the `DataSourceSchema` class. +class WeaviateDatabaseSchema(DataSourceSchema): + index_name: str = Field(description='Weaviate index name') + content_property_name: str = Field(description='Weaviate content property name') + additional_properties: Optional[List[str]] = Field(description='Weaviate additional properties', default=[]) + connection: Optional[WeaviateConnection] = Field(description='Weaviate connection string') + + +class WeaviateConnectionConfiguration(Config): + config_type = 'weaviate_connection' + is_encrypted = True + weaviate_config: Optional[Dict] + +# This class helps to manage and interact with a Weaviate Data Source. +# It inherits from the DataSourceProcessor class and operates on a WeaviateDatabaseSchema. +class WeaviateDataSource(DataSourceProcessor[WeaviateDatabaseSchema]): + + # Initializer for the class. + # It requires a datasource object as input, checks if it has a 'data' configuration, and sets up Weaviate Database Configuration. + def __init__(self, datasource: DataSource): + self.datasource = datasource + if self.datasource.config and 'data' in self.datasource.config: + config_dict = WeaviateConnectionConfiguration().from_dict(self.datasource.config, self.datasource.profile.decrypt_value) + self._configuration = WeaviateDatabaseSchema(**config_dict['weaviate_config']) + self._weviate_client = Weaviate(**WeaviateConfiguration( + url=self._configuration.connection.weaviate_url, + username=self._configuration.connection.username, + password=self._configuration.connection.password, + api_key=self._configuration.connection.api_key, + additional_headers=json.loads(self._configuration.connection.additional_headers) if self._configuration.connection.additional_headers else {}, + ).dict()) + + # This static method returns the name of the datasource class as 'Weaviate'. + @staticmethod + def name() -> str: + return 'Weaviate' + + # This static method returns the slug for the datasource as 'weaviate'. + @staticmethod + def slug() -> str: + return 'weaviate' + + # This static method takes a dictionary for configuration and a DataSource object as inputs. + # Validation of these inputs is performed and a dictionary containing the Weaviate Connection Configuration is returned. + @staticmethod + def process_validate_config(config_data: dict, datasource: DataSource) -> dict: + return WeaviateConnectionConfiguration(weaviate_config=config_data).to_dict(encrypt_fn=datasource.profile.encrypt_value) + + # This static method returns the provider slug for the datasource connector. + @staticmethod + def provider_slug() -> str: + return 'weaviate' + + def validate_and_process(self, data: dict) -> List[DataSourceEntryItem]: + raise NotImplementedError + + def get_data_documents(self, data: dict) -> List[Document]: + raise NotImplementedError + + def add_entry(self, data: dict) -> Optional[DataSourceEntryItem]: + raise NotImplementedError + + """ + This function performs similarity search on documents by using 'near text' concept of Weaviate where it tries to fetch documents in which concepts match with the given query. + """ + def similarity_search(self, query: str, **kwargs) -> List[dict]: + index_name = self._configuration.index_name + additional_properties = self._configuration.additional_properties + + result = self._weviate_client.similarity_search( + index_name=index_name, + document_query=DocumentQuery( + query=query, + page_content_key=self._configuration.content_property_name, + limit=kwargs.get('limit', 2), + metadata={'additional_properties' : additional_properties, 'metadata_properties' : ['distance']}, + search_filters=kwargs.get('search_filters', None), + ), + **kwargs + ) + return result + + def hybrid_search(self, query: str, **kwargs) -> List[dict]: + raise NotImplementedError + + def delete_entry(self, data: dict) -> None: + raise NotImplementedError + + def resync_entry(self, data: dict) -> Optional[DataSourceEntryItem]: + raise NotImplementedError + + def delete_all_entries(self) -> None: + raise NotImplementedError + + def get_entry_text(self, data: dict) -> str: + return None, self._configuration.json() \ No newline at end of file diff --git a/datasources/migrations/0002_datasource_config_and_more.py b/datasources/migrations/0002_datasource_config_and_more.py new file mode 100644 index 00000000000..129f3abe19f --- /dev/null +++ b/datasources/migrations/0002_datasource_config_and_more.py @@ -0,0 +1,23 @@ +# Generated by Django 4.2.1 on 2023-09-20 17:30 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('datasources', '0001_initial'), + ] + + operations = [ + migrations.AddField( + model_name='datasource', + name='config', + field=models.JSONField(default=dict, help_text='Config for the data source'), + ), + migrations.AddField( + model_name='datasourcetype', + name='is_external_datasource', + field=models.BooleanField(default=False, help_text='Is this an external data source?'), + ), + ] diff --git a/datasources/models.py b/datasources/models.py index 073d7190ae4..7fe984ece01 100644 --- a/datasources/models.py +++ b/datasources/models.py @@ -44,6 +44,9 @@ class DataSourceType(models.Model): description = models.TextField( default='', blank=True, help_text='Description of the data source type', ) + is_external_datasource = models.BooleanField( + default=False, help_text='Is this an external data source?', + ) def __str__(self): return self.name @@ -71,6 +74,9 @@ class DataSource(models.Model): visibility = models.PositiveSmallIntegerField( default=DataSourceVisibility.PRIVATE, choices=DataSourceVisibility.choices, help_text='Visibility of the data source', ) + config = models.JSONField( + default=dict, help_text='Config for the data source', + ) created_at = models.DateTimeField( help_text='Time when the data source was created', default=now, ) @@ -80,6 +86,10 @@ class DataSource(models.Model): def __str__(self): return self.name + ' (' + self.type.name + ')' + ' - ' + str(self.owner) + + @property + def profile(self): + return Profile.objects.get(user=self.owner) class DataSourceEntry(models.Model): diff --git a/datasources/serializers.py b/datasources/serializers.py index d59f9787938..5099d5d7c0c 100644 --- a/datasources/serializers.py +++ b/datasources/serializers.py @@ -41,7 +41,7 @@ class Meta: model = DataSourceType fields = [ 'id', 'name', 'description', - 'entry_config_schema', 'entry_config_ui_schema', 'sync_config' + 'entry_config_schema', 'entry_config_ui_schema', 'sync_config', 'is_external_datasource' ] diff --git a/processors/apis.py b/processors/apis.py index c85136b257c..c42a55531b6 100644 --- a/processors/apis.py +++ b/processors/apis.py @@ -435,7 +435,7 @@ async def stream_output(): while True: await asyncio.sleep(0.0001) output = next(output_iter) - yield json.dumps({'output': output['processor']}) + '\n' + yield {'output': output['processor']} except StopIteration: coordinator_ref.stop() except Exception as e: diff --git a/processors/providers/promptly/datasource_search.py b/processors/providers/promptly/datasource_search.py index fbe9a6f9f49..94d2486464b 100644 --- a/processors/providers/promptly/datasource_search.py +++ b/processors/providers/promptly/datasource_search.py @@ -32,6 +32,8 @@ class Document(ApiProcessorSchema): ) source: Optional[str] = Field(description='Source of the document') metadata: DocumentMetadata = Field(description='Metadata of the document') + additional_properties: Optional[dict] = {} + class DataSourceSearchOutput(ApiProcessorSchema): @@ -108,6 +110,7 @@ def process(self) -> DataSourceSearchOutput: certainty=document.metadata['certainty'] if 'certainty' in document.metadata else 0.0, distance=document.metadata['distance'], ), + additional_properties=document.metadata ), ) answer_text += f'Content: {document.page_content} \n\nSource: {source} \n\n\n\n'